1use async_graphql::Response;
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::hash::{Hash, Hasher};
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11
12#[derive(Clone, Debug, Eq, PartialEq, Hash)]
14pub struct CacheKey {
15 pub operation_name: String,
17 pub query: String,
19 pub variables: String,
21}
22
23impl CacheKey {
24 pub fn new(operation_name: String, query: String, variables: String) -> Self {
26 Self {
27 operation_name,
28 query,
29 variables,
30 }
31 }
32
33 pub fn from_request(
35 operation_name: Option<String>,
36 query: String,
37 variables: serde_json::Value,
38 ) -> Self {
39 Self {
40 operation_name: operation_name.unwrap_or_default(),
41 query,
42 variables: variables.to_string(),
43 }
44 }
45}
46
47pub struct CachedResponse {
49 pub data: serde_json::Value,
51 pub errors: Vec<serde_json::Value>,
53 pub cached_at: Instant,
55 pub hit_count: usize,
57}
58
59impl CachedResponse {
60 pub fn to_response(&self) -> Response {
62 let mut response = Response::new(async_graphql::Value::Null);
64
65 if !self.data.is_null() {
67 response = Response::new(async_graphql::Value::Null);
70 }
71
72 response
73 }
74
75 pub fn from_response(response: &Response) -> Self {
77 let data = serde_json::Value::Null;
80 let errors = Vec::new();
81
82 Self {
83 data,
84 errors,
85 cached_at: Instant::now(),
86 hit_count: 0,
87 }
88 }
89}
90
91#[derive(Clone, Debug)]
93pub struct CacheConfig {
94 pub max_entries: usize,
96 pub ttl: Duration,
98 pub enable_stats: bool,
100}
101
102impl Default for CacheConfig {
103 fn default() -> Self {
104 Self {
105 max_entries: 1000,
106 ttl: Duration::from_secs(300), enable_stats: true,
108 }
109 }
110}
111
112#[derive(Clone, Debug, Default)]
114pub struct CacheStats {
115 pub hits: u64,
117 pub misses: u64,
119 pub evictions: u64,
121 pub size: usize,
123}
124
125impl CacheStats {
126 pub fn hit_rate(&self) -> f64 {
128 let total = self.hits + self.misses;
129 if total == 0 {
130 0.0
131 } else {
132 (self.hits as f64 / total as f64) * 100.0
133 }
134 }
135}
136
137pub struct ResponseCache {
139 cache: Arc<RwLock<HashMap<CacheKey, CachedResponse>>>,
141 config: CacheConfig,
143 stats: Arc<RwLock<CacheStats>>,
145}
146
147impl ResponseCache {
148 pub fn new(config: CacheConfig) -> Self {
150 Self {
151 cache: Arc::new(RwLock::new(HashMap::new())),
152 config,
153 stats: Arc::new(RwLock::new(CacheStats::default())),
154 }
155 }
156
157 pub fn default() -> Self {
159 Self::new(CacheConfig::default())
160 }
161
162 pub fn get(&self, key: &CacheKey) -> Option<Response> {
164 let mut cache = self.cache.write();
165
166 if let Some(cached) = cache.get_mut(key) {
167 if cached.cached_at.elapsed() > self.config.ttl {
169 cache.remove(key);
170 self.record_miss();
171 return None;
172 }
173
174 cached.hit_count += 1;
176 self.record_hit();
177
178 Some(cached.to_response())
180 } else {
181 self.record_miss();
182 None
183 }
184 }
185
186 pub fn put(&self, key: CacheKey, response: Response) {
188 let mut cache = self.cache.write();
189
190 if cache.len() >= self.config.max_entries {
192 if let Some(oldest_key) = self.find_oldest_key(&cache) {
193 cache.remove(&oldest_key);
194 self.record_eviction();
195 }
196 }
197
198 let cached_response = CachedResponse::from_response(&response);
200
201 cache.insert(key, cached_response);
202
203 self.update_size(cache.len());
204 }
205
206 pub fn clear(&self) {
208 let mut cache = self.cache.write();
209 cache.clear();
210 self.update_size(0);
211 }
212
213 pub fn clear_expired(&self) {
215 let mut cache = self.cache.write();
216 let ttl = self.config.ttl;
217
218 cache.retain(|_, cached| cached.cached_at.elapsed() <= ttl);
219 self.update_size(cache.len());
220 }
221
222 pub fn stats(&self) -> CacheStats {
224 self.stats.read().clone()
225 }
226
227 pub fn reset_stats(&self) {
229 let mut stats = self.stats.write();
230 *stats = CacheStats::default();
231 }
232
233 fn find_oldest_key(&self, cache: &HashMap<CacheKey, CachedResponse>) -> Option<CacheKey> {
236 cache
237 .iter()
238 .min_by_key(|(_, cached)| cached.cached_at)
239 .map(|(key, _)| key.clone())
240 }
241
242 fn record_hit(&self) {
243 if self.config.enable_stats {
244 let mut stats = self.stats.write();
245 stats.hits += 1;
246 }
247 }
248
249 fn record_miss(&self) {
250 if self.config.enable_stats {
251 let mut stats = self.stats.write();
252 stats.misses += 1;
253 }
254 }
255
256 fn record_eviction(&self) {
257 if self.config.enable_stats {
258 let mut stats = self.stats.write();
259 stats.evictions += 1;
260 }
261 }
262
263 fn update_size(&self, size: usize) {
264 if self.config.enable_stats {
265 let mut stats = self.stats.write();
266 stats.size = size;
267 }
268 }
269}
270
271pub struct CacheMiddleware {
273 cache: Arc<ResponseCache>,
274 cacheable_operations: Option<Vec<String>>,
276}
277
278impl CacheMiddleware {
279 pub fn new(cache: Arc<ResponseCache>) -> Self {
281 Self {
282 cache,
283 cacheable_operations: None,
284 }
285 }
286
287 pub fn with_operations(mut self, operations: Vec<String>) -> Self {
289 self.cacheable_operations = Some(operations);
290 self
291 }
292
293 pub fn should_cache(&self, operation_name: Option<&str>) -> bool {
295 match &self.cacheable_operations {
296 None => true, Some(ops) => {
298 operation_name.map(|name| ops.contains(&name.to_string())).unwrap_or(false)
299 }
300 }
301 }
302
303 pub fn get_cached(&self, key: &CacheKey) -> Option<Response> {
305 self.cache.get(key)
306 }
307
308 pub fn cache_response(&self, key: CacheKey, response: Response) {
310 self.cache.put(key, response);
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317 use async_graphql::Value;
318
319 #[test]
320 fn test_cache_key_creation() {
321 let key = CacheKey::new(
322 "getUser".to_string(),
323 "query { user { id } }".to_string(),
324 r#"{"id": "123"}"#.to_string(),
325 );
326
327 assert_eq!(key.operation_name, "getUser");
328 }
329
330 #[test]
331 fn test_cache_key_from_request() {
332 let key = CacheKey::from_request(
333 Some("getUser".to_string()),
334 "query { user { id } }".to_string(),
335 serde_json::json!({"id": "123"}),
336 );
337
338 assert_eq!(key.operation_name, "getUser");
339 assert!(key.variables.contains("123"));
340 }
341
342 #[test]
343 fn test_cache_config_default() {
344 let config = CacheConfig::default();
345 assert_eq!(config.max_entries, 1000);
346 assert_eq!(config.ttl, Duration::from_secs(300));
347 assert!(config.enable_stats);
348 }
349
350 #[test]
351 fn test_cache_stats_hit_rate() {
352 let mut stats = CacheStats::default();
353 stats.hits = 80;
354 stats.misses = 20;
355
356 assert_eq!(stats.hit_rate(), 80.0);
357 }
358
359 #[test]
360 fn test_cache_put_and_get() {
361 let cache = ResponseCache::default();
362 let key = CacheKey::new("test".to_string(), "query".to_string(), "{}".to_string());
363 let response = Response::new(Value::Null);
364
365 cache.put(key.clone(), response);
366 let cached = cache.get(&key);
367
368 assert!(cached.is_some());
369 }
370
371 #[test]
372 fn test_cache_miss() {
373 let cache = ResponseCache::default();
374 let key = CacheKey::new("nonexistent".to_string(), "query".to_string(), "{}".to_string());
375
376 let cached = cache.get(&key);
377 assert!(cached.is_none());
378 }
379
380 #[test]
381 fn test_cache_clear() {
382 let cache = ResponseCache::default();
383 let key = CacheKey::new("test".to_string(), "query".to_string(), "{}".to_string());
384 let response = Response::new(Value::Null);
385
386 cache.put(key.clone(), response);
387 assert!(cache.get(&key).is_some());
388
389 cache.clear();
390 assert!(cache.get(&key).is_none());
391 }
392
393 #[test]
394 fn test_cache_stats() {
395 let cache = ResponseCache::default();
396 let key = CacheKey::new("test".to_string(), "query".to_string(), "{}".to_string());
397
398 let _ = cache.get(&key);
400
401 let response = Response::new(Value::Null);
403 cache.put(key.clone(), response);
404 let _ = cache.get(&key);
405
406 let stats = cache.stats();
407 assert_eq!(stats.hits, 1);
408 assert_eq!(stats.misses, 1);
409 assert_eq!(stats.size, 1);
410 }
411
412 #[test]
413 fn test_cache_middleware_should_cache() {
414 let cache = Arc::new(ResponseCache::default());
415 let middleware = CacheMiddleware::new(cache);
416
417 assert!(middleware.should_cache(Some("getUser")));
418 assert!(middleware.should_cache(None));
419 }
420
421 #[test]
422 fn test_cache_middleware_with_specific_operations() {
423 let cache = Arc::new(ResponseCache::default());
424 let middleware = CacheMiddleware::new(cache)
425 .with_operations(vec!["getUser".to_string(), "getProduct".to_string()]);
426
427 assert!(middleware.should_cache(Some("getUser")));
428 assert!(!middleware.should_cache(Some("createUser")));
429 }
430
431 #[test]
432 fn test_cache_eviction() {
433 let config = CacheConfig {
434 max_entries: 2,
435 ttl: Duration::from_secs(300),
436 enable_stats: true,
437 };
438 let cache = ResponseCache::new(config);
439
440 for i in 0..3 {
442 let key = CacheKey::new(format!("op{}", i), "query".to_string(), "{}".to_string());
443 cache.put(key, Response::new(Value::Null));
444 std::thread::sleep(Duration::from_millis(10)); }
446
447 let stats = cache.stats();
448 assert_eq!(stats.size, 2);
449 assert_eq!(stats.evictions, 1);
450 }
451}