1use async_graphql::Response;
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::hash::Hash;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11
12#[derive(Clone, Debug, Eq, PartialEq, Hash)]
14pub struct CacheKey {
15 pub operation_name: String,
17 pub query: String,
19 pub variables: String,
21}
22
23impl CacheKey {
24 pub fn new(operation_name: String, query: String, variables: String) -> Self {
26 Self {
27 operation_name,
28 query,
29 variables,
30 }
31 }
32
33 pub fn from_request(
35 operation_name: Option<String>,
36 query: String,
37 variables: serde_json::Value,
38 ) -> Self {
39 Self {
40 operation_name: operation_name.unwrap_or_default(),
41 query,
42 variables: variables.to_string(),
43 }
44 }
45}
46
47pub struct CachedResponse {
49 pub data: serde_json::Value,
51 pub errors: Vec<CachedError>,
53 pub extensions: Option<serde_json::Value>,
55 pub cached_at: Instant,
57 pub hit_count: usize,
59}
60
61#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
63pub struct CachedError {
64 pub message: String,
66 pub locations: Vec<CachedErrorLocation>,
68 pub path: Option<Vec<serde_json::Value>>,
70 pub extensions: Option<serde_json::Value>,
72}
73
74#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
76pub struct CachedErrorLocation {
77 pub line: usize,
79 pub column: usize,
81}
82
83impl CachedResponse {
84 pub fn to_response(&self) -> Response {
86 let graphql_value = json_to_graphql_value(&self.data);
88 let mut response = Response::new(graphql_value);
89
90 for cached_error in &self.errors {
92 let mut server_error =
93 async_graphql::ServerError::new(cached_error.message.clone(), None);
94
95 server_error.locations = cached_error
97 .locations
98 .iter()
99 .map(|loc| async_graphql::Pos {
100 line: loc.line,
101 column: loc.column,
102 })
103 .collect();
104
105 if let Some(path) = &cached_error.path {
107 server_error.path = path
108 .iter()
109 .filter_map(|v| match v {
110 serde_json::Value::String(s) => {
111 Some(async_graphql::PathSegment::Field(s.clone()))
112 }
113 serde_json::Value::Number(n) => {
114 n.as_u64().map(|i| async_graphql::PathSegment::Index(i as usize))
115 }
116 _ => None,
117 })
118 .collect();
119 }
120
121 response.errors.push(server_error);
122 }
123
124 if let Some(serde_json::Value::Object(map)) = &self.extensions {
126 for (key, value) in map {
127 response.extensions.insert(key.clone(), json_to_graphql_value(value));
128 }
129 }
130
131 response
132 }
133
134 pub fn from_response(response: &Response) -> Self {
136 let data = graphql_value_to_json(&response.data);
138
139 let errors: Vec<CachedError> = response
141 .errors
142 .iter()
143 .map(|e| CachedError {
144 message: e.message.clone(),
145 locations: e
146 .locations
147 .iter()
148 .map(|loc| CachedErrorLocation {
149 line: loc.line,
150 column: loc.column,
151 })
152 .collect(),
153 path: if e.path.is_empty() {
154 None
155 } else {
156 Some(
157 e.path
158 .iter()
159 .map(|seg| match seg {
160 async_graphql::PathSegment::Field(s) => {
161 serde_json::Value::String(s.clone())
162 }
163 async_graphql::PathSegment::Index(i) => {
164 serde_json::Value::Number((*i as u64).into())
165 }
166 })
167 .collect(),
168 )
169 },
170 extensions: None, })
172 .collect();
173
174 let extensions = if response.extensions.is_empty() {
176 None
177 } else {
178 let mut map = serde_json::Map::new();
179 for (key, value) in &response.extensions {
180 map.insert(key.clone(), graphql_value_to_json(value));
181 }
182 Some(serde_json::Value::Object(map))
183 };
184
185 Self {
186 data,
187 errors,
188 extensions,
189 cached_at: Instant::now(),
190 hit_count: 0,
191 }
192 }
193}
194
195fn graphql_value_to_json(value: &async_graphql::Value) -> serde_json::Value {
197 match value {
198 async_graphql::Value::Null => serde_json::Value::Null,
199 async_graphql::Value::Number(n) => {
200 if let Some(i) = n.as_i64() {
201 serde_json::Value::Number(i.into())
202 } else if let Some(u) = n.as_u64() {
203 serde_json::Value::Number(u.into())
204 } else if let Some(f) = n.as_f64() {
205 serde_json::Number::from_f64(f)
206 .map(serde_json::Value::Number)
207 .unwrap_or(serde_json::Value::Null)
208 } else {
209 serde_json::Value::Null
210 }
211 }
212 async_graphql::Value::String(s) => serde_json::Value::String(s.clone()),
213 async_graphql::Value::Boolean(b) => serde_json::Value::Bool(*b),
214 async_graphql::Value::List(arr) => {
215 serde_json::Value::Array(arr.iter().map(graphql_value_to_json).collect())
216 }
217 async_graphql::Value::Object(obj) => {
218 let map: serde_json::Map<String, serde_json::Value> =
219 obj.iter().map(|(k, v)| (k.to_string(), graphql_value_to_json(v))).collect();
220 serde_json::Value::Object(map)
221 }
222 async_graphql::Value::Enum(e) => serde_json::Value::String(e.to_string()),
223 async_graphql::Value::Binary(b) => {
224 use base64::Engine;
225 serde_json::Value::String(base64::engine::general_purpose::STANDARD.encode(b))
226 }
227 }
228}
229
230fn json_to_graphql_value(value: &serde_json::Value) -> async_graphql::Value {
232 match value {
233 serde_json::Value::Null => async_graphql::Value::Null,
234 serde_json::Value::Bool(b) => async_graphql::Value::Boolean(*b),
235 serde_json::Value::Number(n) => {
236 if let Some(i) = n.as_i64() {
237 async_graphql::Value::Number(i.into())
238 } else if let Some(u) = n.as_u64() {
239 async_graphql::Value::Number(u.into())
240 } else if let Some(f) = n.as_f64() {
241 async_graphql::Value::Number(
242 async_graphql::Number::from_f64(f).unwrap_or_else(|| 0.into()),
243 )
244 } else {
245 async_graphql::Value::Null
246 }
247 }
248 serde_json::Value::String(s) => async_graphql::Value::String(s.clone()),
249 serde_json::Value::Array(arr) => {
250 async_graphql::Value::List(arr.iter().map(json_to_graphql_value).collect())
251 }
252 serde_json::Value::Object(obj) => {
253 let map: indexmap::IndexMap<async_graphql::Name, async_graphql::Value> = obj
254 .iter()
255 .filter_map(|(k, v)| {
256 let is_valid =
258 k.chars().next().is_some_and(|c| c == '_' || c.is_ascii_alphabetic())
259 && k.chars().all(|c| c == '_' || c.is_ascii_alphanumeric());
260 if is_valid {
261 Some((async_graphql::Name::new(k), json_to_graphql_value(v)))
262 } else {
263 None
264 }
265 })
266 .collect();
267 async_graphql::Value::Object(map)
268 }
269 }
270}
271
272#[derive(Clone, Debug)]
274pub struct CacheConfig {
275 pub max_entries: usize,
277 pub ttl: Duration,
279 pub enable_stats: bool,
281}
282
283impl Default for CacheConfig {
284 fn default() -> Self {
285 Self {
286 max_entries: 1000,
287 ttl: Duration::from_secs(300), enable_stats: true,
289 }
290 }
291}
292
293#[derive(Clone, Debug, Default)]
295pub struct CacheStats {
296 pub hits: u64,
298 pub misses: u64,
300 pub evictions: u64,
302 pub size: usize,
304}
305
306impl CacheStats {
307 pub fn hit_rate(&self) -> f64 {
309 let total = self.hits + self.misses;
310 if total == 0 {
311 0.0
312 } else {
313 (self.hits as f64 / total as f64) * 100.0
314 }
315 }
316}
317
318pub struct ResponseCache {
320 cache: Arc<RwLock<HashMap<CacheKey, CachedResponse>>>,
322 config: CacheConfig,
324 stats: Arc<RwLock<CacheStats>>,
326}
327
328impl ResponseCache {
329 pub fn new(config: CacheConfig) -> Self {
331 Self {
332 cache: Arc::new(RwLock::new(HashMap::new())),
333 config,
334 stats: Arc::new(RwLock::new(CacheStats::default())),
335 }
336 }
337
338 pub fn get(&self, key: &CacheKey) -> Option<Response> {
340 let mut cache = self.cache.write();
341
342 if let Some(cached) = cache.get_mut(key) {
343 if cached.cached_at.elapsed() > self.config.ttl {
345 cache.remove(key);
346 self.record_miss();
347 return None;
348 }
349
350 cached.hit_count += 1;
352 self.record_hit();
353
354 Some(cached.to_response())
356 } else {
357 self.record_miss();
358 None
359 }
360 }
361
362 pub fn put(&self, key: CacheKey, response: Response) {
364 let mut cache = self.cache.write();
365
366 if cache.len() >= self.config.max_entries {
368 if let Some(oldest_key) = self.find_oldest_key(&cache) {
369 cache.remove(&oldest_key);
370 self.record_eviction();
371 }
372 }
373
374 let cached_response = CachedResponse::from_response(&response);
376
377 cache.insert(key, cached_response);
378
379 self.update_size(cache.len());
380 }
381
382 pub fn clear(&self) {
384 let mut cache = self.cache.write();
385 cache.clear();
386 self.update_size(0);
387 }
388
389 pub fn clear_expired(&self) {
391 let mut cache = self.cache.write();
392 let ttl = self.config.ttl;
393
394 cache.retain(|_, cached| cached.cached_at.elapsed() <= ttl);
395 self.update_size(cache.len());
396 }
397
398 pub fn stats(&self) -> CacheStats {
400 self.stats.read().clone()
401 }
402
403 pub fn reset_stats(&self) {
405 let mut stats = self.stats.write();
406 *stats = CacheStats::default();
407 }
408
409 fn find_oldest_key(&self, cache: &HashMap<CacheKey, CachedResponse>) -> Option<CacheKey> {
412 cache
413 .iter()
414 .min_by_key(|(_, cached)| cached.cached_at)
415 .map(|(key, _)| key.clone())
416 }
417
418 fn record_hit(&self) {
419 if self.config.enable_stats {
420 let mut stats = self.stats.write();
421 stats.hits += 1;
422 }
423 }
424
425 fn record_miss(&self) {
426 if self.config.enable_stats {
427 let mut stats = self.stats.write();
428 stats.misses += 1;
429 }
430 }
431
432 fn record_eviction(&self) {
433 if self.config.enable_stats {
434 let mut stats = self.stats.write();
435 stats.evictions += 1;
436 }
437 }
438
439 fn update_size(&self, size: usize) {
440 if self.config.enable_stats {
441 let mut stats = self.stats.write();
442 stats.size = size;
443 }
444 }
445}
446
447impl Default for ResponseCache {
448 fn default() -> Self {
449 Self::new(CacheConfig::default())
450 }
451}
452
453pub struct CacheMiddleware {
455 cache: Arc<ResponseCache>,
456 cacheable_operations: Option<Vec<String>>,
458}
459
460impl CacheMiddleware {
461 pub fn new(cache: Arc<ResponseCache>) -> Self {
463 Self {
464 cache,
465 cacheable_operations: None,
466 }
467 }
468
469 pub fn with_operations(mut self, operations: Vec<String>) -> Self {
471 self.cacheable_operations = Some(operations);
472 self
473 }
474
475 pub fn should_cache(&self, operation_name: Option<&str>) -> bool {
477 match &self.cacheable_operations {
478 None => true, Some(ops) => {
480 operation_name.map(|name| ops.contains(&name.to_string())).unwrap_or(false)
481 }
482 }
483 }
484
485 pub fn get_cached(&self, key: &CacheKey) -> Option<Response> {
487 self.cache.get(key)
488 }
489
490 pub fn cache_response(&self, key: CacheKey, response: Response) {
492 self.cache.put(key, response);
493 }
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499 use async_graphql::Value;
500
501 #[test]
502 fn test_cache_key_creation() {
503 let key = CacheKey::new(
504 "getUser".to_string(),
505 "query { user { id } }".to_string(),
506 r#"{"id": "123"}"#.to_string(),
507 );
508
509 assert_eq!(key.operation_name, "getUser");
510 }
511
512 #[test]
513 fn test_cache_key_from_request() {
514 let key = CacheKey::from_request(
515 Some("getUser".to_string()),
516 "query { user { id } }".to_string(),
517 serde_json::json!({"id": "123"}),
518 );
519
520 assert_eq!(key.operation_name, "getUser");
521 assert!(key.variables.contains("123"));
522 }
523
524 #[test]
525 fn test_cache_config_default() {
526 let config = CacheConfig::default();
527 assert_eq!(config.max_entries, 1000);
528 assert_eq!(config.ttl, Duration::from_secs(300));
529 assert!(config.enable_stats);
530 }
531
532 #[test]
533 fn test_cache_stats_hit_rate() {
534 let stats = CacheStats {
535 hits: 80,
536 misses: 20,
537 ..Default::default()
538 };
539
540 assert_eq!(stats.hit_rate(), 80.0);
541 }
542
543 #[test]
544 fn test_cache_put_and_get() {
545 let cache = ResponseCache::default();
546 let key = CacheKey::new("test".to_string(), "query".to_string(), "{}".to_string());
547 let response = Response::new(Value::Null);
548
549 cache.put(key.clone(), response);
550 let cached = cache.get(&key);
551
552 assert!(cached.is_some());
553 }
554
555 #[test]
556 fn test_cache_miss() {
557 let cache = ResponseCache::default();
558 let key = CacheKey::new("nonexistent".to_string(), "query".to_string(), "{}".to_string());
559
560 let cached = cache.get(&key);
561 assert!(cached.is_none());
562 }
563
564 #[test]
565 fn test_cache_clear() {
566 let cache = ResponseCache::default();
567 let key = CacheKey::new("test".to_string(), "query".to_string(), "{}".to_string());
568 let response = Response::new(Value::Null);
569
570 cache.put(key.clone(), response);
571 assert!(cache.get(&key).is_some());
572
573 cache.clear();
574 assert!(cache.get(&key).is_none());
575 }
576
577 #[test]
578 fn test_cache_stats() {
579 let cache = ResponseCache::default();
580 let key = CacheKey::new("test".to_string(), "query".to_string(), "{}".to_string());
581
582 let _ = cache.get(&key);
584
585 let response = Response::new(Value::Null);
587 cache.put(key.clone(), response);
588 let _ = cache.get(&key);
589
590 let stats = cache.stats();
591 assert_eq!(stats.hits, 1);
592 assert_eq!(stats.misses, 1);
593 assert_eq!(stats.size, 1);
594 }
595
596 #[test]
597 fn test_cache_middleware_should_cache() {
598 let cache = Arc::new(ResponseCache::default());
599 let middleware = CacheMiddleware::new(cache);
600
601 assert!(middleware.should_cache(Some("getUser")));
602 assert!(middleware.should_cache(None));
603 }
604
605 #[test]
606 fn test_cache_middleware_with_specific_operations() {
607 let cache = Arc::new(ResponseCache::default());
608 let middleware = CacheMiddleware::new(cache)
609 .with_operations(vec!["getUser".to_string(), "getProduct".to_string()]);
610
611 assert!(middleware.should_cache(Some("getUser")));
612 assert!(!middleware.should_cache(Some("createUser")));
613 }
614
615 #[test]
616 fn test_cache_eviction() {
617 let config = CacheConfig {
618 max_entries: 2,
619 ttl: Duration::from_secs(300),
620 enable_stats: true,
621 };
622 let cache = ResponseCache::new(config);
623
624 for i in 0..3 {
626 let key = CacheKey::new(format!("op{}", i), "query".to_string(), "{}".to_string());
627 cache.put(key, Response::new(Value::Null));
628 std::thread::sleep(Duration::from_millis(10)); }
630
631 let stats = cache.stats();
632 assert_eq!(stats.size, 2);
633 assert_eq!(stats.evictions, 1);
634 }
635}