kaccy_db/
cache_invalidation.rs

1//! Cache invalidation patterns for distributed systems.
2//!
3//! This module provides:
4//! - Redis pub/sub for cross-instance invalidation
5//! - Tag-based invalidation
6//! - Cascade invalidation rules
7
8use futures::StreamExt;
9use parking_lot::RwLock;
10use redis::Client;
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13use std::sync::Arc;
14use tokio::sync::mpsc;
15use tracing::{debug, error, info, warn};
16
17use crate::cache::RedisCache;
18use crate::error::{DbError, Result};
19
20/// Invalidation event
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct InvalidationEvent {
23    /// Type of invalidation
24    pub event_type: InvalidationType,
25    /// Keys to invalidate
26    pub keys: Vec<String>,
27    /// Tags to invalidate
28    pub tags: Vec<String>,
29    /// Timestamp of the event
30    pub timestamp: chrono::DateTime<chrono::Utc>,
31    /// Source instance that triggered the invalidation
32    pub source_instance: String,
33}
34
35/// Type of invalidation
36#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
37pub enum InvalidationType {
38    /// Invalidate specific keys
39    Keys,
40    /// Invalidate all keys with a tag
41    Tags,
42    /// Invalidate with cascade rules
43    Cascade,
44    /// Invalidate all keys matching a pattern
45    Pattern,
46}
47
48/// Configuration for invalidation manager
49#[derive(Debug, Clone)]
50pub struct InvalidationConfig {
51    /// Redis pub/sub channel name
52    pub pubsub_channel: String,
53    /// Instance identifier (for preventing self-invalidation loops)
54    pub instance_id: String,
55    /// Enable cascade invalidation
56    pub enable_cascade: bool,
57    /// Maximum cascade depth
58    pub max_cascade_depth: usize,
59}
60
61impl Default for InvalidationConfig {
62    fn default() -> Self {
63        Self {
64            pubsub_channel: "cache:invalidation".to_string(),
65            instance_id: uuid::Uuid::new_v4().to_string(),
66            enable_cascade: true,
67            max_cascade_depth: 5,
68        }
69    }
70}
71
72/// Tag registry for managing key-tag relationships
73#[derive(Debug, Clone)]
74pub struct TagRegistry {
75    /// Map of tags to keys
76    tags: Arc<RwLock<HashMap<String, HashSet<String>>>>,
77    /// Map of keys to tags
78    keys: Arc<RwLock<HashMap<String, HashSet<String>>>>,
79}
80
81impl Default for TagRegistry {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87impl TagRegistry {
88    /// Create a new tag registry
89    pub fn new() -> Self {
90        Self {
91            tags: Arc::new(RwLock::new(HashMap::new())),
92            keys: Arc::new(RwLock::new(HashMap::new())),
93        }
94    }
95
96    /// Register a key with tags
97    pub fn register(&self, key: String, tags: Vec<String>) {
98        let mut tag_map = self.tags.write();
99        let mut key_map = self.keys.write();
100
101        for tag in &tags {
102            tag_map.entry(tag.clone()).or_default().insert(key.clone());
103        }
104
105        key_map.insert(key, tags.into_iter().collect());
106    }
107
108    /// Get all keys for a tag
109    pub fn get_keys_for_tag(&self, tag: &str) -> Vec<String> {
110        self.tags
111            .read()
112            .get(tag)
113            .map(|keys| keys.iter().cloned().collect())
114            .unwrap_or_default()
115    }
116
117    /// Get all tags for a key
118    pub fn get_tags_for_key(&self, key: &str) -> Vec<String> {
119        self.keys
120            .read()
121            .get(key)
122            .map(|tags| tags.iter().cloned().collect())
123            .unwrap_or_default()
124    }
125
126    /// Unregister a key
127    pub fn unregister(&self, key: &str) {
128        let mut key_map = self.keys.write();
129        if let Some(tags) = key_map.remove(key) {
130            let mut tag_map = self.tags.write();
131            for tag in tags {
132                if let Some(keys) = tag_map.get_mut(&tag) {
133                    keys.remove(key);
134                }
135            }
136        }
137    }
138}
139
140/// Cascade rule for invalidation
141#[derive(Debug, Clone)]
142pub struct CascadeRule {
143    /// Source tag that triggers the cascade
144    pub source_tag: String,
145    /// Target tags to invalidate
146    pub target_tags: Vec<String>,
147}
148
149/// Invalidation manager
150pub struct InvalidationManager {
151    cache: Arc<RedisCache>,
152    config: InvalidationConfig,
153    tag_registry: TagRegistry,
154    cascade_rules: Arc<RwLock<Vec<CascadeRule>>>,
155    pubsub_tx: mpsc::UnboundedSender<InvalidationEvent>,
156}
157
158impl InvalidationManager {
159    /// Create a new invalidation manager
160    pub fn new(
161        cache: Arc<RedisCache>,
162        config: InvalidationConfig,
163    ) -> (Self, mpsc::UnboundedReceiver<InvalidationEvent>) {
164        let (tx, rx) = mpsc::unbounded_channel();
165
166        let manager = Self {
167            cache,
168            config,
169            tag_registry: TagRegistry::new(),
170            cascade_rules: Arc::new(RwLock::new(Vec::new())),
171            pubsub_tx: tx,
172        };
173
174        (manager, rx)
175    }
176
177    /// Register a cascade rule
178    pub fn add_cascade_rule(&self, rule: CascadeRule) {
179        info!(
180            source = %rule.source_tag,
181            targets = ?rule.target_tags,
182            "Added cascade invalidation rule"
183        );
184        self.cascade_rules.write().push(rule);
185    }
186
187    /// Register a key with tags
188    pub fn register_key(&self, key: String, tags: Vec<String>) {
189        self.tag_registry.register(key, tags);
190    }
191
192    /// Invalidate specific keys
193    pub async fn invalidate_keys(&self, keys: Vec<String>) -> Result<()> {
194        debug!(count = keys.len(), "Invalidating keys");
195
196        for key in &keys {
197            if let Err(e) = self.cache.delete(key).await {
198                error!(key = %key, error = %e, "Failed to invalidate key");
199            }
200        }
201
202        // Publish invalidation event
203        let event = InvalidationEvent {
204            event_type: InvalidationType::Keys,
205            keys,
206            tags: Vec::new(),
207            timestamp: chrono::Utc::now(),
208            source_instance: self.config.instance_id.clone(),
209        };
210
211        self.publish_event(event).await?;
212
213        Ok(())
214    }
215
216    /// Invalidate all keys with a specific tag
217    pub async fn invalidate_tag(&self, tag: String) -> Result<()> {
218        let keys = self.tag_registry.get_keys_for_tag(&tag);
219
220        debug!(tag = %tag, key_count = keys.len(), "Invalidating tag");
221
222        for key in &keys {
223            if let Err(e) = self.cache.delete(key).await {
224                error!(key = %key, error = %e, "Failed to invalidate key");
225            }
226        }
227
228        // Apply cascade rules if enabled
229        if self.config.enable_cascade {
230            self.apply_cascade_rules(&tag, 0).await?;
231        }
232
233        // Publish invalidation event
234        let event = InvalidationEvent {
235            event_type: InvalidationType::Tags,
236            keys: Vec::new(),
237            tags: vec![tag],
238            timestamp: chrono::Utc::now(),
239            source_instance: self.config.instance_id.clone(),
240        };
241
242        self.publish_event(event).await?;
243
244        Ok(())
245    }
246
247    /// Invalidate keys matching a pattern
248    pub async fn invalidate_pattern(&self, pattern: String) -> Result<()> {
249        debug!(pattern = %pattern, "Invalidating pattern");
250
251        let deleted = self.cache.delete_pattern(&pattern).await?;
252
253        info!(pattern = %pattern, deleted = deleted, "Pattern invalidation completed");
254
255        // Publish invalidation event
256        let event = InvalidationEvent {
257            event_type: InvalidationType::Pattern,
258            keys: vec![pattern],
259            tags: Vec::new(),
260            timestamp: chrono::Utc::now(),
261            source_instance: self.config.instance_id.clone(),
262        };
263
264        self.publish_event(event).await?;
265
266        Ok(())
267    }
268
269    /// Apply cascade invalidation rules
270    fn apply_cascade_rules<'a>(
271        &'a self,
272        tag: &'a str,
273        depth: usize,
274    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> {
275        Box::pin(async move {
276            if depth >= self.config.max_cascade_depth {
277                warn!(tag = %tag, depth = depth, "Max cascade depth reached");
278                return Ok(());
279            }
280
281            let matching_rules: Vec<_> = {
282                let rules = self.cascade_rules.read();
283                rules
284                    .iter()
285                    .filter(|rule| rule.source_tag == tag)
286                    .cloned()
287                    .collect()
288            };
289
290            for rule in matching_rules {
291                debug!(
292                    source = %rule.source_tag,
293                    targets = ?rule.target_tags,
294                    depth = depth,
295                    "Applying cascade rule"
296                );
297
298                for target_tag in rule.target_tags {
299                    let keys = self.tag_registry.get_keys_for_tag(&target_tag);
300
301                    for key in keys {
302                        if let Err(e) = self.cache.delete(&key).await {
303                            error!(key = %key, error = %e, "Failed to cascade invalidate key");
304                        }
305                    }
306
307                    // Recursively apply cascade rules
308                    self.apply_cascade_rules(&target_tag, depth + 1).await?;
309                }
310            }
311
312            Ok(())
313        })
314    }
315
316    /// Publish invalidation event via Redis pub/sub
317    async fn publish_event(&self, event: InvalidationEvent) -> Result<()> {
318        let _json = serde_json::to_string(&event)
319            .map_err(|e| DbError::Cache(format!("Serialization error: {}", e)))?;
320
321        if let Err(e) = self.pubsub_tx.send(event) {
322            error!(error = %e, "Failed to send invalidation event to channel");
323        }
324
325        debug!("Published invalidation event");
326
327        Ok(())
328    }
329
330    /// Start pub/sub subscriber
331    pub async fn start_subscriber(self: Arc<Self>, redis_url: String) -> Result<()> {
332        let client = Client::open(redis_url.as_str())
333            .map_err(|e| DbError::Connection(format!("Redis client error: {}", e)))?;
334
335        let mut pubsub = client
336            .get_async_pubsub()
337            .await
338            .map_err(|e| DbError::Connection(format!("Redis pubsub error: {}", e)))?;
339        pubsub
340            .subscribe(&self.config.pubsub_channel)
341            .await
342            .map_err(|e| DbError::Cache(format!("Subscribe error: {}", e)))?;
343
344        info!(channel = %self.config.pubsub_channel, "Started invalidation subscriber");
345
346        tokio::spawn(async move {
347            loop {
348                match pubsub.on_message().next().await {
349                    Some(msg) => {
350                        let payload: String = match msg.get_payload() {
351                            Ok(p) => p,
352                            Err(e) => {
353                                error!(error = %e, "Failed to get message payload");
354                                continue;
355                            }
356                        };
357
358                        let event: InvalidationEvent = match serde_json::from_str(&payload) {
359                            Ok(e) => e,
360                            Err(e) => {
361                                error!(error = %e, "Failed to deserialize event");
362                                continue;
363                            }
364                        };
365
366                        // Skip events from self
367                        if event.source_instance == self.config.instance_id {
368                            continue;
369                        }
370
371                        debug!(
372                            event_type = ?event.event_type,
373                            source = %event.source_instance,
374                            "Received invalidation event"
375                        );
376
377                        // Process the event
378                        match event.event_type {
379                            InvalidationType::Keys => {
380                                for key in &event.keys {
381                                    if let Err(e) = self.cache.delete(key).await {
382                                        error!(key = %key, error = %e, "Failed to invalidate key");
383                                    }
384                                }
385                            }
386                            InvalidationType::Tags => {
387                                for tag in &event.tags {
388                                    let keys = self.tag_registry.get_keys_for_tag(tag);
389                                    for key in keys {
390                                        if let Err(e) = self.cache.delete(&key).await {
391                                            error!(key = %key, error = %e, "Failed to invalidate key");
392                                        }
393                                    }
394                                }
395                            }
396                            InvalidationType::Pattern => {
397                                for pattern in &event.keys {
398                                    if let Err(e) = self.cache.delete_pattern(pattern).await {
399                                        error!(pattern = %pattern, error = %e, "Failed to invalidate pattern");
400                                    }
401                                }
402                            }
403                            InvalidationType::Cascade => {
404                                // Cascade handled by publisher
405                            }
406                        }
407                    }
408                    None => {
409                        warn!("Pub/sub connection closed, reconnecting...");
410                        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
411                    }
412                }
413            }
414        });
415
416        Ok(())
417    }
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423
424    #[test]
425    fn test_invalidation_config_default() {
426        let config = InvalidationConfig::default();
427        assert_eq!(config.pubsub_channel, "cache:invalidation");
428        assert!(config.enable_cascade);
429        assert_eq!(config.max_cascade_depth, 5);
430    }
431
432    #[test]
433    fn test_tag_registry_register() {
434        let registry = TagRegistry::new();
435
436        registry.register(
437            "key1".to_string(),
438            vec!["tag1".to_string(), "tag2".to_string()],
439        );
440
441        let keys = registry.get_keys_for_tag("tag1");
442        assert_eq!(keys.len(), 1);
443        assert!(keys.contains(&"key1".to_string()));
444    }
445
446    #[test]
447    fn test_tag_registry_get_tags() {
448        let registry = TagRegistry::new();
449
450        registry.register(
451            "key1".to_string(),
452            vec!["tag1".to_string(), "tag2".to_string()],
453        );
454
455        let tags = registry.get_tags_for_key("key1");
456        assert_eq!(tags.len(), 2);
457        assert!(tags.contains(&"tag1".to_string()));
458        assert!(tags.contains(&"tag2".to_string()));
459    }
460
461    #[test]
462    fn test_tag_registry_unregister() {
463        let registry = TagRegistry::new();
464
465        registry.register("key1".to_string(), vec!["tag1".to_string()]);
466        registry.unregister("key1");
467
468        let keys = registry.get_keys_for_tag("tag1");
469        assert_eq!(keys.len(), 0);
470    }
471
472    #[test]
473    fn test_cascade_rule_creation() {
474        let rule = CascadeRule {
475            source_tag: "user".to_string(),
476            target_tags: vec!["user_profile".to_string(), "user_orders".to_string()],
477        };
478
479        assert_eq!(rule.source_tag, "user");
480        assert_eq!(rule.target_tags.len(), 2);
481    }
482
483    #[test]
484    fn test_invalidation_event_serialization() {
485        let event = InvalidationEvent {
486            event_type: InvalidationType::Keys,
487            keys: vec!["key1".to_string()],
488            tags: vec![],
489            timestamp: chrono::Utc::now(),
490            source_instance: "instance1".to_string(),
491        };
492
493        let json = serde_json::to_string(&event).unwrap();
494        let deserialized: InvalidationEvent = serde_json::from_str(&json).unwrap();
495
496        assert_eq!(deserialized.event_type, InvalidationType::Keys);
497        assert_eq!(deserialized.keys.len(), 1);
498    }
499}