1use crate::{CacheBackend, CacheError, CacheResult, CacheTag, CacheKey};
7use async_trait::async_trait;
8use dashmap::DashMap;
9use std::{
10 collections::{HashMap, HashSet},
11 sync::Arc,
12 time::Duration,
13};
14use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct TaggedEntry {
19 pub key: CacheKey,
21
22 pub tags: HashSet<CacheTag>,
24
25 pub created_at: chrono::DateTime<chrono::Utc>,
27
28 pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
30}
31
32#[async_trait]
34pub trait TagRegistry: Send + Sync {
35 async fn tag_key(&self, key: &str, tags: &[&str]) -> CacheResult<()>;
37
38 async fn untag_key(&self, key: &str, tags: &[&str]) -> CacheResult<()>;
40
41 async fn get_keys_by_tag(&self, tag: &str) -> CacheResult<Vec<String>>;
43
44 async fn get_tags_for_key(&self, key: &str) -> CacheResult<Vec<String>>;
46
47 async fn remove_key(&self, key: &str) -> CacheResult<()>;
49
50 async fn clear(&self) -> CacheResult<()>;
52}
53
54#[derive(Debug)]
56pub struct MemoryTagRegistry {
57 tag_to_keys: DashMap<CacheTag, HashSet<CacheKey>>,
59
60 key_to_tags: DashMap<CacheKey, HashSet<CacheTag>>,
62}
63
64impl MemoryTagRegistry {
65 pub fn new() -> Self {
66 Self {
67 tag_to_keys: DashMap::new(),
68 key_to_tags: DashMap::new(),
69 }
70 }
71}
72
73impl Default for MemoryTagRegistry {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79#[async_trait]
80impl TagRegistry for MemoryTagRegistry {
81 async fn tag_key(&self, key: &str, tags: &[&str]) -> CacheResult<()> {
82 let key = key.to_string();
83
84 for tag in tags {
85 let tag = tag.to_string();
86
87 self.tag_to_keys
89 .entry(tag.clone())
90 .or_insert_with(HashSet::new)
91 .insert(key.clone());
92
93 self.key_to_tags
95 .entry(key.clone())
96 .or_insert_with(HashSet::new)
97 .insert(tag);
98 }
99
100 Ok(())
101 }
102
103 async fn untag_key(&self, key: &str, tags: &[&str]) -> CacheResult<()> {
104 let key = key.to_string();
105
106 for tag in tags {
107 let tag = tag.to_string();
108
109 if let Some(mut tag_keys) = self.tag_to_keys.get_mut(&tag) {
111 tag_keys.remove(&key);
112 if tag_keys.is_empty() {
113 drop(tag_keys);
114 self.tag_to_keys.remove(&tag);
115 }
116 }
117
118 if let Some(mut key_tags) = self.key_to_tags.get_mut(&key) {
120 key_tags.remove(&tag);
121 if key_tags.is_empty() {
122 drop(key_tags);
123 self.key_to_tags.remove(&key);
124 }
125 }
126 }
127
128 Ok(())
129 }
130
131 async fn get_keys_by_tag(&self, tag: &str) -> CacheResult<Vec<String>> {
132 if let Some(keys) = self.tag_to_keys.get(tag) {
133 Ok(keys.iter().cloned().collect())
134 } else {
135 Ok(vec![])
136 }
137 }
138
139 async fn get_tags_for_key(&self, key: &str) -> CacheResult<Vec<String>> {
140 if let Some(tags) = self.key_to_tags.get(key) {
141 Ok(tags.iter().cloned().collect())
142 } else {
143 Ok(vec![])
144 }
145 }
146
147 async fn remove_key(&self, key: &str) -> CacheResult<()> {
148 let key = key.to_string();
149
150 if let Some((_, tags)) = self.key_to_tags.remove(&key) {
152 for tag in tags {
154 if let Some(mut tag_keys) = self.tag_to_keys.get_mut(&tag) {
155 tag_keys.remove(&key);
156 if tag_keys.is_empty() {
157 drop(tag_keys);
158 self.tag_to_keys.remove(&tag);
159 }
160 }
161 }
162 }
163
164 Ok(())
165 }
166
167 async fn clear(&self) -> CacheResult<()> {
168 self.tag_to_keys.clear();
169 self.key_to_tags.clear();
170 Ok(())
171 }
172}
173
174pub struct TaggedCache<B, R>
176where
177 B: CacheBackend,
178 R: TagRegistry,
179{
180 backend: B,
181 registry: R,
182}
183
184impl<B, R> TaggedCache<B, R>
185where
186 B: CacheBackend,
187 R: TagRegistry,
188{
189 pub fn new(backend: B, registry: R) -> Self {
190 Self { backend, registry }
191 }
192
193 pub async fn put_with_tags(
195 &self,
196 key: &str,
197 value: Vec<u8>,
198 ttl: Option<Duration>,
199 tags: &[&str],
200 ) -> CacheResult<()> {
201 self.backend.put(key, value, ttl).await?;
203
204 if !tags.is_empty() {
206 self.registry.tag_key(key, tags).await?;
207 }
208
209 Ok(())
210 }
211
212 pub async fn forget_by_tag(&self, tag: &str) -> CacheResult<Vec<String>> {
214 let keys = self.registry.get_keys_by_tag(tag).await?;
215
216 if keys.is_empty() {
217 return Ok(Vec::new());
218 }
219
220 let mut removed_keys = Vec::new();
222 for key in keys {
223 let was_removed = self.backend.forget(&key).await?;
225
226 self.registry.remove_key(&key).await?;
228
229 if was_removed {
230 removed_keys.push(key);
231 }
232 }
233
234 Ok(removed_keys)
235 }
236
237 pub async fn forget_by_tags(&self, tags: &[&str]) -> CacheResult<Vec<String>> {
239 let mut all_keys = HashSet::new();
240
241 for tag in tags {
243 let keys = self.registry.get_keys_by_tag(tag).await?;
244 all_keys.extend(keys);
245 }
246
247 if all_keys.is_empty() {
248 return Ok(Vec::new());
249 }
250
251 let mut removed_keys = Vec::new();
253 for key in all_keys {
254 let was_removed = self.backend.forget(&key).await?;
256
257 self.registry.remove_key(&key).await?;
259
260 if was_removed {
261 removed_keys.push(key);
262 }
263 }
264
265 Ok(removed_keys)
266 }
267
268 pub async fn keys_by_tag(&self, tag: &str) -> CacheResult<Vec<String>> {
270 self.registry.get_keys_by_tag(tag).await
271 }
272
273 pub async fn tags_for_key(&self, key: &str) -> CacheResult<Vec<String>> {
275 self.registry.get_tags_for_key(key).await
276 }
277
278 pub async fn tag_existing(&self, key: &str, tags: &[&str]) -> CacheResult<()> {
280 if self.backend.exists(key).await? {
282 self.registry.tag_key(key, tags).await
283 } else {
284 Err(CacheError::KeyNotFound(key.to_string()))
285 }
286 }
287
288 pub async fn untag(&self, key: &str, tags: &[&str]) -> CacheResult<()> {
290 self.registry.untag_key(key, tags).await
291 }
292
293 pub async fn tagged_stats(&self) -> CacheResult<TaggedCacheStats> {
295 let base_stats = self.backend.stats().await?;
296
297 Ok(TaggedCacheStats {
300 base_stats,
301 total_tags: 0,
302 tagged_keys: 0,
303 })
304 }
305}
306
307#[derive(Debug, Clone)]
309pub struct TaggedCacheStats {
310 pub base_stats: crate::CacheStats,
311 pub total_tags: u64,
312 pub tagged_keys: u64,
313}
314
315#[async_trait]
316impl<B, R> CacheBackend for TaggedCache<B, R>
317where
318 B: CacheBackend,
319 R: TagRegistry,
320{
321 async fn get(&self, key: &str) -> CacheResult<Option<Vec<u8>>> {
322 self.backend.get(key).await
323 }
324
325 async fn put(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> CacheResult<()> {
326 self.backend.put(key, value, ttl).await
327 }
328
329 async fn forget(&self, key: &str) -> CacheResult<bool> {
330 let result = self.backend.forget(key).await?;
331
332 if result {
333 self.registry.remove_key(key).await?;
335 }
336
337 Ok(result)
338 }
339
340 async fn exists(&self, key: &str) -> CacheResult<bool> {
341 self.backend.exists(key).await
342 }
343
344 async fn flush(&self) -> CacheResult<()> {
345 self.backend.flush().await?;
347
348 self.registry.clear().await?;
350
351 Ok(())
352 }
353
354 async fn get_many(&self, keys: &[&str]) -> CacheResult<Vec<Option<Vec<u8>>>> {
355 self.backend.get_many(keys).await
356 }
357
358 async fn put_many(&self, entries: &[(&str, Vec<u8>, Option<Duration>)]) -> CacheResult<()> {
359 self.backend.put_many(entries).await
360 }
361
362 async fn stats(&self) -> CacheResult<crate::CacheStats> {
363 self.backend.stats().await
364 }
365}
366
367pub struct TaggedCacheManager<B, R>
369where
370 B: CacheBackend,
371 R: TagRegistry,
372{
373 cache: TaggedCache<B, R>,
374}
375
376impl<B, R> TaggedCacheManager<B, R>
377where
378 B: CacheBackend,
379 R: TagRegistry,
380{
381 pub fn new(backend: B, registry: R) -> Self {
382 Self {
383 cache: TaggedCache::new(backend, registry),
384 }
385 }
386
387 pub async fn remember_with_tags<T, F, Fut>(
389 &self,
390 key: &str,
391 ttl: Duration,
392 tags: &[&str],
393 compute: F,
394 ) -> CacheResult<T>
395 where
396 T: serde::Serialize + serde::de::DeserializeOwned,
397 F: FnOnce() -> Fut,
398 Fut: std::future::Future<Output = T>,
399 {
400 if let Some(cached_bytes) = self.cache.get(key).await? {
402 let value: T = serde_json::from_slice(&cached_bytes)
403 .map_err(CacheError::Serialization)?;
404 return Ok(value);
405 }
406
407 let value = compute().await;
409
410 let bytes = serde_json::to_vec(&value).map_err(CacheError::Serialization)?;
412 self.cache.put_with_tags(key, bytes, Some(ttl), tags).await?;
413
414 Ok(value)
415 }
416
417 pub async fn invalidate_by_tags(&self, tags: &[&str]) -> CacheResult<u32> {
419 let removed_keys = self.cache.forget_by_tags(tags).await?;
420 Ok(removed_keys.len() as u32)
421 }
422
423 pub async fn add_tags(&self, key: &str, tags: &[&str]) -> CacheResult<()> {
425 self.cache.tag_existing(key, tags).await
426 }
427
428 pub async fn remove_tags(&self, key: &str, tags: &[&str]) -> CacheResult<()> {
429 self.cache.untag(key, tags).await
430 }
431
432 pub async fn get_key_tags(&self, key: &str) -> CacheResult<Vec<String>> {
433 self.cache.tags_for_key(key).await
434 }
435
436 pub async fn get_tagged_keys(&self, tag: &str) -> CacheResult<Vec<String>> {
437 self.cache.keys_by_tag(tag).await
438 }
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444 use crate::backends::MemoryBackend;
445 use crate::config::CacheConfig;
446 use std::time::Duration;
447
448 #[tokio::test]
449 async fn test_memory_tag_registry() {
450 let registry = MemoryTagRegistry::new();
451
452 registry.tag_key("user:1", &["users", "active"]).await.unwrap();
454 registry.tag_key("user:2", &["users"]).await.unwrap();
455 registry.tag_key("post:1", &["posts", "active"]).await.unwrap();
456
457 let users = registry.get_keys_by_tag("users").await.unwrap();
459 assert_eq!(users.len(), 2);
460 assert!(users.contains(&"user:1".to_string()));
461 assert!(users.contains(&"user:2".to_string()));
462
463 let active = registry.get_keys_by_tag("active").await.unwrap();
464 assert_eq!(active.len(), 2);
465 assert!(active.contains(&"user:1".to_string()));
466 assert!(active.contains(&"post:1".to_string()));
467
468 let tags = registry.get_tags_for_key("user:1").await.unwrap();
470 assert_eq!(tags.len(), 2);
471 assert!(tags.contains(&"users".to_string()));
472 assert!(tags.contains(&"active".to_string()));
473
474 registry.untag_key("user:1", &["active"]).await.unwrap();
476 let tags = registry.get_tags_for_key("user:1").await.unwrap();
477 assert_eq!(tags.len(), 1);
478 assert!(tags.contains(&"users".to_string()));
479
480 let active = registry.get_keys_by_tag("active").await.unwrap();
481 assert_eq!(active.len(), 1);
482 assert!(active.contains(&"post:1".to_string()));
483
484 registry.remove_key("user:2").await.unwrap();
486 let users = registry.get_keys_by_tag("users").await.unwrap();
487 assert_eq!(users.len(), 1);
488 assert!(users.contains(&"user:1".to_string()));
489 }
490
491 #[tokio::test]
492 async fn test_tagged_cache() {
493 let backend = MemoryBackend::new(CacheConfig::default());
494 let registry = MemoryTagRegistry::new();
495 let cache = TaggedCache::new(backend, registry);
496
497 cache.put_with_tags(
499 "user:1",
500 b"user data".to_vec(),
501 Some(Duration::from_secs(60)),
502 &["users", "active"]
503 ).await.unwrap();
504
505 cache.put_with_tags(
506 "user:2",
507 b"user data 2".to_vec(),
508 Some(Duration::from_secs(60)),
509 &["users"]
510 ).await.unwrap();
511
512 cache.put_with_tags(
513 "post:1",
514 b"post data".to_vec(),
515 Some(Duration::from_secs(60)),
516 &["posts", "active"]
517 ).await.unwrap();
518
519 let data = cache.get("user:1").await.unwrap();
521 assert_eq!(data, Some(b"user data".to_vec()));
522
523 let user_keys = cache.keys_by_tag("users").await.unwrap();
525 assert_eq!(user_keys.len(), 2);
526
527 let active_keys = cache.keys_by_tag("active").await.unwrap();
528 assert_eq!(active_keys.len(), 2);
529
530 let removed = cache.forget_by_tag("active").await.unwrap();
532 assert_eq!(removed.len(), 2);
533
534 assert_eq!(cache.get("user:1").await.unwrap(), None);
536 assert_eq!(cache.get("post:1").await.unwrap(), None);
537 assert_eq!(cache.get("user:2").await.unwrap(), Some(b"user data 2".to_vec()));
538 }
539
540 #[tokio::test]
541 async fn test_tagged_cache_manager() {
542 let backend = MemoryBackend::new(CacheConfig::default());
543 let registry = MemoryTagRegistry::new();
544 let manager = TaggedCacheManager::new(backend, registry);
545
546 let mut call_count = 0;
547
548 let result = manager.remember_with_tags(
550 "expensive:1",
551 Duration::from_secs(60),
552 &["expensive", "computations"],
553 || async {
554 call_count += 1;
555 format!("result_{}", call_count)
556 }
557 ).await.unwrap();
558
559 assert_eq!(result, "result_1");
560
561 let result2 = manager.remember_with_tags(
563 "expensive:1",
564 Duration::from_secs(60),
565 &["expensive", "computations"],
566 || async {
567 call_count += 1;
568 format!("result_{}", call_count)
569 }
570 ).await.unwrap();
571
572 assert_eq!(result2, "result_1");
573 assert_eq!(call_count, 1); let invalidated = manager.invalidate_by_tags(&["expensive"]).await.unwrap();
577 assert_eq!(invalidated, 1);
578
579 let result3 = manager.remember_with_tags(
581 "expensive:1",
582 Duration::from_secs(60),
583 &["expensive", "computations"],
584 || async {
585 call_count += 1;
586 format!("result_{}", call_count)
587 }
588 ).await.unwrap();
589
590 assert_eq!(result3, "result_2");
591 assert_eq!(call_count, 2);
592 }
593
594 #[tokio::test]
595 async fn test_forget_by_tag_selective_removal() {
596 let backend = MemoryBackend::new(CacheConfig::default());
597 let registry = MemoryTagRegistry::new();
598 let cache = TaggedCache::new(backend, registry);
599
600 cache.put_with_tags("key1", b"data1".to_vec(), Some(Duration::from_secs(60)), &["tag1"]).await.unwrap();
602 cache.put_with_tags("key2", b"data2".to_vec(), Some(Duration::from_secs(60)), &["tag1"]).await.unwrap();
603 cache.put_with_tags("key3", b"data3".to_vec(), Some(Duration::from_secs(60)), &["tag1"]).await.unwrap();
604
605 cache.backend.forget("key2").await.unwrap();
607
608 let removed = cache.forget_by_tag("tag1").await.unwrap();
610
611 assert_eq!(removed.len(), 2);
613 assert!(removed.contains(&"key1".to_string()));
614 assert!(removed.contains(&"key3".to_string()));
615 assert!(!removed.contains(&"key2".to_string()));
616 }
617}