Skip to main content

mockforge_graphql/
cache.rs

1//! Response caching and memoization for GraphQL operations
2//!
3//! Provides intelligent caching of GraphQL responses to improve performance.
4
5use async_graphql::Response;
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::hash::Hash;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11
12/// Cache key for GraphQL operations
13#[derive(Clone, Debug, Eq, PartialEq, Hash)]
14pub struct CacheKey {
15    /// Operation name
16    pub operation_name: String,
17    /// Query string
18    pub query: String,
19    /// Variables as JSON string for hashing
20    pub variables: String,
21}
22
23impl CacheKey {
24    /// Create a new cache key
25    pub fn new(operation_name: String, query: String, variables: String) -> Self {
26        Self {
27            operation_name,
28            query,
29            variables,
30        }
31    }
32
33    /// Create from GraphQL request components
34    pub fn from_request(
35        operation_name: Option<String>,
36        query: String,
37        variables: serde_json::Value,
38    ) -> Self {
39        Self {
40            operation_name: operation_name.unwrap_or_default(),
41            query,
42            variables: variables.to_string(),
43        }
44    }
45}
46
47/// Cached response with metadata
48pub struct CachedResponse {
49    /// The GraphQL response data (as serde_json::Value for easy serialization)
50    pub data: serde_json::Value,
51    /// Any errors in the response
52    pub errors: Vec<CachedError>,
53    /// Extensions from the response
54    pub extensions: Option<serde_json::Value>,
55    /// When this was cached
56    pub cached_at: Instant,
57    /// Number of cache hits
58    pub hit_count: usize,
59}
60
61/// Cached error representation
62#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
63pub struct CachedError {
64    /// Error message
65    pub message: String,
66    /// Error locations in the query
67    pub locations: Vec<CachedErrorLocation>,
68    /// Path to the field that caused the error
69    pub path: Option<Vec<serde_json::Value>>,
70    /// Additional error extensions
71    pub extensions: Option<serde_json::Value>,
72}
73
74/// Error location in the query
75#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
76pub struct CachedErrorLocation {
77    /// Line number in the query (1-indexed)
78    pub line: usize,
79    /// Column number in the query (1-indexed)
80    pub column: usize,
81}
82
83impl CachedResponse {
84    /// Convert to GraphQL Response
85    pub fn to_response(&self) -> Response {
86        // Convert serde_json::Value back to async_graphql::Value
87        let graphql_value = json_to_graphql_value(&self.data);
88        let mut response = Response::new(graphql_value);
89
90        // Restore errors
91        for cached_error in &self.errors {
92            let mut server_error =
93                async_graphql::ServerError::new(cached_error.message.clone(), None);
94
95            // Restore locations
96            server_error.locations = cached_error
97                .locations
98                .iter()
99                .map(|loc| async_graphql::Pos {
100                    line: loc.line,
101                    column: loc.column,
102                })
103                .collect();
104
105            // Restore path
106            if let Some(path) = &cached_error.path {
107                server_error.path = path
108                    .iter()
109                    .filter_map(|v| match v {
110                        serde_json::Value::String(s) => {
111                            Some(async_graphql::PathSegment::Field(s.clone()))
112                        }
113                        serde_json::Value::Number(n) => {
114                            n.as_u64().map(|i| async_graphql::PathSegment::Index(i as usize))
115                        }
116                        _ => None,
117                    })
118                    .collect();
119            }
120
121            response.errors.push(server_error);
122        }
123
124        // Restore extensions
125        if let Some(serde_json::Value::Object(map)) = &self.extensions {
126            for (key, value) in map {
127                response.extensions.insert(key.clone(), json_to_graphql_value(value));
128            }
129        }
130
131        response
132    }
133
134    /// Create from GraphQL Response
135    pub fn from_response(response: &Response) -> Self {
136        // Convert async_graphql::Value to serde_json::Value
137        let data = graphql_value_to_json(&response.data);
138
139        // Convert errors
140        let errors: Vec<CachedError> = response
141            .errors
142            .iter()
143            .map(|e| CachedError {
144                message: e.message.clone(),
145                locations: e
146                    .locations
147                    .iter()
148                    .map(|loc| CachedErrorLocation {
149                        line: loc.line,
150                        column: loc.column,
151                    })
152                    .collect(),
153                path: if e.path.is_empty() {
154                    None
155                } else {
156                    Some(
157                        e.path
158                            .iter()
159                            .map(|seg| match seg {
160                                async_graphql::PathSegment::Field(s) => {
161                                    serde_json::Value::String(s.clone())
162                                }
163                                async_graphql::PathSegment::Index(i) => {
164                                    serde_json::Value::Number((*i as u64).into())
165                                }
166                            })
167                            .collect(),
168                    )
169                },
170                extensions: None, // ServerError extensions are not easily accessible
171            })
172            .collect();
173
174        // Convert extensions
175        let extensions = if response.extensions.is_empty() {
176            None
177        } else {
178            let mut map = serde_json::Map::new();
179            for (key, value) in &response.extensions {
180                map.insert(key.clone(), graphql_value_to_json(value));
181            }
182            Some(serde_json::Value::Object(map))
183        };
184
185        Self {
186            data,
187            errors,
188            extensions,
189            cached_at: Instant::now(),
190            hit_count: 0,
191        }
192    }
193}
194
195/// Convert async_graphql::Value to serde_json::Value
196fn graphql_value_to_json(value: &async_graphql::Value) -> serde_json::Value {
197    match value {
198        async_graphql::Value::Null => serde_json::Value::Null,
199        async_graphql::Value::Number(n) => {
200            if let Some(i) = n.as_i64() {
201                serde_json::Value::Number(i.into())
202            } else if let Some(u) = n.as_u64() {
203                serde_json::Value::Number(u.into())
204            } else if let Some(f) = n.as_f64() {
205                serde_json::Number::from_f64(f)
206                    .map(serde_json::Value::Number)
207                    .unwrap_or(serde_json::Value::Null)
208            } else {
209                serde_json::Value::Null
210            }
211        }
212        async_graphql::Value::String(s) => serde_json::Value::String(s.clone()),
213        async_graphql::Value::Boolean(b) => serde_json::Value::Bool(*b),
214        async_graphql::Value::List(arr) => {
215            serde_json::Value::Array(arr.iter().map(graphql_value_to_json).collect())
216        }
217        async_graphql::Value::Object(obj) => {
218            let map: serde_json::Map<String, serde_json::Value> =
219                obj.iter().map(|(k, v)| (k.to_string(), graphql_value_to_json(v))).collect();
220            serde_json::Value::Object(map)
221        }
222        async_graphql::Value::Enum(e) => serde_json::Value::String(e.to_string()),
223        async_graphql::Value::Binary(b) => {
224            use base64::Engine;
225            serde_json::Value::String(base64::engine::general_purpose::STANDARD.encode(b))
226        }
227    }
228}
229
230/// Convert serde_json::Value to async_graphql::Value
231fn json_to_graphql_value(value: &serde_json::Value) -> async_graphql::Value {
232    match value {
233        serde_json::Value::Null => async_graphql::Value::Null,
234        serde_json::Value::Bool(b) => async_graphql::Value::Boolean(*b),
235        serde_json::Value::Number(n) => {
236            if let Some(i) = n.as_i64() {
237                async_graphql::Value::Number(i.into())
238            } else if let Some(u) = n.as_u64() {
239                async_graphql::Value::Number(u.into())
240            } else if let Some(f) = n.as_f64() {
241                async_graphql::Value::Number(
242                    async_graphql::Number::from_f64(f).unwrap_or_else(|| 0.into()),
243                )
244            } else {
245                async_graphql::Value::Null
246            }
247        }
248        serde_json::Value::String(s) => async_graphql::Value::String(s.clone()),
249        serde_json::Value::Array(arr) => {
250            async_graphql::Value::List(arr.iter().map(json_to_graphql_value).collect())
251        }
252        serde_json::Value::Object(obj) => {
253            let map: indexmap::IndexMap<async_graphql::Name, async_graphql::Value> = obj
254                .iter()
255                .filter_map(|(k, v)| {
256                    // GraphQL names must match [_A-Za-z][_0-9A-Za-z]*
257                    let is_valid =
258                        k.chars().next().is_some_and(|c| c == '_' || c.is_ascii_alphabetic())
259                            && k.chars().all(|c| c == '_' || c.is_ascii_alphanumeric());
260                    if is_valid {
261                        Some((async_graphql::Name::new(k), json_to_graphql_value(v)))
262                    } else {
263                        None
264                    }
265                })
266                .collect();
267            async_graphql::Value::Object(map)
268        }
269    }
270}
271
272/// Cache configuration
273#[derive(Clone, Debug)]
274pub struct CacheConfig {
275    /// Maximum number of entries
276    pub max_entries: usize,
277    /// Time-to-live for cached responses
278    pub ttl: Duration,
279    /// Enable cache statistics
280    pub enable_stats: bool,
281}
282
283impl Default for CacheConfig {
284    fn default() -> Self {
285        Self {
286            max_entries: 1000,
287            ttl: Duration::from_secs(300), // 5 minutes
288            enable_stats: true,
289        }
290    }
291}
292
293/// Cache statistics
294#[derive(Clone, Debug, Default)]
295pub struct CacheStats {
296    /// Total cache hits
297    pub hits: u64,
298    /// Total cache misses
299    pub misses: u64,
300    /// Number of evictions
301    pub evictions: u64,
302    /// Current cache size
303    pub size: usize,
304}
305
306impl CacheStats {
307    /// Get hit rate as a percentage
308    pub fn hit_rate(&self) -> f64 {
309        let total = self.hits + self.misses;
310        if total == 0 {
311            0.0
312        } else {
313            (self.hits as f64 / total as f64) * 100.0
314        }
315    }
316}
317
318/// Response cache for GraphQL operations
319pub struct ResponseCache {
320    /// Cache storage
321    cache: Arc<RwLock<HashMap<CacheKey, CachedResponse>>>,
322    /// Cache configuration
323    config: CacheConfig,
324    /// Cache statistics
325    stats: Arc<RwLock<CacheStats>>,
326}
327
328impl ResponseCache {
329    /// Create a new response cache
330    pub fn new(config: CacheConfig) -> Self {
331        Self {
332            cache: Arc::new(RwLock::new(HashMap::new())),
333            config,
334            stats: Arc::new(RwLock::new(CacheStats::default())),
335        }
336    }
337
338    /// Get a cached response
339    pub fn get(&self, key: &CacheKey) -> Option<Response> {
340        let mut cache = self.cache.write();
341
342        if let Some(cached) = cache.get_mut(key) {
343            // Check if TTL expired
344            if cached.cached_at.elapsed() > self.config.ttl {
345                cache.remove(key);
346                self.record_miss();
347                return None;
348            }
349
350            // Update hit count
351            cached.hit_count += 1;
352            self.record_hit();
353
354            // Convert cached response to GraphQL Response
355            Some(cached.to_response())
356        } else {
357            self.record_miss();
358            None
359        }
360    }
361
362    /// Put a response in the cache
363    pub fn put(&self, key: CacheKey, response: Response) {
364        let mut cache = self.cache.write();
365
366        // Evict oldest entry if at capacity
367        if cache.len() >= self.config.max_entries {
368            if let Some(oldest_key) = self.find_oldest_key(&cache) {
369                cache.remove(&oldest_key);
370                self.record_eviction();
371            }
372        }
373
374        // Convert response to cached format
375        let cached_response = CachedResponse::from_response(&response);
376
377        cache.insert(key, cached_response);
378
379        self.update_size(cache.len());
380    }
381
382    /// Clear all cached responses
383    pub fn clear(&self) {
384        let mut cache = self.cache.write();
385        cache.clear();
386        self.update_size(0);
387    }
388
389    /// Clear expired entries
390    pub fn clear_expired(&self) {
391        let mut cache = self.cache.write();
392        let ttl = self.config.ttl;
393
394        cache.retain(|_, cached| cached.cached_at.elapsed() <= ttl);
395        self.update_size(cache.len());
396    }
397
398    /// Get cache statistics
399    pub fn stats(&self) -> CacheStats {
400        self.stats.read().clone()
401    }
402
403    /// Reset statistics
404    pub fn reset_stats(&self) {
405        let mut stats = self.stats.write();
406        *stats = CacheStats::default();
407    }
408
409    // Private helper methods
410
411    fn find_oldest_key(&self, cache: &HashMap<CacheKey, CachedResponse>) -> Option<CacheKey> {
412        cache
413            .iter()
414            .min_by_key(|(_, cached)| cached.cached_at)
415            .map(|(key, _)| key.clone())
416    }
417
418    fn record_hit(&self) {
419        if self.config.enable_stats {
420            let mut stats = self.stats.write();
421            stats.hits += 1;
422        }
423    }
424
425    fn record_miss(&self) {
426        if self.config.enable_stats {
427            let mut stats = self.stats.write();
428            stats.misses += 1;
429        }
430    }
431
432    fn record_eviction(&self) {
433        if self.config.enable_stats {
434            let mut stats = self.stats.write();
435            stats.evictions += 1;
436        }
437    }
438
439    fn update_size(&self, size: usize) {
440        if self.config.enable_stats {
441            let mut stats = self.stats.write();
442            stats.size = size;
443        }
444    }
445}
446
447impl Default for ResponseCache {
448    fn default() -> Self {
449        Self::new(CacheConfig::default())
450    }
451}
452
453/// Cache middleware for automatic caching
454pub struct CacheMiddleware {
455    cache: Arc<ResponseCache>,
456    /// Operations to cache (None = cache all)
457    cacheable_operations: Option<Vec<String>>,
458}
459
460impl CacheMiddleware {
461    /// Create new cache middleware
462    pub fn new(cache: Arc<ResponseCache>) -> Self {
463        Self {
464            cache,
465            cacheable_operations: None,
466        }
467    }
468
469    /// Set specific operations to cache
470    pub fn with_operations(mut self, operations: Vec<String>) -> Self {
471        self.cacheable_operations = Some(operations);
472        self
473    }
474
475    /// Check if an operation should be cached
476    pub fn should_cache(&self, operation_name: Option<&str>) -> bool {
477        match &self.cacheable_operations {
478            None => true, // Cache everything
479            Some(ops) => {
480                operation_name.map(|name| ops.contains(&name.to_string())).unwrap_or(false)
481            }
482        }
483    }
484
485    /// Get cached response if available
486    pub fn get_cached(&self, key: &CacheKey) -> Option<Response> {
487        self.cache.get(key)
488    }
489
490    /// Cache a response
491    pub fn cache_response(&self, key: CacheKey, response: Response) {
492        self.cache.put(key, response);
493    }
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499    use async_graphql::Value;
500
501    #[test]
502    fn test_cache_key_creation() {
503        let key = CacheKey::new(
504            "getUser".to_string(),
505            "query { user { id } }".to_string(),
506            r#"{"id": "123"}"#.to_string(),
507        );
508
509        assert_eq!(key.operation_name, "getUser");
510    }
511
512    #[test]
513    fn test_cache_key_from_request() {
514        let key = CacheKey::from_request(
515            Some("getUser".to_string()),
516            "query { user { id } }".to_string(),
517            serde_json::json!({"id": "123"}),
518        );
519
520        assert_eq!(key.operation_name, "getUser");
521        assert!(key.variables.contains("123"));
522    }
523
524    #[test]
525    fn test_cache_config_default() {
526        let config = CacheConfig::default();
527        assert_eq!(config.max_entries, 1000);
528        assert_eq!(config.ttl, Duration::from_secs(300));
529        assert!(config.enable_stats);
530    }
531
532    #[test]
533    fn test_cache_stats_hit_rate() {
534        let stats = CacheStats {
535            hits: 80,
536            misses: 20,
537            ..Default::default()
538        };
539
540        assert_eq!(stats.hit_rate(), 80.0);
541    }
542
543    #[test]
544    fn test_cache_put_and_get() {
545        let cache = ResponseCache::default();
546        let key = CacheKey::new("test".to_string(), "query".to_string(), "{}".to_string());
547        let response = Response::new(Value::Null);
548
549        cache.put(key.clone(), response);
550        let cached = cache.get(&key);
551
552        assert!(cached.is_some());
553    }
554
555    #[test]
556    fn test_cache_miss() {
557        let cache = ResponseCache::default();
558        let key = CacheKey::new("nonexistent".to_string(), "query".to_string(), "{}".to_string());
559
560        let cached = cache.get(&key);
561        assert!(cached.is_none());
562    }
563
564    #[test]
565    fn test_cache_clear() {
566        let cache = ResponseCache::default();
567        let key = CacheKey::new("test".to_string(), "query".to_string(), "{}".to_string());
568        let response = Response::new(Value::Null);
569
570        cache.put(key.clone(), response);
571        assert!(cache.get(&key).is_some());
572
573        cache.clear();
574        assert!(cache.get(&key).is_none());
575    }
576
577    #[test]
578    fn test_cache_stats() {
579        let cache = ResponseCache::default();
580        let key = CacheKey::new("test".to_string(), "query".to_string(), "{}".to_string());
581
582        // Miss
583        let _ = cache.get(&key);
584
585        // Put and hit
586        let response = Response::new(Value::Null);
587        cache.put(key.clone(), response);
588        let _ = cache.get(&key);
589
590        let stats = cache.stats();
591        assert_eq!(stats.hits, 1);
592        assert_eq!(stats.misses, 1);
593        assert_eq!(stats.size, 1);
594    }
595
596    #[test]
597    fn test_cache_middleware_should_cache() {
598        let cache = Arc::new(ResponseCache::default());
599        let middleware = CacheMiddleware::new(cache);
600
601        assert!(middleware.should_cache(Some("getUser")));
602        assert!(middleware.should_cache(None));
603    }
604
605    #[test]
606    fn test_cache_middleware_with_specific_operations() {
607        let cache = Arc::new(ResponseCache::default());
608        let middleware = CacheMiddleware::new(cache)
609            .with_operations(vec!["getUser".to_string(), "getProduct".to_string()]);
610
611        assert!(middleware.should_cache(Some("getUser")));
612        assert!(!middleware.should_cache(Some("createUser")));
613    }
614
615    #[test]
616    fn test_cache_eviction() {
617        let config = CacheConfig {
618            max_entries: 2,
619            ttl: Duration::from_secs(300),
620            enable_stats: true,
621        };
622        let cache = ResponseCache::new(config);
623
624        // Add 3 entries (should evict the oldest)
625        for i in 0..3 {
626            let key = CacheKey::new(format!("op{}", i), "query".to_string(), "{}".to_string());
627            cache.put(key, Response::new(Value::Null));
628            std::thread::sleep(Duration::from_millis(10)); // Ensure different timestamps
629        }
630
631        let stats = cache.stats();
632        assert_eq!(stats.size, 2);
633        assert_eq!(stats.evictions, 1);
634    }
635}