multi_tier_cache/
invalidation.rs

1//! Cache invalidation and synchronization module
2//!
3//! This module provides cross-instance cache invalidation using Redis Pub/Sub.
4//! It supports both cache removal (invalidation) and cache updates (refresh).
5
6use anyhow::{Context, Result};
7use redis::AsyncCommands;
8use serde::{Deserialize, Serialize};
9use std::time::{Duration, SystemTime, UNIX_EPOCH};
10use futures_util::StreamExt;
11
12/// Invalidation message types sent across cache instances via Redis Pub/Sub
13#[derive(Debug, Clone, Serialize, Deserialize)]
14#[serde(tag = "type")]
15pub enum InvalidationMessage {
16    /// Remove a single key from all cache instances
17    Remove {
18        key: String,
19    },
20
21    /// Update a key with new value across all cache instances
22    /// This is more efficient than Remove for hot keys as it avoids cache miss
23    Update {
24        key: String,
25        value: serde_json::Value,
26        #[serde(skip_serializing_if = "Option::is_none")]
27        ttl_secs: Option<u64>,
28    },
29
30    /// Remove all keys matching a pattern from all cache instances
31    /// Uses glob-style patterns (e.g., "user:*", "product:123:*")
32    RemovePattern {
33        pattern: String,
34    },
35
36    /// Bulk remove multiple keys at once
37    RemoveBulk {
38        keys: Vec<String>,
39    },
40}
41
42impl InvalidationMessage {
43    /// Create a Remove message
44    pub fn remove(key: impl Into<String>) -> Self {
45        Self::Remove { key: key.into() }
46    }
47
48    /// Create an Update message
49    pub fn update(key: impl Into<String>, value: serde_json::Value, ttl: Option<Duration>) -> Self {
50        Self::Update {
51            key: key.into(),
52            value,
53            ttl_secs: ttl.map(|d| d.as_secs()),
54        }
55    }
56
57    /// Create a RemovePattern message
58    pub fn remove_pattern(pattern: impl Into<String>) -> Self {
59        Self::RemovePattern {
60            pattern: pattern.into(),
61        }
62    }
63
64    /// Create a RemoveBulk message
65    pub fn remove_bulk(keys: Vec<String>) -> Self {
66        Self::RemoveBulk { keys }
67    }
68
69    /// Serialize to JSON for transmission
70    pub fn to_json(&self) -> Result<String> {
71        serde_json::to_string(self).context("Failed to serialize invalidation message")
72    }
73
74    /// Deserialize from JSON
75    pub fn from_json(json: &str) -> Result<Self> {
76        serde_json::from_str(json).context("Failed to deserialize invalidation message")
77    }
78
79    /// Get TTL as Duration if present
80    pub fn ttl(&self) -> Option<Duration> {
81        match self {
82            Self::Update { ttl_secs, .. } => ttl_secs.map(Duration::from_secs),
83            _ => None,
84        }
85    }
86}
87
88/// Configuration for cache invalidation
89#[derive(Debug, Clone)]
90pub struct InvalidationConfig {
91    /// Redis Pub/Sub channel name for invalidation messages
92    pub channel: String,
93
94    /// Whether to automatically broadcast invalidation on writes
95    pub auto_broadcast_on_write: bool,
96
97    /// Whether to also publish invalidation events to Redis Streams for audit
98    pub enable_audit_stream: bool,
99
100    /// Redis Stream name for invalidation audit trail
101    pub audit_stream: String,
102
103    /// Maximum length of audit stream (older entries are trimmed)
104    pub audit_stream_maxlen: Option<usize>,
105}
106
107impl Default for InvalidationConfig {
108    fn default() -> Self {
109        Self {
110            channel: "cache:invalidate".to_string(),
111            auto_broadcast_on_write: false, // Conservative default
112            enable_audit_stream: false,
113            audit_stream: "cache:invalidations".to_string(),
114            audit_stream_maxlen: Some(10000),
115        }
116    }
117}
118
119/// Handle for sending invalidation messages
120pub struct InvalidationPublisher {
121    connection: redis::aio::ConnectionManager,
122    config: InvalidationConfig,
123}
124
125impl InvalidationPublisher {
126    /// Create a new publisher
127    pub fn new(connection: redis::aio::ConnectionManager, config: InvalidationConfig) -> Self {
128        Self { connection, config }
129    }
130
131    /// Publish an invalidation message to all subscribers
132    pub async fn publish(&mut self, message: &InvalidationMessage) -> Result<()> {
133        let json = message.to_json()?;
134
135        // Publish to Pub/Sub channel
136        let _: () = self
137            .connection
138            .publish(&self.config.channel, &json)
139            .await
140            .context("Failed to publish invalidation message")?;
141
142        // Optionally publish to audit stream
143        if self.config.enable_audit_stream {
144            if let Err(e) = self.publish_to_audit_stream(message).await {
145                // Don't fail the invalidation if audit logging fails
146                eprintln!("Warning: Failed to publish to audit stream: {}", e);
147            }
148        }
149
150        Ok(())
151    }
152
153    /// Publish to audit stream for observability
154    async fn publish_to_audit_stream(&mut self, message: &InvalidationMessage) -> Result<()> {
155        let timestamp = SystemTime::now()
156            .duration_since(UNIX_EPOCH)
157            .unwrap()
158            .as_secs()
159            .to_string();
160
161        let (type_str, key_str, extra_str) = match message {
162            InvalidationMessage::Remove { key } => {
163                ("remove".to_string(), key.clone(), String::new())
164            }
165            InvalidationMessage::Update { key, .. } => {
166                ("update".to_string(), key.clone(), String::new())
167            }
168            InvalidationMessage::RemovePattern { pattern } => {
169                ("remove_pattern".to_string(), pattern.clone(), String::new())
170            }
171            InvalidationMessage::RemoveBulk { keys } => {
172                ("remove_bulk".to_string(), String::new(), keys.len().to_string())
173            }
174        };
175
176        let mut fields = vec![
177            ("type", type_str.as_str()),
178            ("timestamp", timestamp.as_str()),
179        ];
180
181        if !key_str.is_empty() {
182            fields.push(("key", key_str.as_str()));
183        }
184        if !extra_str.is_empty() {
185            fields.push(("count", extra_str.as_str()));
186        }
187
188        let mut cmd = redis::cmd("XADD");
189        cmd.arg(&self.config.audit_stream);
190
191        if let Some(maxlen) = self.config.audit_stream_maxlen {
192            cmd.arg("MAXLEN").arg("~").arg(maxlen);
193        }
194
195        cmd.arg("*"); // Auto-generate ID
196
197        for (key, value) in fields {
198            cmd.arg(key).arg(value);
199        }
200
201        let _: String = cmd
202            .query_async(&mut self.connection)
203            .await
204            .context("Failed to add to audit stream")?;
205
206        Ok(())
207    }
208}
209
210/// Statistics for invalidation operations
211#[derive(Debug, Default, Clone)]
212pub struct InvalidationStats {
213    /// Number of invalidation messages published
214    pub messages_sent: u64,
215
216    /// Number of invalidation messages received
217    pub messages_received: u64,
218
219    /// Number of Remove operations performed
220    pub removes_received: u64,
221
222    /// Number of Update operations performed
223    pub updates_received: u64,
224
225    /// Number of RemovePattern operations performed
226    pub patterns_received: u64,
227
228    /// Number of RemoveBulk operations performed
229    pub bulk_removes_received: u64,
230
231    /// Number of failed message processing attempts
232    pub processing_errors: u64,
233}
234
235use std::sync::atomic::{AtomicU64, Ordering};
236
237/// Thread-safe statistics for invalidation operations
238#[derive(Debug, Default)]
239pub struct AtomicInvalidationStats {
240    pub messages_sent: AtomicU64,
241    pub messages_received: AtomicU64,
242    pub removes_received: AtomicU64,
243    pub updates_received: AtomicU64,
244    pub patterns_received: AtomicU64,
245    pub bulk_removes_received: AtomicU64,
246    pub processing_errors: AtomicU64,
247}
248
249impl AtomicInvalidationStats {
250    pub fn snapshot(&self) -> InvalidationStats {
251        InvalidationStats {
252            messages_sent: self.messages_sent.load(Ordering::Relaxed),
253            messages_received: self.messages_received.load(Ordering::Relaxed),
254            removes_received: self.removes_received.load(Ordering::Relaxed),
255            updates_received: self.updates_received.load(Ordering::Relaxed),
256            patterns_received: self.patterns_received.load(Ordering::Relaxed),
257            bulk_removes_received: self.bulk_removes_received.load(Ordering::Relaxed),
258            processing_errors: self.processing_errors.load(Ordering::Relaxed),
259        }
260    }
261}
262
263use std::sync::Arc;
264use tokio::sync::broadcast;
265
266/// Handle for subscribing to invalidation messages
267///
268/// This spawns a background task that listens to Redis Pub/Sub and processes
269/// invalidation messages by calling the provided handler callback.
270pub struct InvalidationSubscriber {
271    /// Redis client for creating Pub/Sub connections
272    client: redis::Client,
273    /// Configuration
274    config: InvalidationConfig,
275    /// Statistics
276    stats: Arc<AtomicInvalidationStats>,
277    /// Shutdown signal sender
278    shutdown_tx: broadcast::Sender<()>,
279}
280
281impl InvalidationSubscriber {
282    /// Create a new subscriber
283    ///
284    /// # Arguments
285    /// * `redis_url` - Redis connection URL
286    /// * `config` - Invalidation configuration
287    pub fn new(redis_url: &str, config: InvalidationConfig) -> Result<Self> {
288        let client = redis::Client::open(redis_url)
289            .context("Failed to create Redis client for subscriber")?;
290
291        let (shutdown_tx, _) = broadcast::channel(1);
292
293        Ok(Self {
294            client,
295            config,
296            stats: Arc::new(AtomicInvalidationStats::default()),
297            shutdown_tx,
298        })
299    }
300
301    /// Get a snapshot of current statistics
302    pub fn stats(&self) -> InvalidationStats {
303        self.stats.snapshot()
304    }
305
306    /// Start the subscriber background task
307    ///
308    /// # Arguments
309    /// * `handler` - Async function to handle each invalidation message
310    ///
311    /// # Returns
312    /// Join handle for the background task
313    pub fn start<F, Fut>(
314        &self,
315        handler: F,
316    ) -> tokio::task::JoinHandle<()>
317    where
318        F: Fn(InvalidationMessage) -> Fut + Send + Sync + 'static,
319        Fut: std::future::Future<Output = Result<()>> + Send + 'static,
320    {
321        let client = self.client.clone();
322        let channel = self.config.channel.clone();
323        let stats = Arc::clone(&self.stats);
324        let mut shutdown_rx = self.shutdown_tx.subscribe();
325
326        tokio::spawn(async move {
327            let handler = Arc::new(handler);
328
329            loop {
330                // Check for shutdown signal
331                if shutdown_rx.try_recv().is_ok() {
332                    println!("🛑 Invalidation subscriber shutting down...");
333                    break;
334                }
335
336                // Attempt to connect and subscribe
337                match Self::run_subscriber_loop(
338                    &client,
339                    &channel,
340                    Arc::clone(&handler),
341                    Arc::clone(&stats),
342                    &mut shutdown_rx,
343                ).await {
344                    Ok(_) => {
345                        println!("✅ Invalidation subscriber loop completed normally");
346                        break;
347                    }
348                    Err(e) => {
349                        eprintln!("⚠️  Invalidation subscriber error: {}. Reconnecting in 5s...", e);
350                        stats.processing_errors.fetch_add(1, Ordering::Relaxed);
351
352                        // Wait before reconnecting
353                        tokio::select! {
354                            _ = tokio::time::sleep(Duration::from_secs(5)) => {},
355                            _ = shutdown_rx.recv() => {
356                                println!("🛑 Invalidation subscriber shutting down...");
357                                break;
358                            }
359                        }
360                    }
361                }
362            }
363        })
364    }
365
366    /// Internal subscriber loop
367    async fn run_subscriber_loop<F, Fut>(
368        client: &redis::Client,
369        channel: &str,
370        handler: Arc<F>,
371        stats: Arc<AtomicInvalidationStats>,
372        shutdown_rx: &mut broadcast::Receiver<()>,
373    ) -> Result<()>
374    where
375        F: Fn(InvalidationMessage) -> Fut + Send + 'static,
376        Fut: std::future::Future<Output = Result<()>> + Send + 'static,
377    {
378        // Get Pub/Sub connection
379        let mut pubsub = client.get_async_pubsub().await
380            .context("Failed to get pubsub connection")?;
381
382        // Subscribe to channel
383        pubsub.subscribe(channel).await
384            .context("Failed to subscribe to channel")?;
385
386        println!("📡 Subscribed to invalidation channel: {}", channel);
387
388        // Get message stream
389        let mut stream = pubsub.on_message();
390
391        loop {
392            // Wait for message or shutdown signal
393            tokio::select! {
394                msg_result = stream.next() => {
395                    match msg_result {
396                        Some(msg) => {
397                            // Get payload
398                            let payload: String = match msg.get_payload() {
399                                Ok(p) => p,
400                                Err(e) => {
401                                    eprintln!("⚠️  Failed to get message payload: {}", e);
402                                    stats.processing_errors.fetch_add(1, Ordering::Relaxed);
403                                    continue;
404                                }
405                            };
406
407                            // Deserialize message
408                            let invalidation_msg = match InvalidationMessage::from_json(&payload) {
409                                Ok(m) => m,
410                                Err(e) => {
411                                    eprintln!("⚠️  Failed to deserialize invalidation message: {}", e);
412                                    stats.processing_errors.fetch_add(1, Ordering::Relaxed);
413                                    continue;
414                                }
415                            };
416
417                            // Update stats
418                            stats.messages_received.fetch_add(1, Ordering::Relaxed);
419                            match &invalidation_msg {
420                                InvalidationMessage::Remove { .. } => {
421                                    stats.removes_received.fetch_add(1, Ordering::Relaxed);
422                                }
423                                InvalidationMessage::Update { .. } => {
424                                    stats.updates_received.fetch_add(1, Ordering::Relaxed);
425                                }
426                                InvalidationMessage::RemovePattern { .. } => {
427                                    stats.patterns_received.fetch_add(1, Ordering::Relaxed);
428                                }
429                                InvalidationMessage::RemoveBulk { .. } => {
430                                    stats.bulk_removes_received.fetch_add(1, Ordering::Relaxed);
431                                }
432                            }
433
434                            // Call handler
435                            if let Err(e) = handler(invalidation_msg).await {
436                                eprintln!("⚠️  Invalidation handler error: {}", e);
437                                stats.processing_errors.fetch_add(1, Ordering::Relaxed);
438                            }
439                        }
440                        None => {
441                            // Stream ended
442                            return Err(anyhow::anyhow!("Pub/Sub message stream ended"));
443                        }
444                    }
445                }
446                _ = shutdown_rx.recv() => {
447                    return Ok(());
448                }
449            }
450        }
451    }
452
453    /// Signal the subscriber to shutdown
454    pub fn shutdown(&self) {
455        let _ = self.shutdown_tx.send(());
456    }
457}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462
463    #[test]
464    fn test_invalidation_message_serialization() {
465        // Test Remove
466        let msg = InvalidationMessage::remove("test_key");
467        let json = msg.to_json().unwrap();
468        let parsed = InvalidationMessage::from_json(&json).unwrap();
469        match parsed {
470            InvalidationMessage::Remove { key } => assert_eq!(key, "test_key"),
471            _ => panic!("Wrong message type"),
472        }
473
474        // Test Update
475        let msg = InvalidationMessage::update(
476            "test_key",
477            serde_json::json!({"value": 123}),
478            Some(Duration::from_secs(300)),
479        );
480        let json = msg.to_json().unwrap();
481        let parsed = InvalidationMessage::from_json(&json).unwrap();
482        match parsed {
483            InvalidationMessage::Update { key, value, ttl_secs } => {
484                assert_eq!(key, "test_key");
485                assert_eq!(value, serde_json::json!({"value": 123}));
486                assert_eq!(ttl_secs, Some(300));
487            }
488            _ => panic!("Wrong message type"),
489        }
490
491        // Test RemovePattern
492        let msg = InvalidationMessage::remove_pattern("user:*");
493        let json = msg.to_json().unwrap();
494        let parsed = InvalidationMessage::from_json(&json).unwrap();
495        match parsed {
496            InvalidationMessage::RemovePattern { pattern } => assert_eq!(pattern, "user:*"),
497            _ => panic!("Wrong message type"),
498        }
499
500        // Test RemoveBulk
501        let msg = InvalidationMessage::remove_bulk(vec!["key1".to_string(), "key2".to_string()]);
502        let json = msg.to_json().unwrap();
503        let parsed = InvalidationMessage::from_json(&json).unwrap();
504        match parsed {
505            InvalidationMessage::RemoveBulk { keys } => assert_eq!(keys, vec!["key1", "key2"]),
506            _ => panic!("Wrong message type"),
507        }
508    }
509
510    #[test]
511    fn test_invalidation_config_default() {
512        let config = InvalidationConfig::default();
513        assert_eq!(config.channel, "cache:invalidate");
514        assert_eq!(config.auto_broadcast_on_write, false);
515        assert_eq!(config.enable_audit_stream, false);
516    }
517}