multi_tier_cache/
invalidation.rs

1//! Cache invalidation and synchronization module
2//!
3//! This module provides cross-instance cache invalidation using Redis Pub/Sub.
4//! It supports both cache removal (invalidation) and cache updates (refresh).
5
6use anyhow::{Context, Result};
7use redis::AsyncCommands;
8use serde::{Deserialize, Serialize};
9use std::time::{Duration, SystemTime, UNIX_EPOCH};
10use futures_util::StreamExt;
11use tracing::{info, warn, error};
12
13/// Invalidation message types sent across cache instances via Redis Pub/Sub
14#[derive(Debug, Clone, Serialize, Deserialize)]
15#[serde(tag = "type")]
16pub enum InvalidationMessage {
17    /// Remove a single key from all cache instances
18    Remove {
19        key: String,
20    },
21
22    /// Update a key with new value across all cache instances
23    /// This is more efficient than Remove for hot keys as it avoids cache miss
24    Update {
25        key: String,
26        value: serde_json::Value,
27        #[serde(skip_serializing_if = "Option::is_none")]
28        ttl_secs: Option<u64>,
29    },
30
31    /// Remove all keys matching a pattern from all cache instances
32    /// Uses glob-style patterns (e.g., "user:*", "product:123:*")
33    RemovePattern {
34        pattern: String,
35    },
36
37    /// Bulk remove multiple keys at once
38    RemoveBulk {
39        keys: Vec<String>,
40    },
41}
42
43impl InvalidationMessage {
44    /// Create a Remove message
45    pub fn remove(key: impl Into<String>) -> Self {
46        Self::Remove { key: key.into() }
47    }
48
49    /// Create an Update message
50    pub fn update(key: impl Into<String>, value: serde_json::Value, ttl: Option<Duration>) -> Self {
51        Self::Update {
52            key: key.into(),
53            value,
54            ttl_secs: ttl.map(|d| d.as_secs()),
55        }
56    }
57
58    /// Create a RemovePattern message
59    pub fn remove_pattern(pattern: impl Into<String>) -> Self {
60        Self::RemovePattern {
61            pattern: pattern.into(),
62        }
63    }
64
65    /// Create a RemoveBulk message
66    pub fn remove_bulk(keys: Vec<String>) -> Self {
67        Self::RemoveBulk { keys }
68    }
69
70    /// Serialize to JSON for transmission
71    pub fn to_json(&self) -> Result<String> {
72        serde_json::to_string(self).context("Failed to serialize invalidation message")
73    }
74
75    /// Deserialize from JSON
76    pub fn from_json(json: &str) -> Result<Self> {
77        serde_json::from_str(json).context("Failed to deserialize invalidation message")
78    }
79
80    /// Get TTL as Duration if present
81    pub fn ttl(&self) -> Option<Duration> {
82        match self {
83            Self::Update { ttl_secs, .. } => ttl_secs.map(Duration::from_secs),
84            _ => None,
85        }
86    }
87}
88
89/// Configuration for cache invalidation
90#[derive(Debug, Clone)]
91pub struct InvalidationConfig {
92    /// Redis Pub/Sub channel name for invalidation messages
93    pub channel: String,
94
95    /// Whether to automatically broadcast invalidation on writes
96    pub auto_broadcast_on_write: bool,
97
98    /// Whether to also publish invalidation events to Redis Streams for audit
99    pub enable_audit_stream: bool,
100
101    /// Redis Stream name for invalidation audit trail
102    pub audit_stream: String,
103
104    /// Maximum length of audit stream (older entries are trimmed)
105    pub audit_stream_maxlen: Option<usize>,
106}
107
108impl Default for InvalidationConfig {
109    fn default() -> Self {
110        Self {
111            channel: "cache:invalidate".to_string(),
112            auto_broadcast_on_write: false, // Conservative default
113            enable_audit_stream: false,
114            audit_stream: "cache:invalidations".to_string(),
115            audit_stream_maxlen: Some(10000),
116        }
117    }
118}
119
120/// Handle for sending invalidation messages
121pub struct InvalidationPublisher {
122    connection: redis::aio::ConnectionManager,
123    config: InvalidationConfig,
124}
125
126impl InvalidationPublisher {
127    /// Create a new publisher
128    pub fn new(connection: redis::aio::ConnectionManager, config: InvalidationConfig) -> Self {
129        Self { connection, config }
130    }
131
132    /// Publish an invalidation message to all subscribers
133    pub async fn publish(&mut self, message: &InvalidationMessage) -> Result<()> {
134        let json = message.to_json()?;
135
136        // Publish to Pub/Sub channel
137        let _: () = self
138            .connection
139            .publish(&self.config.channel, &json)
140            .await
141            .context("Failed to publish invalidation message")?;
142
143        // Optionally publish to audit stream
144        if self.config.enable_audit_stream {
145            if let Err(e) = self.publish_to_audit_stream(message).await {
146                // Don't fail the invalidation if audit logging fails
147                warn!("Failed to publish to audit stream: {}", e);
148            }
149        }
150
151        Ok(())
152    }
153
154    /// Publish to audit stream for observability
155    async fn publish_to_audit_stream(&mut self, message: &InvalidationMessage) -> Result<()> {
156        let timestamp = SystemTime::now()
157            .duration_since(UNIX_EPOCH)
158            .unwrap_or(Duration::ZERO)
159            .as_secs()
160            .to_string();
161
162        // Use &str to avoid unnecessary allocations
163        let (type_str, key_str): (&str, &str);
164        let extra_str: String;
165
166        match message {
167            InvalidationMessage::Remove { key } => {
168                type_str = "remove";
169                key_str = key.as_str();
170                extra_str = String::new();
171            }
172            InvalidationMessage::Update { key, .. } => {
173                type_str = "update";
174                key_str = key.as_str();
175                extra_str = String::new();
176            }
177            InvalidationMessage::RemovePattern { pattern } => {
178                type_str = "remove_pattern";
179                key_str = pattern.as_str();
180                extra_str = String::new();
181            }
182            InvalidationMessage::RemoveBulk { keys } => {
183                type_str = "remove_bulk";
184                key_str = "";
185                extra_str = keys.len().to_string();
186            }
187        }
188
189        let mut fields = vec![
190            ("type", type_str),
191            ("timestamp", timestamp.as_str()),
192        ];
193
194        if !key_str.is_empty() {
195            fields.push(("key", key_str));
196        }
197        if !extra_str.is_empty() {
198            fields.push(("count", extra_str.as_str()));
199        }
200
201        let mut cmd = redis::cmd("XADD");
202        cmd.arg(&self.config.audit_stream);
203
204        if let Some(maxlen) = self.config.audit_stream_maxlen {
205            cmd.arg("MAXLEN").arg("~").arg(maxlen);
206        }
207
208        cmd.arg("*"); // Auto-generate ID
209
210        for (key, value) in fields {
211            cmd.arg(key).arg(value);
212        }
213
214        let _: String = cmd
215            .query_async(&mut self.connection)
216            .await
217            .context("Failed to add to audit stream")?;
218
219        Ok(())
220    }
221}
222
223/// Statistics for invalidation operations
224#[derive(Debug, Default, Clone)]
225pub struct InvalidationStats {
226    /// Number of invalidation messages published
227    pub messages_sent: u64,
228
229    /// Number of invalidation messages received
230    pub messages_received: u64,
231
232    /// Number of Remove operations performed
233    pub removes_received: u64,
234
235    /// Number of Update operations performed
236    pub updates_received: u64,
237
238    /// Number of RemovePattern operations performed
239    pub patterns_received: u64,
240
241    /// Number of RemoveBulk operations performed
242    pub bulk_removes_received: u64,
243
244    /// Number of failed message processing attempts
245    pub processing_errors: u64,
246}
247
248use std::sync::atomic::{AtomicU64, Ordering};
249
250/// Thread-safe statistics for invalidation operations
251#[derive(Debug, Default)]
252pub struct AtomicInvalidationStats {
253    pub messages_sent: AtomicU64,
254    pub messages_received: AtomicU64,
255    pub removes_received: AtomicU64,
256    pub updates_received: AtomicU64,
257    pub patterns_received: AtomicU64,
258    pub bulk_removes_received: AtomicU64,
259    pub processing_errors: AtomicU64,
260}
261
262impl AtomicInvalidationStats {
263    pub fn snapshot(&self) -> InvalidationStats {
264        InvalidationStats {
265            messages_sent: self.messages_sent.load(Ordering::Relaxed),
266            messages_received: self.messages_received.load(Ordering::Relaxed),
267            removes_received: self.removes_received.load(Ordering::Relaxed),
268            updates_received: self.updates_received.load(Ordering::Relaxed),
269            patterns_received: self.patterns_received.load(Ordering::Relaxed),
270            bulk_removes_received: self.bulk_removes_received.load(Ordering::Relaxed),
271            processing_errors: self.processing_errors.load(Ordering::Relaxed),
272        }
273    }
274}
275
276use std::sync::Arc;
277use tokio::sync::broadcast;
278
279/// Handle for subscribing to invalidation messages
280///
281/// This spawns a background task that listens to Redis Pub/Sub and processes
282/// invalidation messages by calling the provided handler callback.
283pub struct InvalidationSubscriber {
284    /// Redis client for creating Pub/Sub connections
285    client: redis::Client,
286    /// Configuration
287    config: InvalidationConfig,
288    /// Statistics
289    stats: Arc<AtomicInvalidationStats>,
290    /// Shutdown signal sender
291    shutdown_tx: broadcast::Sender<()>,
292}
293
294impl InvalidationSubscriber {
295    /// Create a new subscriber
296    ///
297    /// # Arguments
298    /// * `redis_url` - Redis connection URL
299    /// * `config` - Invalidation configuration
300    pub fn new(redis_url: &str, config: InvalidationConfig) -> Result<Self> {
301        let client = redis::Client::open(redis_url)
302            .context("Failed to create Redis client for subscriber")?;
303
304        let (shutdown_tx, _) = broadcast::channel(1);
305
306        Ok(Self {
307            client,
308            config,
309            stats: Arc::new(AtomicInvalidationStats::default()),
310            shutdown_tx,
311        })
312    }
313
314    /// Get a snapshot of current statistics
315    pub fn stats(&self) -> InvalidationStats {
316        self.stats.snapshot()
317    }
318
319    /// Start the subscriber background task
320    ///
321    /// # Arguments
322    /// * `handler` - Async function to handle each invalidation message
323    ///
324    /// # Returns
325    /// Join handle for the background task
326    pub fn start<F, Fut>(
327        &self,
328        handler: F,
329    ) -> tokio::task::JoinHandle<()>
330    where
331        F: Fn(InvalidationMessage) -> Fut + Send + Sync + 'static,
332        Fut: std::future::Future<Output = Result<()>> + Send + 'static,
333    {
334        let client = self.client.clone();
335        let channel = self.config.channel.clone();
336        let stats = Arc::clone(&self.stats);
337        let mut shutdown_rx = self.shutdown_tx.subscribe();
338
339        tokio::spawn(async move {
340            let handler = Arc::new(handler);
341
342            loop {
343                // Check for shutdown signal
344                if shutdown_rx.try_recv().is_ok() {
345                    info!("Invalidation subscriber shutting down...");
346                    break;
347                }
348
349                // Attempt to connect and subscribe
350                match Self::run_subscriber_loop(
351                    &client,
352                    &channel,
353                    Arc::clone(&handler),
354                    Arc::clone(&stats),
355                    &mut shutdown_rx,
356                ).await {
357                    Ok(_) => {
358                        info!("Invalidation subscriber loop completed normally");
359                        break;
360                    }
361                    Err(e) => {
362                        error!("Invalidation subscriber error: {}. Reconnecting in 5s...", e);
363                        stats.processing_errors.fetch_add(1, Ordering::Relaxed);
364
365                        // Wait before reconnecting
366                        tokio::select! {
367                            _ = tokio::time::sleep(Duration::from_secs(5)) => {},
368                            _ = shutdown_rx.recv() => {
369                                info!("Invalidation subscriber shutting down...");
370                                break;
371                            }
372                        }
373                    }
374                }
375            }
376        })
377    }
378
379    /// Internal subscriber loop
380    async fn run_subscriber_loop<F, Fut>(
381        client: &redis::Client,
382        channel: &str,
383        handler: Arc<F>,
384        stats: Arc<AtomicInvalidationStats>,
385        shutdown_rx: &mut broadcast::Receiver<()>,
386    ) -> Result<()>
387    where
388        F: Fn(InvalidationMessage) -> Fut + Send + 'static,
389        Fut: std::future::Future<Output = Result<()>> + Send + 'static,
390    {
391        // Get Pub/Sub connection
392        let mut pubsub = client.get_async_pubsub().await
393            .context("Failed to get pubsub connection")?;
394
395        // Subscribe to channel
396        pubsub.subscribe(channel).await
397            .context("Failed to subscribe to channel")?;
398
399        info!("Subscribed to invalidation channel: {}", channel);
400
401        // Get message stream
402        let mut stream = pubsub.on_message();
403
404        loop {
405            // Wait for message or shutdown signal
406            tokio::select! {
407                msg_result = stream.next() => {
408                    match msg_result {
409                        Some(msg) => {
410                            // Get payload
411                            let payload: String = match msg.get_payload() {
412                                Ok(p) => p,
413                                Err(e) => {
414                                    warn!("Failed to get message payload: {}", e);
415                                    stats.processing_errors.fetch_add(1, Ordering::Relaxed);
416                                    continue;
417                                }
418                            };
419
420                            // Deserialize message
421                            let invalidation_msg = match InvalidationMessage::from_json(&payload) {
422                                Ok(m) => m,
423                                Err(e) => {
424                                    warn!("Failed to deserialize invalidation message: {}", e);
425                                    stats.processing_errors.fetch_add(1, Ordering::Relaxed);
426                                    continue;
427                                }
428                            };
429
430                            // Update stats
431                            stats.messages_received.fetch_add(1, Ordering::Relaxed);
432                            match &invalidation_msg {
433                                InvalidationMessage::Remove { .. } => {
434                                    stats.removes_received.fetch_add(1, Ordering::Relaxed);
435                                }
436                                InvalidationMessage::Update { .. } => {
437                                    stats.updates_received.fetch_add(1, Ordering::Relaxed);
438                                }
439                                InvalidationMessage::RemovePattern { .. } => {
440                                    stats.patterns_received.fetch_add(1, Ordering::Relaxed);
441                                }
442                                InvalidationMessage::RemoveBulk { .. } => {
443                                    stats.bulk_removes_received.fetch_add(1, Ordering::Relaxed);
444                                }
445                            }
446
447                            // Call handler
448                            if let Err(e) = handler(invalidation_msg).await {
449                                error!("Invalidation handler error: {}", e);
450                                stats.processing_errors.fetch_add(1, Ordering::Relaxed);
451                            }
452                        }
453                        None => {
454                            // Stream ended
455                            return Err(anyhow::anyhow!("Pub/Sub message stream ended"));
456                        }
457                    }
458                }
459                _ = shutdown_rx.recv() => {
460                    return Ok(());
461                }
462            }
463        }
464    }
465
466    /// Signal the subscriber to shutdown
467    pub fn shutdown(&self) {
468        let _ = self.shutdown_tx.send(());
469    }
470}
471
472#[cfg(test)]
473mod tests {
474    use super::*;
475
476    #[test]
477    fn test_invalidation_message_serialization() {
478        // Test Remove
479        let msg = InvalidationMessage::remove("test_key");
480        let json = msg.to_json().unwrap();
481        let parsed = InvalidationMessage::from_json(&json).unwrap();
482        match parsed {
483            InvalidationMessage::Remove { key } => assert_eq!(key, "test_key"),
484            _ => panic!("Wrong message type"),
485        }
486
487        // Test Update
488        let msg = InvalidationMessage::update(
489            "test_key",
490            serde_json::json!({"value": 123}),
491            Some(Duration::from_secs(300)),
492        );
493        let json = msg.to_json().unwrap();
494        let parsed = InvalidationMessage::from_json(&json).unwrap();
495        match parsed {
496            InvalidationMessage::Update { key, value, ttl_secs } => {
497                assert_eq!(key, "test_key");
498                assert_eq!(value, serde_json::json!({"value": 123}));
499                assert_eq!(ttl_secs, Some(300));
500            }
501            _ => panic!("Wrong message type"),
502        }
503
504        // Test RemovePattern
505        let msg = InvalidationMessage::remove_pattern("user:*");
506        let json = msg.to_json().unwrap();
507        let parsed = InvalidationMessage::from_json(&json).unwrap();
508        match parsed {
509            InvalidationMessage::RemovePattern { pattern } => assert_eq!(pattern, "user:*"),
510            _ => panic!("Wrong message type"),
511        }
512
513        // Test RemoveBulk
514        let msg = InvalidationMessage::remove_bulk(vec!["key1".to_string(), "key2".to_string()]);
515        let json = msg.to_json().unwrap();
516        let parsed = InvalidationMessage::from_json(&json).unwrap();
517        match parsed {
518            InvalidationMessage::RemoveBulk { keys } => assert_eq!(keys, vec!["key1", "key2"]),
519            _ => panic!("Wrong message type"),
520        }
521    }
522
523    #[test]
524    fn test_invalidation_config_default() {
525        let config = InvalidationConfig::default();
526        assert_eq!(config.channel, "cache:invalidate");
527        assert_eq!(config.auto_broadcast_on_write, false);
528        assert_eq!(config.enable_audit_stream, false);
529    }
530}