forge_runtime/function/
cache.rs

1use std::collections::HashMap;
2use std::hash::{Hash, Hasher};
3use std::sync::RwLock;
4use std::time::{Duration, Instant};
5
6use serde_json::Value;
7
8/// A simple in-memory cache for query results.
9pub struct QueryCache {
10    entries: RwLock<HashMap<CacheKey, CacheEntry>>,
11    max_entries: usize,
12}
13
14#[derive(Clone, Eq, PartialEq, Hash)]
15struct CacheKey {
16    function_name: String,
17    args_hash: u64,
18}
19
20struct CacheEntry {
21    value: Value,
22    expires_at: Instant,
23    created_at: Instant,
24}
25
26impl QueryCache {
27    /// Create a new query cache with default settings.
28    pub fn new() -> Self {
29        Self::with_max_entries(10_000)
30    }
31
32    /// Create a new query cache with a maximum number of entries.
33    pub fn with_max_entries(max_entries: usize) -> Self {
34        Self {
35            entries: RwLock::new(HashMap::new()),
36            max_entries,
37        }
38    }
39
40    /// Get a cached value if it exists and hasn't expired.
41    pub fn get(&self, function_name: &str, args: &Value) -> Option<Value> {
42        let key = self.make_key(function_name, args);
43
44        let entries = self.entries.read().ok()?;
45        let entry = entries.get(&key)?;
46
47        if Instant::now() < entry.expires_at {
48            Some(entry.value.clone())
49        } else {
50            None
51        }
52    }
53
54    /// Set a cached value with a TTL.
55    pub fn set(&self, function_name: &str, args: &Value, value: Value, ttl: Duration) {
56        let key = self.make_key(function_name, args);
57        let now = Instant::now();
58
59        let entry = CacheEntry {
60            value,
61            expires_at: now + ttl,
62            created_at: now,
63        };
64
65        if let Ok(mut entries) = self.entries.write() {
66            // Evict expired entries if we're at capacity
67            if entries.len() >= self.max_entries {
68                self.evict_expired(&mut entries);
69            }
70
71            // If still at capacity, evict oldest entries
72            if entries.len() >= self.max_entries {
73                self.evict_oldest(&mut entries, self.max_entries / 10);
74            }
75
76            entries.insert(key, entry);
77        }
78    }
79
80    /// Invalidate a specific cache entry.
81    pub fn invalidate(&self, function_name: &str, args: &Value) {
82        let key = self.make_key(function_name, args);
83        if let Ok(mut entries) = self.entries.write() {
84            entries.remove(&key);
85        }
86    }
87
88    /// Invalidate all entries for a function.
89    pub fn invalidate_function(&self, function_name: &str) {
90        if let Ok(mut entries) = self.entries.write() {
91            entries.retain(|k, _| k.function_name != function_name);
92        }
93    }
94
95    /// Clear the entire cache.
96    pub fn clear(&self) {
97        if let Ok(mut entries) = self.entries.write() {
98            entries.clear();
99        }
100    }
101
102    /// Get the number of cached entries.
103    pub fn len(&self) -> usize {
104        self.entries.read().map(|e| e.len()).unwrap_or(0)
105    }
106
107    /// Check if the cache is empty.
108    pub fn is_empty(&self) -> bool {
109        self.len() == 0
110    }
111
112    fn make_key(&self, function_name: &str, args: &Value) -> CacheKey {
113        CacheKey {
114            function_name: function_name.to_string(),
115            args_hash: hash_value(args),
116        }
117    }
118
119    fn evict_expired(&self, entries: &mut HashMap<CacheKey, CacheEntry>) {
120        let now = Instant::now();
121        entries.retain(|_, v| v.expires_at > now);
122    }
123
124    fn evict_oldest(&self, entries: &mut HashMap<CacheKey, CacheEntry>, count: usize) {
125        let mut oldest: Vec<_> = entries
126            .iter()
127            .map(|(k, v)| (k.clone(), v.created_at))
128            .collect();
129
130        oldest.sort_by_key(|(_, t)| *t);
131
132        for (key, _) in oldest.into_iter().take(count) {
133            entries.remove(&key);
134        }
135    }
136}
137
138impl Default for QueryCache {
139    fn default() -> Self {
140        Self::new()
141    }
142}
143
144fn hash_value(value: &Value) -> u64 {
145    let mut hasher = std::collections::hash_map::DefaultHasher::new();
146    hash_value_recursive(value, &mut hasher);
147    hasher.finish()
148}
149
150fn hash_value_recursive<H: Hasher>(value: &Value, hasher: &mut H) {
151    match value {
152        Value::Null => 0u8.hash(hasher),
153        Value::Bool(b) => {
154            1u8.hash(hasher);
155            b.hash(hasher);
156        }
157        Value::Number(n) => {
158            2u8.hash(hasher);
159            n.to_string().hash(hasher);
160        }
161        Value::String(s) => {
162            3u8.hash(hasher);
163            s.hash(hasher);
164        }
165        Value::Array(arr) => {
166            4u8.hash(hasher);
167            arr.len().hash(hasher);
168            for v in arr {
169                hash_value_recursive(v, hasher);
170            }
171        }
172        Value::Object(obj) => {
173            5u8.hash(hasher);
174            obj.len().hash(hasher);
175            // Sort keys for consistent hashing
176            let mut keys: Vec<_> = obj.keys().collect();
177            keys.sort();
178            for key in keys {
179                key.hash(hasher);
180                hash_value_recursive(&obj[key], hasher);
181            }
182        }
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use serde_json::json;
190
191    #[test]
192    fn test_cache_set_get() {
193        let cache = QueryCache::new();
194        let args = json!({"id": 123});
195        let value = json!({"name": "test"});
196
197        cache.set("get_user", &args, value.clone(), Duration::from_secs(60));
198
199        let result = cache.get("get_user", &args);
200        assert_eq!(result, Some(value));
201    }
202
203    #[test]
204    fn test_cache_miss() {
205        let cache = QueryCache::new();
206        let args = json!({"id": 123});
207
208        let result = cache.get("get_user", &args);
209        assert_eq!(result, None);
210    }
211
212    #[test]
213    fn test_cache_invalidate() {
214        let cache = QueryCache::new();
215        let args = json!({"id": 123});
216        let value = json!({"name": "test"});
217
218        cache.set("get_user", &args, value, Duration::from_secs(60));
219        cache.invalidate("get_user", &args);
220
221        let result = cache.get("get_user", &args);
222        assert_eq!(result, None);
223    }
224
225    #[test]
226    fn test_cache_invalidate_function() {
227        let cache = QueryCache::new();
228        let args1 = json!({"id": 1});
229        let args2 = json!({"id": 2});
230
231        cache.set(
232            "get_user",
233            &args1,
234            json!({"name": "a"}),
235            Duration::from_secs(60),
236        );
237        cache.set(
238            "get_user",
239            &args2,
240            json!({"name": "b"}),
241            Duration::from_secs(60),
242        );
243        cache.set("list_users", &json!({}), json!([]), Duration::from_secs(60));
244
245        cache.invalidate_function("get_user");
246
247        assert_eq!(cache.get("get_user", &args1), None);
248        assert_eq!(cache.get("get_user", &args2), None);
249        assert!(cache.get("list_users", &json!({})).is_some());
250    }
251
252    #[test]
253    fn test_hash_consistency() {
254        let v1 = json!({"a": 1, "b": 2});
255        let v2 = json!({"b": 2, "a": 1});
256
257        // Object keys should be sorted for consistent hashing
258        assert_eq!(hash_value(&v1), hash_value(&v2));
259    }
260}