1use std::collections::HashMap;
2use std::hash::{Hash, Hasher};
3use std::sync::{Arc, RwLock};
4use std::time::{Duration, Instant};
5
6use serde_json::Value;
7
8pub struct QueryCache {
10 entries: RwLock<HashMap<CacheKey, CacheEntry>>,
11 max_entries: usize,
12}
13
14#[derive(Clone, Eq, PartialEq, Hash)]
15struct CacheKey {
16 function_name: String,
17 args_hash: u64,
18 auth_scope_hash: u64,
19}
20
21struct CacheEntry {
22 value: Arc<Value>,
23 expires_at: Instant,
24 created_at: Instant,
25}
26
27impl QueryCache {
28 pub fn new() -> Self {
30 Self::with_max_entries(10_000)
31 }
32
33 pub fn with_max_entries(max_entries: usize) -> Self {
35 Self {
36 entries: RwLock::new(HashMap::new()),
37 max_entries,
38 }
39 }
40
41 pub fn get(
43 &self,
44 function_name: &str,
45 args: &Value,
46 auth_scope: Option<&str>,
47 ) -> Option<Arc<Value>> {
48 let key = self.make_key(function_name, args, auth_scope);
49
50 let entries = self.entries.read().ok()?;
51 let entry = entries.get(&key)?;
52
53 if Instant::now() < entry.expires_at {
54 Some(Arc::clone(&entry.value))
55 } else {
56 None
57 }
58 }
59
60 pub fn set(
62 &self,
63 function_name: &str,
64 args: &Value,
65 auth_scope: Option<&str>,
66 value: Value,
67 ttl: Duration,
68 ) {
69 let key = self.make_key(function_name, args, auth_scope);
70 let now = Instant::now();
71
72 let entry = CacheEntry {
73 value: Arc::new(value),
74 expires_at: now + ttl,
75 created_at: now,
76 };
77
78 if let Ok(mut entries) = self.entries.write() {
79 if entries.len() >= self.max_entries {
81 self.evict_expired(&mut entries);
82 }
83
84 if entries.len() >= self.max_entries {
86 self.evict_oldest(&mut entries, self.max_entries / 10);
87 }
88
89 entries.insert(key, entry);
90 }
91 }
92
93 pub fn invalidate(&self, function_name: &str, args: &Value) {
95 let key = self.make_key(function_name, args, None);
96 if let Ok(mut entries) = self.entries.write() {
97 entries.retain(|k, _| {
98 !(k.function_name == key.function_name && k.args_hash == key.args_hash)
99 });
100 }
101 }
102
103 pub fn invalidate_function(&self, function_name: &str) {
105 if let Ok(mut entries) = self.entries.write() {
106 entries.retain(|k, _| k.function_name != function_name);
107 }
108 }
109
110 pub fn clear(&self) {
112 if let Ok(mut entries) = self.entries.write() {
113 entries.clear();
114 }
115 }
116
117 pub fn len(&self) -> usize {
119 self.entries.read().map(|e| e.len()).unwrap_or(0)
120 }
121
122 pub fn is_empty(&self) -> bool {
124 self.len() == 0
125 }
126
127 fn make_key(&self, function_name: &str, args: &Value, auth_scope: Option<&str>) -> CacheKey {
128 CacheKey {
129 function_name: function_name.to_string(),
130 args_hash: hash_value(args),
131 auth_scope_hash: hash_str(auth_scope.unwrap_or("")),
132 }
133 }
134
135 fn evict_expired(&self, entries: &mut HashMap<CacheKey, CacheEntry>) {
136 let now = Instant::now();
137 entries.retain(|_, v| v.expires_at > now);
138 }
139
140 fn evict_oldest(&self, entries: &mut HashMap<CacheKey, CacheEntry>, count: usize) {
141 let mut oldest: Vec<_> = entries
142 .iter()
143 .map(|(k, v)| (k.clone(), v.created_at))
144 .collect();
145
146 oldest.sort_by_key(|(_, t)| *t);
147
148 for (key, _) in oldest.into_iter().take(count) {
149 entries.remove(&key);
150 }
151 }
152}
153
154impl Default for QueryCache {
155 fn default() -> Self {
156 Self::new()
157 }
158}
159
160fn hash_value(value: &Value) -> u64 {
161 let mut hasher = std::collections::hash_map::DefaultHasher::new();
162 hash_value_recursive(value, &mut hasher);
163 hasher.finish()
164}
165
166fn hash_str(value: &str) -> u64 {
167 let mut hasher = std::collections::hash_map::DefaultHasher::new();
168 value.hash(&mut hasher);
169 hasher.finish()
170}
171
172fn hash_value_recursive<H: Hasher>(value: &Value, hasher: &mut H) {
173 match value {
174 Value::Null => 0u8.hash(hasher),
175 Value::Bool(b) => {
176 1u8.hash(hasher);
177 b.hash(hasher);
178 }
179 Value::Number(n) => {
180 2u8.hash(hasher);
181 n.to_string().hash(hasher);
182 }
183 Value::String(s) => {
184 3u8.hash(hasher);
185 s.hash(hasher);
186 }
187 Value::Array(arr) => {
188 4u8.hash(hasher);
189 arr.len().hash(hasher);
190 for v in arr {
191 hash_value_recursive(v, hasher);
192 }
193 }
194 Value::Object(obj) => {
195 5u8.hash(hasher);
196 obj.len().hash(hasher);
197 let mut keys: Vec<_> = obj.keys().collect();
199 keys.sort();
200 for key in keys {
201 key.hash(hasher);
202 hash_value_recursive(&obj[key], hasher);
203 }
204 }
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use serde_json::json;
212
213 #[test]
214 fn test_cache_set_get() {
215 let cache = QueryCache::new();
216 let args = json!({"id": 123});
217 let value = json!({"name": "test"});
218
219 cache.set(
220 "get_user",
221 &args,
222 Some("user:1"),
223 value.clone(),
224 Duration::from_secs(60),
225 );
226
227 let result = cache.get("get_user", &args, Some("user:1"));
228 assert_eq!(result.as_deref(), Some(&value));
229 }
230
231 #[test]
232 fn test_cache_miss() {
233 let cache = QueryCache::new();
234 let args = json!({"id": 123});
235
236 let result = cache.get("get_user", &args, Some("user:1"));
237 assert_eq!(result, None);
238 }
239
240 #[test]
241 fn test_cache_invalidate() {
242 let cache = QueryCache::new();
243 let args = json!({"id": 123});
244 let value = json!({"name": "test"});
245
246 cache.set(
247 "get_user",
248 &args,
249 Some("user:1"),
250 value,
251 Duration::from_secs(60),
252 );
253 cache.invalidate("get_user", &args);
254
255 let result = cache.get("get_user", &args, Some("user:1"));
256 assert_eq!(result, None);
257 }
258
259 #[test]
260 fn test_cache_invalidate_function() {
261 let cache = QueryCache::new();
262 let args1 = json!({"id": 1});
263 let args2 = json!({"id": 2});
264
265 cache.set(
266 "get_user",
267 &args1,
268 Some("user:1"),
269 json!({"name": "a"}),
270 Duration::from_secs(60),
271 );
272 cache.set(
273 "get_user",
274 &args2,
275 Some("user:1"),
276 json!({"name": "b"}),
277 Duration::from_secs(60),
278 );
279 cache.set(
280 "list_users",
281 &json!({}),
282 Some("user:1"),
283 json!([]),
284 Duration::from_secs(60),
285 );
286
287 cache.invalidate_function("get_user");
288
289 assert_eq!(cache.get("get_user", &args1, Some("user:1")), None);
290 assert_eq!(cache.get("get_user", &args2, Some("user:1")), None);
291 assert!(
292 cache
293 .get("list_users", &json!({}), Some("user:1"))
294 .is_some()
295 );
296 }
297
298 #[test]
299 fn test_hash_consistency() {
300 let v1 = json!({"a": 1, "b": 2});
301 let v2 = json!({"b": 2, "a": 1});
302
303 assert_eq!(hash_value(&v1), hash_value(&v2));
305 }
306
307 #[test]
308 fn test_cache_isolation_by_auth_scope() {
309 let cache = QueryCache::new();
310 let args = json!({"id": 1});
311
312 cache.set(
313 "get_profile",
314 &args,
315 Some("subject:user-a"),
316 json!({"name": "Alice"}),
317 Duration::from_secs(60),
318 );
319
320 assert!(
321 cache
322 .get("get_profile", &args, Some("subject:user-b"))
323 .is_none()
324 );
325 assert!(
326 cache
327 .get("get_profile", &args, Some("subject:user-a"))
328 .is_some()
329 );
330 }
331}