Skip to main content

forge_runtime/function/
cache.rs

1use std::collections::HashMap;
2use std::hash::{Hash, Hasher};
3use std::sync::{Arc, RwLock};
4use std::time::{Duration, Instant};
5
6use serde_json::Value;
7
8/// A simple in-memory cache for query results.
9pub struct QueryCache {
10    entries: RwLock<HashMap<CacheKey, CacheEntry>>,
11    max_entries: usize,
12}
13
14#[derive(Clone, Eq, PartialEq, Hash)]
15struct CacheKey {
16    function_name: String,
17    args_hash: u64,
18    auth_scope_hash: u64,
19}
20
21struct CacheEntry {
22    value: Arc<Value>,
23    expires_at: Instant,
24    created_at: Instant,
25}
26
27impl QueryCache {
28    /// Create a new query cache with default settings.
29    pub fn new() -> Self {
30        Self::with_max_entries(10_000)
31    }
32
33    /// Create a new query cache with a maximum number of entries.
34    pub fn with_max_entries(max_entries: usize) -> Self {
35        Self {
36            entries: RwLock::new(HashMap::new()),
37            max_entries,
38        }
39    }
40
41    /// Get a cached value if it exists and hasn't expired.
42    pub fn get(
43        &self,
44        function_name: &str,
45        args: &Value,
46        auth_scope: Option<&str>,
47    ) -> Option<Arc<Value>> {
48        let key = self.make_key(function_name, args, auth_scope);
49
50        let entries = self.entries.read().ok()?;
51        let entry = entries.get(&key)?;
52
53        if Instant::now() < entry.expires_at {
54            Some(Arc::clone(&entry.value))
55        } else {
56            None
57        }
58    }
59
60    /// Set a cached value with a TTL.
61    pub fn set(
62        &self,
63        function_name: &str,
64        args: &Value,
65        auth_scope: Option<&str>,
66        value: Value,
67        ttl: Duration,
68    ) {
69        let key = self.make_key(function_name, args, auth_scope);
70        let now = Instant::now();
71
72        let entry = CacheEntry {
73            value: Arc::new(value),
74            expires_at: now + ttl,
75            created_at: now,
76        };
77
78        if let Ok(mut entries) = self.entries.write() {
79            // Evict expired entries if we're at capacity
80            if entries.len() >= self.max_entries {
81                self.evict_expired(&mut entries);
82            }
83
84            // If still at capacity, evict oldest entries
85            if entries.len() >= self.max_entries {
86                self.evict_oldest(&mut entries, self.max_entries / 10);
87            }
88
89            entries.insert(key, entry);
90        }
91    }
92
93    /// Invalidate a specific cache entry.
94    pub fn invalidate(&self, function_name: &str, args: &Value) {
95        let key = self.make_key(function_name, args, None);
96        if let Ok(mut entries) = self.entries.write() {
97            entries.retain(|k, _| {
98                !(k.function_name == key.function_name && k.args_hash == key.args_hash)
99            });
100        }
101    }
102
103    /// Invalidate all entries for a function.
104    pub fn invalidate_function(&self, function_name: &str) {
105        if let Ok(mut entries) = self.entries.write() {
106            entries.retain(|k, _| k.function_name != function_name);
107        }
108    }
109
110    /// Clear the entire cache.
111    pub fn clear(&self) {
112        if let Ok(mut entries) = self.entries.write() {
113            entries.clear();
114        }
115    }
116
117    /// Get the number of cached entries.
118    pub fn len(&self) -> usize {
119        self.entries.read().map(|e| e.len()).unwrap_or(0)
120    }
121
122    /// Check if the cache is empty.
123    pub fn is_empty(&self) -> bool {
124        self.len() == 0
125    }
126
127    fn make_key(&self, function_name: &str, args: &Value, auth_scope: Option<&str>) -> CacheKey {
128        CacheKey {
129            function_name: function_name.to_string(),
130            args_hash: hash_value(args),
131            auth_scope_hash: hash_str(auth_scope.unwrap_or("")),
132        }
133    }
134
135    fn evict_expired(&self, entries: &mut HashMap<CacheKey, CacheEntry>) {
136        let now = Instant::now();
137        entries.retain(|_, v| v.expires_at > now);
138    }
139
140    fn evict_oldest(&self, entries: &mut HashMap<CacheKey, CacheEntry>, count: usize) {
141        let mut oldest: Vec<_> = entries
142            .iter()
143            .map(|(k, v)| (k.clone(), v.created_at))
144            .collect();
145
146        oldest.sort_by_key(|(_, t)| *t);
147
148        for (key, _) in oldest.into_iter().take(count) {
149            entries.remove(&key);
150        }
151    }
152}
153
154impl Default for QueryCache {
155    fn default() -> Self {
156        Self::new()
157    }
158}
159
160fn hash_value(value: &Value) -> u64 {
161    let mut hasher = std::collections::hash_map::DefaultHasher::new();
162    hash_value_recursive(value, &mut hasher);
163    hasher.finish()
164}
165
166fn hash_str(value: &str) -> u64 {
167    let mut hasher = std::collections::hash_map::DefaultHasher::new();
168    value.hash(&mut hasher);
169    hasher.finish()
170}
171
172fn hash_value_recursive<H: Hasher>(value: &Value, hasher: &mut H) {
173    match value {
174        Value::Null => 0u8.hash(hasher),
175        Value::Bool(b) => {
176            1u8.hash(hasher);
177            b.hash(hasher);
178        }
179        Value::Number(n) => {
180            2u8.hash(hasher);
181            n.to_string().hash(hasher);
182        }
183        Value::String(s) => {
184            3u8.hash(hasher);
185            s.hash(hasher);
186        }
187        Value::Array(arr) => {
188            4u8.hash(hasher);
189            arr.len().hash(hasher);
190            for v in arr {
191                hash_value_recursive(v, hasher);
192            }
193        }
194        Value::Object(obj) => {
195            5u8.hash(hasher);
196            obj.len().hash(hasher);
197            // Sort keys for consistent hashing
198            let mut keys: Vec<_> = obj.keys().collect();
199            keys.sort();
200            for key in keys {
201                key.hash(hasher);
202                hash_value_recursive(&obj[key], hasher);
203            }
204        }
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use serde_json::json;
212
213    #[test]
214    fn test_cache_set_get() {
215        let cache = QueryCache::new();
216        let args = json!({"id": 123});
217        let value = json!({"name": "test"});
218
219        cache.set(
220            "get_user",
221            &args,
222            Some("user:1"),
223            value.clone(),
224            Duration::from_secs(60),
225        );
226
227        let result = cache.get("get_user", &args, Some("user:1"));
228        assert_eq!(result.as_deref(), Some(&value));
229    }
230
231    #[test]
232    fn test_cache_miss() {
233        let cache = QueryCache::new();
234        let args = json!({"id": 123});
235
236        let result = cache.get("get_user", &args, Some("user:1"));
237        assert_eq!(result, None);
238    }
239
240    #[test]
241    fn test_cache_invalidate() {
242        let cache = QueryCache::new();
243        let args = json!({"id": 123});
244        let value = json!({"name": "test"});
245
246        cache.set(
247            "get_user",
248            &args,
249            Some("user:1"),
250            value,
251            Duration::from_secs(60),
252        );
253        cache.invalidate("get_user", &args);
254
255        let result = cache.get("get_user", &args, Some("user:1"));
256        assert_eq!(result, None);
257    }
258
259    #[test]
260    fn test_cache_invalidate_function() {
261        let cache = QueryCache::new();
262        let args1 = json!({"id": 1});
263        let args2 = json!({"id": 2});
264
265        cache.set(
266            "get_user",
267            &args1,
268            Some("user:1"),
269            json!({"name": "a"}),
270            Duration::from_secs(60),
271        );
272        cache.set(
273            "get_user",
274            &args2,
275            Some("user:1"),
276            json!({"name": "b"}),
277            Duration::from_secs(60),
278        );
279        cache.set(
280            "list_users",
281            &json!({}),
282            Some("user:1"),
283            json!([]),
284            Duration::from_secs(60),
285        );
286
287        cache.invalidate_function("get_user");
288
289        assert_eq!(cache.get("get_user", &args1, Some("user:1")), None);
290        assert_eq!(cache.get("get_user", &args2, Some("user:1")), None);
291        assert!(
292            cache
293                .get("list_users", &json!({}), Some("user:1"))
294                .is_some()
295        );
296    }
297
298    #[test]
299    fn test_hash_consistency() {
300        let v1 = json!({"a": 1, "b": 2});
301        let v2 = json!({"b": 2, "a": 1});
302
303        // Object keys should be sorted for consistent hashing
304        assert_eq!(hash_value(&v1), hash_value(&v2));
305    }
306
307    #[test]
308    fn test_cache_isolation_by_auth_scope() {
309        let cache = QueryCache::new();
310        let args = json!({"id": 1});
311
312        cache.set(
313            "get_profile",
314            &args,
315            Some("subject:user-a"),
316            json!({"name": "Alice"}),
317            Duration::from_secs(60),
318        );
319
320        assert!(
321            cache
322                .get("get_profile", &args, Some("subject:user-b"))
323                .is_none()
324        );
325        assert!(
326            cache
327                .get("get_profile", &args, Some("subject:user-a"))
328                .is_some()
329        );
330    }
331}