Skip to main content

agentic_forge_core/cache/
invalidation.rs

1//! Cache invalidation on mutation.
2
3use super::lru::Cache;
4use std::collections::HashSet;
5use std::hash::Hash;
6use std::sync::RwLock;
7
8pub struct CacheInvalidator<K: Hash + Eq + Clone> {
9    dependencies: RwLock<std::collections::HashMap<K, HashSet<K>>>,
10}
11
12impl<K: Hash + Eq + Clone> CacheInvalidator<K> {
13    pub fn new() -> Self {
14        Self {
15            dependencies: RwLock::new(std::collections::HashMap::new()),
16        }
17    }
18
19    pub fn register_dependency(&self, key: K, depends_on: K) {
20        let mut deps = self.dependencies.write().unwrap();
21        deps.entry(depends_on).or_default().insert(key);
22    }
23
24    pub fn invalidate_cascade<V: Clone>(&self, changed_key: &K, cache: &Cache<K, V>) -> usize {
25        let mut invalidated = 0;
26        cache.invalidate(changed_key);
27        invalidated += 1;
28
29        let deps = self.dependencies.read().unwrap();
30        if let Some(dependents) = deps.get(changed_key) {
31            for dependent in dependents {
32                if cache.invalidate(dependent) {
33                    invalidated += 1;
34                }
35            }
36        }
37        invalidated
38    }
39
40    pub fn clear(&self) {
41        self.dependencies.write().unwrap().clear();
42    }
43}
44
45impl<K: Hash + Eq + Clone> Default for CacheInvalidator<K> {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51#[cfg(test)]
52mod tests {
53    use super::*;
54    use std::time::Duration;
55
56    #[test]
57    fn test_invalidate_cascade() {
58        let cache: Cache<String, String> = Cache::new(100, Duration::from_secs(60));
59        let inv: CacheInvalidator<String> = CacheInvalidator::new();
60
61        cache.insert("parent".into(), "p_val".into());
62        cache.insert("child".into(), "c_val".into());
63        inv.register_dependency("child".into(), "parent".into());
64
65        let count = inv.invalidate_cascade(&"parent".into(), &cache);
66        assert_eq!(count, 2);
67        assert!(cache.get(&"parent".into()).is_none());
68        assert!(cache.get(&"child".into()).is_none());
69    }
70
71    #[test]
72    fn test_invalidate_no_dependents() {
73        let cache: Cache<String, String> = Cache::new(100, Duration::from_secs(60));
74        let inv: CacheInvalidator<String> = CacheInvalidator::new();
75
76        cache.insert("solo".into(), "val".into());
77        let count = inv.invalidate_cascade(&"solo".into(), &cache);
78        assert_eq!(count, 1);
79    }
80
81    #[test]
82    fn test_invalidator_clear() {
83        let inv: CacheInvalidator<String> = CacheInvalidator::new();
84        inv.register_dependency("a".into(), "b".into());
85        inv.clear();
86        // After clear, no cascading should happen
87        let cache: Cache<String, String> = Cache::new(10, Duration::from_secs(60));
88        cache.insert("b".into(), "val".into());
89        cache.insert("a".into(), "val".into());
90        let count = inv.invalidate_cascade(&"b".into(), &cache);
91        assert_eq!(count, 1); // Only "b", not "a"
92    }
93}