dashmap_cache/
lib.rs

1use core::future::Future;
2use core::hash::Hash;
3use dashmap::{DashMap, DashSet};
4use serde::{Deserialize, Serialize};
5use std::fmt::Debug;
6use std::marker::{Send, Sync};
7use std::pin::Pin;
8#[derive(Clone, Debug)]
9pub struct DashmapCache {
10    inner: DashMap<Vec<u8>, Vec<u8>>,
11    tags: DashMap<String, DashSet<Vec<u8>>>,
12}
13
14#[derive(Debug)]
15pub enum CacheError {
16    Decode(rmp_serde::decode::Error),
17    Encode(rmp_serde::encode::Error),
18}
19
20impl From<rmp_serde::decode::Error> for CacheError {
21    fn from(value: rmp_serde::decode::Error) -> Self {
22        Self::Decode(value)
23    }
24}
25
26impl From<rmp_serde::encode::Error> for CacheError {
27    fn from(value: rmp_serde::encode::Error) -> Self {
28        Self::Encode(value)
29    }
30}
31
32impl<'a> DashmapCache {
33    pub fn new() -> Self {
34        let inner = DashMap::new();
35        Self {
36            inner,
37            tags: DashMap::new(),
38        }
39    }
40
41    fn insert(&self, tags: &Vec<String>, key: Vec<u8>, val: Vec<u8>) -> Option<Vec<u8>> {
42        for tag in tags {
43            if !self.tags.contains_key(tag) {
44                let dash = DashSet::new();
45                dash.insert(key.clone());
46                self.tags.insert(tag.to_owned(), dash);
47            } else {
48                self.tags.alter(tag, |_k, ex_tags| {
49                    ex_tags.insert(key.clone());
50                    ex_tags
51                })
52            }
53        }
54        self.inner.insert(key, val)
55    }
56
57    /// Atomic operation to replace a cached entry by a new computation value
58    pub fn refresh_cache<F, A, V>(
59        &self,
60        invalidate_keys: &Vec<String>,
61        closure: F,
62        arg: A,
63    ) -> Result<V, CacheError>
64    where
65        F: Fn(&A) -> V,
66        A: Hash + Sync + Send + Eq + Serialize,
67        V: Send + Sync + Clone + Serialize + for<'b> Deserialize<'b>,
68    {
69        let arg_bytes = rmp_serde::to_vec(&arg)?;
70        let val = closure(&arg);
71        let val_bytes = rmp_serde::to_vec(&val)?;
72        self.insert(invalidate_keys, arg_bytes, val_bytes);
73        Ok(val)
74    }
75
76    /// Computes a signature for arg
77    /// If already present in the cache, returns directly associated return value
78    /// Otherwise, compute a new return value and fills the cache with it
79    /// It is recommended to use a call enum and dispatch in the same closure for the same cache if the input types or values are susceptible to overlap.
80    pub fn cached<F, A, V>(
81        &self,
82        invalidate_keys: &Vec<String>,
83        closure: F,
84        arg: A,
85    ) -> Result<V, CacheError>
86    where
87        F: Fn(&A) -> V,
88        A: Hash + Sync + Send + Eq + Serialize,
89        V: Send + Sync + Clone + Serialize + for<'b> Deserialize<'b>,
90    {
91        let arg_bytes = rmp_serde::to_vec(&arg)?;
92
93        match self.inner.get(&arg_bytes) {
94            None => {
95                let val = closure(&arg);
96                let val_bytes = rmp_serde::to_vec(&val)?;
97                self.insert(invalidate_keys, arg_bytes, val_bytes);
98                Ok(val)
99            }
100            Some(val) => {
101                let ret_val = rmp_serde::from_slice::<V>(&val)?;
102                Ok(ret_val.to_owned())
103            }
104        }
105    }
106
107    /// Async version of cached()
108    pub async fn async_cached<F, A, V>(
109        &self,
110        invalidate_keys: &Vec<String>,
111        closure: F,
112        arg: A,
113    ) -> Result<V, CacheError>
114    where
115        F: Fn(&A) -> Pin<Box<dyn Future<Output = V>>>,
116        A: Hash + Sync + Send + Eq + Serialize,
117        V: Send + Sync + Clone + Serialize + for<'b> Deserialize<'b>,
118    {
119        let arg_bytes = rmp_serde::to_vec(&arg)?;
120
121        match self.inner.get(&arg_bytes) {
122            None => {
123                let val = closure(&arg).await;
124                let val_bytes = rmp_serde::to_vec(&val)?;
125                self.insert(invalidate_keys, arg_bytes, val_bytes);
126                Ok(val)
127            }
128            Some(val) => {
129                let ret_val = rmp_serde::from_slice::<V>(&val)?;
130                Ok(ret_val.to_owned())
131            }
132        }
133    }
134
135    /// Tokio version of cached()
136    #[cfg(feature = "tokio")]
137    pub async fn tokio_cached<F, A, V>(
138        &self,
139        invalidate_keys: &Vec<String>,
140        closure: F,
141        arg: A,
142    ) -> Result<V, CacheError>
143    where
144        F: Fn(&A) -> tokio::task::JoinHandle<V>,
145        A: Hash + Sync + Send + Eq + Serialize,
146        V: Send + Sync + Clone + Serialize + for<'b> Deserialize<'b>,
147    {
148        let arg_bytes = rmp_serde::to_vec(&arg)?;
149
150        match self.inner.get(&arg_bytes) {
151            None => {
152                let val = closure(&arg).await.unwrap();
153                let val_bytes = rmp_serde::to_vec(&val)?;
154                self.insert(invalidate_keys, arg_bytes, val_bytes);
155                Ok(val)
156            }
157            Some(val) => {
158                let ret_val = rmp_serde::from_slice::<V>(&val)?;
159                Ok(ret_val.to_owned())
160            }
161        }
162    }
163
164    pub fn invalidate(&self, tag: &str) {
165        let hashes = self.tags.get(tag);
166        if hashes.is_some() {
167            self.tags.remove(tag);
168            for hsh in hashes.unwrap().clone() {
169                self.inner.remove(&hsh);
170            }
171        }
172    }
173}