1use anyhow::{anyhow, Result};
7use async_trait::async_trait;
8use redis::{cmd, Client};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{Duration, SystemTime};
13use tokio::sync::RwLock;
14use tracing::{debug, info};
15
16#[derive(Debug, Clone)]
18pub struct CacheConfig {
19 pub redis_urls: Vec<String>,
20 pub default_ttl: Duration,
21 pub max_cache_size: u64,
22 pub compression_enabled: bool,
23 pub encryption_enabled: bool,
24 pub cluster_mode: bool,
25 pub sharding_strategy: ShardingStrategy,
26 pub eviction_policy: EvictionPolicy,
27 pub consistency_level: ConsistencyLevel,
28 pub replication_factor: usize,
29 pub local_cache_size: usize,
30 pub prefetch_enabled: bool,
31}
32
33impl Default for CacheConfig {
34 fn default() -> Self {
35 Self {
36 redis_urls: vec!["redis://localhost:6379".to_string()],
37 default_ttl: Duration::from_secs(3600),
38 max_cache_size: 1024 * 1024 * 1024, compression_enabled: true,
40 encryption_enabled: false,
41 cluster_mode: false,
42 sharding_strategy: ShardingStrategy::ConsistentHashing,
43 eviction_policy: EvictionPolicy::LRU,
44 consistency_level: ConsistencyLevel::Eventual,
45 replication_factor: 2,
46 local_cache_size: 10000,
47 prefetch_enabled: true,
48 }
49 }
50}
51
52#[derive(Debug, Clone)]
54pub enum ShardingStrategy {
55 ConsistentHashing,
56 Range,
57 ModuloHash,
58 QueryType,
59 ServiceAffinity,
60}
61
62#[derive(Debug, Clone)]
64pub enum EvictionPolicy {
65 LRU,
66 LFU,
67 FIFO,
68 TTL,
69 Adaptive,
70}
71
72#[derive(Debug, Clone)]
74pub enum ConsistencyLevel {
75 Strong,
76 Eventual,
77 Session,
78 Bounded,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct CacheEntry {
84 pub key: String,
85 pub value: Vec<u8>,
86 pub created_at: SystemTime,
87 pub expires_at: SystemTime,
88 pub access_count: u64,
89 pub last_accessed: SystemTime,
90 pub size_bytes: usize,
91 pub tags: Vec<String>,
92 pub metadata: HashMap<String, String>,
93}
94
95#[derive(Debug, Clone, Default)]
97pub struct CacheStats {
98 pub hits: u64,
99 pub misses: u64,
100 pub sets: u64,
101 pub deletes: u64,
102 pub evictions: u64,
103 pub total_size_bytes: u64,
104 pub entry_count: u64,
105 pub average_response_time: Duration,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct InvalidationEvent {
111 pub keys: Vec<String>,
112 pub tags: Vec<String>,
113 pub timestamp: SystemTime,
114 pub source: String,
115 pub reason: InvalidationReason,
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub enum InvalidationReason {
121 SchemaChange,
122 DataUpdate,
123 Manual,
124 TTLExpired,
125 MemoryPressure,
126 ErrorRecovery,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct QueryContext {
132 pub query_hash: String,
133 pub variables_hash: String,
134 pub operation_name: Option<String>,
135 pub user_id: Option<String>,
136 pub service_ids: Vec<String>,
137 pub schema_version: String,
138 pub requested_fields: Vec<String>,
139}
140
141impl QueryContext {
142 pub fn cache_key(&self) -> String {
144 format!(
145 "gql:{}:{}:{}:{}",
146 self.query_hash,
147 self.variables_hash,
148 self.schema_version,
149 self.service_ids.join(",")
150 )
151 }
152
153 pub fn tags(&self) -> Vec<String> {
155 let mut tags = vec![
156 format!("query:{}", self.query_hash),
157 format!("schema:{}", self.schema_version),
158 ];
159
160 for service_id in &self.service_ids {
161 tags.push(format!("service:{service_id}"));
162 }
163
164 for field in &self.requested_fields {
165 tags.push(format!("field:{field}"));
166 }
167
168 if let Some(user_id) = &self.user_id {
169 tags.push(format!("user:{user_id}"));
170 }
171
172 tags
173 }
174}
175
176#[async_trait]
178pub trait DistributedCache: Send + Sync {
179 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>>;
180 async fn set(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<()>;
181 async fn delete(&self, key: &str) -> Result<()>;
182 async fn exists(&self, key: &str) -> Result<bool>;
183 async fn invalidate_by_tags(&self, tags: &[String]) -> Result<u64>;
184 async fn get_stats(&self) -> Result<CacheStats>;
185 async fn health_check(&self) -> Result<bool>;
186 async fn clear(&self) -> Result<()>;
187}
188
189pub struct RedisDistributedCache {
191 config: CacheConfig,
192 redis_pool: Arc<RwLock<Vec<Client>>>,
193 local_cache: Arc<RwLock<lru::LruCache<String, CacheEntry>>>,
194 stats: Arc<RwLock<CacheStats>>,
195 compression: Option<Arc<dyn CompressionStrategy>>,
196 encryption: Option<Arc<dyn EncryptionStrategy>>,
197}
198
199impl RedisDistributedCache {
200 pub async fn new(config: CacheConfig) -> Result<Self> {
202 let mut redis_clients = Vec::new();
203
204 for redis_url in &config.redis_urls {
205 let client = Client::open(redis_url.as_str())
206 .map_err(|e| anyhow!("Failed to create Redis client: {}", e))?;
207 redis_clients.push(client);
208 }
209
210 let local_cache = lru::LruCache::new(
211 std::num::NonZeroUsize::new(config.local_cache_size).unwrap_or(
212 std::num::NonZeroUsize::new(1000).expect("1000 is a valid NonZeroUsize"),
213 ),
214 );
215
216 let compression = if config.compression_enabled {
217 Some(Arc::new(GzipCompressionStrategy::new()) as Arc<dyn CompressionStrategy>)
218 } else {
219 None
220 };
221
222 let encryption = if config.encryption_enabled {
223 Some(Arc::new(AesEncryptionStrategy::new()) as Arc<dyn EncryptionStrategy>)
224 } else {
225 None
226 };
227
228 Ok(Self {
229 config,
230 redis_pool: Arc::new(RwLock::new(redis_clients)),
231 local_cache: Arc::new(RwLock::new(local_cache)),
232 stats: Arc::new(RwLock::new(CacheStats::default())),
233 compression,
234 encryption,
235 })
236 }
237
238 async fn get_redis_client(&self, key: &str) -> Result<Client> {
240 let clients = self.redis_pool.read().await;
241
242 if clients.is_empty() {
243 return Err(anyhow!("No Redis clients available"));
244 }
245
246 let index = match self.config.sharding_strategy {
247 ShardingStrategy::ConsistentHashing => self.consistent_hash(key, clients.len()),
248 ShardingStrategy::ModuloHash => self.modulo_hash(key, clients.len()),
249 _ => 0, };
251
252 Ok(clients[index].clone())
253 }
254
255 fn consistent_hash(&self, key: &str, num_nodes: usize) -> usize {
257 use std::collections::hash_map::DefaultHasher;
258 use std::hash::{Hash, Hasher};
259
260 let mut hasher = DefaultHasher::new();
261 key.hash(&mut hasher);
262 (hasher.finish() as usize) % num_nodes
263 }
264
265 fn modulo_hash(&self, key: &str, num_nodes: usize) -> usize {
267 use std::collections::hash_map::DefaultHasher;
268 use std::hash::{Hash, Hasher};
269
270 let mut hasher = DefaultHasher::new();
271 key.hash(&mut hasher);
272 (hasher.finish() as usize) % num_nodes
273 }
274
275 async fn process_data(&self, data: &[u8], encode: bool) -> Result<Vec<u8>> {
277 let mut processed_data = data.to_vec();
278
279 if encode {
280 if let Some(compression) = &self.compression {
282 processed_data = compression.compress(&processed_data).await?;
283 }
284
285 if let Some(encryption) = &self.encryption {
287 processed_data = encryption.encrypt(&processed_data).await?;
288 }
289 } else {
290 if let Some(encryption) = &self.encryption {
292 processed_data = encryption.decrypt(&processed_data).await?;
293 }
294
295 if let Some(compression) = &self.compression {
297 processed_data = compression.decompress(&processed_data).await?;
298 }
299 }
300
301 Ok(processed_data)
302 }
303
304 async fn update_stats<F>(&self, update_fn: F)
306 where
307 F: FnOnce(&mut CacheStats),
308 {
309 let mut stats = self.stats.write().await;
310 update_fn(&mut stats);
311 }
312}
313
314#[async_trait]
315impl DistributedCache for RedisDistributedCache {
316 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
317 let start_time = std::time::Instant::now();
318
319 {
321 let mut local_cache = self.local_cache.write().await;
322 if let Some(entry) = local_cache.get(key) {
323 if entry.expires_at > SystemTime::now() {
324 self.update_stats(|stats| {
325 stats.hits += 1;
326 stats.average_response_time =
327 (stats.average_response_time + start_time.elapsed()) / 2;
328 })
329 .await;
330
331 return Ok(Some(entry.value.clone()));
332 } else {
333 local_cache.pop(key);
335 }
336 }
337 }
338
339 let client = self.get_redis_client(key).await?;
341 let mut connection = client
342 .get_multiplexed_async_connection()
343 .await
344 .map_err(|e| anyhow!("Failed to get Redis connection: {}", e))?;
345
346 let redis_result: Option<Vec<u8>> = cmd("GET")
347 .arg(key)
348 .query_async(&mut connection)
349 .await
350 .map_err(|e| anyhow!("Redis GET failed: {}", e))?;
351
352 if let Some(raw_data) = redis_result {
353 let processed_data = self.process_data(&raw_data, false).await?;
355
356 let entry = CacheEntry {
358 key: key.to_string(),
359 value: processed_data.clone(),
360 created_at: SystemTime::now(),
361 expires_at: SystemTime::now() + self.config.default_ttl,
362 access_count: 1,
363 last_accessed: SystemTime::now(),
364 size_bytes: processed_data.len(),
365 tags: Vec::new(),
366 metadata: HashMap::new(),
367 };
368
369 {
370 let mut local_cache = self.local_cache.write().await;
371 local_cache.put(key.to_string(), entry);
372 }
373
374 self.update_stats(|stats| {
375 stats.hits += 1;
376 stats.average_response_time =
377 (stats.average_response_time + start_time.elapsed()) / 2;
378 })
379 .await;
380
381 Ok(Some(processed_data))
382 } else {
383 self.update_stats(|stats| {
384 stats.misses += 1;
385 stats.average_response_time =
386 (stats.average_response_time + start_time.elapsed()) / 2;
387 })
388 .await;
389
390 Ok(None)
391 }
392 }
393
394 async fn set(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<()> {
395 let ttl = ttl.unwrap_or(self.config.default_ttl);
396
397 let processed_data = self.process_data(&value, true).await?;
399
400 let client = self.get_redis_client(key).await?;
402 let mut connection = client
403 .get_multiplexed_async_connection()
404 .await
405 .map_err(|e| anyhow!("Failed to get Redis connection: {}", e))?;
406
407 cmd("SETEX")
408 .arg(key)
409 .arg(ttl.as_secs())
410 .arg(&processed_data)
411 .exec_async(&mut connection)
412 .await
413 .map_err(|e| anyhow!("Redis SETEX failed: {}", e))?;
414
415 let entry = CacheEntry {
417 key: key.to_string(),
418 value,
419 created_at: SystemTime::now(),
420 expires_at: SystemTime::now() + ttl,
421 access_count: 0,
422 last_accessed: SystemTime::now(),
423 size_bytes: processed_data.len(),
424 tags: Vec::new(),
425 metadata: HashMap::new(),
426 };
427
428 {
429 let mut local_cache = self.local_cache.write().await;
430 local_cache.put(key.to_string(), entry);
431 }
432
433 self.update_stats(|stats| {
434 stats.sets += 1;
435 stats.total_size_bytes += processed_data.len() as u64;
436 stats.entry_count += 1;
437 })
438 .await;
439
440 Ok(())
441 }
442
443 async fn delete(&self, key: &str) -> Result<()> {
444 {
446 let mut local_cache = self.local_cache.write().await;
447 local_cache.pop(key);
448 }
449
450 let client = self.get_redis_client(key).await?;
452 let mut connection = client
453 .get_multiplexed_async_connection()
454 .await
455 .map_err(|e| anyhow!("Failed to get Redis connection: {}", e))?;
456
457 cmd("DEL")
458 .arg(key)
459 .query_async::<()>(&mut connection)
460 .await
461 .map_err(|e| anyhow!("Redis DEL failed: {}", e))?;
462
463 self.update_stats(|stats| {
464 stats.deletes += 1;
465 })
466 .await;
467
468 Ok(())
469 }
470
471 async fn exists(&self, key: &str) -> Result<bool> {
472 {
474 let mut local_cache = self.local_cache.write().await;
475 if let Some(entry) = local_cache.get(key) {
476 if entry.expires_at > SystemTime::now() {
477 return Ok(true);
478 } else {
479 local_cache.pop(key);
480 }
481 }
482 }
483
484 let client = self.get_redis_client(key).await?;
486 let mut connection = client
487 .get_multiplexed_async_connection()
488 .await
489 .map_err(|e| anyhow!("Failed to get Redis connection: {}", e))?;
490
491 let exists: bool = cmd("EXISTS")
492 .arg(key)
493 .query_async(&mut connection)
494 .await
495 .map_err(|e| anyhow!("Redis EXISTS failed: {}", e))?;
496
497 Ok(exists)
498 }
499
500 async fn invalidate_by_tags(&self, tags: &[String]) -> Result<u64> {
501 let mut invalidated = 0;
504
505 for tag in tags {
506 let pattern = format!("*{tag}*");
508
509 let clients = self.redis_pool.read().await;
510 for client in clients.iter() {
511 let mut connection = client.get_multiplexed_async_connection().await?;
512
513 let keys: Vec<String> = cmd("KEYS")
514 .arg(&pattern)
515 .query_async(&mut connection)
516 .await?;
517
518 for key in keys {
519 self.delete(&key).await?;
520 invalidated += 1;
521 }
522 }
523 }
524
525 Ok(invalidated)
526 }
527
528 async fn get_stats(&self) -> Result<CacheStats> {
529 Ok(self.stats.read().await.clone())
530 }
531
532 async fn health_check(&self) -> Result<bool> {
533 let clients = self.redis_pool.read().await;
534
535 for client in clients.iter() {
536 match client.get_multiplexed_async_connection().await {
537 Ok(mut connection) => {
538 let result: Result<String, _> = cmd("PING").query_async(&mut connection).await;
539 if result.is_err() {
540 return Ok(false);
541 }
542 }
543 Err(_) => return Ok(false),
544 }
545 }
546
547 Ok(true)
548 }
549
550 async fn clear(&self) -> Result<()> {
551 {
553 let mut local_cache = self.local_cache.write().await;
554 local_cache.clear();
555 }
556
557 let clients = self.redis_pool.read().await;
559 for client in clients.iter() {
560 let mut connection = client.get_multiplexed_async_connection().await?;
561 cmd("FLUSHDB").query_async::<()>(&mut connection).await?;
562 }
563
564 {
566 let mut stats = self.stats.write().await;
567 *stats = CacheStats::default();
568 }
569
570 Ok(())
571 }
572}
573
574#[async_trait]
576pub trait CompressionStrategy: Send + Sync {
577 async fn compress(&self, data: &[u8]) -> Result<Vec<u8>>;
578 async fn decompress(&self, data: &[u8]) -> Result<Vec<u8>>;
579}
580
581pub struct GzipCompressionStrategy;
583
584impl Default for GzipCompressionStrategy {
585 fn default() -> Self {
586 Self::new()
587 }
588}
589
590impl GzipCompressionStrategy {
591 pub fn new() -> Self {
592 Self
593 }
594}
595
596#[async_trait]
597impl CompressionStrategy for GzipCompressionStrategy {
598 async fn compress(&self, data: &[u8]) -> Result<Vec<u8>> {
599 use flate2::{write::GzEncoder, Compression};
600 use std::io::Write;
601
602 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
603 encoder.write_all(data)?;
604 Ok(encoder.finish()?)
605 }
606
607 async fn decompress(&self, data: &[u8]) -> Result<Vec<u8>> {
608 use flate2::read::GzDecoder;
609 use std::io::Read;
610
611 let mut decoder = GzDecoder::new(data);
612 let mut decompressed = Vec::new();
613 decoder.read_to_end(&mut decompressed)?;
614 Ok(decompressed)
615 }
616}
617
618#[async_trait]
620pub trait EncryptionStrategy: Send + Sync {
621 async fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>>;
622 async fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>>;
623}
624
625pub struct AesEncryptionStrategy;
627
628impl Default for AesEncryptionStrategy {
629 fn default() -> Self {
630 Self::new()
631 }
632}
633
634impl AesEncryptionStrategy {
635 pub fn new() -> Self {
636 Self
637 }
638}
639
640#[async_trait]
641impl EncryptionStrategy for AesEncryptionStrategy {
642 async fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
643 Ok(data.to_vec())
645 }
646
647 async fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
648 Ok(data.to_vec())
650 }
651}
652
653#[allow(dead_code)]
655pub struct GraphQLQueryCache {
656 cache: Arc<dyn DistributedCache>,
657 config: CacheConfig,
658}
659
660impl GraphQLQueryCache {
661 pub async fn new(config: CacheConfig) -> Result<Self> {
663 let cache = Arc::new(RedisDistributedCache::new(config.clone()).await?);
664
665 Ok(Self { cache, config })
666 }
667
668 pub async fn cache_query_result(
670 &self,
671 context: &QueryContext,
672 result: &serde_json::Value,
673 ttl: Option<Duration>,
674 ) -> Result<()> {
675 let key = context.cache_key();
676 let value = serde_json::to_vec(result)?;
677
678 self.cache.set(&key, value, ttl).await?;
679
680 info!("Cached GraphQL query result: {}", key);
681 Ok(())
682 }
683
684 pub async fn get_cached_result(
686 &self,
687 context: &QueryContext,
688 ) -> Result<Option<serde_json::Value>> {
689 let key = context.cache_key();
690
691 if let Some(cached_data) = self.cache.get(&key).await? {
692 let result: serde_json::Value = serde_json::from_slice(&cached_data)?;
693 debug!("Cache hit for GraphQL query: {}", key);
694 return Ok(Some(result));
695 }
696
697 debug!("Cache miss for GraphQL query: {}", key);
698 Ok(None)
699 }
700
701 pub async fn invalidate_on_schema_change(&self, schema_version: &str) -> Result<u64> {
703 let tags = vec![format!("schema:{}", schema_version)];
704 self.cache.invalidate_by_tags(&tags).await
705 }
706
707 pub async fn invalidate_for_services(&self, service_ids: &[String]) -> Result<u64> {
709 let tags: Vec<String> = service_ids
710 .iter()
711 .map(|id| format!("service:{id}"))
712 .collect();
713 self.cache.invalidate_by_tags(&tags).await
714 }
715
716 pub async fn get_stats(&self) -> Result<CacheStats> {
718 self.cache.get_stats().await
719 }
720
721 pub async fn health_check(&self) -> Result<bool> {
723 self.cache.health_check().await
724 }
725
726 pub async fn raw_get(&self, key: &str) -> Result<Option<Vec<u8>>> {
728 self.cache.get(key).await
729 }
730
731 pub async fn raw_set(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<()> {
733 self.cache.set(key, value, ttl).await
734 }
735}
736
737#[cfg(test)]
738mod tests {
739 use super::*;
740
741 #[tokio::test]
742 async fn test_query_context_cache_key() {
743 let context = QueryContext {
744 query_hash: "abc123".to_string(),
745 variables_hash: "def456".to_string(),
746 operation_name: Some("GetUser".to_string()),
747 user_id: Some("user123".to_string()),
748 service_ids: vec!["service1".to_string(), "service2".to_string()],
749 schema_version: "v1.0".to_string(),
750 requested_fields: vec!["name".to_string(), "email".to_string()],
751 };
752
753 let cache_key = context.cache_key();
754 assert!(cache_key.contains("abc123"));
755 assert!(cache_key.contains("def456"));
756 assert!(cache_key.contains("v1.0"));
757 }
758
759 #[tokio::test]
760 async fn test_gzip_compression() {
761 let compression = GzipCompressionStrategy::new();
762 let original_data = b"This is a test string for compression. ".repeat(100);
764
765 let compressed = compression
766 .compress(&original_data)
767 .await
768 .expect("should succeed");
769 let decompressed = compression
770 .decompress(&compressed)
771 .await
772 .expect("should succeed");
773
774 assert_eq!(original_data.as_slice(), decompressed.as_slice());
775 assert!(compressed.len() < original_data.len()); }
777}