forge_runtime/function/
cache.rs1use std::collections::HashMap;
2use std::hash::{Hash, Hasher};
3use std::sync::RwLock;
4use std::time::{Duration, Instant};
5
6use serde_json::Value;
7
8pub struct QueryCache {
10 entries: RwLock<HashMap<CacheKey, CacheEntry>>,
11 max_entries: usize,
12}
13
14#[derive(Clone, Eq, PartialEq, Hash)]
15struct CacheKey {
16 function_name: String,
17 args_hash: u64,
18}
19
20struct CacheEntry {
21 value: Value,
22 expires_at: Instant,
23 created_at: Instant,
24}
25
26impl QueryCache {
27 pub fn new() -> Self {
29 Self::with_max_entries(10_000)
30 }
31
32 pub fn with_max_entries(max_entries: usize) -> Self {
34 Self {
35 entries: RwLock::new(HashMap::new()),
36 max_entries,
37 }
38 }
39
40 pub fn get(&self, function_name: &str, args: &Value) -> Option<Value> {
42 let key = self.make_key(function_name, args);
43
44 let entries = self.entries.read().ok()?;
45 let entry = entries.get(&key)?;
46
47 if Instant::now() < entry.expires_at {
48 Some(entry.value.clone())
49 } else {
50 None
51 }
52 }
53
54 pub fn set(&self, function_name: &str, args: &Value, value: Value, ttl: Duration) {
56 let key = self.make_key(function_name, args);
57 let now = Instant::now();
58
59 let entry = CacheEntry {
60 value,
61 expires_at: now + ttl,
62 created_at: now,
63 };
64
65 if let Ok(mut entries) = self.entries.write() {
66 if entries.len() >= self.max_entries {
68 self.evict_expired(&mut entries);
69 }
70
71 if entries.len() >= self.max_entries {
73 self.evict_oldest(&mut entries, self.max_entries / 10);
74 }
75
76 entries.insert(key, entry);
77 }
78 }
79
80 pub fn invalidate(&self, function_name: &str, args: &Value) {
82 let key = self.make_key(function_name, args);
83 if let Ok(mut entries) = self.entries.write() {
84 entries.remove(&key);
85 }
86 }
87
88 pub fn invalidate_function(&self, function_name: &str) {
90 if let Ok(mut entries) = self.entries.write() {
91 entries.retain(|k, _| k.function_name != function_name);
92 }
93 }
94
95 pub fn clear(&self) {
97 if let Ok(mut entries) = self.entries.write() {
98 entries.clear();
99 }
100 }
101
102 pub fn len(&self) -> usize {
104 self.entries.read().map(|e| e.len()).unwrap_or(0)
105 }
106
107 pub fn is_empty(&self) -> bool {
109 self.len() == 0
110 }
111
112 fn make_key(&self, function_name: &str, args: &Value) -> CacheKey {
113 CacheKey {
114 function_name: function_name.to_string(),
115 args_hash: hash_value(args),
116 }
117 }
118
119 fn evict_expired(&self, entries: &mut HashMap<CacheKey, CacheEntry>) {
120 let now = Instant::now();
121 entries.retain(|_, v| v.expires_at > now);
122 }
123
124 fn evict_oldest(&self, entries: &mut HashMap<CacheKey, CacheEntry>, count: usize) {
125 let mut oldest: Vec<_> = entries
126 .iter()
127 .map(|(k, v)| (k.clone(), v.created_at))
128 .collect();
129
130 oldest.sort_by_key(|(_, t)| *t);
131
132 for (key, _) in oldest.into_iter().take(count) {
133 entries.remove(&key);
134 }
135 }
136}
137
138impl Default for QueryCache {
139 fn default() -> Self {
140 Self::new()
141 }
142}
143
144fn hash_value(value: &Value) -> u64 {
145 let mut hasher = std::collections::hash_map::DefaultHasher::new();
146 hash_value_recursive(value, &mut hasher);
147 hasher.finish()
148}
149
150fn hash_value_recursive<H: Hasher>(value: &Value, hasher: &mut H) {
151 match value {
152 Value::Null => 0u8.hash(hasher),
153 Value::Bool(b) => {
154 1u8.hash(hasher);
155 b.hash(hasher);
156 }
157 Value::Number(n) => {
158 2u8.hash(hasher);
159 n.to_string().hash(hasher);
160 }
161 Value::String(s) => {
162 3u8.hash(hasher);
163 s.hash(hasher);
164 }
165 Value::Array(arr) => {
166 4u8.hash(hasher);
167 arr.len().hash(hasher);
168 for v in arr {
169 hash_value_recursive(v, hasher);
170 }
171 }
172 Value::Object(obj) => {
173 5u8.hash(hasher);
174 obj.len().hash(hasher);
175 let mut keys: Vec<_> = obj.keys().collect();
177 keys.sort();
178 for key in keys {
179 key.hash(hasher);
180 hash_value_recursive(&obj[key], hasher);
181 }
182 }
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use serde_json::json;
190
191 #[test]
192 fn test_cache_set_get() {
193 let cache = QueryCache::new();
194 let args = json!({"id": 123});
195 let value = json!({"name": "test"});
196
197 cache.set("get_user", &args, value.clone(), Duration::from_secs(60));
198
199 let result = cache.get("get_user", &args);
200 assert_eq!(result, Some(value));
201 }
202
203 #[test]
204 fn test_cache_miss() {
205 let cache = QueryCache::new();
206 let args = json!({"id": 123});
207
208 let result = cache.get("get_user", &args);
209 assert_eq!(result, None);
210 }
211
212 #[test]
213 fn test_cache_invalidate() {
214 let cache = QueryCache::new();
215 let args = json!({"id": 123});
216 let value = json!({"name": "test"});
217
218 cache.set("get_user", &args, value, Duration::from_secs(60));
219 cache.invalidate("get_user", &args);
220
221 let result = cache.get("get_user", &args);
222 assert_eq!(result, None);
223 }
224
225 #[test]
226 fn test_cache_invalidate_function() {
227 let cache = QueryCache::new();
228 let args1 = json!({"id": 1});
229 let args2 = json!({"id": 2});
230
231 cache.set(
232 "get_user",
233 &args1,
234 json!({"name": "a"}),
235 Duration::from_secs(60),
236 );
237 cache.set(
238 "get_user",
239 &args2,
240 json!({"name": "b"}),
241 Duration::from_secs(60),
242 );
243 cache.set("list_users", &json!({}), json!([]), Duration::from_secs(60));
244
245 cache.invalidate_function("get_user");
246
247 assert_eq!(cache.get("get_user", &args1), None);
248 assert_eq!(cache.get("get_user", &args2), None);
249 assert!(cache.get("list_users", &json!({})).is_some());
250 }
251
252 #[test]
253 fn test_hash_consistency() {
254 let v1 = json!({"a": 1, "b": 2});
255 let v2 = json!({"b": 2, "a": 1});
256
257 assert_eq!(hash_value(&v1), hash_value(&v2));
259 }
260}