Skip to main content

oxirs_gql/
distributed_cache.rs

1//! Distributed Caching with Redis Integration
2//!
3//! This module provides high-performance distributed caching for GraphQL queries
4//! with Redis backend, intelligent cache strategies, and federation support.
5
6use anyhow::{anyhow, Result};
7use async_trait::async_trait;
8use redis::{cmd, Client};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{Duration, SystemTime};
13use tokio::sync::RwLock;
14use tracing::{debug, info};
15
16/// Cache configuration
17#[derive(Debug, Clone)]
18pub struct CacheConfig {
19    pub redis_urls: Vec<String>,
20    pub default_ttl: Duration,
21    pub max_cache_size: u64,
22    pub compression_enabled: bool,
23    pub encryption_enabled: bool,
24    pub cluster_mode: bool,
25    pub sharding_strategy: ShardingStrategy,
26    pub eviction_policy: EvictionPolicy,
27    pub consistency_level: ConsistencyLevel,
28    pub replication_factor: usize,
29    pub local_cache_size: usize,
30    pub prefetch_enabled: bool,
31}
32
33impl Default for CacheConfig {
34    fn default() -> Self {
35        Self {
36            redis_urls: vec!["redis://localhost:6379".to_string()],
37            default_ttl: Duration::from_secs(3600),
38            max_cache_size: 1024 * 1024 * 1024, // 1GB
39            compression_enabled: true,
40            encryption_enabled: false,
41            cluster_mode: false,
42            sharding_strategy: ShardingStrategy::ConsistentHashing,
43            eviction_policy: EvictionPolicy::LRU,
44            consistency_level: ConsistencyLevel::Eventual,
45            replication_factor: 2,
46            local_cache_size: 10000,
47            prefetch_enabled: true,
48        }
49    }
50}
51
52/// Sharding strategies for distributed cache
53#[derive(Debug, Clone)]
54pub enum ShardingStrategy {
55    ConsistentHashing,
56    Range,
57    ModuloHash,
58    QueryType,
59    ServiceAffinity,
60}
61
62/// Cache eviction policies
63#[derive(Debug, Clone)]
64pub enum EvictionPolicy {
65    LRU,
66    LFU,
67    FIFO,
68    TTL,
69    Adaptive,
70}
71
72/// Consistency levels for distributed caching
73#[derive(Debug, Clone)]
74pub enum ConsistencyLevel {
75    Strong,
76    Eventual,
77    Session,
78    Bounded,
79}
80
81/// Cache entry metadata
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct CacheEntry {
84    pub key: String,
85    pub value: Vec<u8>,
86    pub created_at: SystemTime,
87    pub expires_at: SystemTime,
88    pub access_count: u64,
89    pub last_accessed: SystemTime,
90    pub size_bytes: usize,
91    pub tags: Vec<String>,
92    pub metadata: HashMap<String, String>,
93}
94
95/// Cache operation statistics
96#[derive(Debug, Clone, Default)]
97pub struct CacheStats {
98    pub hits: u64,
99    pub misses: u64,
100    pub sets: u64,
101    pub deletes: u64,
102    pub evictions: u64,
103    pub total_size_bytes: u64,
104    pub entry_count: u64,
105    pub average_response_time: Duration,
106}
107
108/// Cache invalidation event
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct InvalidationEvent {
111    pub keys: Vec<String>,
112    pub tags: Vec<String>,
113    pub timestamp: SystemTime,
114    pub source: String,
115    pub reason: InvalidationReason,
116}
117
118/// Reasons for cache invalidation
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub enum InvalidationReason {
121    SchemaChange,
122    DataUpdate,
123    Manual,
124    TTLExpired,
125    MemoryPressure,
126    ErrorRecovery,
127}
128
129/// GraphQL query context for caching
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct QueryContext {
132    pub query_hash: String,
133    pub variables_hash: String,
134    pub operation_name: Option<String>,
135    pub user_id: Option<String>,
136    pub service_ids: Vec<String>,
137    pub schema_version: String,
138    pub requested_fields: Vec<String>,
139}
140
141impl QueryContext {
142    /// Generate cache key from query context
143    pub fn cache_key(&self) -> String {
144        format!(
145            "gql:{}:{}:{}:{}",
146            self.query_hash,
147            self.variables_hash,
148            self.schema_version,
149            self.service_ids.join(",")
150        )
151    }
152
153    /// Generate tags for cache invalidation
154    pub fn tags(&self) -> Vec<String> {
155        let mut tags = vec![
156            format!("query:{}", self.query_hash),
157            format!("schema:{}", self.schema_version),
158        ];
159
160        for service_id in &self.service_ids {
161            tags.push(format!("service:{service_id}"));
162        }
163
164        for field in &self.requested_fields {
165            tags.push(format!("field:{field}"));
166        }
167
168        if let Some(user_id) = &self.user_id {
169            tags.push(format!("user:{user_id}"));
170        }
171
172        tags
173    }
174}
175
176/// Distributed cache trait
177#[async_trait]
178pub trait DistributedCache: Send + Sync {
179    async fn get(&self, key: &str) -> Result<Option<Vec<u8>>>;
180    async fn set(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<()>;
181    async fn delete(&self, key: &str) -> Result<()>;
182    async fn exists(&self, key: &str) -> Result<bool>;
183    async fn invalidate_by_tags(&self, tags: &[String]) -> Result<u64>;
184    async fn get_stats(&self) -> Result<CacheStats>;
185    async fn health_check(&self) -> Result<bool>;
186    async fn clear(&self) -> Result<()>;
187}
188
189/// Redis-based distributed cache implementation
190pub struct RedisDistributedCache {
191    config: CacheConfig,
192    redis_pool: Arc<RwLock<Vec<Client>>>,
193    local_cache: Arc<RwLock<lru::LruCache<String, CacheEntry>>>,
194    stats: Arc<RwLock<CacheStats>>,
195    compression: Option<Arc<dyn CompressionStrategy>>,
196    encryption: Option<Arc<dyn EncryptionStrategy>>,
197}
198
199impl RedisDistributedCache {
200    /// Create a new Redis-based distributed cache
201    pub async fn new(config: CacheConfig) -> Result<Self> {
202        let mut redis_clients = Vec::new();
203
204        for redis_url in &config.redis_urls {
205            let client = Client::open(redis_url.as_str())
206                .map_err(|e| anyhow!("Failed to create Redis client: {}", e))?;
207            redis_clients.push(client);
208        }
209
210        let local_cache = lru::LruCache::new(
211            std::num::NonZeroUsize::new(config.local_cache_size).unwrap_or(
212                std::num::NonZeroUsize::new(1000).expect("1000 is a valid NonZeroUsize"),
213            ),
214        );
215
216        let compression = if config.compression_enabled {
217            Some(Arc::new(GzipCompressionStrategy::new()) as Arc<dyn CompressionStrategy>)
218        } else {
219            None
220        };
221
222        let encryption = if config.encryption_enabled {
223            Some(Arc::new(AesEncryptionStrategy::new()) as Arc<dyn EncryptionStrategy>)
224        } else {
225            None
226        };
227
228        Ok(Self {
229            config,
230            redis_pool: Arc::new(RwLock::new(redis_clients)),
231            local_cache: Arc::new(RwLock::new(local_cache)),
232            stats: Arc::new(RwLock::new(CacheStats::default())),
233            compression,
234            encryption,
235        })
236    }
237
238    /// Get Redis client for a given key
239    async fn get_redis_client(&self, key: &str) -> Result<Client> {
240        let clients = self.redis_pool.read().await;
241
242        if clients.is_empty() {
243            return Err(anyhow!("No Redis clients available"));
244        }
245
246        let index = match self.config.sharding_strategy {
247            ShardingStrategy::ConsistentHashing => self.consistent_hash(key, clients.len()),
248            ShardingStrategy::ModuloHash => self.modulo_hash(key, clients.len()),
249            _ => 0, // Default to first client for other strategies
250        };
251
252        Ok(clients[index].clone())
253    }
254
255    /// Consistent hashing for key distribution
256    fn consistent_hash(&self, key: &str, num_nodes: usize) -> usize {
257        use std::collections::hash_map::DefaultHasher;
258        use std::hash::{Hash, Hasher};
259
260        let mut hasher = DefaultHasher::new();
261        key.hash(&mut hasher);
262        (hasher.finish() as usize) % num_nodes
263    }
264
265    /// Simple modulo hashing
266    fn modulo_hash(&self, key: &str, num_nodes: usize) -> usize {
267        use std::collections::hash_map::DefaultHasher;
268        use std::hash::{Hash, Hasher};
269
270        let mut hasher = DefaultHasher::new();
271        key.hash(&mut hasher);
272        (hasher.finish() as usize) % num_nodes
273    }
274
275    /// Process data through compression/encryption pipeline
276    async fn process_data(&self, data: &[u8], encode: bool) -> Result<Vec<u8>> {
277        let mut processed_data = data.to_vec();
278
279        if encode {
280            // Apply compression first
281            if let Some(compression) = &self.compression {
282                processed_data = compression.compress(&processed_data).await?;
283            }
284
285            // Then encryption
286            if let Some(encryption) = &self.encryption {
287                processed_data = encryption.encrypt(&processed_data).await?;
288            }
289        } else {
290            // Reverse order for decoding: decrypt first
291            if let Some(encryption) = &self.encryption {
292                processed_data = encryption.decrypt(&processed_data).await?;
293            }
294
295            // Then decompress
296            if let Some(compression) = &self.compression {
297                processed_data = compression.decompress(&processed_data).await?;
298            }
299        }
300
301        Ok(processed_data)
302    }
303
304    /// Update cache statistics
305    async fn update_stats<F>(&self, update_fn: F)
306    where
307        F: FnOnce(&mut CacheStats),
308    {
309        let mut stats = self.stats.write().await;
310        update_fn(&mut stats);
311    }
312}
313
314#[async_trait]
315impl DistributedCache for RedisDistributedCache {
316    async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
317        let start_time = std::time::Instant::now();
318
319        // Check local cache first
320        {
321            let mut local_cache = self.local_cache.write().await;
322            if let Some(entry) = local_cache.get(key) {
323                if entry.expires_at > SystemTime::now() {
324                    self.update_stats(|stats| {
325                        stats.hits += 1;
326                        stats.average_response_time =
327                            (stats.average_response_time + start_time.elapsed()) / 2;
328                    })
329                    .await;
330
331                    return Ok(Some(entry.value.clone()));
332                } else {
333                    // Entry expired, remove it
334                    local_cache.pop(key);
335                }
336            }
337        }
338
339        // Check Redis
340        let client = self.get_redis_client(key).await?;
341        let mut connection = client
342            .get_multiplexed_async_connection()
343            .await
344            .map_err(|e| anyhow!("Failed to get Redis connection: {}", e))?;
345
346        let redis_result: Option<Vec<u8>> = cmd("GET")
347            .arg(key)
348            .query_async(&mut connection)
349            .await
350            .map_err(|e| anyhow!("Redis GET failed: {}", e))?;
351
352        if let Some(raw_data) = redis_result {
353            // Process data (decrypt/decompress)
354            let processed_data = self.process_data(&raw_data, false).await?;
355
356            // Store in local cache
357            let entry = CacheEntry {
358                key: key.to_string(),
359                value: processed_data.clone(),
360                created_at: SystemTime::now(),
361                expires_at: SystemTime::now() + self.config.default_ttl,
362                access_count: 1,
363                last_accessed: SystemTime::now(),
364                size_bytes: processed_data.len(),
365                tags: Vec::new(),
366                metadata: HashMap::new(),
367            };
368
369            {
370                let mut local_cache = self.local_cache.write().await;
371                local_cache.put(key.to_string(), entry);
372            }
373
374            self.update_stats(|stats| {
375                stats.hits += 1;
376                stats.average_response_time =
377                    (stats.average_response_time + start_time.elapsed()) / 2;
378            })
379            .await;
380
381            Ok(Some(processed_data))
382        } else {
383            self.update_stats(|stats| {
384                stats.misses += 1;
385                stats.average_response_time =
386                    (stats.average_response_time + start_time.elapsed()) / 2;
387            })
388            .await;
389
390            Ok(None)
391        }
392    }
393
394    async fn set(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<()> {
395        let ttl = ttl.unwrap_or(self.config.default_ttl);
396
397        // Process data (compress/encrypt)
398        let processed_data = self.process_data(&value, true).await?;
399
400        // Store in Redis
401        let client = self.get_redis_client(key).await?;
402        let mut connection = client
403            .get_multiplexed_async_connection()
404            .await
405            .map_err(|e| anyhow!("Failed to get Redis connection: {}", e))?;
406
407        cmd("SETEX")
408            .arg(key)
409            .arg(ttl.as_secs())
410            .arg(&processed_data)
411            .exec_async(&mut connection)
412            .await
413            .map_err(|e| anyhow!("Redis SETEX failed: {}", e))?;
414
415        // Store in local cache
416        let entry = CacheEntry {
417            key: key.to_string(),
418            value,
419            created_at: SystemTime::now(),
420            expires_at: SystemTime::now() + ttl,
421            access_count: 0,
422            last_accessed: SystemTime::now(),
423            size_bytes: processed_data.len(),
424            tags: Vec::new(),
425            metadata: HashMap::new(),
426        };
427
428        {
429            let mut local_cache = self.local_cache.write().await;
430            local_cache.put(key.to_string(), entry);
431        }
432
433        self.update_stats(|stats| {
434            stats.sets += 1;
435            stats.total_size_bytes += processed_data.len() as u64;
436            stats.entry_count += 1;
437        })
438        .await;
439
440        Ok(())
441    }
442
443    async fn delete(&self, key: &str) -> Result<()> {
444        // Remove from local cache
445        {
446            let mut local_cache = self.local_cache.write().await;
447            local_cache.pop(key);
448        }
449
450        // Remove from Redis
451        let client = self.get_redis_client(key).await?;
452        let mut connection = client
453            .get_multiplexed_async_connection()
454            .await
455            .map_err(|e| anyhow!("Failed to get Redis connection: {}", e))?;
456
457        cmd("DEL")
458            .arg(key)
459            .query_async::<()>(&mut connection)
460            .await
461            .map_err(|e| anyhow!("Redis DEL failed: {}", e))?;
462
463        self.update_stats(|stats| {
464            stats.deletes += 1;
465        })
466        .await;
467
468        Ok(())
469    }
470
471    async fn exists(&self, key: &str) -> Result<bool> {
472        // Check local cache first
473        {
474            let mut local_cache = self.local_cache.write().await;
475            if let Some(entry) = local_cache.get(key) {
476                if entry.expires_at > SystemTime::now() {
477                    return Ok(true);
478                } else {
479                    local_cache.pop(key);
480                }
481            }
482        }
483
484        // Check Redis
485        let client = self.get_redis_client(key).await?;
486        let mut connection = client
487            .get_multiplexed_async_connection()
488            .await
489            .map_err(|e| anyhow!("Failed to get Redis connection: {}", e))?;
490
491        let exists: bool = cmd("EXISTS")
492            .arg(key)
493            .query_async(&mut connection)
494            .await
495            .map_err(|e| anyhow!("Redis EXISTS failed: {}", e))?;
496
497        Ok(exists)
498    }
499
500    async fn invalidate_by_tags(&self, tags: &[String]) -> Result<u64> {
501        // This is a simplified implementation
502        // A production implementation would use Redis sets to track keys by tags
503        let mut invalidated = 0;
504
505        for tag in tags {
506            // Create a pattern to match keys with this tag
507            let pattern = format!("*{tag}*");
508
509            let clients = self.redis_pool.read().await;
510            for client in clients.iter() {
511                let mut connection = client.get_multiplexed_async_connection().await?;
512
513                let keys: Vec<String> = cmd("KEYS")
514                    .arg(&pattern)
515                    .query_async(&mut connection)
516                    .await?;
517
518                for key in keys {
519                    self.delete(&key).await?;
520                    invalidated += 1;
521                }
522            }
523        }
524
525        Ok(invalidated)
526    }
527
528    async fn get_stats(&self) -> Result<CacheStats> {
529        Ok(self.stats.read().await.clone())
530    }
531
532    async fn health_check(&self) -> Result<bool> {
533        let clients = self.redis_pool.read().await;
534
535        for client in clients.iter() {
536            match client.get_multiplexed_async_connection().await {
537                Ok(mut connection) => {
538                    let result: Result<String, _> = cmd("PING").query_async(&mut connection).await;
539                    if result.is_err() {
540                        return Ok(false);
541                    }
542                }
543                Err(_) => return Ok(false),
544            }
545        }
546
547        Ok(true)
548    }
549
550    async fn clear(&self) -> Result<()> {
551        // Clear local cache
552        {
553            let mut local_cache = self.local_cache.write().await;
554            local_cache.clear();
555        }
556
557        // Clear Redis
558        let clients = self.redis_pool.read().await;
559        for client in clients.iter() {
560            let mut connection = client.get_multiplexed_async_connection().await?;
561            cmd("FLUSHDB").query_async::<()>(&mut connection).await?;
562        }
563
564        // Reset stats
565        {
566            let mut stats = self.stats.write().await;
567            *stats = CacheStats::default();
568        }
569
570        Ok(())
571    }
572}
573
574/// Compression strategy trait
575#[async_trait]
576pub trait CompressionStrategy: Send + Sync {
577    async fn compress(&self, data: &[u8]) -> Result<Vec<u8>>;
578    async fn decompress(&self, data: &[u8]) -> Result<Vec<u8>>;
579}
580
581/// Gzip compression strategy
582pub struct GzipCompressionStrategy;
583
584impl Default for GzipCompressionStrategy {
585    fn default() -> Self {
586        Self::new()
587    }
588}
589
590impl GzipCompressionStrategy {
591    pub fn new() -> Self {
592        Self
593    }
594}
595
596#[async_trait]
597impl CompressionStrategy for GzipCompressionStrategy {
598    async fn compress(&self, data: &[u8]) -> Result<Vec<u8>> {
599        use flate2::{write::GzEncoder, Compression};
600        use std::io::Write;
601
602        let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
603        encoder.write_all(data)?;
604        Ok(encoder.finish()?)
605    }
606
607    async fn decompress(&self, data: &[u8]) -> Result<Vec<u8>> {
608        use flate2::read::GzDecoder;
609        use std::io::Read;
610
611        let mut decoder = GzDecoder::new(data);
612        let mut decompressed = Vec::new();
613        decoder.read_to_end(&mut decompressed)?;
614        Ok(decompressed)
615    }
616}
617
618/// Encryption strategy trait
619#[async_trait]
620pub trait EncryptionStrategy: Send + Sync {
621    async fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>>;
622    async fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>>;
623}
624
625/// AES encryption strategy (stub implementation)
626pub struct AesEncryptionStrategy;
627
628impl Default for AesEncryptionStrategy {
629    fn default() -> Self {
630        Self::new()
631    }
632}
633
634impl AesEncryptionStrategy {
635    pub fn new() -> Self {
636        Self
637    }
638}
639
640#[async_trait]
641impl EncryptionStrategy for AesEncryptionStrategy {
642    async fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
643        // Stub implementation - would use actual AES encryption
644        Ok(data.to_vec())
645    }
646
647    async fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
648        // Stub implementation - would use actual AES decryption
649        Ok(data.to_vec())
650    }
651}
652
653/// GraphQL query cache manager
654#[allow(dead_code)]
655pub struct GraphQLQueryCache {
656    cache: Arc<dyn DistributedCache>,
657    config: CacheConfig,
658}
659
660impl GraphQLQueryCache {
661    /// Create a new GraphQL query cache
662    pub async fn new(config: CacheConfig) -> Result<Self> {
663        let cache = Arc::new(RedisDistributedCache::new(config.clone()).await?);
664
665        Ok(Self { cache, config })
666    }
667
668    /// Cache a GraphQL query result
669    pub async fn cache_query_result(
670        &self,
671        context: &QueryContext,
672        result: &serde_json::Value,
673        ttl: Option<Duration>,
674    ) -> Result<()> {
675        let key = context.cache_key();
676        let value = serde_json::to_vec(result)?;
677
678        self.cache.set(&key, value, ttl).await?;
679
680        info!("Cached GraphQL query result: {}", key);
681        Ok(())
682    }
683
684    /// Get cached GraphQL query result
685    pub async fn get_cached_result(
686        &self,
687        context: &QueryContext,
688    ) -> Result<Option<serde_json::Value>> {
689        let key = context.cache_key();
690
691        if let Some(cached_data) = self.cache.get(&key).await? {
692            let result: serde_json::Value = serde_json::from_slice(&cached_data)?;
693            debug!("Cache hit for GraphQL query: {}", key);
694            return Ok(Some(result));
695        }
696
697        debug!("Cache miss for GraphQL query: {}", key);
698        Ok(None)
699    }
700
701    /// Invalidate cache entries based on schema changes
702    pub async fn invalidate_on_schema_change(&self, schema_version: &str) -> Result<u64> {
703        let tags = vec![format!("schema:{}", schema_version)];
704        self.cache.invalidate_by_tags(&tags).await
705    }
706
707    /// Invalidate cache entries for specific services
708    pub async fn invalidate_for_services(&self, service_ids: &[String]) -> Result<u64> {
709        let tags: Vec<String> = service_ids
710            .iter()
711            .map(|id| format!("service:{id}"))
712            .collect();
713        self.cache.invalidate_by_tags(&tags).await
714    }
715
716    /// Get cache statistics
717    pub async fn get_stats(&self) -> Result<CacheStats> {
718        self.cache.get_stats().await
719    }
720
721    /// Health check
722    pub async fn health_check(&self) -> Result<bool> {
723        self.cache.health_check().await
724    }
725
726    /// Raw cache get for internal use
727    pub async fn raw_get(&self, key: &str) -> Result<Option<Vec<u8>>> {
728        self.cache.get(key).await
729    }
730
731    /// Raw cache set for internal use  
732    pub async fn raw_set(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<()> {
733        self.cache.set(key, value, ttl).await
734    }
735}
736
737#[cfg(test)]
738mod tests {
739    use super::*;
740
741    #[tokio::test]
742    async fn test_query_context_cache_key() {
743        let context = QueryContext {
744            query_hash: "abc123".to_string(),
745            variables_hash: "def456".to_string(),
746            operation_name: Some("GetUser".to_string()),
747            user_id: Some("user123".to_string()),
748            service_ids: vec!["service1".to_string(), "service2".to_string()],
749            schema_version: "v1.0".to_string(),
750            requested_fields: vec!["name".to_string(), "email".to_string()],
751        };
752
753        let cache_key = context.cache_key();
754        assert!(cache_key.contains("abc123"));
755        assert!(cache_key.contains("def456"));
756        assert!(cache_key.contains("v1.0"));
757    }
758
759    #[tokio::test]
760    async fn test_gzip_compression() {
761        let compression = GzipCompressionStrategy::new();
762        // Use a larger, more repetitive string that will actually compress well
763        let original_data = b"This is a test string for compression. ".repeat(100);
764
765        let compressed = compression
766            .compress(&original_data)
767            .await
768            .expect("should succeed");
769        let decompressed = compression
770            .decompress(&compressed)
771            .await
772            .expect("should succeed");
773
774        assert_eq!(original_data.as_slice(), decompressed.as_slice());
775        assert!(compressed.len() < original_data.len()); // Should be compressed
776    }
777}