remember_this/
cache.rs

1use std::collections::HashMap;
2use std::hash::Hash;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::{Arc, RwLock};
5
6use chrono::{DateTime, Duration, Utc};
7use flexbuffers::{FlexbufferSerializer, Reader};
8use serde::Serialize;
9
10pub use crate::result::Error;
11
12/// An entry stored to disk.
13pub struct CacheEntry<V> {
14    pub value: Arc<V>,
15
16    /// A number of seconds since the epoch
17    ///
18    /// This value may be removed after `expiration`.
19    pub expiration: AtomicU64,
20}
21
22/// Persisting queried values to the disk across sessions.
23pub struct Cache<K, V>
24where
25    K: Send + Clone + Hash + Eq + for<'de> serde::Deserialize<'de> + serde::Serialize,
26    V: Send + Clone + for<'de> serde::Deserialize<'de> + serde::Serialize,
27{
28    pub(crate) in_memory: Arc<RwLock<HashMap<K, CacheEntry<V>>>>,
29
30    /// The data cached to the disk as a K -> V mapping.
31    pub(crate) content: sled::Tree,
32
33    /// The expiration dates as a seconds: u64 -> K mapping.
34    pub(crate) expiry: sled::Tree,
35
36    /// How long data should remain in-memory.
37    pub(crate) memory_duration: Duration,
38
39    /// How long data should remain on-disk.
40    pub(crate) disk_duration: Duration,
41}
42impl<K, V> Cache<K, V>
43where
44    K: Send + Clone + Hash + Eq + for<'de> serde::Deserialize<'de> + serde::Serialize,
45    V: Send + Clone + for<'de> serde::Deserialize<'de> + serde::Serialize,
46{
47    /// Get a value from the cache.
48    ///
49    /// If this value is not in the cache, compute the thunk and insert the value.
50    pub async fn get_or_insert_infallible<F>(&self, key: &K, thunk: F) -> Result<Arc<V>, Error<()>>
51    where
52        F: std::future::Future<Output = V>,
53    {
54        self.get_or_insert::<_, ()>(key, async { Ok(thunk.await) })
55            .await
56    }
57
58    /// Get a value from the cache.
59    ///
60    /// If this value is not in the cache, compute the thunk and insert the value.
61    pub async fn get_or_insert<F, E>(&self, key: &K, thunk: F) -> Result<Arc<V>, Error<E>>
62    where
63        F: std::future::Future<Output = Result<V, E>>,
64    {
65        let memory_expiration = Utc::now() + self.memory_duration;
66
67        // Prepare binary key for disk cache access.
68        let mut key_serializer = FlexbufferSerializer::new();
69        key.serialize(&mut key_serializer).unwrap(); // We assume that in-memory serialization always succeeds.
70        let key_bin = key_serializer.take_buffer();
71
72        match self.get_at(key, &key_bin, memory_expiration) {
73            Ok(Some(found)) => return Ok(found),
74            Ok(None) => {}
75            Err(Error::Database(err)) => return Err(Error::Database(err)),
76            Err(e) => panic!("We shouldn't have any other error here {:?}", e),
77        }
78
79        // Not in cache. Unthunk `thunk`
80        let data = thunk.await.map_err(Error::Client)?;
81        let result = Arc::new(data);
82
83        // Store in memory.
84        self.store_in_memory_cache(key, &result, memory_expiration);
85
86        let disk_expiration = Utc::now() + self.disk_duration;
87
88        // Store in cache.
89        self.store_in_disk_cache(&key_bin, &result, disk_expiration)
90            .map_err(Error::Database)?;
91
92        Ok(result)
93    }
94
95    /// Get a value from the cache.
96    pub fn get(&self, key: &K) -> Result<Option<Arc<V>>, Error<()>> {
97        // Prepare binary key for disk cache access.
98        let mut key_serializer = FlexbufferSerializer::new();
99        key.serialize(&mut key_serializer).unwrap(); // We assume that in-memory serialization always succeeds.
100        let key_bin = key_serializer.take_buffer();
101
102        self.get_at(key, &key_bin, Utc::now() + self.memory_duration)
103    }
104    fn get_at(
105        &self,
106        key: &K,
107        key_bin: &[u8],
108        memory_expiration: DateTime<Utc>,
109    ) -> Result<Option<Arc<V>>, Error<()>> {
110        {
111            // Fetch from in-memory cache.
112            let read_lock = self.in_memory.read().unwrap();
113            if let Some(found) = read_lock.get(key) {
114                found
115                    .expiration
116                    .store(memory_expiration.timestamp() as u64, Ordering::Relaxed);
117                // FIXME: Postpone expiry on disk.
118                return Ok(Some(found.value.clone()));
119            }
120        }
121        debug!(target: "disk-cache", "Value not found in memory");
122
123        {
124            // Fetch from disk cache.
125            if let Some(value_bin) = self.content.get(&key_bin).map_err(Error::Database)? {
126                debug!(target: "disk-cache", "Value was in disk cache");
127                // Found in cache.
128                let reader = Reader::get_root(&value_bin).unwrap();
129                if let Ok(value) = V::deserialize(reader) {
130                    debug!(target: "disk-cache", "Value deserialized");
131
132                    let result = Arc::new(value);
133
134                    // Store back in memory.
135                    self.store_in_memory_cache(key, &result, memory_expiration);
136
137                    // FIXME: Postpone expiration on disk
138
139                    // Finally, return.
140                    return Ok(Some(result));
141                }
142
143                // If we reach this stage, deserialization failed, either because of disk corruption (unlikely)
144                // or because the format has changed (more likely). In either case, ignore and overwrite data.
145            }
146        }
147
148        debug!(target: "disk-cache", "Value not found on disk");
149        Ok(None)
150    }
151
152    /// Store in the memory cache.
153    ///
154    /// Schedule a task to cleanup from memory.
155    fn store_in_memory_cache(&self, key: &K, value: &Arc<V>, expiration: DateTime<Utc>) {
156        debug!(target: "disk-cache", "Adding value to memory cache");
157        let mut write_lock = self.in_memory.write().unwrap();
158        let entry = CacheEntry {
159            value: value.clone(),
160            expiration: AtomicU64::new(expiration.timestamp() as u64),
161        };
162        write_lock.insert(key.clone(), entry);
163    }
164
165    /// Store in the memory cache.
166    ///
167    /// Schedule a task to cleanup from disk.
168    fn store_in_disk_cache(
169        &self,
170        key: &[u8],
171        value: &Arc<V>,
172        expiration: DateTime<Utc>,
173    ) -> Result<(), sled::Error> {
174        debug!(target: "disk-", "Adding value to disk cache");
175        let mut value_serializer = FlexbufferSerializer::new();
176        value.serialize(&mut value_serializer).unwrap();
177        let entry_bin = value_serializer.take_buffer();
178
179        self.content.insert(key, entry_bin)?;
180        self.expiry
181            .insert(u64_to_bytes(expiration.timestamp() as u64), key)?;
182        Ok(())
183    }
184
185    pub fn cleanup_expired_from_memory_cache(&self) {
186        cleanup_memory_cache(&self.in_memory)
187    }
188
189    pub fn cleanup_expired_disk_cache(&self) {
190        cleanup_disk_cache::<K, V>(&self.expiry, &self.content)
191    }
192}
193
194// Internal functions.
195
196/// Remove all values from memory that have nothing to do here anymore.
197pub fn cleanup_memory_cache<K, V>(memory_cache: &Arc<RwLock<HashMap<K, CacheEntry<V>>>>)
198where
199    K: Eq + Hash + Clone,
200{
201    let now = Utc::now().timestamp() as u64;
202    {
203        let mut write_lock = memory_cache.write().unwrap();
204        write_lock.retain(|_, v| v.expiration.load(Ordering::Relaxed) > now)
205    }
206}
207
208/// Remove all values from disk cache that have nothing to do here anymore.
209pub fn cleanup_disk_cache<K, V>(expiry: &sled::Tree, content: &sled::Tree)
210where
211    K: Send + Clone + Hash + Eq + for<'de> serde::Deserialize<'de> + serde::Serialize,
212{
213    let now = Utc::now();
214    let mut batch = sled::Batch::default();
215    for cursor in expiry.range(u64_to_bytes(0)..u64_to_bytes(now.timestamp() as u64)) {
216        let (ts, k) = cursor.unwrap(); // FIXME: Handle errors
217        debug_assert!(bytes_to_u64(&ts) <= now.timestamp() as u64);
218        batch.remove(k);
219    }
220    content.apply_batch(batch).unwrap(); // FIXME: Handle errors
221}
222
223fn bytes_to_u64(bytes: &[u8]) -> u64 {
224    ((bytes[0] as u64) << 56)
225        + ((bytes[1] as u64) << 48)
226        + ((bytes[2] as u64) << 40)
227        + ((bytes[3] as u64) << 32)
228        + ((bytes[4] as u64) << 24)
229        + ((bytes[5] as u64) << 16)
230        + ((bytes[6] as u64) << 8)
231        + bytes[7] as u64
232}
233
234fn u64_to_bytes(value: u64) -> [u8; 8] {
235    [
236        ((value >> 56) & 0b11111111) as u8,
237        ((value >> 48) & 0b11111111) as u8,
238        ((value >> 40) & 0b11111111) as u8,
239        ((value >> 32) & 0b11111111) as u8,
240        ((value >> 24) & 0b11111111) as u8,
241        ((value >> 16) & 0b11111111) as u8,
242        ((value >> 8) & 0b11111111) as u8,
243        (value % 256) as u8,
244    ]
245}
246
247#[test]
248fn test_bytes_to_u64() {
249    let mut i: u128 = 0;
250    while i <= std::u64::MAX as u128 {
251        let bytes = u64_to_bytes(i as u64);
252        let num = bytes_to_u64(&bytes);
253        assert_eq!(num, i as u64);
254        i = (i + 1) * 7;
255    }
256}