Skip to main content

multi_tier_cache/
invalidation.rs

1//! Cache invalidation and synchronization module
2//!
3//! This module provides cross-instance cache invalidation using Redis Pub/Sub.
4//! It supports both cache removal (invalidation) and cache updates (refresh).
5
6use anyhow::{Context, Result};
7use bytes::Bytes;
8use futures_util::StreamExt;
9use redis::AsyncCommands;
10use serde::{Deserialize, Serialize};
11use std::time::{Duration, SystemTime, UNIX_EPOCH};
12use tracing::{error, info, warn};
13
14/// Invalidation message types sent across cache instances via Redis Pub/Sub
15#[derive(Debug, Clone, Serialize, Deserialize)]
16#[serde(tag = "type")]
17pub enum InvalidationMessage {
18    /// Remove a single key from all cache instances
19    Remove { key: String },
20
21    /// Update a key with new value across all cache instances
22    /// This is more efficient than Remove for hot keys as it avoids cache miss
23    Update {
24        key: String,
25        #[serde(with = "serde_bytes_wrapper")]
26        value: Bytes,
27        #[serde(skip_serializing_if = "Option::is_none")]
28        ttl_secs: Option<u64>,
29    },
30
31    /// Remove all keys matching a pattern from all cache instances
32    /// Uses glob-style patterns (e.g., "user:*", "product:123:*")
33    RemovePattern { pattern: String },
34
35    /// Bulk remove multiple keys at once
36    RemoveBulk { keys: Vec<String> },
37}
38
39impl InvalidationMessage {
40    /// Create a Remove message
41    pub fn remove(key: impl Into<String>) -> Self {
42        Self::Remove { key: key.into() }
43    }
44
45    /// Create an Update message
46    pub fn update(key: impl Into<String>, value: Bytes, ttl: Option<Duration>) -> Self {
47        Self::Update {
48            key: key.into(),
49            value,
50            ttl_secs: ttl.map(|d| d.as_secs()),
51        }
52    }
53
54    /// Create a `RemovePattern` message
55    pub fn remove_pattern(pattern: impl Into<String>) -> Self {
56        Self::RemovePattern {
57            pattern: pattern.into(),
58        }
59    }
60
61    /// Create a `RemoveBulk` message
62    #[must_use]
63    pub fn remove_bulk(keys: Vec<String>) -> Self {
64        Self::RemoveBulk { keys }
65    }
66
67    /// Serialize to JSON for transmission
68    ///
69    /// # Errors
70    ///
71    /// Returns an error if serialization fails.
72    pub fn to_json(&self) -> Result<String> {
73        serde_json::to_string(self).context("Failed to serialize invalidation message")
74    }
75
76    /// Deserialize from JSON
77    ///
78    /// # Errors
79    ///
80    /// Returns an error if deserialization fails.
81    pub fn from_json(json: &str) -> Result<Self> {
82        serde_json::from_str(json).context("Failed to deserialize invalidation message")
83    }
84
85    /// Get TTL as Duration if present
86    pub fn ttl(&self) -> Option<Duration> {
87        match self {
88            Self::Update { ttl_secs, .. } => ttl_secs.map(Duration::from_secs),
89            _ => None,
90        }
91    }
92}
93
94/// Helper module for Bytes serialization in JSON
95mod serde_bytes_wrapper {
96    use bytes::Bytes;
97    use serde::{Deserialize, Deserializer, Serializer};
98
99    pub fn serialize<S>(bytes: &Bytes, serializer: S) -> Result<S::Ok, S::Error>
100    where
101        S: Serializer,
102    {
103        // For JSON, we use a vector of bytes.
104        // In a real production system, we'd use Base64.
105        serializer.serialize_bytes(bytes)
106    }
107
108    pub fn deserialize<'de, D>(deserializer: D) -> Result<Bytes, D::Error>
109    where
110        D: Deserializer<'de>,
111    {
112        let v: Vec<u8> = Vec::deserialize(deserializer)?;
113        Ok(Bytes::from(v))
114    }
115}
116
117/// Configuration for cache invalidation
118#[derive(Debug, Clone)]
119pub struct InvalidationConfig {
120    /// Redis Pub/Sub channel name for invalidation messages
121    pub channel: String,
122
123    /// Whether to automatically broadcast invalidation on writes
124    pub auto_broadcast_on_write: bool,
125
126    /// Whether to also publish invalidation events to Redis Streams for audit
127    pub enable_audit_stream: bool,
128
129    /// Redis Stream name for invalidation audit trail
130    pub audit_stream: String,
131
132    /// Maximum length of audit stream (older entries are trimmed)
133    pub audit_stream_maxlen: Option<usize>,
134}
135
136impl Default for InvalidationConfig {
137    fn default() -> Self {
138        Self {
139            channel: "cache:invalidate".to_string(),
140            auto_broadcast_on_write: false, // Conservative default
141            enable_audit_stream: false,
142            audit_stream: "cache:invalidations".to_string(),
143            audit_stream_maxlen: Some(10000),
144        }
145    }
146}
147
148/// Handle for sending invalidation messages
149pub struct InvalidationPublisher {
150    connection: redis::aio::ConnectionManager,
151    config: InvalidationConfig,
152}
153
154impl InvalidationPublisher {
155    /// Create a new publisher
156    #[must_use]
157    pub fn new(connection: redis::aio::ConnectionManager, config: InvalidationConfig) -> Self {
158        Self { connection, config }
159    }
160
161    /// Publish an invalidation message to all subscribers
162    ///
163    /// # Errors
164    ///
165    /// Returns an error if serialization or publishing fails.
166    pub async fn publish(&mut self, message: &InvalidationMessage) -> Result<()> {
167        let json = message.to_json()?;
168
169        // Publish to Pub/Sub channel
170        let _: () = self
171            .connection
172            .publish(&self.config.channel, &json)
173            .await
174            .context("Failed to publish invalidation message")?;
175
176        // Optionally publish to audit stream
177        if self.config.enable_audit_stream
178            && let Err(e) = self.publish_to_audit_stream(message).await
179        {
180            // Don't fail the invalidation if audit logging fails
181            warn!("Failed to publish to audit stream: {}", e);
182        }
183
184        Ok(())
185    }
186
187    /// Publish to audit stream for observability
188    async fn publish_to_audit_stream(&mut self, message: &InvalidationMessage) -> Result<()> {
189        let timestamp = SystemTime::now()
190            .duration_since(UNIX_EPOCH)
191            .unwrap_or(Duration::ZERO)
192            .as_secs()
193            .to_string();
194
195        // Use &str to avoid unnecessary allocations
196        let (type_str, key_str): (&str, &str);
197        let extra_str: String;
198
199        match message {
200            InvalidationMessage::Remove { key } => {
201                type_str = "remove";
202                key_str = key.as_str();
203                extra_str = String::new();
204            }
205            InvalidationMessage::Update { key, .. } => {
206                type_str = "update";
207                key_str = key.as_str();
208                extra_str = String::new();
209            }
210            InvalidationMessage::RemovePattern { pattern } => {
211                type_str = "remove_pattern";
212                key_str = pattern.as_str();
213                extra_str = String::new();
214            }
215            InvalidationMessage::RemoveBulk { keys } => {
216                type_str = "remove_bulk";
217                key_str = "";
218                extra_str = keys.len().to_string();
219            }
220        }
221
222        let mut fields = vec![("type", type_str), ("timestamp", timestamp.as_str())];
223
224        if !key_str.is_empty() {
225            fields.push(("key", key_str));
226        }
227        if !extra_str.is_empty() {
228            fields.push(("count", extra_str.as_str()));
229        }
230
231        let mut cmd = redis::cmd("XADD");
232        cmd.arg(&self.config.audit_stream);
233
234        if let Some(maxlen) = self.config.audit_stream_maxlen {
235            cmd.arg("MAXLEN").arg("~").arg(maxlen);
236        }
237
238        cmd.arg("*"); // Auto-generate ID
239
240        for (key, value) in fields {
241            cmd.arg(key).arg(value);
242        }
243
244        let _: String = cmd
245            .query_async(&mut self.connection)
246            .await
247            .context("Failed to add to audit stream")?;
248
249        Ok(())
250    }
251}
252
253/// Statistics for invalidation operations
254#[derive(Debug, Default, Clone)]
255pub struct InvalidationStats {
256    /// Number of invalidation messages published
257    pub messages_sent: u64,
258
259    /// Number of invalidation messages received
260    pub messages_received: u64,
261
262    /// Number of Remove operations performed
263    pub removes_received: u64,
264
265    /// Number of Update operations performed
266    pub updates_received: u64,
267
268    /// Number of `RemovePattern` operations performed
269    pub patterns_received: u64,
270
271    /// Number of `RemoveBulk` operations performed
272    pub bulk_removes_received: u64,
273
274    /// Number of failed message processing attempts
275    pub processing_errors: u64,
276}
277
278use std::sync::atomic::{AtomicU64, Ordering};
279
280/// Thread-safe statistics for invalidation operations
281#[derive(Debug, Default)]
282pub struct AtomicInvalidationStats {
283    pub messages_sent: AtomicU64,
284    pub messages_received: AtomicU64,
285    pub removes_received: AtomicU64,
286    pub updates_received: AtomicU64,
287    pub patterns_received: AtomicU64,
288    pub bulk_removes_received: AtomicU64,
289    pub processing_errors: AtomicU64,
290}
291
292impl AtomicInvalidationStats {
293    pub fn snapshot(&self) -> InvalidationStats {
294        InvalidationStats {
295            messages_sent: self.messages_sent.load(Ordering::Relaxed),
296            messages_received: self.messages_received.load(Ordering::Relaxed),
297            removes_received: self.removes_received.load(Ordering::Relaxed),
298            updates_received: self.updates_received.load(Ordering::Relaxed),
299            patterns_received: self.patterns_received.load(Ordering::Relaxed),
300            bulk_removes_received: self.bulk_removes_received.load(Ordering::Relaxed),
301            processing_errors: self.processing_errors.load(Ordering::Relaxed),
302        }
303    }
304}
305
306use std::sync::Arc;
307use tokio::sync::broadcast;
308
309/// Handle for subscribing to invalidation messages
310///
311/// This spawns a background task that listens to Redis Pub/Sub and processes
312/// invalidation messages by calling the provided handler callback.
313pub struct InvalidationSubscriber {
314    /// Redis client for creating Pub/Sub connections
315    client: redis::Client,
316    /// Configuration
317    config: InvalidationConfig,
318    /// Statistics
319    stats: Arc<AtomicInvalidationStats>,
320    /// Shutdown signal sender
321    shutdown_tx: broadcast::Sender<()>,
322}
323
324impl InvalidationSubscriber {
325    /// Create a new subscriber
326    ///
327    /// # Arguments
328    /// * `redis_url` - Redis connection URL
329    /// * `config` - Invalidation configuration
330    /// # Errors
331    ///
332    /// Returns an error if Redis client creation fails.
333    pub fn new(redis_url: &str, config: InvalidationConfig) -> Result<Self> {
334        let client = redis::Client::open(redis_url)
335            .context("Failed to create Redis client for subscriber")?;
336
337        let (shutdown_tx, _) = broadcast::channel(1);
338
339        Ok(Self {
340            client,
341            config,
342            stats: Arc::new(AtomicInvalidationStats::default()),
343            shutdown_tx,
344        })
345    }
346
347    /// Get a snapshot of current statistics
348    #[must_use]
349    pub fn stats(&self) -> InvalidationStats {
350        self.stats.snapshot()
351    }
352
353    /// Start the subscriber background task
354    ///
355    /// # Arguments
356    /// * `handler` - Async function to handle each invalidation message
357    ///
358    /// # Returns
359    /// Join handle for the background task
360    pub fn start<F, Fut>(&self, handler: F) -> tokio::task::JoinHandle<()>
361    where
362        F: Fn(InvalidationMessage) -> Fut + Send + Sync + 'static,
363        Fut: std::future::Future<Output = Result<()>> + Send + 'static,
364    {
365        let client = self.client.clone();
366        let channel = self.config.channel.clone();
367        let stats = Arc::clone(&self.stats);
368        let mut shutdown_rx = self.shutdown_tx.subscribe();
369
370        tokio::spawn(async move {
371            let handler = Arc::new(handler);
372
373            loop {
374                // Check for shutdown signal
375                if shutdown_rx.try_recv().is_ok() {
376                    info!("Invalidation subscriber shutting down...");
377                    break;
378                }
379
380                // Attempt to connect and subscribe
381                match Self::run_subscriber_loop(
382                    &client,
383                    &channel,
384                    Arc::clone(&handler),
385                    Arc::clone(&stats),
386                    &mut shutdown_rx,
387                )
388                .await
389                {
390                    Ok(()) => {
391                        info!("Invalidation subscriber loop completed normally");
392                        break;
393                    }
394                    Err(e) => {
395                        error!(
396                            "Invalidation subscriber error: {}. Reconnecting in 5s...",
397                            e
398                        );
399                        stats.processing_errors.fetch_add(1, Ordering::Relaxed);
400
401                        // Wait before reconnecting
402                        tokio::select! {
403                            () = tokio::time::sleep(Duration::from_secs(5)) => {},
404                            _ = shutdown_rx.recv() => {
405                                info!("Invalidation subscriber shutting down...");
406                                break;
407                            }
408                        }
409                    }
410                }
411            }
412        })
413    }
414
415    /// Internal subscriber loop
416    async fn run_subscriber_loop<F, Fut>(
417        client: &redis::Client,
418        channel: &str,
419        handler: Arc<F>,
420        stats: Arc<AtomicInvalidationStats>,
421        shutdown_rx: &mut broadcast::Receiver<()>,
422    ) -> Result<()>
423    where
424        F: Fn(InvalidationMessage) -> Fut + Send + 'static,
425        Fut: std::future::Future<Output = Result<()>> + Send + 'static,
426    {
427        // Get Pub/Sub connection
428        let mut pubsub = client
429            .get_async_pubsub()
430            .await
431            .context("Failed to get pubsub connection")?;
432
433        // Subscribe to channel
434        pubsub
435            .subscribe(channel)
436            .await
437            .context("Failed to subscribe to channel")?;
438
439        info!("Subscribed to invalidation channel: {}", channel);
440
441        // Get message stream
442        let mut stream = pubsub.on_message();
443
444        loop {
445            // Wait for message or shutdown signal
446            tokio::select! {
447                msg_result = stream.next() => {
448                    match msg_result {
449                        Some(msg) => {
450                            // Get payload
451                            let payload: String = match msg.get_payload() {
452                                Ok(p) => p,
453                                Err(e) => {
454                                    warn!("Failed to get message payload: {}", e);
455                                    stats.processing_errors.fetch_add(1, Ordering::Relaxed);
456                                    continue;
457                                }
458                            };
459
460                            // Deserialize message
461                            let invalidation_msg = match InvalidationMessage::from_json(&payload) {
462                                Ok(m) => m,
463                                Err(e) => {
464                                    warn!("Failed to deserialize invalidation message: {}", e);
465                                    stats.processing_errors.fetch_add(1, Ordering::Relaxed);
466                                    continue;
467                                }
468                            };
469
470                            // Update stats
471                            stats.messages_received.fetch_add(1, Ordering::Relaxed);
472                            match &invalidation_msg {
473                                InvalidationMessage::Remove { .. } => {
474                                    stats.removes_received.fetch_add(1, Ordering::Relaxed);
475                                }
476                                InvalidationMessage::Update { .. } => {
477                                    stats.updates_received.fetch_add(1, Ordering::Relaxed);
478                                }
479                                InvalidationMessage::RemovePattern { .. } => {
480                                    stats.patterns_received.fetch_add(1, Ordering::Relaxed);
481                                }
482                                InvalidationMessage::RemoveBulk { .. } => {
483                                    stats.bulk_removes_received.fetch_add(1, Ordering::Relaxed);
484                                }
485                            }
486
487                            // Call handler
488                            if let Err(e) = handler(invalidation_msg).await {
489                                error!("Invalidation handler error: {}", e);
490                                stats.processing_errors.fetch_add(1, Ordering::Relaxed);
491                            }
492                        }
493                        None => {
494                            // Stream ended
495                            return Err(anyhow::anyhow!("Pub/Sub message stream ended"));
496                        }
497                    }
498                }
499                _ = shutdown_rx.recv() => {
500                    return Ok(());
501                }
502            }
503        }
504    }
505
506    /// Signal the subscriber to shutdown
507    pub fn shutdown(&self) {
508        let _ = self.shutdown_tx.send(());
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515
516    #[test]
517    fn test_invalidation_message_serialization() -> Result<()> {
518        // Test Remove
519        let msg = InvalidationMessage::remove("test_key");
520        let json = msg.to_json()?;
521        let parsed = InvalidationMessage::from_json(&json)?;
522        match parsed {
523            InvalidationMessage::Remove { key } => assert_eq!(key, "test_key"),
524            _ => panic!("Wrong message type"),
525        }
526
527        // Test Update
528        let msg = InvalidationMessage::update(
529            "test_key",
530            Bytes::from("{\"value\": 123}"),
531            Some(Duration::from_secs(3600)),
532        );
533
534        if let InvalidationMessage::Update {
535            key,
536            value,
537            ttl_secs,
538        } = msg
539        {
540            assert_eq!(key, "test_key");
541            assert_eq!(value, Bytes::from("{\"value\": 123}"));
542            assert_eq!(ttl_secs, Some(3600));
543        } else {
544            panic!("Expected Update message");
545        }
546
547        // Test RemovePattern
548        let msg = InvalidationMessage::remove_pattern("user:*");
549        let json = msg.to_json()?;
550        let parsed = InvalidationMessage::from_json(&json)?;
551        match parsed {
552            InvalidationMessage::RemovePattern { pattern } => assert_eq!(pattern, "user:*"),
553            _ => panic!("Wrong message type"),
554        }
555
556        // Test RemoveBulk
557        let msg = InvalidationMessage::remove_bulk(vec!["key1".to_string(), "key2".to_string()]);
558        let json = msg.to_json()?;
559        let parsed = InvalidationMessage::from_json(&json)?;
560        match parsed {
561            InvalidationMessage::RemoveBulk { keys } => assert_eq!(keys, vec!["key1", "key2"]),
562            _ => panic!("Wrong message type"),
563        }
564        Ok(())
565    }
566
567    #[test]
568    fn test_invalidation_config_default() {
569        let config = InvalidationConfig::default();
570        assert_eq!(config.channel, "cache:invalidate");
571        assert!(!config.auto_broadcast_on_write);
572        assert!(!config.enable_audit_stream);
573    }
574}