multi_tier_cache/
invalidation.rs

1//! Cache invalidation and synchronization module
2//!
3//! This module provides cross-instance cache invalidation using Redis Pub/Sub.
4//! It supports both cache removal (invalidation) and cache updates (refresh).
5
6use anyhow::{Context, Result};
7use futures_util::StreamExt;
8use redis::AsyncCommands;
9use serde::{Deserialize, Serialize};
10use std::time::{Duration, SystemTime, UNIX_EPOCH};
11use tracing::{error, info, warn};
12
13/// Invalidation message types sent across cache instances via Redis Pub/Sub
14#[derive(Debug, Clone, Serialize, Deserialize)]
15#[serde(tag = "type")]
16pub enum InvalidationMessage {
17    /// Remove a single key from all cache instances
18    Remove { key: String },
19
20    /// Update a key with new value across all cache instances
21    /// This is more efficient than Remove for hot keys as it avoids cache miss
22    Update {
23        key: String,
24        value: serde_json::Value,
25        #[serde(skip_serializing_if = "Option::is_none")]
26        ttl_secs: Option<u64>,
27    },
28
29    /// Remove all keys matching a pattern from all cache instances
30    /// Uses glob-style patterns (e.g., "user:*", "product:123:*")
31    RemovePattern { pattern: String },
32
33    /// Bulk remove multiple keys at once
34    RemoveBulk { keys: Vec<String> },
35}
36
37impl InvalidationMessage {
38    /// Create a Remove message
39    pub fn remove(key: impl Into<String>) -> Self {
40        Self::Remove { key: key.into() }
41    }
42
43    /// Create an Update message
44    pub fn update(key: impl Into<String>, value: serde_json::Value, ttl: Option<Duration>) -> Self {
45        Self::Update {
46            key: key.into(),
47            value,
48            ttl_secs: ttl.map(|d| d.as_secs()),
49        }
50    }
51
52    /// Create a `RemovePattern` message
53    pub fn remove_pattern(pattern: impl Into<String>) -> Self {
54        Self::RemovePattern {
55            pattern: pattern.into(),
56        }
57    }
58
59    /// Create a `RemoveBulk` message
60    #[must_use]
61    pub fn remove_bulk(keys: Vec<String>) -> Self {
62        Self::RemoveBulk { keys }
63    }
64
65    /// Serialize to JSON for transmission
66    ///
67    /// # Errors
68    ///
69    /// Returns an error if serialization fails.
70    pub fn to_json(&self) -> Result<String> {
71        serde_json::to_string(self).context("Failed to serialize invalidation message")
72    }
73
74    /// Deserialize from JSON
75    ///
76    /// # Errors
77    ///
78    /// Returns an error if deserialization fails.
79    pub fn from_json(json: &str) -> Result<Self> {
80        serde_json::from_str(json).context("Failed to deserialize invalidation message")
81    }
82
83    /// Get TTL as Duration if present
84    pub fn ttl(&self) -> Option<Duration> {
85        match self {
86            Self::Update { ttl_secs, .. } => ttl_secs.map(Duration::from_secs),
87            _ => None,
88        }
89    }
90}
91
92/// Configuration for cache invalidation
93#[derive(Debug, Clone)]
94pub struct InvalidationConfig {
95    /// Redis Pub/Sub channel name for invalidation messages
96    pub channel: String,
97
98    /// Whether to automatically broadcast invalidation on writes
99    pub auto_broadcast_on_write: bool,
100
101    /// Whether to also publish invalidation events to Redis Streams for audit
102    pub enable_audit_stream: bool,
103
104    /// Redis Stream name for invalidation audit trail
105    pub audit_stream: String,
106
107    /// Maximum length of audit stream (older entries are trimmed)
108    pub audit_stream_maxlen: Option<usize>,
109}
110
111impl Default for InvalidationConfig {
112    fn default() -> Self {
113        Self {
114            channel: "cache:invalidate".to_string(),
115            auto_broadcast_on_write: false, // Conservative default
116            enable_audit_stream: false,
117            audit_stream: "cache:invalidations".to_string(),
118            audit_stream_maxlen: Some(10000),
119        }
120    }
121}
122
123/// Handle for sending invalidation messages
124pub struct InvalidationPublisher {
125    connection: redis::aio::ConnectionManager,
126    config: InvalidationConfig,
127}
128
129impl InvalidationPublisher {
130    /// Create a new publisher
131    #[must_use]
132    pub fn new(connection: redis::aio::ConnectionManager, config: InvalidationConfig) -> Self {
133        Self { connection, config }
134    }
135
136    /// Publish an invalidation message to all subscribers
137    ///
138    /// # Errors
139    ///
140    /// Returns an error if serialization or publishing fails.
141    pub async fn publish(&mut self, message: &InvalidationMessage) -> Result<()> {
142        let json = message.to_json()?;
143
144        // Publish to Pub/Sub channel
145        let _: () = self
146            .connection
147            .publish(&self.config.channel, &json)
148            .await
149            .context("Failed to publish invalidation message")?;
150
151        // Optionally publish to audit stream
152        if self.config.enable_audit_stream {
153            if let Err(e) = self.publish_to_audit_stream(message).await {
154                // Don't fail the invalidation if audit logging fails
155                warn!("Failed to publish to audit stream: {}", e);
156            }
157        }
158
159        Ok(())
160    }
161
162    /// Publish to audit stream for observability
163    async fn publish_to_audit_stream(&mut self, message: &InvalidationMessage) -> Result<()> {
164        let timestamp = SystemTime::now()
165            .duration_since(UNIX_EPOCH)
166            .unwrap_or(Duration::ZERO)
167            .as_secs()
168            .to_string();
169
170        // Use &str to avoid unnecessary allocations
171        let (type_str, key_str): (&str, &str);
172        let extra_str: String;
173
174        match message {
175            InvalidationMessage::Remove { key } => {
176                type_str = "remove";
177                key_str = key.as_str();
178                extra_str = String::new();
179            }
180            InvalidationMessage::Update { key, .. } => {
181                type_str = "update";
182                key_str = key.as_str();
183                extra_str = String::new();
184            }
185            InvalidationMessage::RemovePattern { pattern } => {
186                type_str = "remove_pattern";
187                key_str = pattern.as_str();
188                extra_str = String::new();
189            }
190            InvalidationMessage::RemoveBulk { keys } => {
191                type_str = "remove_bulk";
192                key_str = "";
193                extra_str = keys.len().to_string();
194            }
195        }
196
197        let mut fields = vec![("type", type_str), ("timestamp", timestamp.as_str())];
198
199        if !key_str.is_empty() {
200            fields.push(("key", key_str));
201        }
202        if !extra_str.is_empty() {
203            fields.push(("count", extra_str.as_str()));
204        }
205
206        let mut cmd = redis::cmd("XADD");
207        cmd.arg(&self.config.audit_stream);
208
209        if let Some(maxlen) = self.config.audit_stream_maxlen {
210            cmd.arg("MAXLEN").arg("~").arg(maxlen);
211        }
212
213        cmd.arg("*"); // Auto-generate ID
214
215        for (key, value) in fields {
216            cmd.arg(key).arg(value);
217        }
218
219        let _: String = cmd
220            .query_async(&mut self.connection)
221            .await
222            .context("Failed to add to audit stream")?;
223
224        Ok(())
225    }
226}
227
228/// Statistics for invalidation operations
229#[derive(Debug, Default, Clone)]
230pub struct InvalidationStats {
231    /// Number of invalidation messages published
232    pub messages_sent: u64,
233
234    /// Number of invalidation messages received
235    pub messages_received: u64,
236
237    /// Number of Remove operations performed
238    pub removes_received: u64,
239
240    /// Number of Update operations performed
241    pub updates_received: u64,
242
243    /// Number of `RemovePattern` operations performed
244    pub patterns_received: u64,
245
246    /// Number of `RemoveBulk` operations performed
247    pub bulk_removes_received: u64,
248
249    /// Number of failed message processing attempts
250    pub processing_errors: u64,
251}
252
253use std::sync::atomic::{AtomicU64, Ordering};
254
255/// Thread-safe statistics for invalidation operations
256#[derive(Debug, Default)]
257pub struct AtomicInvalidationStats {
258    pub messages_sent: AtomicU64,
259    pub messages_received: AtomicU64,
260    pub removes_received: AtomicU64,
261    pub updates_received: AtomicU64,
262    pub patterns_received: AtomicU64,
263    pub bulk_removes_received: AtomicU64,
264    pub processing_errors: AtomicU64,
265}
266
267impl AtomicInvalidationStats {
268    pub fn snapshot(&self) -> InvalidationStats {
269        InvalidationStats {
270            messages_sent: self.messages_sent.load(Ordering::Relaxed),
271            messages_received: self.messages_received.load(Ordering::Relaxed),
272            removes_received: self.removes_received.load(Ordering::Relaxed),
273            updates_received: self.updates_received.load(Ordering::Relaxed),
274            patterns_received: self.patterns_received.load(Ordering::Relaxed),
275            bulk_removes_received: self.bulk_removes_received.load(Ordering::Relaxed),
276            processing_errors: self.processing_errors.load(Ordering::Relaxed),
277        }
278    }
279}
280
281use std::sync::Arc;
282use tokio::sync::broadcast;
283
284/// Handle for subscribing to invalidation messages
285///
286/// This spawns a background task that listens to Redis Pub/Sub and processes
287/// invalidation messages by calling the provided handler callback.
288pub struct InvalidationSubscriber {
289    /// Redis client for creating Pub/Sub connections
290    client: redis::Client,
291    /// Configuration
292    config: InvalidationConfig,
293    /// Statistics
294    stats: Arc<AtomicInvalidationStats>,
295    /// Shutdown signal sender
296    shutdown_tx: broadcast::Sender<()>,
297}
298
299impl InvalidationSubscriber {
300    /// Create a new subscriber
301    ///
302    /// # Arguments
303    /// * `redis_url` - Redis connection URL
304    /// * `config` - Invalidation configuration
305    /// # Errors
306    ///
307    /// Returns an error if Redis client creation fails.
308    pub fn new(redis_url: &str, config: InvalidationConfig) -> Result<Self> {
309        let client = redis::Client::open(redis_url)
310            .context("Failed to create Redis client for subscriber")?;
311
312        let (shutdown_tx, _) = broadcast::channel(1);
313
314        Ok(Self {
315            client,
316            config,
317            stats: Arc::new(AtomicInvalidationStats::default()),
318            shutdown_tx,
319        })
320    }
321
322    /// Get a snapshot of current statistics
323    #[must_use]
324    pub fn stats(&self) -> InvalidationStats {
325        self.stats.snapshot()
326    }
327
328    /// Start the subscriber background task
329    ///
330    /// # Arguments
331    /// * `handler` - Async function to handle each invalidation message
332    ///
333    /// # Returns
334    /// Join handle for the background task
335    pub fn start<F, Fut>(&self, handler: F) -> tokio::task::JoinHandle<()>
336    where
337        F: Fn(InvalidationMessage) -> Fut + Send + Sync + 'static,
338        Fut: std::future::Future<Output = Result<()>> + Send + 'static,
339    {
340        let client = self.client.clone();
341        let channel = self.config.channel.clone();
342        let stats = Arc::clone(&self.stats);
343        let mut shutdown_rx = self.shutdown_tx.subscribe();
344
345        tokio::spawn(async move {
346            let handler = Arc::new(handler);
347
348            loop {
349                // Check for shutdown signal
350                if shutdown_rx.try_recv().is_ok() {
351                    info!("Invalidation subscriber shutting down...");
352                    break;
353                }
354
355                // Attempt to connect and subscribe
356                match Self::run_subscriber_loop(
357                    &client,
358                    &channel,
359                    Arc::clone(&handler),
360                    Arc::clone(&stats),
361                    &mut shutdown_rx,
362                )
363                .await
364                {
365                    Ok(()) => {
366                        info!("Invalidation subscriber loop completed normally");
367                        break;
368                    }
369                    Err(e) => {
370                        error!(
371                            "Invalidation subscriber error: {}. Reconnecting in 5s...",
372                            e
373                        );
374                        stats.processing_errors.fetch_add(1, Ordering::Relaxed);
375
376                        // Wait before reconnecting
377                        tokio::select! {
378                            () = tokio::time::sleep(Duration::from_secs(5)) => {},
379                            _ = shutdown_rx.recv() => {
380                                info!("Invalidation subscriber shutting down...");
381                                break;
382                            }
383                        }
384                    }
385                }
386            }
387        })
388    }
389
390    /// Internal subscriber loop
391    async fn run_subscriber_loop<F, Fut>(
392        client: &redis::Client,
393        channel: &str,
394        handler: Arc<F>,
395        stats: Arc<AtomicInvalidationStats>,
396        shutdown_rx: &mut broadcast::Receiver<()>,
397    ) -> Result<()>
398    where
399        F: Fn(InvalidationMessage) -> Fut + Send + 'static,
400        Fut: std::future::Future<Output = Result<()>> + Send + 'static,
401    {
402        // Get Pub/Sub connection
403        let mut pubsub = client
404            .get_async_pubsub()
405            .await
406            .context("Failed to get pubsub connection")?;
407
408        // Subscribe to channel
409        pubsub
410            .subscribe(channel)
411            .await
412            .context("Failed to subscribe to channel")?;
413
414        info!("Subscribed to invalidation channel: {}", channel);
415
416        // Get message stream
417        let mut stream = pubsub.on_message();
418
419        loop {
420            // Wait for message or shutdown signal
421            tokio::select! {
422                msg_result = stream.next() => {
423                    match msg_result {
424                        Some(msg) => {
425                            // Get payload
426                            let payload: String = match msg.get_payload() {
427                                Ok(p) => p,
428                                Err(e) => {
429                                    warn!("Failed to get message payload: {}", e);
430                                    stats.processing_errors.fetch_add(1, Ordering::Relaxed);
431                                    continue;
432                                }
433                            };
434
435                            // Deserialize message
436                            let invalidation_msg = match InvalidationMessage::from_json(&payload) {
437                                Ok(m) => m,
438                                Err(e) => {
439                                    warn!("Failed to deserialize invalidation message: {}", e);
440                                    stats.processing_errors.fetch_add(1, Ordering::Relaxed);
441                                    continue;
442                                }
443                            };
444
445                            // Update stats
446                            stats.messages_received.fetch_add(1, Ordering::Relaxed);
447                            match &invalidation_msg {
448                                InvalidationMessage::Remove { .. } => {
449                                    stats.removes_received.fetch_add(1, Ordering::Relaxed);
450                                }
451                                InvalidationMessage::Update { .. } => {
452                                    stats.updates_received.fetch_add(1, Ordering::Relaxed);
453                                }
454                                InvalidationMessage::RemovePattern { .. } => {
455                                    stats.patterns_received.fetch_add(1, Ordering::Relaxed);
456                                }
457                                InvalidationMessage::RemoveBulk { .. } => {
458                                    stats.bulk_removes_received.fetch_add(1, Ordering::Relaxed);
459                                }
460                            }
461
462                            // Call handler
463                            if let Err(e) = handler(invalidation_msg).await {
464                                error!("Invalidation handler error: {}", e);
465                                stats.processing_errors.fetch_add(1, Ordering::Relaxed);
466                            }
467                        }
468                        None => {
469                            // Stream ended
470                            return Err(anyhow::anyhow!("Pub/Sub message stream ended"));
471                        }
472                    }
473                }
474                _ = shutdown_rx.recv() => {
475                    return Ok(());
476                }
477            }
478        }
479    }
480
481    /// Signal the subscriber to shutdown
482    pub fn shutdown(&self) {
483        let _ = self.shutdown_tx.send(());
484    }
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490
491    #[test]
492    fn test_invalidation_message_serialization() -> Result<()> {
493        // Test Remove
494        let msg = InvalidationMessage::remove("test_key");
495        let json = msg.to_json()?;
496        let parsed = InvalidationMessage::from_json(&json)?;
497        match parsed {
498            InvalidationMessage::Remove { key } => assert_eq!(key, "test_key"),
499            _ => panic!("Wrong message type"),
500        }
501
502        // Test Update
503        let msg = InvalidationMessage::update(
504            "test_key",
505            serde_json::json!({"value": 123}),
506            Some(Duration::from_secs(300)),
507        );
508        let json = msg.to_json()?;
509        let parsed = InvalidationMessage::from_json(&json)?;
510        match parsed {
511            InvalidationMessage::Update {
512                key,
513                value,
514                ttl_secs,
515            } => {
516                assert_eq!(key, "test_key");
517                assert_eq!(value, serde_json::json!({"value": 123}));
518                assert_eq!(ttl_secs, Some(300));
519            }
520            _ => panic!("Wrong message type"),
521        }
522
523        // Test RemovePattern
524        let msg = InvalidationMessage::remove_pattern("user:*");
525        let json = msg.to_json()?;
526        let parsed = InvalidationMessage::from_json(&json)?;
527        match parsed {
528            InvalidationMessage::RemovePattern { pattern } => assert_eq!(pattern, "user:*"),
529            _ => panic!("Wrong message type"),
530        }
531
532        // Test RemoveBulk
533        let msg = InvalidationMessage::remove_bulk(vec!["key1".to_string(), "key2".to_string()]);
534        let json = msg.to_json()?;
535        let parsed = InvalidationMessage::from_json(&json)?;
536        match parsed {
537            InvalidationMessage::RemoveBulk { keys } => assert_eq!(keys, vec!["key1", "key2"]),
538            _ => panic!("Wrong message type"),
539        }
540        Ok(())
541    }
542
543    #[test]
544    fn test_invalidation_config_default() {
545        let config = InvalidationConfig::default();
546        assert_eq!(config.channel, "cache:invalidate");
547        assert!(!config.auto_broadcast_on_write);
548        assert!(!config.enable_audit_stream);
549    }
550}