1use async_graphql::Response;
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::hash::Hash;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11
12#[derive(Clone, Debug, Eq, PartialEq, Hash)]
14pub struct CacheKey {
15 pub operation_name: String,
17 pub query: String,
19 pub variables: String,
21}
22
23impl CacheKey {
24 pub fn new(operation_name: String, query: String, variables: String) -> Self {
26 Self {
27 operation_name,
28 query,
29 variables,
30 }
31 }
32
33 pub fn from_request(
35 operation_name: Option<String>,
36 query: String,
37 variables: serde_json::Value,
38 ) -> Self {
39 Self {
40 operation_name: operation_name.unwrap_or_default(),
41 query,
42 variables: variables.to_string(),
43 }
44 }
45}
46
47pub struct CachedResponse {
49 pub data: serde_json::Value,
51 pub errors: Vec<CachedError>,
53 pub extensions: Option<serde_json::Value>,
55 pub cached_at: Instant,
57 pub hit_count: usize,
59}
60
61#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
63pub struct CachedError {
64 pub message: String,
66 pub locations: Vec<CachedErrorLocation>,
68 pub path: Option<Vec<serde_json::Value>>,
70 pub extensions: Option<serde_json::Value>,
72}
73
74#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
76pub struct CachedErrorLocation {
77 pub line: usize,
79 pub column: usize,
81}
82
83impl CachedResponse {
84 pub fn to_response(&self) -> Response {
86 let graphql_value = json_to_graphql_value(&self.data);
88 let mut response = Response::new(graphql_value);
89
90 for cached_error in &self.errors {
92 let mut server_error =
93 async_graphql::ServerError::new(cached_error.message.clone(), None);
94
95 server_error.locations = cached_error
97 .locations
98 .iter()
99 .map(|loc| async_graphql::Pos {
100 line: loc.line,
101 column: loc.column,
102 })
103 .collect();
104
105 if let Some(path) = &cached_error.path {
107 server_error.path = path
108 .iter()
109 .filter_map(|v| match v {
110 serde_json::Value::String(s) => {
111 Some(async_graphql::PathSegment::Field(s.clone()))
112 }
113 serde_json::Value::Number(n) => {
114 n.as_u64().map(|i| async_graphql::PathSegment::Index(i as usize))
115 }
116 _ => None,
117 })
118 .collect();
119 }
120
121 response.errors.push(server_error);
122 }
123
124 if let Some(ext) = &self.extensions {
126 if let serde_json::Value::Object(map) = ext {
127 for (key, value) in map {
128 response.extensions.insert(key.clone(), json_to_graphql_value(value));
129 }
130 }
131 }
132
133 response
134 }
135
136 pub fn from_response(response: &Response) -> Self {
138 let data = graphql_value_to_json(&response.data);
140
141 let errors: Vec<CachedError> = response
143 .errors
144 .iter()
145 .map(|e| CachedError {
146 message: e.message.clone(),
147 locations: e
148 .locations
149 .iter()
150 .map(|loc| CachedErrorLocation {
151 line: loc.line,
152 column: loc.column,
153 })
154 .collect(),
155 path: if e.path.is_empty() {
156 None
157 } else {
158 Some(
159 e.path
160 .iter()
161 .map(|seg| match seg {
162 async_graphql::PathSegment::Field(s) => {
163 serde_json::Value::String(s.clone())
164 }
165 async_graphql::PathSegment::Index(i) => {
166 serde_json::Value::Number((*i as u64).into())
167 }
168 })
169 .collect(),
170 )
171 },
172 extensions: None, })
174 .collect();
175
176 let extensions = if response.extensions.is_empty() {
178 None
179 } else {
180 let mut map = serde_json::Map::new();
181 for (key, value) in &response.extensions {
182 map.insert(key.clone(), graphql_value_to_json(value));
183 }
184 Some(serde_json::Value::Object(map))
185 };
186
187 Self {
188 data,
189 errors,
190 extensions,
191 cached_at: Instant::now(),
192 hit_count: 0,
193 }
194 }
195}
196
197fn graphql_value_to_json(value: &async_graphql::Value) -> serde_json::Value {
199 match value {
200 async_graphql::Value::Null => serde_json::Value::Null,
201 async_graphql::Value::Number(n) => {
202 if let Some(i) = n.as_i64() {
203 serde_json::Value::Number(i.into())
204 } else if let Some(u) = n.as_u64() {
205 serde_json::Value::Number(u.into())
206 } else if let Some(f) = n.as_f64() {
207 serde_json::Number::from_f64(f)
208 .map(serde_json::Value::Number)
209 .unwrap_or(serde_json::Value::Null)
210 } else {
211 serde_json::Value::Null
212 }
213 }
214 async_graphql::Value::String(s) => serde_json::Value::String(s.clone()),
215 async_graphql::Value::Boolean(b) => serde_json::Value::Bool(*b),
216 async_graphql::Value::List(arr) => {
217 serde_json::Value::Array(arr.iter().map(graphql_value_to_json).collect())
218 }
219 async_graphql::Value::Object(obj) => {
220 let map: serde_json::Map<String, serde_json::Value> =
221 obj.iter().map(|(k, v)| (k.to_string(), graphql_value_to_json(v))).collect();
222 serde_json::Value::Object(map)
223 }
224 async_graphql::Value::Enum(e) => serde_json::Value::String(e.to_string()),
225 async_graphql::Value::Binary(b) => {
226 use base64::Engine;
227 serde_json::Value::String(base64::engine::general_purpose::STANDARD.encode(b))
228 }
229 }
230}
231
232fn json_to_graphql_value(value: &serde_json::Value) -> async_graphql::Value {
234 match value {
235 serde_json::Value::Null => async_graphql::Value::Null,
236 serde_json::Value::Bool(b) => async_graphql::Value::Boolean(*b),
237 serde_json::Value::Number(n) => {
238 if let Some(i) = n.as_i64() {
239 async_graphql::Value::Number(i.into())
240 } else if let Some(u) = n.as_u64() {
241 async_graphql::Value::Number(u.into())
242 } else if let Some(f) = n.as_f64() {
243 async_graphql::Value::Number(
244 async_graphql::Number::from_f64(f).unwrap_or_else(|| 0.into()),
245 )
246 } else {
247 async_graphql::Value::Null
248 }
249 }
250 serde_json::Value::String(s) => async_graphql::Value::String(s.clone()),
251 serde_json::Value::Array(arr) => {
252 async_graphql::Value::List(arr.iter().map(json_to_graphql_value).collect())
253 }
254 serde_json::Value::Object(obj) => {
255 let map: async_graphql::indexmap::IndexMap<async_graphql::Name, async_graphql::Value> =
256 obj.iter()
257 .filter_map(|(k, v)| {
258 let is_valid =
260 k.chars().next().is_some_and(|c| c == '_' || c.is_ascii_alphabetic())
261 && k.chars().all(|c| c == '_' || c.is_ascii_alphanumeric());
262 if is_valid {
263 Some((async_graphql::Name::new(k), json_to_graphql_value(v)))
264 } else {
265 None
266 }
267 })
268 .collect();
269 async_graphql::Value::Object(map)
270 }
271 }
272}
273
274#[derive(Clone, Debug)]
276pub struct CacheConfig {
277 pub max_entries: usize,
279 pub ttl: Duration,
281 pub enable_stats: bool,
283}
284
285impl Default for CacheConfig {
286 fn default() -> Self {
287 Self {
288 max_entries: 1000,
289 ttl: Duration::from_secs(300), enable_stats: true,
291 }
292 }
293}
294
295#[derive(Clone, Debug, Default)]
297pub struct CacheStats {
298 pub hits: u64,
300 pub misses: u64,
302 pub evictions: u64,
304 pub size: usize,
306}
307
308impl CacheStats {
309 pub fn hit_rate(&self) -> f64 {
311 let total = self.hits + self.misses;
312 if total == 0 {
313 0.0
314 } else {
315 (self.hits as f64 / total as f64) * 100.0
316 }
317 }
318}
319
320pub struct ResponseCache {
322 cache: Arc<RwLock<HashMap<CacheKey, CachedResponse>>>,
324 config: CacheConfig,
326 stats: Arc<RwLock<CacheStats>>,
328}
329
330impl ResponseCache {
331 pub fn new(config: CacheConfig) -> Self {
333 Self {
334 cache: Arc::new(RwLock::new(HashMap::new())),
335 config,
336 stats: Arc::new(RwLock::new(CacheStats::default())),
337 }
338 }
339
340 pub fn default() -> Self {
342 Self::new(CacheConfig::default())
343 }
344
345 pub fn get(&self, key: &CacheKey) -> Option<Response> {
347 let mut cache = self.cache.write();
348
349 if let Some(cached) = cache.get_mut(key) {
350 if cached.cached_at.elapsed() > self.config.ttl {
352 cache.remove(key);
353 self.record_miss();
354 return None;
355 }
356
357 cached.hit_count += 1;
359 self.record_hit();
360
361 Some(cached.to_response())
363 } else {
364 self.record_miss();
365 None
366 }
367 }
368
369 pub fn put(&self, key: CacheKey, response: Response) {
371 let mut cache = self.cache.write();
372
373 if cache.len() >= self.config.max_entries {
375 if let Some(oldest_key) = self.find_oldest_key(&cache) {
376 cache.remove(&oldest_key);
377 self.record_eviction();
378 }
379 }
380
381 let cached_response = CachedResponse::from_response(&response);
383
384 cache.insert(key, cached_response);
385
386 self.update_size(cache.len());
387 }
388
389 pub fn clear(&self) {
391 let mut cache = self.cache.write();
392 cache.clear();
393 self.update_size(0);
394 }
395
396 pub fn clear_expired(&self) {
398 let mut cache = self.cache.write();
399 let ttl = self.config.ttl;
400
401 cache.retain(|_, cached| cached.cached_at.elapsed() <= ttl);
402 self.update_size(cache.len());
403 }
404
405 pub fn stats(&self) -> CacheStats {
407 self.stats.read().clone()
408 }
409
410 pub fn reset_stats(&self) {
412 let mut stats = self.stats.write();
413 *stats = CacheStats::default();
414 }
415
416 fn find_oldest_key(&self, cache: &HashMap<CacheKey, CachedResponse>) -> Option<CacheKey> {
419 cache
420 .iter()
421 .min_by_key(|(_, cached)| cached.cached_at)
422 .map(|(key, _)| key.clone())
423 }
424
425 fn record_hit(&self) {
426 if self.config.enable_stats {
427 let mut stats = self.stats.write();
428 stats.hits += 1;
429 }
430 }
431
432 fn record_miss(&self) {
433 if self.config.enable_stats {
434 let mut stats = self.stats.write();
435 stats.misses += 1;
436 }
437 }
438
439 fn record_eviction(&self) {
440 if self.config.enable_stats {
441 let mut stats = self.stats.write();
442 stats.evictions += 1;
443 }
444 }
445
446 fn update_size(&self, size: usize) {
447 if self.config.enable_stats {
448 let mut stats = self.stats.write();
449 stats.size = size;
450 }
451 }
452}
453
454pub struct CacheMiddleware {
456 cache: Arc<ResponseCache>,
457 cacheable_operations: Option<Vec<String>>,
459}
460
461impl CacheMiddleware {
462 pub fn new(cache: Arc<ResponseCache>) -> Self {
464 Self {
465 cache,
466 cacheable_operations: None,
467 }
468 }
469
470 pub fn with_operations(mut self, operations: Vec<String>) -> Self {
472 self.cacheable_operations = Some(operations);
473 self
474 }
475
476 pub fn should_cache(&self, operation_name: Option<&str>) -> bool {
478 match &self.cacheable_operations {
479 None => true, Some(ops) => {
481 operation_name.map(|name| ops.contains(&name.to_string())).unwrap_or(false)
482 }
483 }
484 }
485
486 pub fn get_cached(&self, key: &CacheKey) -> Option<Response> {
488 self.cache.get(key)
489 }
490
491 pub fn cache_response(&self, key: CacheKey, response: Response) {
493 self.cache.put(key, response);
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500 use async_graphql::Value;
501
502 #[test]
503 fn test_cache_key_creation() {
504 let key = CacheKey::new(
505 "getUser".to_string(),
506 "query { user { id } }".to_string(),
507 r#"{"id": "123"}"#.to_string(),
508 );
509
510 assert_eq!(key.operation_name, "getUser");
511 }
512
513 #[test]
514 fn test_cache_key_from_request() {
515 let key = CacheKey::from_request(
516 Some("getUser".to_string()),
517 "query { user { id } }".to_string(),
518 serde_json::json!({"id": "123"}),
519 );
520
521 assert_eq!(key.operation_name, "getUser");
522 assert!(key.variables.contains("123"));
523 }
524
525 #[test]
526 fn test_cache_config_default() {
527 let config = CacheConfig::default();
528 assert_eq!(config.max_entries, 1000);
529 assert_eq!(config.ttl, Duration::from_secs(300));
530 assert!(config.enable_stats);
531 }
532
533 #[test]
534 fn test_cache_stats_hit_rate() {
535 let mut stats = CacheStats::default();
536 stats.hits = 80;
537 stats.misses = 20;
538
539 assert_eq!(stats.hit_rate(), 80.0);
540 }
541
542 #[test]
543 fn test_cache_put_and_get() {
544 let cache = ResponseCache::default();
545 let key = CacheKey::new("test".to_string(), "query".to_string(), "{}".to_string());
546 let response = Response::new(Value::Null);
547
548 cache.put(key.clone(), response);
549 let cached = cache.get(&key);
550
551 assert!(cached.is_some());
552 }
553
554 #[test]
555 fn test_cache_miss() {
556 let cache = ResponseCache::default();
557 let key = CacheKey::new("nonexistent".to_string(), "query".to_string(), "{}".to_string());
558
559 let cached = cache.get(&key);
560 assert!(cached.is_none());
561 }
562
563 #[test]
564 fn test_cache_clear() {
565 let cache = ResponseCache::default();
566 let key = CacheKey::new("test".to_string(), "query".to_string(), "{}".to_string());
567 let response = Response::new(Value::Null);
568
569 cache.put(key.clone(), response);
570 assert!(cache.get(&key).is_some());
571
572 cache.clear();
573 assert!(cache.get(&key).is_none());
574 }
575
576 #[test]
577 fn test_cache_stats() {
578 let cache = ResponseCache::default();
579 let key = CacheKey::new("test".to_string(), "query".to_string(), "{}".to_string());
580
581 let _ = cache.get(&key);
583
584 let response = Response::new(Value::Null);
586 cache.put(key.clone(), response);
587 let _ = cache.get(&key);
588
589 let stats = cache.stats();
590 assert_eq!(stats.hits, 1);
591 assert_eq!(stats.misses, 1);
592 assert_eq!(stats.size, 1);
593 }
594
595 #[test]
596 fn test_cache_middleware_should_cache() {
597 let cache = Arc::new(ResponseCache::default());
598 let middleware = CacheMiddleware::new(cache);
599
600 assert!(middleware.should_cache(Some("getUser")));
601 assert!(middleware.should_cache(None));
602 }
603
604 #[test]
605 fn test_cache_middleware_with_specific_operations() {
606 let cache = Arc::new(ResponseCache::default());
607 let middleware = CacheMiddleware::new(cache)
608 .with_operations(vec!["getUser".to_string(), "getProduct".to_string()]);
609
610 assert!(middleware.should_cache(Some("getUser")));
611 assert!(!middleware.should_cache(Some("createUser")));
612 }
613
614 #[test]
615 fn test_cache_eviction() {
616 let config = CacheConfig {
617 max_entries: 2,
618 ttl: Duration::from_secs(300),
619 enable_stats: true,
620 };
621 let cache = ResponseCache::new(config);
622
623 for i in 0..3 {
625 let key = CacheKey::new(format!("op{}", i), "query".to_string(), "{}".to_string());
626 cache.put(key, Response::new(Value::Null));
627 std::thread::sleep(Duration::from_millis(10)); }
629
630 let stats = cache.stats();
631 assert_eq!(stats.size, 2);
632 assert_eq!(stats.evictions, 1);
633 }
634}