Skip to main content

agentic_evolve_core/cache/
invalidation.rs

1//! Cache invalidation with dependency tracking and cascade support.
2
3use std::collections::{HashMap, HashSet};
4use std::hash::Hash;
5use std::sync::RwLock;
6
7use serde::{Deserialize, Serialize};
8
9/// Tracks dependencies between cache keys and supports cascade invalidation.
10///
11/// When key A depends on key B, invalidating B also invalidates A.
12pub struct CacheInvalidator<K> {
13    /// Map from a key to the set of keys that depend on it.
14    dependents: RwLock<HashMap<K, HashSet<K>>>,
15}
16
17/// A record of which keys were invalidated in a cascade.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct InvalidationResult<K> {
20    /// The key that was directly invalidated.
21    pub root: K,
22    /// All keys that were cascade-invalidated (including root).
23    pub invalidated: Vec<K>,
24}
25
26impl<K> CacheInvalidator<K>
27where
28    K: Eq + Hash + Clone,
29{
30    /// Create a new invalidator with no dependencies.
31    pub fn new() -> Self {
32        Self {
33            dependents: RwLock::new(HashMap::new()),
34        }
35    }
36
37    /// Register that `dependent` depends on `dependency`.
38    ///
39    /// When `dependency` is invalidated, `dependent` will also be invalidated.
40    pub fn add_dependency(&self, dependency: K, dependent: K) {
41        let mut map = self.dependents.write().unwrap();
42        map.entry(dependency).or_default().insert(dependent);
43    }
44
45    /// Remove a specific dependency relationship.
46    pub fn remove_dependency(&self, dependency: &K, dependent: &K) {
47        let mut map = self.dependents.write().unwrap();
48        if let Some(deps) = map.get_mut(dependency) {
49            deps.remove(dependent);
50            if deps.is_empty() {
51                map.remove(dependency);
52            }
53        }
54    }
55
56    /// Compute the full set of keys that should be invalidated when `root` is
57    /// invalidated, following the dependency graph transitively.
58    pub fn cascade(&self, root: &K) -> InvalidationResult<K> {
59        let map = self.dependents.read().unwrap();
60        let mut visited = HashSet::new();
61        let mut stack = vec![root.clone()];
62
63        while let Some(key) = stack.pop() {
64            if visited.contains(&key) {
65                continue;
66            }
67            visited.insert(key.clone());
68            if let Some(deps) = map.get(&key) {
69                for dep in deps {
70                    if !visited.contains(dep) {
71                        stack.push(dep.clone());
72                    }
73                }
74            }
75        }
76
77        InvalidationResult {
78            root: root.clone(),
79            invalidated: visited.into_iter().collect(),
80        }
81    }
82
83    /// Clear all dependency tracking.
84    pub fn clear(&self) {
85        self.dependents.write().unwrap().clear();
86    }
87
88    /// Number of keys that have dependents registered.
89    pub fn dependency_count(&self) -> usize {
90        self.dependents.read().unwrap().len()
91    }
92
93    /// Check if a key has any dependents.
94    pub fn has_dependents(&self, key: &K) -> bool {
95        self.dependents
96            .read()
97            .unwrap()
98            .get(key)
99            .is_some_and(|deps| !deps.is_empty())
100    }
101}
102
103impl<K: Eq + Hash + Clone> Default for CacheInvalidator<K> {
104    fn default() -> Self {
105        Self::new()
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112
113    #[test]
114    fn cascade_single_key() {
115        let inv = CacheInvalidator::<String>::new();
116        let result = inv.cascade(&"root".to_string());
117        assert_eq!(result.invalidated.len(), 1);
118        assert!(result.invalidated.contains(&"root".to_string()));
119    }
120
121    #[test]
122    fn cascade_follows_dependencies() {
123        let inv = CacheInvalidator::new();
124        inv.add_dependency("a".to_string(), "b".to_string());
125        inv.add_dependency("b".to_string(), "c".to_string());
126        let result = inv.cascade(&"a".to_string());
127        assert_eq!(result.invalidated.len(), 3);
128        assert!(result.invalidated.contains(&"a".to_string()));
129        assert!(result.invalidated.contains(&"b".to_string()));
130        assert!(result.invalidated.contains(&"c".to_string()));
131    }
132
133    #[test]
134    fn cascade_handles_cycles() {
135        let inv = CacheInvalidator::new();
136        inv.add_dependency("x".to_string(), "y".to_string());
137        inv.add_dependency("y".to_string(), "x".to_string());
138        let result = inv.cascade(&"x".to_string());
139        assert_eq!(result.invalidated.len(), 2);
140    }
141
142    #[test]
143    fn remove_dependency_works() {
144        let inv = CacheInvalidator::new();
145        inv.add_dependency("a".to_string(), "b".to_string());
146        assert!(inv.has_dependents(&"a".to_string()));
147        inv.remove_dependency(&"a".to_string(), &"b".to_string());
148        assert!(!inv.has_dependents(&"a".to_string()));
149    }
150
151    #[test]
152    fn clear_removes_all() {
153        let inv = CacheInvalidator::new();
154        inv.add_dependency("a".to_string(), "b".to_string());
155        inv.add_dependency("c".to_string(), "d".to_string());
156        inv.clear();
157        assert_eq!(inv.dependency_count(), 0);
158    }
159
160    #[test]
161    fn default_is_empty() {
162        let inv = CacheInvalidator::<String>::default();
163        assert_eq!(inv.dependency_count(), 0);
164    }
165}