1use std::collections::HashMap;
2use std::hash::Hash;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::{Arc, RwLock};
5
6use chrono::{DateTime, Duration, Utc};
7use flexbuffers::{FlexbufferSerializer, Reader};
8use serde::Serialize;
9
10pub use crate::result::Error;
11
12pub struct CacheEntry<V> {
14 pub value: Arc<V>,
15
16 pub expiration: AtomicU64,
20}
21
22pub struct Cache<K, V>
24where
25 K: Send + Clone + Hash + Eq + for<'de> serde::Deserialize<'de> + serde::Serialize,
26 V: Send + Clone + for<'de> serde::Deserialize<'de> + serde::Serialize,
27{
28 pub(crate) in_memory: Arc<RwLock<HashMap<K, CacheEntry<V>>>>,
29
30 pub(crate) content: sled::Tree,
32
33 pub(crate) expiry: sled::Tree,
35
36 pub(crate) memory_duration: Duration,
38
39 pub(crate) disk_duration: Duration,
41}
42impl<K, V> Cache<K, V>
43where
44 K: Send + Clone + Hash + Eq + for<'de> serde::Deserialize<'de> + serde::Serialize,
45 V: Send + Clone + for<'de> serde::Deserialize<'de> + serde::Serialize,
46{
47 pub async fn get_or_insert_infallible<F>(&self, key: &K, thunk: F) -> Result<Arc<V>, Error<()>>
51 where
52 F: std::future::Future<Output = V>,
53 {
54 self.get_or_insert::<_, ()>(key, async { Ok(thunk.await) })
55 .await
56 }
57
58 pub async fn get_or_insert<F, E>(&self, key: &K, thunk: F) -> Result<Arc<V>, Error<E>>
62 where
63 F: std::future::Future<Output = Result<V, E>>,
64 {
65 let memory_expiration = Utc::now() + self.memory_duration;
66
67 let mut key_serializer = FlexbufferSerializer::new();
69 key.serialize(&mut key_serializer).unwrap(); let key_bin = key_serializer.take_buffer();
71
72 match self.get_at(key, &key_bin, memory_expiration) {
73 Ok(Some(found)) => return Ok(found),
74 Ok(None) => {}
75 Err(Error::Database(err)) => return Err(Error::Database(err)),
76 Err(e) => panic!("We shouldn't have any other error here {:?}", e),
77 }
78
79 let data = thunk.await.map_err(Error::Client)?;
81 let result = Arc::new(data);
82
83 self.store_in_memory_cache(key, &result, memory_expiration);
85
86 let disk_expiration = Utc::now() + self.disk_duration;
87
88 self.store_in_disk_cache(&key_bin, &result, disk_expiration)
90 .map_err(Error::Database)?;
91
92 Ok(result)
93 }
94
95 pub fn get(&self, key: &K) -> Result<Option<Arc<V>>, Error<()>> {
97 let mut key_serializer = FlexbufferSerializer::new();
99 key.serialize(&mut key_serializer).unwrap(); let key_bin = key_serializer.take_buffer();
101
102 self.get_at(key, &key_bin, Utc::now() + self.memory_duration)
103 }
104 fn get_at(
105 &self,
106 key: &K,
107 key_bin: &[u8],
108 memory_expiration: DateTime<Utc>,
109 ) -> Result<Option<Arc<V>>, Error<()>> {
110 {
111 let read_lock = self.in_memory.read().unwrap();
113 if let Some(found) = read_lock.get(key) {
114 found
115 .expiration
116 .store(memory_expiration.timestamp() as u64, Ordering::Relaxed);
117 return Ok(Some(found.value.clone()));
119 }
120 }
121 debug!(target: "disk-cache", "Value not found in memory");
122
123 {
124 if let Some(value_bin) = self.content.get(&key_bin).map_err(Error::Database)? {
126 debug!(target: "disk-cache", "Value was in disk cache");
127 let reader = Reader::get_root(&value_bin).unwrap();
129 if let Ok(value) = V::deserialize(reader) {
130 debug!(target: "disk-cache", "Value deserialized");
131
132 let result = Arc::new(value);
133
134 self.store_in_memory_cache(key, &result, memory_expiration);
136
137 return Ok(Some(result));
141 }
142
143 }
146 }
147
148 debug!(target: "disk-cache", "Value not found on disk");
149 Ok(None)
150 }
151
152 fn store_in_memory_cache(&self, key: &K, value: &Arc<V>, expiration: DateTime<Utc>) {
156 debug!(target: "disk-cache", "Adding value to memory cache");
157 let mut write_lock = self.in_memory.write().unwrap();
158 let entry = CacheEntry {
159 value: value.clone(),
160 expiration: AtomicU64::new(expiration.timestamp() as u64),
161 };
162 write_lock.insert(key.clone(), entry);
163 }
164
165 fn store_in_disk_cache(
169 &self,
170 key: &[u8],
171 value: &Arc<V>,
172 expiration: DateTime<Utc>,
173 ) -> Result<(), sled::Error> {
174 debug!(target: "disk-", "Adding value to disk cache");
175 let mut value_serializer = FlexbufferSerializer::new();
176 value.serialize(&mut value_serializer).unwrap();
177 let entry_bin = value_serializer.take_buffer();
178
179 self.content.insert(key, entry_bin)?;
180 self.expiry
181 .insert(u64_to_bytes(expiration.timestamp() as u64), key)?;
182 Ok(())
183 }
184
185 pub fn cleanup_expired_from_memory_cache(&self) {
186 cleanup_memory_cache(&self.in_memory)
187 }
188
189 pub fn cleanup_expired_disk_cache(&self) {
190 cleanup_disk_cache::<K, V>(&self.expiry, &self.content)
191 }
192}
193
194pub fn cleanup_memory_cache<K, V>(memory_cache: &Arc<RwLock<HashMap<K, CacheEntry<V>>>>)
198where
199 K: Eq + Hash + Clone,
200{
201 let now = Utc::now().timestamp() as u64;
202 {
203 let mut write_lock = memory_cache.write().unwrap();
204 write_lock.retain(|_, v| v.expiration.load(Ordering::Relaxed) > now)
205 }
206}
207
208pub fn cleanup_disk_cache<K, V>(expiry: &sled::Tree, content: &sled::Tree)
210where
211 K: Send + Clone + Hash + Eq + for<'de> serde::Deserialize<'de> + serde::Serialize,
212{
213 let now = Utc::now();
214 let mut batch = sled::Batch::default();
215 for cursor in expiry.range(u64_to_bytes(0)..u64_to_bytes(now.timestamp() as u64)) {
216 let (ts, k) = cursor.unwrap(); debug_assert!(bytes_to_u64(&ts) <= now.timestamp() as u64);
218 batch.remove(k);
219 }
220 content.apply_batch(batch).unwrap(); }
222
223fn bytes_to_u64(bytes: &[u8]) -> u64 {
224 ((bytes[0] as u64) << 56)
225 + ((bytes[1] as u64) << 48)
226 + ((bytes[2] as u64) << 40)
227 + ((bytes[3] as u64) << 32)
228 + ((bytes[4] as u64) << 24)
229 + ((bytes[5] as u64) << 16)
230 + ((bytes[6] as u64) << 8)
231 + bytes[7] as u64
232}
233
234fn u64_to_bytes(value: u64) -> [u8; 8] {
235 [
236 ((value >> 56) & 0b11111111) as u8,
237 ((value >> 48) & 0b11111111) as u8,
238 ((value >> 40) & 0b11111111) as u8,
239 ((value >> 32) & 0b11111111) as u8,
240 ((value >> 24) & 0b11111111) as u8,
241 ((value >> 16) & 0b11111111) as u8,
242 ((value >> 8) & 0b11111111) as u8,
243 (value % 256) as u8,
244 ]
245}
246
247#[test]
248fn test_bytes_to_u64() {
249 let mut i: u128 = 0;
250 while i <= std::u64::MAX as u128 {
251 let bytes = u64_to_bytes(i as u64);
252 let num = bytes_to_u64(&bytes);
253 assert_eq!(num, i as u64);
254 i = (i + 1) * 7;
255 }
256}