1use parking_lot::RwLock;
30use std::collections::{HashMap, HashSet};
31use std::sync::Arc;
32
33#[derive(Debug, Clone, PartialEq, Eq)]
35pub enum InvalidationStrategy {
36 Tag(String),
38 Event(String),
40 Dependency(String),
42}
43
44#[derive(Debug, Clone)]
46pub struct InvalidationMetadata {
47 pub tags: Vec<String>,
49 pub events: Vec<String>,
51 pub dependencies: Vec<String>,
53}
54
55impl InvalidationMetadata {
56 pub fn new(tags: Vec<String>, events: Vec<String>, dependencies: Vec<String>) -> Self {
58 Self {
59 tags,
60 events,
61 dependencies,
62 }
63 }
64
65 pub fn is_empty(&self) -> bool {
67 self.tags.is_empty() && self.events.is_empty() && self.dependencies.is_empty()
68 }
69}
70
71pub struct InvalidationRegistry {
76 tag_to_caches: RwLock<HashMap<String, HashSet<String>>>,
78 event_to_caches: RwLock<HashMap<String, HashSet<String>>>,
80 dependency_to_caches: RwLock<HashMap<String, HashSet<String>>>,
82 cache_metadata: RwLock<HashMap<String, InvalidationMetadata>>,
84 clear_callbacks: RwLock<HashMap<String, Arc<dyn Fn() + Send + Sync>>>,
86 invalidation_check_callbacks:
89 RwLock<HashMap<String, Arc<dyn Fn(&dyn Fn(&str) -> bool) + Send + Sync>>>,
90}
91
92impl InvalidationRegistry {
93 fn new() -> Self {
95 Self {
96 tag_to_caches: RwLock::new(HashMap::new()),
97 event_to_caches: RwLock::new(HashMap::new()),
98 dependency_to_caches: RwLock::new(HashMap::new()),
99 cache_metadata: RwLock::new(HashMap::new()),
100 clear_callbacks: RwLock::new(HashMap::new()),
101 invalidation_check_callbacks: RwLock::new(HashMap::new()),
102 }
103 }
104
105 pub fn global() -> &'static InvalidationRegistry {
107 static INSTANCE: std::sync::OnceLock<InvalidationRegistry> = std::sync::OnceLock::new();
108 INSTANCE.get_or_init(InvalidationRegistry::new)
109 }
110
111 pub fn register(&self, cache_name: &str, metadata: InvalidationMetadata) {
118 {
120 let mut tag_map = self.tag_to_caches.write();
121 for tag in &metadata.tags {
122 tag_map
123 .entry(tag.clone())
124 .or_insert_with(HashSet::new)
125 .insert(cache_name.to_string());
126 }
127 }
128
129 {
131 let mut event_map = self.event_to_caches.write();
132 for event in &metadata.events {
133 event_map
134 .entry(event.clone())
135 .or_insert_with(HashSet::new)
136 .insert(cache_name.to_string());
137 }
138 }
139
140 {
142 let mut dep_map = self.dependency_to_caches.write();
143 for dep in &metadata.dependencies {
144 dep_map
145 .entry(dep.clone())
146 .or_insert_with(HashSet::new)
147 .insert(cache_name.to_string());
148 }
149 }
150
151 self.cache_metadata
153 .write()
154 .insert(cache_name.to_string(), metadata);
155 }
156
157 pub fn register_callback<F>(&self, cache_name: &str, callback: F)
166 where
167 F: Fn() + Send + Sync + 'static,
168 {
169 self.clear_callbacks
170 .write()
171 .insert(cache_name.to_string(), Arc::new(callback));
172 }
173
174 pub fn register_invalidation_callback<F>(&self, cache_name: &str, callback: F)
185 where
186 F: Fn(&dyn Fn(&str) -> bool) + Send + Sync + 'static,
187 {
188 self.invalidation_check_callbacks
189 .write()
190 .insert(cache_name.to_string(), Arc::new(callback));
191 }
192
193 pub fn invalidate_by_tag(&self, tag: &str) -> usize {
203 let cache_names = self
204 .tag_to_caches
205 .read()
206 .get(tag)
207 .cloned()
208 .unwrap_or_default();
209
210 self.invalidate_caches(&cache_names)
211 }
212
213 pub fn invalidate_by_event(&self, event: &str) -> usize {
223 let cache_names = self
224 .event_to_caches
225 .read()
226 .get(event)
227 .cloned()
228 .unwrap_or_default();
229
230 self.invalidate_caches(&cache_names)
231 }
232
233 pub fn invalidate_by_dependency(&self, dependency: &str) -> usize {
243 let cache_names = self
244 .dependency_to_caches
245 .read()
246 .get(dependency)
247 .cloned()
248 .unwrap_or_default();
249
250 self.invalidate_caches(&cache_names)
251 }
252
253 pub fn invalidate_cache(&self, cache_name: &str) -> bool {
263 if let Some(callback) = self.clear_callbacks.read().get(cache_name) {
264 callback();
265 true
266 } else {
267 false
268 }
269 }
270
271 fn invalidate_caches(&self, cache_names: &HashSet<String>) -> usize {
281 let callbacks = self.clear_callbacks.read();
282 let mut count = 0;
283
284 for name in cache_names {
285 if let Some(callback) = callbacks.get(name) {
286 callback();
287 count += 1;
288 }
289 }
290
291 count
292 }
293
294 pub fn get_caches_by_tag(&self, tag: &str) -> Vec<String> {
296 self.tag_to_caches
297 .read()
298 .get(tag)
299 .map(|set| set.iter().cloned().collect())
300 .unwrap_or_default()
301 }
302
303 pub fn get_caches_by_event(&self, event: &str) -> Vec<String> {
305 self.event_to_caches
306 .read()
307 .get(event)
308 .map(|set| set.iter().cloned().collect())
309 .unwrap_or_default()
310 }
311
312 pub fn get_dependent_caches(&self, dependency: &str) -> Vec<String> {
314 self.dependency_to_caches
315 .read()
316 .get(dependency)
317 .map(|set| set.iter().cloned().collect())
318 .unwrap_or_default()
319 }
320
321 pub fn invalidate_with<F>(&self, cache_name: &str, predicate: F) -> bool
345 where
346 F: Fn(&str) -> bool,
347 {
348 if let Some(callback) = self.invalidation_check_callbacks.read().get(cache_name) {
349 callback(&predicate);
350 true
351 } else {
352 false
353 }
354 }
355
356 pub fn invalidate_all_with<F>(&self, predicate: F) -> usize
379 where
380 F: Fn(&str, &str) -> bool,
381 {
382 let callbacks = self.invalidation_check_callbacks.read();
383 let mut count = 0;
384
385 for (cache_name, callback) in callbacks.iter() {
386 let cache_name_clone = cache_name.clone();
387 callback(&|key: &str| predicate(&cache_name_clone, key));
388 count += 1;
389 }
390
391 count
392 }
393
394 pub fn clear(&self) {
396 self.tag_to_caches.write().clear();
397 self.event_to_caches.write().clear();
398 self.dependency_to_caches.write().clear();
399 self.cache_metadata.write().clear();
400 self.clear_callbacks.write().clear();
401 self.invalidation_check_callbacks.write().clear();
402 }
403}
404
405impl Default for InvalidationRegistry {
406 fn default() -> Self {
407 Self::new()
408 }
409}
410
411pub fn invalidate_by_tag(tag: &str) -> usize {
430 InvalidationRegistry::global().invalidate_by_tag(tag)
431}
432
433pub fn invalidate_by_event(event: &str) -> usize {
452 InvalidationRegistry::global().invalidate_by_event(event)
453}
454
455pub fn invalidate_by_dependency(dependency: &str) -> usize {
474 InvalidationRegistry::global().invalidate_by_dependency(dependency)
475}
476
477pub fn invalidate_cache(cache_name: &str) -> bool {
498 InvalidationRegistry::global().invalidate_cache(cache_name)
499}
500
501pub fn invalidate_with<F>(cache_name: &str, predicate: F) -> bool
525where
526 F: Fn(&str) -> bool,
527{
528 InvalidationRegistry::global().invalidate_with(cache_name, predicate)
529}
530
531pub fn invalidate_all_with<F>(predicate: F) -> usize
554where
555 F: Fn(&str, &str) -> bool,
556{
557 InvalidationRegistry::global().invalidate_all_with(predicate)
558}
559
560#[cfg(test)]
561mod tests {
562 use super::*;
563 use std::sync::atomic::{AtomicUsize, Ordering};
564
565 #[test]
566 fn test_tag_based_invalidation() {
567 let registry = InvalidationRegistry::new();
568 let counter1 = Arc::new(AtomicUsize::new(0));
569 let counter2 = Arc::new(AtomicUsize::new(0));
570
571 let c1 = counter1.clone();
572 let c2 = counter2.clone();
573
574 registry.register(
576 "cache1",
577 InvalidationMetadata::new(vec!["user_data".to_string()], vec![], vec![]),
578 );
579 registry.register(
580 "cache2",
581 InvalidationMetadata::new(vec!["user_data".to_string()], vec![], vec![]),
582 );
583
584 registry.register_callback("cache1", move || {
585 c1.fetch_add(1, Ordering::SeqCst);
586 });
587 registry.register_callback("cache2", move || {
588 c2.fetch_add(1, Ordering::SeqCst);
589 });
590
591 let count = registry.invalidate_by_tag("user_data");
593 assert_eq!(count, 2);
594 assert_eq!(counter1.load(Ordering::SeqCst), 1);
595 assert_eq!(counter2.load(Ordering::SeqCst), 1);
596 }
597
598 #[test]
599 fn test_event_based_invalidation() {
600 let registry = InvalidationRegistry::new();
601 let counter = Arc::new(AtomicUsize::new(0));
602 let c = counter.clone();
603
604 registry.register(
605 "cache1",
606 InvalidationMetadata::new(vec![], vec!["user_updated".to_string()], vec![]),
607 );
608 registry.register_callback("cache1", move || {
609 c.fetch_add(1, Ordering::SeqCst);
610 });
611
612 let count = registry.invalidate_by_event("user_updated");
613 assert_eq!(count, 1);
614 assert_eq!(counter.load(Ordering::SeqCst), 1);
615 }
616
617 #[test]
618 fn test_dependency_based_invalidation() {
619 let registry = InvalidationRegistry::new();
620 let counter = Arc::new(AtomicUsize::new(0));
621 let c = counter.clone();
622
623 registry.register(
624 "cache1",
625 InvalidationMetadata::new(vec![], vec![], vec!["get_user".to_string()]),
626 );
627 registry.register_callback("cache1", move || {
628 c.fetch_add(1, Ordering::SeqCst);
629 });
630
631 let count = registry.invalidate_by_dependency("get_user");
632 assert_eq!(count, 1);
633 assert_eq!(counter.load(Ordering::SeqCst), 1);
634 }
635
636 #[test]
637 fn test_get_caches_by_tag() {
638 let registry = InvalidationRegistry::new();
639
640 registry.register(
641 "cache1",
642 InvalidationMetadata::new(vec!["tag1".to_string()], vec![], vec![]),
643 );
644 registry.register(
645 "cache2",
646 InvalidationMetadata::new(vec!["tag1".to_string()], vec![], vec![]),
647 );
648
649 let caches = registry.get_caches_by_tag("tag1");
650 assert_eq!(caches.len(), 2);
651 assert!(caches.contains(&"cache1".to_string()));
652 assert!(caches.contains(&"cache2".to_string()));
653 }
654
655 #[test]
656 fn test_invalidate_specific_cache() {
657 let registry = InvalidationRegistry::new();
658 let counter = Arc::new(AtomicUsize::new(0));
659 let c = counter.clone();
660
661 registry.register_callback("cache1", move || {
662 c.fetch_add(1, Ordering::SeqCst);
663 });
664
665 assert!(registry.invalidate_cache("cache1"));
666 assert_eq!(counter.load(Ordering::SeqCst), 1);
667
668 assert!(!registry.invalidate_cache("cache2"));
670 }
671
672 #[test]
673 fn test_clear_registry() {
674 let registry = InvalidationRegistry::new();
675 registry.register("cache1", InvalidationMetadata::new(vec![], vec![], vec![]));
676 registry.clear();
677 assert!(registry.cache_metadata.read().is_empty());
678 }
679
680 #[test]
681 fn test_conditional_invalidation() {
682 use std::sync::Mutex;
683
684 let registry = InvalidationRegistry::new();
685 let removed_keys = Arc::new(Mutex::new(Vec::new()));
686 let removed_keys_clone = removed_keys.clone();
687
688 registry.register_invalidation_callback(
690 "cache1",
691 move |check_fn: &dyn Fn(&str) -> bool| {
692 let test_keys = vec!["key1", "key2", "key100", "key500", "key1001"];
693 let mut removed = removed_keys_clone.lock().unwrap();
694 removed.clear();
695
696 for key in test_keys {
697 if check_fn(key) {
698 removed.push(key.to_string());
699 }
700 }
701 },
702 );
703
704 registry.invalidate_with("cache1", |key: &str| {
706 key.strip_prefix("key")
707 .and_then(|s| s.parse::<u64>().ok())
708 .map(|n| n > 100)
709 .unwrap_or(false)
710 });
711
712 let removed = removed_keys.lock().unwrap();
713 assert_eq!(removed.len(), 2);
714 assert!(removed.contains(&"key500".to_string()));
715 assert!(removed.contains(&"key1001".to_string()));
716 assert!(!removed.contains(&"key1".to_string()));
717 assert!(!removed.contains(&"key2".to_string()));
718 assert!(!removed.contains(&"key100".to_string()));
719 }
720
721 #[test]
722 fn test_conditional_invalidation_nonexistent_cache() {
723 let registry = InvalidationRegistry::new();
724
725 let result = registry.invalidate_with("nonexistent", |_key: &str| true);
727 assert!(!result);
728 }
729
730 #[test]
731 fn test_invalidate_all_with_check_function() {
732 use std::sync::Mutex;
733
734 let registry = InvalidationRegistry::new();
735
736 let cache1_removed = Arc::new(Mutex::new(Vec::new()));
738 let cache2_removed = Arc::new(Mutex::new(Vec::new()));
739
740 let cache1_removed_clone = cache1_removed.clone();
741 let cache2_removed_clone = cache2_removed.clone();
742
743 registry.register_invalidation_callback(
745 "cache1",
746 move |check_fn: &dyn Fn(&str) -> bool| {
747 let test_keys = vec!["1", "2", "3", "4", "5"];
748 let mut removed = cache1_removed_clone.lock().unwrap();
749 removed.clear();
750
751 for key in test_keys {
752 if check_fn(key) {
753 removed.push(key.to_string());
754 }
755 }
756 },
757 );
758
759 registry.register_invalidation_callback(
760 "cache2",
761 move |check_fn: &dyn Fn(&str) -> bool| {
762 let test_keys = vec!["10", "20", "30"];
763 let mut removed = cache2_removed_clone.lock().unwrap();
764 removed.clear();
765
766 for key in test_keys {
767 if check_fn(key) {
768 removed.push(key.to_string());
769 }
770 }
771 },
772 );
773
774 let count = registry.invalidate_all_with(|_cache_name: &str, key: &str| {
776 key.parse::<u64>().unwrap_or(0) >= 3
777 });
778
779 assert_eq!(count, 2); let cache1_removed = cache1_removed.lock().unwrap();
782 assert_eq!(cache1_removed.len(), 3); assert!(cache1_removed.contains(&"3".to_string()));
784 assert!(cache1_removed.contains(&"4".to_string()));
785 assert!(cache1_removed.contains(&"5".to_string()));
786
787 let cache2_removed = cache2_removed.lock().unwrap();
788 assert_eq!(cache2_removed.len(), 3); assert!(cache2_removed.contains(&"10".to_string()));
790 assert!(cache2_removed.contains(&"20".to_string()));
791 assert!(cache2_removed.contains(&"30".to_string()));
792 }
793
794 #[test]
795 fn test_complex_conditional_checks() {
796 use std::sync::Mutex;
797
798 let registry = InvalidationRegistry::new();
799 let removed_keys = Arc::new(Mutex::new(Vec::new()));
800 let removed_keys_clone = removed_keys.clone();
801
802 registry.register_invalidation_callback(
803 "cache1",
804 move |check_fn: &dyn Fn(&str) -> bool| {
805 let test_keys = vec!["user_10", "user_20", "user_30", "user_40", "user_50"];
806 let mut removed = removed_keys_clone.lock().unwrap();
807 removed.clear();
808
809 for key in test_keys {
810 if check_fn(key) {
811 removed.push(key.to_string());
812 }
813 }
814 },
815 );
816
817 registry.invalidate_with("cache1", |key: &str| {
819 key.strip_prefix("user_")
820 .and_then(|s| s.parse::<u64>().ok())
821 .map(|n| n % 20 == 0)
822 .unwrap_or(false)
823 });
824
825 let removed = removed_keys.lock().unwrap();
826 assert_eq!(removed.len(), 2);
827 assert!(removed.contains(&"user_20".to_string()));
828 assert!(removed.contains(&"user_40".to_string()));
829 assert!(!removed.contains(&"user_10".to_string()));
830 assert!(!removed.contains(&"user_30".to_string()));
831 }
832}