multi_tier_cache/
invalidation.rs

1//! Cache invalidation and synchronization module
2//!
3//! This module provides cross-instance cache invalidation using Redis Pub/Sub.
4//! It supports both cache removal (invalidation) and cache updates (refresh).
5
6use anyhow::{Context, Result};
7use redis::AsyncCommands;
8use serde::{Deserialize, Serialize};
9use std::time::{Duration, SystemTime, UNIX_EPOCH};
10use futures_util::StreamExt;
11
12/// Invalidation message types sent across cache instances via Redis Pub/Sub
13#[derive(Debug, Clone, Serialize, Deserialize)]
14#[serde(tag = "type")]
15pub enum InvalidationMessage {
16    /// Remove a single key from all cache instances
17    Remove {
18        key: String,
19    },
20
21    /// Update a key with new value across all cache instances
22    /// This is more efficient than Remove for hot keys as it avoids cache miss
23    Update {
24        key: String,
25        value: serde_json::Value,
26        #[serde(skip_serializing_if = "Option::is_none")]
27        ttl_secs: Option<u64>,
28    },
29
30    /// Remove all keys matching a pattern from all cache instances
31    /// Uses glob-style patterns (e.g., "user:*", "product:123:*")
32    RemovePattern {
33        pattern: String,
34    },
35
36    /// Bulk remove multiple keys at once
37    RemoveBulk {
38        keys: Vec<String>,
39    },
40}
41
42impl InvalidationMessage {
43    /// Create a Remove message
44    pub fn remove(key: impl Into<String>) -> Self {
45        Self::Remove { key: key.into() }
46    }
47
48    /// Create an Update message
49    pub fn update(key: impl Into<String>, value: serde_json::Value, ttl: Option<Duration>) -> Self {
50        Self::Update {
51            key: key.into(),
52            value,
53            ttl_secs: ttl.map(|d| d.as_secs()),
54        }
55    }
56
57    /// Create a RemovePattern message
58    pub fn remove_pattern(pattern: impl Into<String>) -> Self {
59        Self::RemovePattern {
60            pattern: pattern.into(),
61        }
62    }
63
64    /// Create a RemoveBulk message
65    pub fn remove_bulk(keys: Vec<String>) -> Self {
66        Self::RemoveBulk { keys }
67    }
68
69    /// Serialize to JSON for transmission
70    pub fn to_json(&self) -> Result<String> {
71        serde_json::to_string(self).context("Failed to serialize invalidation message")
72    }
73
74    /// Deserialize from JSON
75    pub fn from_json(json: &str) -> Result<Self> {
76        serde_json::from_str(json).context("Failed to deserialize invalidation message")
77    }
78
79    /// Get TTL as Duration if present
80    pub fn ttl(&self) -> Option<Duration> {
81        match self {
82            Self::Update { ttl_secs, .. } => ttl_secs.map(Duration::from_secs),
83            _ => None,
84        }
85    }
86}
87
88/// Configuration for cache invalidation
89#[derive(Debug, Clone)]
90pub struct InvalidationConfig {
91    /// Redis Pub/Sub channel name for invalidation messages
92    pub channel: String,
93
94    /// Whether to automatically broadcast invalidation on writes
95    pub auto_broadcast_on_write: bool,
96
97    /// Whether to also publish invalidation events to Redis Streams for audit
98    pub enable_audit_stream: bool,
99
100    /// Redis Stream name for invalidation audit trail
101    pub audit_stream: String,
102
103    /// Maximum length of audit stream (older entries are trimmed)
104    pub audit_stream_maxlen: Option<usize>,
105}
106
107impl Default for InvalidationConfig {
108    fn default() -> Self {
109        Self {
110            channel: "cache:invalidate".to_string(),
111            auto_broadcast_on_write: false, // Conservative default
112            enable_audit_stream: false,
113            audit_stream: "cache:invalidations".to_string(),
114            audit_stream_maxlen: Some(10000),
115        }
116    }
117}
118
119/// Handle for sending invalidation messages
120pub struct InvalidationPublisher {
121    connection: redis::aio::ConnectionManager,
122    config: InvalidationConfig,
123}
124
125impl InvalidationPublisher {
126    /// Create a new publisher
127    pub fn new(connection: redis::aio::ConnectionManager, config: InvalidationConfig) -> Self {
128        Self { connection, config }
129    }
130
131    /// Publish an invalidation message to all subscribers
132    pub async fn publish(&mut self, message: &InvalidationMessage) -> Result<()> {
133        let json = message.to_json()?;
134
135        // Publish to Pub/Sub channel
136        let _: () = self
137            .connection
138            .publish(&self.config.channel, &json)
139            .await
140            .context("Failed to publish invalidation message")?;
141
142        // Optionally publish to audit stream
143        if self.config.enable_audit_stream {
144            if let Err(e) = self.publish_to_audit_stream(message).await {
145                // Don't fail the invalidation if audit logging fails
146                eprintln!("Warning: Failed to publish to audit stream: {}", e);
147            }
148        }
149
150        Ok(())
151    }
152
153    /// Publish to audit stream for observability
154    async fn publish_to_audit_stream(&mut self, message: &InvalidationMessage) -> Result<()> {
155        let timestamp = SystemTime::now()
156            .duration_since(UNIX_EPOCH)
157            .unwrap_or(Duration::ZERO)
158            .as_secs()
159            .to_string();
160
161        // Use &str to avoid unnecessary allocations
162        let (type_str, key_str): (&str, &str);
163        let extra_str: String;
164
165        match message {
166            InvalidationMessage::Remove { key } => {
167                type_str = "remove";
168                key_str = key.as_str();
169                extra_str = String::new();
170            }
171            InvalidationMessage::Update { key, .. } => {
172                type_str = "update";
173                key_str = key.as_str();
174                extra_str = String::new();
175            }
176            InvalidationMessage::RemovePattern { pattern } => {
177                type_str = "remove_pattern";
178                key_str = pattern.as_str();
179                extra_str = String::new();
180            }
181            InvalidationMessage::RemoveBulk { keys } => {
182                type_str = "remove_bulk";
183                key_str = "";
184                extra_str = keys.len().to_string();
185            }
186        }
187
188        let mut fields = vec![
189            ("type", type_str),
190            ("timestamp", timestamp.as_str()),
191        ];
192
193        if !key_str.is_empty() {
194            fields.push(("key", key_str));
195        }
196        if !extra_str.is_empty() {
197            fields.push(("count", extra_str.as_str()));
198        }
199
200        let mut cmd = redis::cmd("XADD");
201        cmd.arg(&self.config.audit_stream);
202
203        if let Some(maxlen) = self.config.audit_stream_maxlen {
204            cmd.arg("MAXLEN").arg("~").arg(maxlen);
205        }
206
207        cmd.arg("*"); // Auto-generate ID
208
209        for (key, value) in fields {
210            cmd.arg(key).arg(value);
211        }
212
213        let _: String = cmd
214            .query_async(&mut self.connection)
215            .await
216            .context("Failed to add to audit stream")?;
217
218        Ok(())
219    }
220}
221
222/// Statistics for invalidation operations
223#[derive(Debug, Default, Clone)]
224pub struct InvalidationStats {
225    /// Number of invalidation messages published
226    pub messages_sent: u64,
227
228    /// Number of invalidation messages received
229    pub messages_received: u64,
230
231    /// Number of Remove operations performed
232    pub removes_received: u64,
233
234    /// Number of Update operations performed
235    pub updates_received: u64,
236
237    /// Number of RemovePattern operations performed
238    pub patterns_received: u64,
239
240    /// Number of RemoveBulk operations performed
241    pub bulk_removes_received: u64,
242
243    /// Number of failed message processing attempts
244    pub processing_errors: u64,
245}
246
247use std::sync::atomic::{AtomicU64, Ordering};
248
249/// Thread-safe statistics for invalidation operations
250#[derive(Debug, Default)]
251pub struct AtomicInvalidationStats {
252    pub messages_sent: AtomicU64,
253    pub messages_received: AtomicU64,
254    pub removes_received: AtomicU64,
255    pub updates_received: AtomicU64,
256    pub patterns_received: AtomicU64,
257    pub bulk_removes_received: AtomicU64,
258    pub processing_errors: AtomicU64,
259}
260
261impl AtomicInvalidationStats {
262    pub fn snapshot(&self) -> InvalidationStats {
263        InvalidationStats {
264            messages_sent: self.messages_sent.load(Ordering::Relaxed),
265            messages_received: self.messages_received.load(Ordering::Relaxed),
266            removes_received: self.removes_received.load(Ordering::Relaxed),
267            updates_received: self.updates_received.load(Ordering::Relaxed),
268            patterns_received: self.patterns_received.load(Ordering::Relaxed),
269            bulk_removes_received: self.bulk_removes_received.load(Ordering::Relaxed),
270            processing_errors: self.processing_errors.load(Ordering::Relaxed),
271        }
272    }
273}
274
275use std::sync::Arc;
276use tokio::sync::broadcast;
277
278/// Handle for subscribing to invalidation messages
279///
280/// This spawns a background task that listens to Redis Pub/Sub and processes
281/// invalidation messages by calling the provided handler callback.
282pub struct InvalidationSubscriber {
283    /// Redis client for creating Pub/Sub connections
284    client: redis::Client,
285    /// Configuration
286    config: InvalidationConfig,
287    /// Statistics
288    stats: Arc<AtomicInvalidationStats>,
289    /// Shutdown signal sender
290    shutdown_tx: broadcast::Sender<()>,
291}
292
293impl InvalidationSubscriber {
294    /// Create a new subscriber
295    ///
296    /// # Arguments
297    /// * `redis_url` - Redis connection URL
298    /// * `config` - Invalidation configuration
299    pub fn new(redis_url: &str, config: InvalidationConfig) -> Result<Self> {
300        let client = redis::Client::open(redis_url)
301            .context("Failed to create Redis client for subscriber")?;
302
303        let (shutdown_tx, _) = broadcast::channel(1);
304
305        Ok(Self {
306            client,
307            config,
308            stats: Arc::new(AtomicInvalidationStats::default()),
309            shutdown_tx,
310        })
311    }
312
313    /// Get a snapshot of current statistics
314    pub fn stats(&self) -> InvalidationStats {
315        self.stats.snapshot()
316    }
317
318    /// Start the subscriber background task
319    ///
320    /// # Arguments
321    /// * `handler` - Async function to handle each invalidation message
322    ///
323    /// # Returns
324    /// Join handle for the background task
325    pub fn start<F, Fut>(
326        &self,
327        handler: F,
328    ) -> tokio::task::JoinHandle<()>
329    where
330        F: Fn(InvalidationMessage) -> Fut + Send + Sync + 'static,
331        Fut: std::future::Future<Output = Result<()>> + Send + 'static,
332    {
333        let client = self.client.clone();
334        let channel = self.config.channel.clone();
335        let stats = Arc::clone(&self.stats);
336        let mut shutdown_rx = self.shutdown_tx.subscribe();
337
338        tokio::spawn(async move {
339            let handler = Arc::new(handler);
340
341            loop {
342                // Check for shutdown signal
343                if shutdown_rx.try_recv().is_ok() {
344                    println!("🛑 Invalidation subscriber shutting down...");
345                    break;
346                }
347
348                // Attempt to connect and subscribe
349                match Self::run_subscriber_loop(
350                    &client,
351                    &channel,
352                    Arc::clone(&handler),
353                    Arc::clone(&stats),
354                    &mut shutdown_rx,
355                ).await {
356                    Ok(_) => {
357                        println!("✅ Invalidation subscriber loop completed normally");
358                        break;
359                    }
360                    Err(e) => {
361                        eprintln!("⚠️  Invalidation subscriber error: {}. Reconnecting in 5s...", e);
362                        stats.processing_errors.fetch_add(1, Ordering::Relaxed);
363
364                        // Wait before reconnecting
365                        tokio::select! {
366                            _ = tokio::time::sleep(Duration::from_secs(5)) => {},
367                            _ = shutdown_rx.recv() => {
368                                println!("🛑 Invalidation subscriber shutting down...");
369                                break;
370                            }
371                        }
372                    }
373                }
374            }
375        })
376    }
377
378    /// Internal subscriber loop
379    async fn run_subscriber_loop<F, Fut>(
380        client: &redis::Client,
381        channel: &str,
382        handler: Arc<F>,
383        stats: Arc<AtomicInvalidationStats>,
384        shutdown_rx: &mut broadcast::Receiver<()>,
385    ) -> Result<()>
386    where
387        F: Fn(InvalidationMessage) -> Fut + Send + 'static,
388        Fut: std::future::Future<Output = Result<()>> + Send + 'static,
389    {
390        // Get Pub/Sub connection
391        let mut pubsub = client.get_async_pubsub().await
392            .context("Failed to get pubsub connection")?;
393
394        // Subscribe to channel
395        pubsub.subscribe(channel).await
396            .context("Failed to subscribe to channel")?;
397
398        println!("📡 Subscribed to invalidation channel: {}", channel);
399
400        // Get message stream
401        let mut stream = pubsub.on_message();
402
403        loop {
404            // Wait for message or shutdown signal
405            tokio::select! {
406                msg_result = stream.next() => {
407                    match msg_result {
408                        Some(msg) => {
409                            // Get payload
410                            let payload: String = match msg.get_payload() {
411                                Ok(p) => p,
412                                Err(e) => {
413                                    eprintln!("⚠️  Failed to get message payload: {}", e);
414                                    stats.processing_errors.fetch_add(1, Ordering::Relaxed);
415                                    continue;
416                                }
417                            };
418
419                            // Deserialize message
420                            let invalidation_msg = match InvalidationMessage::from_json(&payload) {
421                                Ok(m) => m,
422                                Err(e) => {
423                                    eprintln!("⚠️  Failed to deserialize invalidation message: {}", e);
424                                    stats.processing_errors.fetch_add(1, Ordering::Relaxed);
425                                    continue;
426                                }
427                            };
428
429                            // Update stats
430                            stats.messages_received.fetch_add(1, Ordering::Relaxed);
431                            match &invalidation_msg {
432                                InvalidationMessage::Remove { .. } => {
433                                    stats.removes_received.fetch_add(1, Ordering::Relaxed);
434                                }
435                                InvalidationMessage::Update { .. } => {
436                                    stats.updates_received.fetch_add(1, Ordering::Relaxed);
437                                }
438                                InvalidationMessage::RemovePattern { .. } => {
439                                    stats.patterns_received.fetch_add(1, Ordering::Relaxed);
440                                }
441                                InvalidationMessage::RemoveBulk { .. } => {
442                                    stats.bulk_removes_received.fetch_add(1, Ordering::Relaxed);
443                                }
444                            }
445
446                            // Call handler
447                            if let Err(e) = handler(invalidation_msg).await {
448                                eprintln!("⚠️  Invalidation handler error: {}", e);
449                                stats.processing_errors.fetch_add(1, Ordering::Relaxed);
450                            }
451                        }
452                        None => {
453                            // Stream ended
454                            return Err(anyhow::anyhow!("Pub/Sub message stream ended"));
455                        }
456                    }
457                }
458                _ = shutdown_rx.recv() => {
459                    return Ok(());
460                }
461            }
462        }
463    }
464
465    /// Signal the subscriber to shutdown
466    pub fn shutdown(&self) {
467        let _ = self.shutdown_tx.send(());
468    }
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474
475    #[test]
476    fn test_invalidation_message_serialization() {
477        // Test Remove
478        let msg = InvalidationMessage::remove("test_key");
479        let json = msg.to_json().unwrap();
480        let parsed = InvalidationMessage::from_json(&json).unwrap();
481        match parsed {
482            InvalidationMessage::Remove { key } => assert_eq!(key, "test_key"),
483            _ => panic!("Wrong message type"),
484        }
485
486        // Test Update
487        let msg = InvalidationMessage::update(
488            "test_key",
489            serde_json::json!({"value": 123}),
490            Some(Duration::from_secs(300)),
491        );
492        let json = msg.to_json().unwrap();
493        let parsed = InvalidationMessage::from_json(&json).unwrap();
494        match parsed {
495            InvalidationMessage::Update { key, value, ttl_secs } => {
496                assert_eq!(key, "test_key");
497                assert_eq!(value, serde_json::json!({"value": 123}));
498                assert_eq!(ttl_secs, Some(300));
499            }
500            _ => panic!("Wrong message type"),
501        }
502
503        // Test RemovePattern
504        let msg = InvalidationMessage::remove_pattern("user:*");
505        let json = msg.to_json().unwrap();
506        let parsed = InvalidationMessage::from_json(&json).unwrap();
507        match parsed {
508            InvalidationMessage::RemovePattern { pattern } => assert_eq!(pattern, "user:*"),
509            _ => panic!("Wrong message type"),
510        }
511
512        // Test RemoveBulk
513        let msg = InvalidationMessage::remove_bulk(vec!["key1".to_string(), "key2".to_string()]);
514        let json = msg.to_json().unwrap();
515        let parsed = InvalidationMessage::from_json(&json).unwrap();
516        match parsed {
517            InvalidationMessage::RemoveBulk { keys } => assert_eq!(keys, vec!["key1", "key2"]),
518            _ => panic!("Wrong message type"),
519        }
520    }
521
522    #[test]
523    fn test_invalidation_config_default() {
524        let config = InvalidationConfig::default();
525        assert_eq!(config.channel, "cache:invalidate");
526        assert_eq!(config.auto_broadcast_on_write, false);
527        assert_eq!(config.enable_audit_stream, false);
528    }
529}