cache_loader_async/
backing.rs

1use std::collections::HashMap;
2use std::hash::Hash;
3#[cfg(feature = "lru-cache")]
4use lru::LruCache;
5#[cfg(feature = "ttl-cache")]
6use std::collections::VecDeque;
7use std::fmt::Debug;
8#[cfg(feature = "ttl-cache")]
9use std::marker::PhantomData;
10use thiserror::Error;
11#[cfg(feature = "ttl-cache")]
12use std::ops::Add;
13#[cfg(feature = "ttl-cache")]
14use tokio::time::{Instant, Duration};
15
16pub trait CacheBacking<K, V>
17    where K: Eq + Hash + Sized + Clone + Send,
18          V: Sized + Clone + Send {
19    type Meta: Clone + Send;
20
21    fn get_mut(&mut self, key: &K) -> Result<Option<&mut V>, BackingError>;
22    fn get(&mut self, key: &K) -> Result<Option<&V>, BackingError>;
23    fn set(&mut self, key: K, value: V, meta: Option<Self::Meta>) -> Result<Option<V>, BackingError>;
24    fn remove(&mut self, key: &K) -> Result<Option<V>, BackingError>;
25    fn contains_key(&mut self, key: &K) -> Result<bool, BackingError>;
26    fn remove_if(&mut self, predicate: Box<dyn Fn((&K, &V)) -> bool + Send + Sync + 'static>) -> Result<Vec<(K, V)>, BackingError>;
27    fn clear(&mut self) -> Result<(), BackingError>;
28}
29
30#[derive(Debug, Clone, Error)]
31pub enum BackingError {
32    #[error(transparent)]
33    TtlError(#[from] TtlError),
34}
35
36#[derive(Copy, Clone, Debug, Default)]
37pub struct NoMeta {}
38
39#[cfg(feature = "lru-cache")]
40pub struct LruCacheBacking<K, V> {
41    lru: LruCache<K, V>,
42}
43
44#[cfg(feature = "lru-cache")]
45impl<
46    K: Eq + Hash + Sized + Clone + Send,
47    V: Sized + Clone + Send
48> CacheBacking<K, V> for LruCacheBacking<K, V> {
49    type Meta = NoMeta;
50
51    fn get_mut(&mut self, key: &K) -> Result<Option<&mut V>, BackingError> {
52        Ok(self.lru.get_mut(key))
53    }
54
55    fn get(&mut self, key: &K) -> Result<Option<&V>, BackingError> {
56        Ok(self.lru.get(key))
57    }
58
59    fn set(&mut self, key: K, value: V, _meta: Option<Self::Meta>) -> Result<Option<V>, BackingError> {
60        Ok(self.lru.put(key, value))
61    }
62
63    fn remove(&mut self, key: &K) -> Result<Option<V>, BackingError> {
64        Ok(self.lru.pop(key))
65    }
66
67    fn contains_key(&mut self, key: &K) -> Result<bool, BackingError> {
68        Ok(self.lru.contains(&key.clone()))
69    }
70
71    fn remove_if(&mut self, predicate: Box<dyn Fn((&K, &V)) -> bool + Send + Sync>) -> Result<Vec<(K, V)>, BackingError> {
72        let mut removed = Vec::new();
73        let keys = self.lru.iter()
74            .filter_map(|(key, value)| {
75                if predicate((key, value)) {
76                    Some(key)
77                } else {
78                    None
79                }
80            })
81            .cloned()
82            .collect::<Vec<K>>();
83        for key in keys.into_iter() {
84            let val = self.lru.pop(&key);
85            removed.push((key, val.expect("LRU value is empty")))
86        }
87        Ok(removed)
88    }
89
90    fn clear(&mut self) -> Result<(), BackingError> {
91        self.lru.clear();
92        Ok(())
93    }
94}
95
96#[cfg(feature = "lru-cache")]
97impl<
98    K: Eq + Hash + Sized + Clone + Send,
99    V: Sized + Clone + Send
100> LruCacheBacking<K, V> {
101    pub fn new(size: usize) -> LruCacheBacking<K, V> {
102        LruCacheBacking {
103            lru: LruCache::new(size)
104        }
105    }
106
107    pub fn unbounded() -> LruCacheBacking<K, V> {
108        LruCacheBacking {
109            lru: LruCache::unbounded()
110        }
111    }
112}
113
114#[cfg(feature = "ttl-cache")]
115pub struct TtlCacheBacking<
116    K: Clone + Eq + Hash + Send,
117    V: Clone + Sized + Send,
118    B: CacheBacking<K, (V, Instant)>
119> {
120    phantom: PhantomData<V>,
121    ttl: Duration,
122    expiry_queue: VecDeque<TTlEntry<K>>,
123    map: B,
124}
125
126#[cfg(feature = "ttl-cache")]
127struct TTlEntry<K> {
128    key: K,
129    expiry: Instant,
130
131}
132
133#[cfg(feature = "ttl-cache")]
134impl<K> From<(K, Instant)> for TTlEntry<K> {
135    fn from(tuple: (K, Instant)) -> Self {
136        Self {
137            key: tuple.0,
138            expiry: tuple.1,
139        }
140    }
141}
142
143#[derive(Debug, Clone, Error)]
144pub enum TtlError {
145    #[error("The expiry for key not found")]
146    ExpiryNotFound,
147    #[error("No key for expiry matched key")]
148    ExpiryKeyNotFound,
149}
150
151#[cfg(feature = "ttl-cache")]
152#[derive(Debug, Copy, Clone)]
153pub struct TtlMeta {
154    pub ttl: Duration,
155}
156
157#[cfg(feature = "ttl-cache")]
158impl From<Duration> for TtlMeta {
159    fn from(ttl: Duration) -> Self {
160        Self { ttl }
161    }
162}
163
164#[cfg(feature = "ttl-cache")]
165impl<
166    K: Clone + Eq + Hash + Send + 'static,
167    V: Clone + Sized + Send + 'static,
168    B: CacheBacking<K, (V, Instant)>
169> CacheBacking<K, V> for TtlCacheBacking<K, V, B> {
170    type Meta = TtlMeta;
171
172    fn get_mut(&mut self, key: &K) -> Result<Option<&mut V>, BackingError> {
173        self.remove_old()?;
174        Ok(self.map.get_mut(key)?
175            .map(|(value, _)| value))
176    }
177
178    fn get(&mut self, key: &K) -> Result<Option<&V>, BackingError> {
179        self.remove_old()?;
180        Ok(self.map.get(key)?
181            .map(|(value, _)| value))
182    }
183
184    fn set(&mut self, key: K, value: V, meta: Option<Self::Meta>) -> Result<Option<V>, BackingError> {
185        self.remove_old()?;
186        let ttl = if let Some(meta) = meta {
187            meta.ttl
188        } else {
189            self.ttl
190        };
191        let expiry = Instant::now().add(ttl);
192        let result = self.replace(key.clone(), value, expiry)?;
193        Ok(result)
194    }
195
196    fn remove(&mut self, key: &K) -> Result<Option<V>, BackingError> {
197        self.remove_old()?;
198        Ok(self.remove_key(key)?)
199    }
200
201    fn contains_key(&mut self, key: &K) -> Result<bool, BackingError> {
202        self.remove_old()?;
203        Ok(self.map.get(key)?.is_some())
204    }
205
206    fn remove_if(&mut self, predicate: Box<dyn Fn((&K, &V)) -> bool + Send + Sync>) -> Result<Vec<(K, V)>, BackingError> {
207        let values = self.map.remove_if(Box::new(move |(key, (value, _))| predicate((key, value))))?;
208        let mut mapped = Vec::with_capacity(values.len());
209        for (key, (value, _)) in values {
210            // optimize looping through expiry_queue multiple times?
211            self.expiry_queue.retain(|entry| entry.key.ne(&key));
212            mapped.push((key, value));
213        }
214        Ok(mapped)
215    }
216
217    fn clear(&mut self) -> Result<(), BackingError> {
218        self.expiry_queue.clear();
219        self.map.clear()?;
220        Ok(())
221    }
222}
223
224#[cfg(feature = "ttl-cache")]
225impl<
226    K: Eq + Hash + Sized + Clone + Send,
227    V: Sized + Clone + Send,
228> TtlCacheBacking<K, V, HashMapBacking<K, (V, Instant)>> {
229    pub fn new(ttl: Duration) -> TtlCacheBacking<K, V, HashMapBacking<K, (V, Instant)>> {
230        TtlCacheBacking {
231            phantom: Default::default(),
232            ttl,
233            map: HashMapBacking::new(),
234            expiry_queue: VecDeque::new(),
235        }
236    }
237}
238
239#[cfg(feature = "ttl-cache")]
240impl<
241    K: Eq + Hash + Sized + Clone + Send,
242    V: Sized + Clone + Send,
243    B: CacheBacking<K, (V, Instant)>
244> TtlCacheBacking<K, V, B> {
245    pub fn with_backing(ttl: Duration, backing: B) -> TtlCacheBacking<K, V, B> {
246        TtlCacheBacking {
247            phantom: Default::default(),
248            ttl,
249            map: backing,
250            expiry_queue: VecDeque::new(),
251        }
252    }
253
254    fn remove_old(&mut self) -> Result<(), BackingError> {
255        let now = Instant::now();
256        while let Some(entry) = self.expiry_queue.pop_front() {
257            if now.lt(&entry.expiry) {
258                self.expiry_queue.push_front(entry);
259                break;
260            }
261            self.map.remove(&entry.key)?;
262        }
263        Ok(())
264    }
265
266    fn replace(&mut self, key: K, value: V, expiry: Instant) -> Result<Option<V>, BackingError> {
267        let entry = self.map.set(key.clone(), (value, expiry), None)?;
268        let res = self.cleanup_expiry(entry, &key);
269        match self.expiry_queue.binary_search_by_key(&expiry, |entry| entry.expiry) {
270            Ok(found) => {
271                self.expiry_queue.insert(found + 1, (key, expiry).into());
272            }
273            Err(idx) => {
274                self.expiry_queue.insert(idx, (key, expiry).into());
275            }
276        }
277        res
278    }
279
280    fn remove_key(&mut self, key: &K) -> Result<Option<V>, BackingError> {
281        let entry = self.map.remove(key)?;
282        self.cleanup_expiry(entry, key)
283    }
284
285    fn cleanup_expiry(&mut self, entry: Option<(V, Instant)>, key: &K) -> Result<Option<V>, BackingError> {
286        if let Some((value, old_expiry)) = entry {
287            match self.expiry_queue.binary_search_by_key(&old_expiry, |entry| entry.expiry) {
288                Ok(found) => {
289                    let index = self.expiry_index_on_key_eq(found, &old_expiry, key);
290                    if let Some(index) = index {
291                        self.expiry_queue.remove(index);
292                    } else {
293                        return Err(TtlError::ExpiryKeyNotFound.into());
294                    }
295                }
296                Err(_) => {
297                    return Err(TtlError::ExpiryNotFound.into());
298                }
299            }
300            Ok(Some(value))
301        } else {
302            Ok(None)
303        }
304    }
305
306    fn expiry_index_on_key_eq(&self, idx: usize, expiry: &Instant, key: &K) -> Option<usize> {
307        let entry = self.expiry_queue.get(idx).unwrap();
308        if entry.key.eq(key) {
309            return Some(idx);
310        }
311
312        let mut offset = 0;
313        while idx - offset > 0 {
314            offset += 1;
315            let entry = self.expiry_queue.get(idx - offset).unwrap();
316            if !entry.expiry.eq(expiry) {
317                break;
318            }
319            if entry.key.eq(key) {
320                return Some(idx - offset);
321            }
322        }
323        offset = 0;
324        while idx + offset < self.expiry_queue.len() {
325            offset += 1;
326            let entry = self.expiry_queue.get(idx + offset).unwrap();
327            if !entry.expiry.eq(expiry) {
328                break;
329            }
330            if entry.key.eq(key) {
331                return Some(idx + offset);
332            }
333        }
334        None
335    }
336}
337
338pub struct HashMapBacking<K, V> {
339    map: HashMap<K, V>,
340}
341
342impl<
343    K: Eq + Hash + Sized + Clone + Send,
344    V: Sized + Clone + Send
345> CacheBacking<K, V> for HashMapBacking<K, V> {
346    type Meta = NoMeta;
347
348    fn get_mut(&mut self, key: &K) -> Result<Option<&mut V>, BackingError> {
349        Ok(self.map.get_mut(key))
350    }
351
352    fn get(&mut self, key: &K) -> Result<Option<&V>, BackingError> {
353        Ok(self.map.get(key))
354    }
355
356    fn set(&mut self, key: K, value: V, _meta: Option<Self::Meta>) -> Result<Option<V>, BackingError> {
357        Ok(self.map.insert(key, value))
358    }
359
360    fn remove(&mut self, key: &K) -> Result<Option<V>, BackingError> {
361        Ok(self.map.remove(key))
362    }
363
364    fn contains_key(&mut self, key: &K) -> Result<bool, BackingError> {
365        Ok(self.map.contains_key(key))
366    }
367
368    fn remove_if(&mut self, predicate: Box<dyn Fn((&K, &V)) -> bool + Send + Sync>) -> Result<Vec<(K, V)>, BackingError> {
369        let removed = self.map.iter()
370            .filter(|(k, v)| predicate((k, v)))
371            .map(|(k, v)| (k.clone(), v.clone()))
372            .collect::<Vec<(K, V)>>();
373
374        for (k, _) in removed.iter() {
375            self.map.remove(k);
376        }
377        Ok(removed)
378    }
379
380    fn clear(&mut self) -> Result<(), BackingError> {
381        self.map.clear();
382        Ok(())
383    }
384}
385
386impl<K, V> HashMapBacking<K, V> {
387    pub fn new() -> HashMapBacking<K, V> {
388        HashMapBacking {
389            map: Default::default()
390        }
391    }
392
393    pub fn construct(map: HashMap<K, V>) -> HashMapBacking<K, V> {
394        HashMapBacking {
395            map
396        }
397    }
398}