Skip to main content

multi_tier_cache/
invalidation.rs

1//! Cache invalidation and synchronization module
2//!
3//! This module provides cross-instance cache invalidation using Redis Pub/Sub.
4//! It supports both cache removal (invalidation) and cache updates (refresh).
5
6use crate::error::CacheResult;
7use crate::traits::StreamingBackend;
8use bytes::Bytes;
9use futures_util::StreamExt;
10use redis::AsyncCommands;
11use serde::{Deserialize, Serialize};
12use std::time::{Duration, SystemTime, UNIX_EPOCH};
13use tokio::sync::broadcast;
14use tracing::{error, info, warn};
15use uuid::Uuid;
16
17/// Invalidation message types sent across cache instances via Redis Pub/Sub
18#[derive(Debug, Clone, Serialize, Deserialize)]
19#[serde(tag = "type")]
20pub enum InvalidationMessage {
21    /// Remove a single key from all cache instances
22    Remove { key: String },
23
24    /// Update a key with new value across all cache instances
25    /// This is more efficient than Remove for hot keys as it avoids cache miss
26    Update {
27        key: String,
28        #[serde(with = "serde_bytes_wrapper")]
29        value: Bytes,
30        #[serde(skip_serializing_if = "Option::is_none")]
31        ttl_secs: Option<u64>,
32    },
33
34    /// Remove all keys matching a pattern from all cache instances
35    /// Uses glob-style patterns (e.g., "user:*", "product:123:*")
36    RemovePattern { pattern: String },
37
38    /// Bulk remove multiple keys at once
39    RemoveBulk { keys: Vec<String> },
40}
41
42impl InvalidationMessage {
43    /// Create a Remove message
44    pub fn remove(key: impl Into<String>) -> Self {
45        Self::Remove { key: key.into() }
46    }
47
48    /// Create an Update message
49    pub fn update(key: impl Into<String>, value: Bytes, ttl: Option<Duration>) -> Self {
50        Self::Update {
51            key: key.into(),
52            value,
53            ttl_secs: ttl.map(|d| d.as_secs()),
54        }
55    }
56
57    /// Create a `RemovePattern` message
58    pub fn remove_pattern(pattern: impl Into<String>) -> Self {
59        Self::RemovePattern {
60            pattern: pattern.into(),
61        }
62    }
63
64    /// Create a `RemoveBulk` message
65    #[must_use]
66    pub fn remove_bulk(keys: Vec<String>) -> Self {
67        Self::RemoveBulk { keys }
68    }
69
70    /// Serialize to JSON for transmission
71    ///
72    /// # Errors
73    ///
74    /// Returns an error if serialization fails.
75    pub fn to_json(&self) -> CacheResult<String> {
76        serde_json::to_string(self).map_err(|e| {
77            crate::error::CacheError::SerializationError(format!(
78                "Failed to serialize invalidation message: {e}"
79            ))
80        })
81    }
82
83    /// Deserialize from JSON
84    ///
85    /// # Errors
86    ///
87    /// Returns an error if deserialization fails.
88    pub fn from_json(json: &str) -> CacheResult<Self> {
89        serde_json::from_str(json).map_err(|e| {
90            crate::error::CacheError::SerializationError(format!(
91                "Failed to deserialize invalidation message: {e}"
92            ))
93        })
94    }
95
96    /// Get TTL as Duration if present
97    pub fn ttl(&self) -> Option<Duration> {
98        match self {
99            Self::Update { ttl_secs, .. } => ttl_secs.map(Duration::from_secs),
100            _ => None,
101        }
102    }
103}
104
105/// Helper module for Bytes serialization in JSON
106mod serde_bytes_wrapper {
107    use bytes::Bytes;
108    use serde::{Deserialize, Deserializer, Serializer};
109
110    pub fn serialize<S>(bytes: &Bytes, serializer: S) -> Result<S::Ok, S::Error>
111    where
112        S: Serializer,
113    {
114        // For JSON, we use a vector of bytes.
115        // In a real production system, we'd use Base64.
116        serializer.serialize_bytes(bytes)
117    }
118
119    pub fn deserialize<'de, D>(deserializer: D) -> Result<Bytes, D::Error>
120    where
121        D: Deserializer<'de>,
122    {
123        let v: Vec<u8> = Vec::deserialize(deserializer)?;
124        Ok(Bytes::from(v))
125    }
126}
127
128/// Configuration for cache invalidation
129#[derive(Debug, Clone)]
130pub struct InvalidationConfig {
131    /// Redis Pub/Sub channel name for invalidation messages
132    pub channel: String,
133
134    /// Whether to automatically broadcast invalidation on writes
135    pub auto_broadcast_on_write: bool,
136
137    /// Whether to also publish invalidation events to Redis Streams for audit
138    pub enable_audit_stream: bool,
139
140    /// Redis Stream name for invalidation audit trail
141    pub audit_stream: String,
142
143    /// Maximum length of audit stream (older entries are trimmed)
144    pub audit_stream_maxlen: Option<usize>,
145}
146
147impl Default for InvalidationConfig {
148    fn default() -> Self {
149        Self {
150            channel: "cache:invalidate".to_string(),
151            auto_broadcast_on_write: false, // Conservative default
152            enable_audit_stream: false,
153            audit_stream: "cache:invalidations".to_string(),
154            audit_stream_maxlen: Some(10000),
155        }
156    }
157}
158
159/// Handle for sending invalidation messages
160pub struct InvalidationPublisher {
161    connection: redis::aio::ConnectionManager,
162    config: InvalidationConfig,
163}
164
165impl InvalidationPublisher {
166    /// Create a new publisher
167    #[must_use]
168    pub fn new(connection: redis::aio::ConnectionManager, config: InvalidationConfig) -> Self {
169        Self { connection, config }
170    }
171
172    /// Publish an invalidation message to all subscribers
173    ///
174    /// # Errors
175    ///
176    /// Returns an error if serialization or publishing fails.
177    pub async fn publish(&mut self, message: &InvalidationMessage) -> CacheResult<()> {
178        let json = message.to_json()?;
179
180        // Publish to Pub/Sub channel
181        let _: () = self
182            .connection
183            .publish(&self.config.channel, &json)
184            .await
185            .map_err(|e| {
186                crate::error::CacheError::InvalidationError(format!(
187                    "Failed to publish invalidation message: {e}"
188                ))
189            })?;
190
191        // Optionally publish to audit stream
192        if self.config.enable_audit_stream
193            && let Err(e) = self.publish_to_audit_stream(message).await
194        {
195            // Don't fail the invalidation if audit logging fails
196            warn!("Failed to publish to audit stream: {}", e);
197        }
198
199        Ok(())
200    }
201
202    /// Publish to audit stream for observability
203    async fn publish_to_audit_stream(&mut self, message: &InvalidationMessage) -> CacheResult<()> {
204        let timestamp = SystemTime::now()
205            .duration_since(UNIX_EPOCH)
206            .unwrap_or(Duration::ZERO)
207            .as_secs()
208            .to_string();
209
210        // Use &str to avoid unnecessary allocations
211        let (type_str, key_str): (&str, &str);
212        let extra_str: String;
213
214        match message {
215            InvalidationMessage::Remove { key } => {
216                type_str = "remove";
217                key_str = key.as_str();
218                extra_str = String::new();
219            }
220            InvalidationMessage::Update { key, .. } => {
221                type_str = "update";
222                key_str = key.as_str();
223                extra_str = String::new();
224            }
225            InvalidationMessage::RemovePattern { pattern } => {
226                type_str = "remove_pattern";
227                key_str = pattern.as_str();
228                extra_str = String::new();
229            }
230            InvalidationMessage::RemoveBulk { keys } => {
231                type_str = "remove_bulk";
232                key_str = "";
233                extra_str = keys.len().to_string();
234            }
235        }
236
237        let mut fields = vec![("type", type_str), ("timestamp", timestamp.as_str())];
238
239        if !key_str.is_empty() {
240            fields.push(("key", key_str));
241        }
242        if !extra_str.is_empty() {
243            fields.push(("count", extra_str.as_str()));
244        }
245
246        let mut cmd = redis::cmd("XADD");
247        cmd.arg(&self.config.audit_stream);
248
249        if let Some(maxlen) = self.config.audit_stream_maxlen {
250            cmd.arg("MAXLEN").arg("~").arg(maxlen);
251        }
252
253        cmd.arg("*"); // Auto-generate ID
254
255        for (key, value) in fields {
256            cmd.arg(key).arg(value);
257        }
258
259        let _: String = cmd.query_async(&mut self.connection).await.map_err(|e| {
260            crate::error::CacheError::BackendError(format!("Failed to add to audit stream: {e}"))
261        })?;
262
263        Ok(())
264    }
265}
266
267/// Statistics for invalidation operations
268#[derive(Debug, Default, Clone)]
269pub struct InvalidationStats {
270    /// Number of invalidation messages published
271    pub messages_sent: u64,
272
273    /// Number of invalidation messages received
274    pub messages_received: u64,
275
276    /// Number of Remove operations performed
277    pub removes_received: u64,
278
279    /// Number of Update operations performed
280    pub updates_received: u64,
281
282    /// Number of `RemovePattern` operations performed
283    pub patterns_received: u64,
284
285    /// Number of `RemoveBulk` operations performed
286    pub bulk_removes_received: u64,
287
288    /// Number of failed message processing attempts
289    pub processing_errors: u64,
290}
291
292use std::sync::atomic::{AtomicU64, Ordering};
293
294/// Thread-safe statistics for invalidation operations
295#[derive(Debug, Default)]
296pub struct AtomicInvalidationStats {
297    pub messages_sent: AtomicU64,
298    pub messages_received: AtomicU64,
299    pub removes_received: AtomicU64,
300    pub updates_received: AtomicU64,
301    pub patterns_received: AtomicU64,
302    pub bulk_removes_received: AtomicU64,
303    pub processing_errors: AtomicU64,
304}
305
306impl AtomicInvalidationStats {
307    pub fn snapshot(&self) -> InvalidationStats {
308        InvalidationStats {
309            messages_sent: self.messages_sent.load(Ordering::Relaxed),
310            messages_received: self.messages_received.load(Ordering::Relaxed),
311            removes_received: self.removes_received.load(Ordering::Relaxed),
312            updates_received: self.updates_received.load(Ordering::Relaxed),
313            patterns_received: self.patterns_received.load(Ordering::Relaxed),
314            bulk_removes_received: self.bulk_removes_received.load(Ordering::Relaxed),
315            processing_errors: self.processing_errors.load(Ordering::Relaxed),
316        }
317    }
318}
319
320use std::sync::Arc;
321
322/// Handle for subscribing to invalidation messages
323///
324/// This spawns a background task that listens to Redis Pub/Sub and processes
325/// invalidation messages by calling the provided handler callback.
326pub struct InvalidationSubscriber {
327    /// Redis client for creating Pub/Sub connections
328    client: redis::Client,
329    /// Configuration
330    config: InvalidationConfig,
331    /// Statistics
332    stats: Arc<AtomicInvalidationStats>,
333    /// Shutdown signal sender
334    shutdown_tx: broadcast::Sender<()>,
335}
336
337impl InvalidationSubscriber {
338    /// Create a new subscriber
339    ///
340    /// # Arguments
341    /// * `redis_url` - Redis connection URL
342    /// * `config` - Invalidation configuration
343    /// # Errors
344    ///
345    /// Returns an error if Redis client creation fails.
346    pub fn new(redis_url: &str, config: InvalidationConfig) -> CacheResult<Self> {
347        let client = redis::Client::open(redis_url).map_err(|e| {
348            crate::error::CacheError::ConfigError(format!(
349                "Failed to create Redis client for subscriber: {e}"
350            ))
351        })?;
352
353        let (shutdown_tx, _) = broadcast::channel(1);
354
355        Ok(Self {
356            client,
357            config,
358            stats: Arc::new(AtomicInvalidationStats::default()),
359            shutdown_tx,
360        })
361    }
362
363    /// Get a snapshot of current statistics
364    #[must_use]
365    pub fn stats(&self) -> InvalidationStats {
366        self.stats.snapshot()
367    }
368
369    /// Start the subscriber background task
370    ///
371    /// # Arguments
372    /// * `handler` - Async function to handle each invalidation message
373    ///
374    /// # Returns
375    /// Join handle for the background task
376    pub fn start<F, Fut>(&self, handler: F) -> tokio::task::JoinHandle<()>
377    where
378        F: Fn(InvalidationMessage) -> Fut + Send + Sync + 'static,
379        Fut: std::future::Future<Output = CacheResult<()>> + Send + 'static,
380    {
381        let client = self.client.clone();
382        let channel = self.config.channel.clone();
383        let stats = Arc::clone(&self.stats);
384        let mut shutdown_rx = self.shutdown_tx.subscribe();
385
386        tokio::spawn(async move {
387            let handler = Arc::new(handler);
388
389            loop {
390                // Check for shutdown signal
391                if shutdown_rx.try_recv().is_ok() {
392                    info!("Invalidation subscriber shutting down...");
393                    break;
394                }
395
396                // Attempt to connect and subscribe
397                match Self::run_subscriber_loop(
398                    &client,
399                    &channel,
400                    Arc::clone(&handler),
401                    Arc::clone(&stats),
402                    &mut shutdown_rx,
403                )
404                .await
405                {
406                    Ok(()) => {
407                        info!("Invalidation subscriber loop completed normally");
408                        break;
409                    }
410                    Err(e) => {
411                        error!(
412                            "Invalidation subscriber error: {}. Reconnecting in 5s...",
413                            e
414                        );
415                        stats.processing_errors.fetch_add(1, Ordering::Relaxed);
416
417                        // Wait before reconnecting
418                        tokio::select! {
419                            () = tokio::time::sleep(Duration::from_secs(5)) => {},
420                            _ = shutdown_rx.recv() => {
421                                info!("Invalidation subscriber shutting down...");
422                                break;
423                            }
424                        }
425                    }
426                }
427            }
428        })
429    }
430
431    /// Internal subscriber loop
432    async fn run_subscriber_loop<F, Fut>(
433        client: &redis::Client,
434        channel: &str,
435        handler: Arc<F>,
436        stats: Arc<AtomicInvalidationStats>,
437        shutdown_rx: &mut broadcast::Receiver<()>,
438    ) -> CacheResult<()>
439    where
440        F: Fn(InvalidationMessage) -> Fut + Send + Sync + 'static,
441        Fut: std::future::Future<Output = CacheResult<()>> + Send + 'static,
442    {
443        let mut pubsub = client.get_async_pubsub().await.map_err(|e| {
444            crate::error::CacheError::BackendError(format!("Failed to get pubsub connection: {e}"))
445        })?;
446
447        // Subscribe to channel
448        pubsub.subscribe(channel).await.map_err(|e| {
449            crate::error::CacheError::InvalidationError(format!(
450                "Failed to subscribe to channel: {e}"
451            ))
452        })?;
453
454        info!("Subscribed to invalidation channel: {}", channel);
455
456        // Get message stream
457        let mut stream = pubsub.on_message();
458
459        loop {
460            // Wait for message or shutdown signal
461            tokio::select! {
462                msg_result = stream.next() => {
463                    match msg_result {
464                        Some(msg) => {
465                            // Get payload
466                            let payload: String = match msg.get_payload() {
467                                Ok(p) => p,
468                                Err(e) => {
469                                    warn!("Failed to get message payload: {}", e);
470                                    stats.processing_errors.fetch_add(1, Ordering::Relaxed);
471                                    continue;
472                                }
473                            };
474
475                            // Deserialize message
476                            let invalidation_msg = match InvalidationMessage::from_json(&payload) {
477                                Ok(m) => m,
478                                Err(e) => {
479                                    warn!("Failed to deserialize invalidation message: {}", e);
480                                    stats.processing_errors.fetch_add(1, Ordering::Relaxed);
481                                    continue;
482                                }
483                            };
484
485                            // Update stats
486                            stats.messages_received.fetch_add(1, Ordering::Relaxed);
487                            match &invalidation_msg {
488                                InvalidationMessage::Remove { .. } => {
489                                    stats.removes_received.fetch_add(1, Ordering::Relaxed);
490                                }
491                                InvalidationMessage::Update { .. } => {
492                                    stats.updates_received.fetch_add(1, Ordering::Relaxed);
493                                }
494                                InvalidationMessage::RemovePattern { .. } => {
495                                    stats.patterns_received.fetch_add(1, Ordering::Relaxed);
496                                }
497                                InvalidationMessage::RemoveBulk { .. } => {
498                                    stats.bulk_removes_received.fetch_add(1, Ordering::Relaxed);
499                                }
500                            }
501
502                            // Call handler
503                            if let Err(e) = handler(invalidation_msg).await {
504                                error!("Invalidation handler error: {}", e);
505                                stats.processing_errors.fetch_add(1, Ordering::Relaxed);
506                            }
507                        }
508                        None => {
509                            // Stream ended
510                            return Err(crate::error::CacheError::InvalidationError("Pub/Sub message stream ended".to_string()));
511                        }
512                    }
513                }
514                _ = shutdown_rx.recv() => {
515                    return Ok(());
516                }
517            }
518        }
519    }
520
521    /// Signal the subscriber to shutdown
522    pub fn shutdown(&self) {
523        let _ = self.shutdown_tx.send(());
524    }
525}
526
527/// Reliable subscriber using Redis Streams and Consumer Groups
528pub struct ReliableStreamSubscriber {
529    client: redis::Client,
530    config: InvalidationConfig,
531    stats: Arc<AtomicInvalidationStats>,
532    shutdown_tx: broadcast::Sender<()>,
533    group_name: String,
534    consumer_name: String,
535}
536
537impl ReliableStreamSubscriber {
538    /// Create a new `ReliableStreamSubscriber`
539    ///
540    /// # Errors
541    ///
542    /// Returns an error if the Redis client fails to open.
543    pub fn new(redis_url: &str, config: InvalidationConfig, group_name: &str) -> CacheResult<Self> {
544        let client = redis::Client::open(redis_url).map_err(|e| {
545            crate::error::CacheError::ConfigError(format!(
546                "Failed to create Redis client for reliable subscriber: {e}"
547            ))
548        })?;
549
550        let (shutdown_tx, _) = broadcast::channel(1);
551        let consumer_name = format!("consumer-{}", Uuid::new_v4());
552
553        Ok(Self {
554            client,
555            config,
556            stats: Arc::new(AtomicInvalidationStats::default()),
557            shutdown_tx,
558            group_name: group_name.to_string(),
559            consumer_name,
560        })
561    }
562
563    pub fn start<F, Fut>(&self, handler: F) -> tokio::task::JoinHandle<()>
564    where
565        F: Fn(InvalidationMessage) -> Fut + Send + Sync + 'static,
566        Fut: std::future::Future<Output = CacheResult<()>> + Send + 'static,
567    {
568        let client = self.client.clone();
569        let stream_key = self.config.channel.clone();
570        let group_name = self.group_name.clone();
571        let consumer_name = self.consumer_name.clone();
572        let handler = Arc::new(handler);
573        let stats = self.stats.clone();
574        let mut shutdown_rx = self.shutdown_tx.subscribe();
575
576        tokio::spawn(async move {
577            info!(
578                stream = %stream_key,
579                group = %group_name,
580                consumer = %consumer_name,
581                "Starting reliable stream subscriber"
582            );
583
584            // 1. Ensure stream and group exist
585            let redis_backend = crate::redis_streams::RedisStreams::new(
586                client.get_connection_info().addr().to_string().as_str(),
587            )
588            .await;
589            if let Ok(backend) = redis_backend {
590                let _ = backend
591                    .stream_create_group(&stream_key, &group_name, "0")
592                    .await;
593
594                loop {
595                    // Check shutdown before starting loop
596                    if shutdown_rx.try_recv().is_ok() {
597                        break;
598                    }
599
600                    if let Err(e) = Self::run_reliable_loop(
601                        &backend,
602                        &stream_key,
603                        &group_name,
604                        &consumer_name,
605                        handler.clone(),
606                        stats.clone(),
607                        &mut shutdown_rx,
608                    )
609                    .await
610                    {
611                        error!("Reliable subscriber loop error: {}", e);
612
613                        tokio::select! {
614                            () = tokio::time::sleep(Duration::from_secs(5)) => {},
615                            _ = shutdown_rx.recv() => break,
616                        }
617                    } else {
618                        break; // Normal shutdown
619                    }
620                }
621            }
622        })
623    }
624
625    async fn run_reliable_loop<F, Fut>(
626        backend: &dyn crate::traits::StreamingBackend,
627        stream_key: &str,
628        group_name: &str,
629        consumer_name: &str,
630        handler: Arc<F>,
631        stats: Arc<AtomicInvalidationStats>,
632        shutdown_rx: &mut broadcast::Receiver<()>,
633    ) -> CacheResult<()>
634    where
635        F: Fn(InvalidationMessage) -> Fut + Send + Sync + 'static,
636        Fut: std::future::Future<Output = CacheResult<()>> + Send + 'static,
637    {
638        loop {
639            tokio::select! {
640                entries_result = backend.stream_read_group(stream_key, group_name, consumer_name, 10, Some(5000)) => {
641                    let entries = entries_result?;
642                    if entries.is_empty() { continue; }
643
644                    let mut processed_ids = Vec::new();
645                    for (id, fields) in entries {
646                        // Find "payload" field or use first field if it looks like JSON
647                        let payload = fields.iter().find(|(k, _)| k == "payload")
648                            .map(|(_, v)| v.as_str())
649                            .or_else(|| fields.first().map(|(_, v)| v.as_str()));
650
651                        if let Some(msg) = payload.and_then(|json| InvalidationMessage::from_json(json).ok()) {
652                            stats.messages_received.fetch_add(1, Ordering::Relaxed);
653                            if let Err(e) = handler(msg).await {
654                                error!("Reliable handler error: {}", e);
655                                stats.processing_errors.fetch_add(1, Ordering::Relaxed);
656                            } else {
657                                processed_ids.push(id);
658                            }
659                        }
660                    }
661
662                    if !processed_ids.is_empty() {
663                        backend.stream_ack(stream_key, group_name, &processed_ids).await?;
664                    }
665                }
666                _ = shutdown_rx.recv() => return Ok(()),
667            }
668        }
669    }
670
671    /// Signal the subscriber to shutdown
672    pub fn shutdown(&self) {
673        let _ = self.shutdown_tx.send(()).unwrap_or(0);
674    }
675}
676
677#[cfg(test)]
678mod tests {
679    use super::*;
680
681    #[test]
682    fn test_invalidation_message_serialization() -> CacheResult<()> {
683        // Test Remove
684        let msg = InvalidationMessage::remove("test_key");
685        let json = msg.to_json()?;
686        let parsed = InvalidationMessage::from_json(&json)?;
687        match parsed {
688            InvalidationMessage::Remove { key } => assert_eq!(key, "test_key"),
689            _ => panic!("Wrong message type"),
690        }
691
692        // Test Update
693        let msg = InvalidationMessage::update(
694            "test_key",
695            Bytes::from("{\"value\": 123}"),
696            Some(Duration::from_secs(3600)),
697        );
698
699        if let InvalidationMessage::Update {
700            key,
701            value,
702            ttl_secs,
703        } = msg
704        {
705            assert_eq!(key, "test_key");
706            assert_eq!(value, Bytes::from("{\"value\": 123}"));
707            assert_eq!(ttl_secs, Some(3600));
708        } else {
709            panic!("Expected Update message");
710        }
711
712        // Test RemovePattern
713        let msg = InvalidationMessage::remove_pattern("user:*");
714        let json = msg.to_json()?;
715        let parsed = InvalidationMessage::from_json(&json)?;
716        match parsed {
717            InvalidationMessage::RemovePattern { pattern } => assert_eq!(pattern, "user:*"),
718            _ => panic!("Wrong message type"),
719        }
720
721        // Test RemoveBulk
722        let msg = InvalidationMessage::remove_bulk(vec!["key1".to_string(), "key2".to_string()]);
723        let json = msg.to_json()?;
724        let parsed = InvalidationMessage::from_json(&json)?;
725        match parsed {
726            InvalidationMessage::RemoveBulk { keys } => assert_eq!(keys, vec!["key1", "key2"]),
727            _ => panic!("Wrong message type"),
728        }
729        Ok(())
730    }
731
732    #[test]
733    fn test_invalidation_config_default() {
734        let config = InvalidationConfig::default();
735        assert_eq!(config.channel, "cache:invalidate");
736        assert!(!config.auto_broadcast_on_write);
737        assert!(!config.enable_audit_stream);
738    }
739}