1use crate::cache::{CacheConfig, CacheStore};
4use crate::error::Error;
5use serde::{de::DeserializeOwned, Serialize};
6use std::future::Future;
7use std::sync::Arc;
8use std::time::Duration;
9
10pub struct TaggedCache {
12 store: Arc<dyn CacheStore>,
13 tags: Vec<String>,
14 config: CacheConfig,
15}
16
17impl TaggedCache {
18 pub(crate) fn new(store: Arc<dyn CacheStore>, tags: Vec<String>, config: CacheConfig) -> Self {
20 Self {
21 store,
22 tags,
23 config,
24 }
25 }
26
27 fn tagged_key(&self, key: &str) -> String {
29 let tag_prefix: String = self.tags.join(":");
30 if self.config.prefix.is_empty() {
31 format!("tag:{tag_prefix}:{key}")
32 } else {
33 format!("{}:tag:{}:{}", self.config.prefix, tag_prefix, key)
34 }
35 }
36
37 pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, Error> {
39 let key = self.tagged_key(key);
40 match self.store.get_raw(&key).await? {
41 Some(bytes) => {
42 let value = serde_json::from_slice(&bytes)
43 .map_err(|e| Error::deserialization(e.to_string()))?;
44 Ok(Some(value))
45 }
46 None => Ok(None),
47 }
48 }
49
50 pub async fn put<T: Serialize>(
52 &self,
53 key: &str,
54 value: &T,
55 ttl: Duration,
56 ) -> Result<(), Error> {
57 let tagged_key = self.tagged_key(key);
58 let bytes = serde_json::to_vec(value).map_err(|e| Error::serialization(e.to_string()))?;
59
60 self.store.put_raw(&tagged_key, bytes, ttl).await?;
62
63 for tag in &self.tags {
65 let tag_set_key = format!("tag_set:{tag}");
66 self.store.tag_add(&tag_set_key, &tagged_key).await?;
67 }
68
69 Ok(())
70 }
71
72 pub async fn put_default<T: Serialize>(&self, key: &str, value: &T) -> Result<(), Error> {
74 self.put(key, value, self.config.default_ttl).await
75 }
76
77 pub async fn forever<T: Serialize>(&self, key: &str, value: &T) -> Result<(), Error> {
79 self.put(key, value, Duration::from_secs(315_360_000)).await
80 }
81
82 pub async fn has(&self, key: &str) -> Result<bool, Error> {
84 let key = self.tagged_key(key);
85 self.store.has(&key).await
86 }
87
88 pub async fn forget(&self, key: &str) -> Result<bool, Error> {
90 let key = self.tagged_key(key);
91 self.store.forget(&key).await
92 }
93
94 pub async fn flush(&self) -> Result<(), Error> {
96 for tag in &self.tags {
97 let tag_set_key = format!("tag_set:{tag}");
98 self.store.tag_flush(&tag_set_key).await?;
99 }
100 Ok(())
101 }
102
103 pub async fn remember<T, F, Fut>(&self, key: &str, ttl: Duration, f: F) -> Result<T, Error>
105 where
106 T: Serialize + DeserializeOwned,
107 F: FnOnce() -> Fut,
108 Fut: Future<Output = T>,
109 {
110 if let Some(value) = self.get(key).await? {
111 return Ok(value);
112 }
113
114 let value = f().await;
115 self.put(key, &value, ttl).await?;
116 Ok(value)
117 }
118
119 pub async fn remember_forever<T, F, Fut>(&self, key: &str, f: F) -> Result<T, Error>
121 where
122 T: Serialize + DeserializeOwned,
123 F: FnOnce() -> Fut,
124 Fut: Future<Output = T>,
125 {
126 self.remember(key, Duration::from_secs(315_360_000), f)
127 .await
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134 use crate::stores::MemoryStore;
135
136 #[tokio::test]
137 async fn test_tagged_cache_put_get() {
138 let store = Arc::new(MemoryStore::new());
139 let cache = TaggedCache::new(store, vec!["users".to_string()], CacheConfig::default());
140
141 cache
142 .put("user:1", &"Alice", Duration::from_secs(60))
143 .await
144 .unwrap();
145
146 let value: Option<String> = cache.get("user:1").await.unwrap();
147 assert_eq!(value, Some("Alice".to_string()));
148 }
149
150 #[tokio::test]
151 async fn test_tagged_cache_flush() {
152 let store = Arc::new(MemoryStore::new());
153 let cache = TaggedCache::new(
154 store.clone(),
155 vec!["users".to_string()],
156 CacheConfig::default(),
157 );
158
159 cache
160 .put("user:1", &"Alice", Duration::from_secs(60))
161 .await
162 .unwrap();
163 cache
164 .put("user:2", &"Bob", Duration::from_secs(60))
165 .await
166 .unwrap();
167
168 assert!(cache.has("user:1").await.unwrap());
169 assert!(cache.has("user:2").await.unwrap());
170
171 cache.flush().await.unwrap();
172
173 assert!(!cache.has("user:1").await.unwrap());
174 assert!(!cache.has("user:2").await.unwrap());
175 }
176
177 #[tokio::test]
178 async fn test_tagged_cache_remember() {
179 let store = Arc::new(MemoryStore::new());
180 let cache = TaggedCache::new(store, vec!["data".to_string()], CacheConfig::default());
181
182 let mut call_count = 0;
183
184 let value: i32 = cache
185 .remember("computed", Duration::from_secs(60), || async {
186 call_count += 1;
187 42
188 })
189 .await
190 .unwrap();
191
192 assert_eq!(value, 42);
193
194 let value2: i32 = cache
196 .remember("computed", Duration::from_secs(60), || async {
197 call_count += 1;
198 99
199 })
200 .await
201 .unwrap();
202
203 assert_eq!(value2, 42); }
205
206 #[tokio::test]
207 async fn test_tagged_cache_multiple_tags() {
208 let store = Arc::new(MemoryStore::new());
209
210 let cache = TaggedCache::new(
212 store.clone(),
213 vec!["users".to_string(), "admins".to_string()],
214 CacheConfig::default(),
215 );
216
217 cache
218 .put("admin:1", &"Super Admin", Duration::from_secs(60))
219 .await
220 .unwrap();
221
222 let users_cache = TaggedCache::new(
224 store.clone(),
225 vec!["users".to_string()],
226 CacheConfig::default(),
227 );
228
229 users_cache.flush().await.unwrap();
231
232 assert!(!cache.has("admin:1").await.unwrap());
233 }
234}