mockforge_graphql/
cache.rs

1//! Response caching and memoization for GraphQL operations
2//!
3//! Provides intelligent caching of GraphQL responses to improve performance.
4
5use async_graphql::Response;
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::hash::Hash;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11
12/// Cache key for GraphQL operations
13#[derive(Clone, Debug, Eq, PartialEq, Hash)]
14pub struct CacheKey {
15    /// Operation name
16    pub operation_name: String,
17    /// Query string
18    pub query: String,
19    /// Variables as JSON string for hashing
20    pub variables: String,
21}
22
23impl CacheKey {
24    /// Create a new cache key
25    pub fn new(operation_name: String, query: String, variables: String) -> Self {
26        Self {
27            operation_name,
28            query,
29            variables,
30        }
31    }
32
33    /// Create from GraphQL request components
34    pub fn from_request(
35        operation_name: Option<String>,
36        query: String,
37        variables: serde_json::Value,
38    ) -> Self {
39        Self {
40            operation_name: operation_name.unwrap_or_default(),
41            query,
42            variables: variables.to_string(),
43        }
44    }
45}
46
47/// Cached response with metadata
48pub struct CachedResponse {
49    /// The GraphQL response data (as serde_json::Value for easy serialization)
50    pub data: serde_json::Value,
51    /// Any errors in the response
52    pub errors: Vec<CachedError>,
53    /// Extensions from the response
54    pub extensions: Option<serde_json::Value>,
55    /// When this was cached
56    pub cached_at: Instant,
57    /// Number of cache hits
58    pub hit_count: usize,
59}
60
61/// Cached error representation
62#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
63pub struct CachedError {
64    /// Error message
65    pub message: String,
66    /// Error locations in the query
67    pub locations: Vec<CachedErrorLocation>,
68    /// Path to the field that caused the error
69    pub path: Option<Vec<serde_json::Value>>,
70    /// Additional error extensions
71    pub extensions: Option<serde_json::Value>,
72}
73
74/// Error location in the query
75#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
76pub struct CachedErrorLocation {
77    /// Line number in the query (1-indexed)
78    pub line: usize,
79    /// Column number in the query (1-indexed)
80    pub column: usize,
81}
82
83impl CachedResponse {
84    /// Convert to GraphQL Response
85    pub fn to_response(&self) -> Response {
86        // Convert serde_json::Value back to async_graphql::Value
87        let graphql_value = json_to_graphql_value(&self.data);
88        let mut response = Response::new(graphql_value);
89
90        // Restore errors
91        for cached_error in &self.errors {
92            let mut server_error =
93                async_graphql::ServerError::new(cached_error.message.clone(), None);
94
95            // Restore locations
96            server_error.locations = cached_error
97                .locations
98                .iter()
99                .map(|loc| async_graphql::Pos {
100                    line: loc.line,
101                    column: loc.column,
102                })
103                .collect();
104
105            // Restore path
106            if let Some(path) = &cached_error.path {
107                server_error.path = path
108                    .iter()
109                    .filter_map(|v| match v {
110                        serde_json::Value::String(s) => {
111                            Some(async_graphql::PathSegment::Field(s.clone()))
112                        }
113                        serde_json::Value::Number(n) => {
114                            n.as_u64().map(|i| async_graphql::PathSegment::Index(i as usize))
115                        }
116                        _ => None,
117                    })
118                    .collect();
119            }
120
121            response.errors.push(server_error);
122        }
123
124        // Restore extensions
125        if let Some(ext) = &self.extensions {
126            if let serde_json::Value::Object(map) = ext {
127                for (key, value) in map {
128                    response.extensions.insert(key.clone(), json_to_graphql_value(value));
129                }
130            }
131        }
132
133        response
134    }
135
136    /// Create from GraphQL Response
137    pub fn from_response(response: &Response) -> Self {
138        // Convert async_graphql::Value to serde_json::Value
139        let data = graphql_value_to_json(&response.data);
140
141        // Convert errors
142        let errors: Vec<CachedError> = response
143            .errors
144            .iter()
145            .map(|e| CachedError {
146                message: e.message.clone(),
147                locations: e
148                    .locations
149                    .iter()
150                    .map(|loc| CachedErrorLocation {
151                        line: loc.line,
152                        column: loc.column,
153                    })
154                    .collect(),
155                path: if e.path.is_empty() {
156                    None
157                } else {
158                    Some(
159                        e.path
160                            .iter()
161                            .map(|seg| match seg {
162                                async_graphql::PathSegment::Field(s) => {
163                                    serde_json::Value::String(s.clone())
164                                }
165                                async_graphql::PathSegment::Index(i) => {
166                                    serde_json::Value::Number((*i as u64).into())
167                                }
168                            })
169                            .collect(),
170                    )
171                },
172                extensions: None, // ServerError extensions are not easily accessible
173            })
174            .collect();
175
176        // Convert extensions
177        let extensions = if response.extensions.is_empty() {
178            None
179        } else {
180            let mut map = serde_json::Map::new();
181            for (key, value) in &response.extensions {
182                map.insert(key.clone(), graphql_value_to_json(value));
183            }
184            Some(serde_json::Value::Object(map))
185        };
186
187        Self {
188            data,
189            errors,
190            extensions,
191            cached_at: Instant::now(),
192            hit_count: 0,
193        }
194    }
195}
196
197/// Convert async_graphql::Value to serde_json::Value
198fn graphql_value_to_json(value: &async_graphql::Value) -> serde_json::Value {
199    match value {
200        async_graphql::Value::Null => serde_json::Value::Null,
201        async_graphql::Value::Number(n) => {
202            if let Some(i) = n.as_i64() {
203                serde_json::Value::Number(i.into())
204            } else if let Some(u) = n.as_u64() {
205                serde_json::Value::Number(u.into())
206            } else if let Some(f) = n.as_f64() {
207                serde_json::Number::from_f64(f)
208                    .map(serde_json::Value::Number)
209                    .unwrap_or(serde_json::Value::Null)
210            } else {
211                serde_json::Value::Null
212            }
213        }
214        async_graphql::Value::String(s) => serde_json::Value::String(s.clone()),
215        async_graphql::Value::Boolean(b) => serde_json::Value::Bool(*b),
216        async_graphql::Value::List(arr) => {
217            serde_json::Value::Array(arr.iter().map(graphql_value_to_json).collect())
218        }
219        async_graphql::Value::Object(obj) => {
220            let map: serde_json::Map<String, serde_json::Value> =
221                obj.iter().map(|(k, v)| (k.to_string(), graphql_value_to_json(v))).collect();
222            serde_json::Value::Object(map)
223        }
224        async_graphql::Value::Enum(e) => serde_json::Value::String(e.to_string()),
225        async_graphql::Value::Binary(b) => {
226            use base64::Engine;
227            serde_json::Value::String(base64::engine::general_purpose::STANDARD.encode(b))
228        }
229    }
230}
231
232/// Convert serde_json::Value to async_graphql::Value
233fn json_to_graphql_value(value: &serde_json::Value) -> async_graphql::Value {
234    match value {
235        serde_json::Value::Null => async_graphql::Value::Null,
236        serde_json::Value::Bool(b) => async_graphql::Value::Boolean(*b),
237        serde_json::Value::Number(n) => {
238            if let Some(i) = n.as_i64() {
239                async_graphql::Value::Number(i.into())
240            } else if let Some(u) = n.as_u64() {
241                async_graphql::Value::Number(u.into())
242            } else if let Some(f) = n.as_f64() {
243                async_graphql::Value::Number(
244                    async_graphql::Number::from_f64(f).unwrap_or_else(|| 0.into()),
245                )
246            } else {
247                async_graphql::Value::Null
248            }
249        }
250        serde_json::Value::String(s) => async_graphql::Value::String(s.clone()),
251        serde_json::Value::Array(arr) => {
252            async_graphql::Value::List(arr.iter().map(json_to_graphql_value).collect())
253        }
254        serde_json::Value::Object(obj) => {
255            let map: async_graphql::indexmap::IndexMap<async_graphql::Name, async_graphql::Value> =
256                obj.iter()
257                    .filter_map(|(k, v)| {
258                        // GraphQL names must match [_A-Za-z][_0-9A-Za-z]*
259                        let is_valid =
260                            k.chars().next().is_some_and(|c| c == '_' || c.is_ascii_alphabetic())
261                                && k.chars().all(|c| c == '_' || c.is_ascii_alphanumeric());
262                        if is_valid {
263                            Some((async_graphql::Name::new(k), json_to_graphql_value(v)))
264                        } else {
265                            None
266                        }
267                    })
268                    .collect();
269            async_graphql::Value::Object(map)
270        }
271    }
272}
273
274/// Cache configuration
275#[derive(Clone, Debug)]
276pub struct CacheConfig {
277    /// Maximum number of entries
278    pub max_entries: usize,
279    /// Time-to-live for cached responses
280    pub ttl: Duration,
281    /// Enable cache statistics
282    pub enable_stats: bool,
283}
284
285impl Default for CacheConfig {
286    fn default() -> Self {
287        Self {
288            max_entries: 1000,
289            ttl: Duration::from_secs(300), // 5 minutes
290            enable_stats: true,
291        }
292    }
293}
294
295/// Cache statistics
296#[derive(Clone, Debug, Default)]
297pub struct CacheStats {
298    /// Total cache hits
299    pub hits: u64,
300    /// Total cache misses
301    pub misses: u64,
302    /// Number of evictions
303    pub evictions: u64,
304    /// Current cache size
305    pub size: usize,
306}
307
308impl CacheStats {
309    /// Get hit rate as a percentage
310    pub fn hit_rate(&self) -> f64 {
311        let total = self.hits + self.misses;
312        if total == 0 {
313            0.0
314        } else {
315            (self.hits as f64 / total as f64) * 100.0
316        }
317    }
318}
319
320/// Response cache for GraphQL operations
321pub struct ResponseCache {
322    /// Cache storage
323    cache: Arc<RwLock<HashMap<CacheKey, CachedResponse>>>,
324    /// Cache configuration
325    config: CacheConfig,
326    /// Cache statistics
327    stats: Arc<RwLock<CacheStats>>,
328}
329
330impl ResponseCache {
331    /// Create a new response cache
332    pub fn new(config: CacheConfig) -> Self {
333        Self {
334            cache: Arc::new(RwLock::new(HashMap::new())),
335            config,
336            stats: Arc::new(RwLock::new(CacheStats::default())),
337        }
338    }
339
340    /// Create with default configuration
341    pub fn default() -> Self {
342        Self::new(CacheConfig::default())
343    }
344
345    /// Get a cached response
346    pub fn get(&self, key: &CacheKey) -> Option<Response> {
347        let mut cache = self.cache.write();
348
349        if let Some(cached) = cache.get_mut(key) {
350            // Check if TTL expired
351            if cached.cached_at.elapsed() > self.config.ttl {
352                cache.remove(key);
353                self.record_miss();
354                return None;
355            }
356
357            // Update hit count
358            cached.hit_count += 1;
359            self.record_hit();
360
361            // Convert cached response to GraphQL Response
362            Some(cached.to_response())
363        } else {
364            self.record_miss();
365            None
366        }
367    }
368
369    /// Put a response in the cache
370    pub fn put(&self, key: CacheKey, response: Response) {
371        let mut cache = self.cache.write();
372
373        // Evict oldest entry if at capacity
374        if cache.len() >= self.config.max_entries {
375            if let Some(oldest_key) = self.find_oldest_key(&cache) {
376                cache.remove(&oldest_key);
377                self.record_eviction();
378            }
379        }
380
381        // Convert response to cached format
382        let cached_response = CachedResponse::from_response(&response);
383
384        cache.insert(key, cached_response);
385
386        self.update_size(cache.len());
387    }
388
389    /// Clear all cached responses
390    pub fn clear(&self) {
391        let mut cache = self.cache.write();
392        cache.clear();
393        self.update_size(0);
394    }
395
396    /// Clear expired entries
397    pub fn clear_expired(&self) {
398        let mut cache = self.cache.write();
399        let ttl = self.config.ttl;
400
401        cache.retain(|_, cached| cached.cached_at.elapsed() <= ttl);
402        self.update_size(cache.len());
403    }
404
405    /// Get cache statistics
406    pub fn stats(&self) -> CacheStats {
407        self.stats.read().clone()
408    }
409
410    /// Reset statistics
411    pub fn reset_stats(&self) {
412        let mut stats = self.stats.write();
413        *stats = CacheStats::default();
414    }
415
416    // Private helper methods
417
418    fn find_oldest_key(&self, cache: &HashMap<CacheKey, CachedResponse>) -> Option<CacheKey> {
419        cache
420            .iter()
421            .min_by_key(|(_, cached)| cached.cached_at)
422            .map(|(key, _)| key.clone())
423    }
424
425    fn record_hit(&self) {
426        if self.config.enable_stats {
427            let mut stats = self.stats.write();
428            stats.hits += 1;
429        }
430    }
431
432    fn record_miss(&self) {
433        if self.config.enable_stats {
434            let mut stats = self.stats.write();
435            stats.misses += 1;
436        }
437    }
438
439    fn record_eviction(&self) {
440        if self.config.enable_stats {
441            let mut stats = self.stats.write();
442            stats.evictions += 1;
443        }
444    }
445
446    fn update_size(&self, size: usize) {
447        if self.config.enable_stats {
448            let mut stats = self.stats.write();
449            stats.size = size;
450        }
451    }
452}
453
454/// Cache middleware for automatic caching
455pub struct CacheMiddleware {
456    cache: Arc<ResponseCache>,
457    /// Operations to cache (None = cache all)
458    cacheable_operations: Option<Vec<String>>,
459}
460
461impl CacheMiddleware {
462    /// Create new cache middleware
463    pub fn new(cache: Arc<ResponseCache>) -> Self {
464        Self {
465            cache,
466            cacheable_operations: None,
467        }
468    }
469
470    /// Set specific operations to cache
471    pub fn with_operations(mut self, operations: Vec<String>) -> Self {
472        self.cacheable_operations = Some(operations);
473        self
474    }
475
476    /// Check if an operation should be cached
477    pub fn should_cache(&self, operation_name: Option<&str>) -> bool {
478        match &self.cacheable_operations {
479            None => true, // Cache everything
480            Some(ops) => {
481                operation_name.map(|name| ops.contains(&name.to_string())).unwrap_or(false)
482            }
483        }
484    }
485
486    /// Get cached response if available
487    pub fn get_cached(&self, key: &CacheKey) -> Option<Response> {
488        self.cache.get(key)
489    }
490
491    /// Cache a response
492    pub fn cache_response(&self, key: CacheKey, response: Response) {
493        self.cache.put(key, response);
494    }
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500    use async_graphql::Value;
501
502    #[test]
503    fn test_cache_key_creation() {
504        let key = CacheKey::new(
505            "getUser".to_string(),
506            "query { user { id } }".to_string(),
507            r#"{"id": "123"}"#.to_string(),
508        );
509
510        assert_eq!(key.operation_name, "getUser");
511    }
512
513    #[test]
514    fn test_cache_key_from_request() {
515        let key = CacheKey::from_request(
516            Some("getUser".to_string()),
517            "query { user { id } }".to_string(),
518            serde_json::json!({"id": "123"}),
519        );
520
521        assert_eq!(key.operation_name, "getUser");
522        assert!(key.variables.contains("123"));
523    }
524
525    #[test]
526    fn test_cache_config_default() {
527        let config = CacheConfig::default();
528        assert_eq!(config.max_entries, 1000);
529        assert_eq!(config.ttl, Duration::from_secs(300));
530        assert!(config.enable_stats);
531    }
532
533    #[test]
534    fn test_cache_stats_hit_rate() {
535        let mut stats = CacheStats::default();
536        stats.hits = 80;
537        stats.misses = 20;
538
539        assert_eq!(stats.hit_rate(), 80.0);
540    }
541
542    #[test]
543    fn test_cache_put_and_get() {
544        let cache = ResponseCache::default();
545        let key = CacheKey::new("test".to_string(), "query".to_string(), "{}".to_string());
546        let response = Response::new(Value::Null);
547
548        cache.put(key.clone(), response);
549        let cached = cache.get(&key);
550
551        assert!(cached.is_some());
552    }
553
554    #[test]
555    fn test_cache_miss() {
556        let cache = ResponseCache::default();
557        let key = CacheKey::new("nonexistent".to_string(), "query".to_string(), "{}".to_string());
558
559        let cached = cache.get(&key);
560        assert!(cached.is_none());
561    }
562
563    #[test]
564    fn test_cache_clear() {
565        let cache = ResponseCache::default();
566        let key = CacheKey::new("test".to_string(), "query".to_string(), "{}".to_string());
567        let response = Response::new(Value::Null);
568
569        cache.put(key.clone(), response);
570        assert!(cache.get(&key).is_some());
571
572        cache.clear();
573        assert!(cache.get(&key).is_none());
574    }
575
576    #[test]
577    fn test_cache_stats() {
578        let cache = ResponseCache::default();
579        let key = CacheKey::new("test".to_string(), "query".to_string(), "{}".to_string());
580
581        // Miss
582        let _ = cache.get(&key);
583
584        // Put and hit
585        let response = Response::new(Value::Null);
586        cache.put(key.clone(), response);
587        let _ = cache.get(&key);
588
589        let stats = cache.stats();
590        assert_eq!(stats.hits, 1);
591        assert_eq!(stats.misses, 1);
592        assert_eq!(stats.size, 1);
593    }
594
595    #[test]
596    fn test_cache_middleware_should_cache() {
597        let cache = Arc::new(ResponseCache::default());
598        let middleware = CacheMiddleware::new(cache);
599
600        assert!(middleware.should_cache(Some("getUser")));
601        assert!(middleware.should_cache(None));
602    }
603
604    #[test]
605    fn test_cache_middleware_with_specific_operations() {
606        let cache = Arc::new(ResponseCache::default());
607        let middleware = CacheMiddleware::new(cache)
608            .with_operations(vec!["getUser".to_string(), "getProduct".to_string()]);
609
610        assert!(middleware.should_cache(Some("getUser")));
611        assert!(!middleware.should_cache(Some("createUser")));
612    }
613
614    #[test]
615    fn test_cache_eviction() {
616        let config = CacheConfig {
617            max_entries: 2,
618            ttl: Duration::from_secs(300),
619            enable_stats: true,
620        };
621        let cache = ResponseCache::new(config);
622
623        // Add 3 entries (should evict the oldest)
624        for i in 0..3 {
625            let key = CacheKey::new(format!("op{}", i), "query".to_string(), "{}".to_string());
626            cache.put(key, Response::new(Value::Null));
627            std::thread::sleep(Duration::from_millis(10)); // Ensure different timestamps
628        }
629
630        let stats = cache.stats();
631        assert_eq!(stats.size, 2);
632        assert_eq!(stats.evictions, 1);
633    }
634}