1use core::future::Future;
2use core::hash::Hash;
3use dashmap::{DashMap, DashSet};
4use serde::{Deserialize, Serialize};
5use std::fmt::Debug;
6use std::marker::{Send, Sync};
7use std::pin::Pin;
8#[derive(Clone, Debug)]
9pub struct DashmapCache {
10 inner: DashMap<Vec<u8>, Vec<u8>>,
11 tags: DashMap<String, DashSet<Vec<u8>>>,
12}
13
14#[derive(Debug)]
15pub enum CacheError {
16 Decode(rmp_serde::decode::Error),
17 Encode(rmp_serde::encode::Error),
18}
19
20impl From<rmp_serde::decode::Error> for CacheError {
21 fn from(value: rmp_serde::decode::Error) -> Self {
22 Self::Decode(value)
23 }
24}
25
26impl From<rmp_serde::encode::Error> for CacheError {
27 fn from(value: rmp_serde::encode::Error) -> Self {
28 Self::Encode(value)
29 }
30}
31
32impl<'a> DashmapCache {
33 pub fn new() -> Self {
34 let inner = DashMap::new();
35 Self {
36 inner,
37 tags: DashMap::new(),
38 }
39 }
40
41 fn insert(&self, tags: &Vec<String>, key: Vec<u8>, val: Vec<u8>) -> Option<Vec<u8>> {
42 for tag in tags {
43 if !self.tags.contains_key(tag) {
44 let dash = DashSet::new();
45 dash.insert(key.clone());
46 self.tags.insert(tag.to_owned(), dash);
47 } else {
48 self.tags.alter(tag, |_k, ex_tags| {
49 ex_tags.insert(key.clone());
50 ex_tags
51 })
52 }
53 }
54 self.inner.insert(key, val)
55 }
56
57 pub fn refresh_cache<F, A, V>(
59 &self,
60 invalidate_keys: &Vec<String>,
61 closure: F,
62 arg: A,
63 ) -> Result<V, CacheError>
64 where
65 F: Fn(&A) -> V,
66 A: Hash + Sync + Send + Eq + Serialize,
67 V: Send + Sync + Clone + Serialize + for<'b> Deserialize<'b>,
68 {
69 let arg_bytes = rmp_serde::to_vec(&arg)?;
70 let val = closure(&arg);
71 let val_bytes = rmp_serde::to_vec(&val)?;
72 self.insert(invalidate_keys, arg_bytes, val_bytes);
73 Ok(val)
74 }
75
76 pub fn cached<F, A, V>(
81 &self,
82 invalidate_keys: &Vec<String>,
83 closure: F,
84 arg: A,
85 ) -> Result<V, CacheError>
86 where
87 F: Fn(&A) -> V,
88 A: Hash + Sync + Send + Eq + Serialize,
89 V: Send + Sync + Clone + Serialize + for<'b> Deserialize<'b>,
90 {
91 let arg_bytes = rmp_serde::to_vec(&arg)?;
92
93 match self.inner.get(&arg_bytes) {
94 None => {
95 let val = closure(&arg);
96 let val_bytes = rmp_serde::to_vec(&val)?;
97 self.insert(invalidate_keys, arg_bytes, val_bytes);
98 Ok(val)
99 }
100 Some(val) => {
101 let ret_val = rmp_serde::from_slice::<V>(&val)?;
102 Ok(ret_val.to_owned())
103 }
104 }
105 }
106
107 pub async fn async_cached<F, A, V>(
109 &self,
110 invalidate_keys: &Vec<String>,
111 closure: F,
112 arg: A,
113 ) -> Result<V, CacheError>
114 where
115 F: Fn(&A) -> Pin<Box<dyn Future<Output = V>>>,
116 A: Hash + Sync + Send + Eq + Serialize,
117 V: Send + Sync + Clone + Serialize + for<'b> Deserialize<'b>,
118 {
119 let arg_bytes = rmp_serde::to_vec(&arg)?;
120
121 match self.inner.get(&arg_bytes) {
122 None => {
123 let val = closure(&arg).await;
124 let val_bytes = rmp_serde::to_vec(&val)?;
125 self.insert(invalidate_keys, arg_bytes, val_bytes);
126 Ok(val)
127 }
128 Some(val) => {
129 let ret_val = rmp_serde::from_slice::<V>(&val)?;
130 Ok(ret_val.to_owned())
131 }
132 }
133 }
134
135 #[cfg(feature = "tokio")]
137 pub async fn tokio_cached<F, A, V>(
138 &self,
139 invalidate_keys: &Vec<String>,
140 closure: F,
141 arg: A,
142 ) -> Result<V, CacheError>
143 where
144 F: Fn(&A) -> tokio::task::JoinHandle<V>,
145 A: Hash + Sync + Send + Eq + Serialize,
146 V: Send + Sync + Clone + Serialize + for<'b> Deserialize<'b>,
147 {
148 let arg_bytes = rmp_serde::to_vec(&arg)?;
149
150 match self.inner.get(&arg_bytes) {
151 None => {
152 let val = closure(&arg).await.unwrap();
153 let val_bytes = rmp_serde::to_vec(&val)?;
154 self.insert(invalidate_keys, arg_bytes, val_bytes);
155 Ok(val)
156 }
157 Some(val) => {
158 let ret_val = rmp_serde::from_slice::<V>(&val)?;
159 Ok(ret_val.to_owned())
160 }
161 }
162 }
163
164 pub fn invalidate(&self, tag: &str) {
165 let hashes = self.tags.get(tag);
166 if hashes.is_some() {
167 self.tags.remove(tag);
168 for hsh in hashes.unwrap().clone() {
169 self.inner.remove(&hsh);
170 }
171 }
172 }
173}