1use crate::error::CacheResult;
7use crate::traits::StreamingBackend;
8use bytes::Bytes;
9use futures_util::StreamExt;
10use redis::AsyncCommands;
11use serde::{Deserialize, Serialize};
12use std::time::{Duration, SystemTime, UNIX_EPOCH};
13use tokio::sync::broadcast;
14use tracing::{error, info, warn};
15use uuid::Uuid;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19#[serde(tag = "type")]
20pub enum InvalidationMessage {
21 Remove { key: String },
23
24 Update {
27 key: String,
28 #[serde(with = "serde_bytes_wrapper")]
29 value: Bytes,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 ttl_secs: Option<u64>,
32 },
33
34 RemovePattern { pattern: String },
37
38 RemoveBulk { keys: Vec<String> },
40}
41
42impl InvalidationMessage {
43 pub fn remove(key: impl Into<String>) -> Self {
45 Self::Remove { key: key.into() }
46 }
47
48 pub fn update(key: impl Into<String>, value: Bytes, ttl: Option<Duration>) -> Self {
50 Self::Update {
51 key: key.into(),
52 value,
53 ttl_secs: ttl.map(|d| d.as_secs()),
54 }
55 }
56
57 pub fn remove_pattern(pattern: impl Into<String>) -> Self {
59 Self::RemovePattern {
60 pattern: pattern.into(),
61 }
62 }
63
64 #[must_use]
66 pub fn remove_bulk(keys: Vec<String>) -> Self {
67 Self::RemoveBulk { keys }
68 }
69
70 pub fn to_json(&self) -> CacheResult<String> {
76 serde_json::to_string(self).map_err(|e| {
77 crate::error::CacheError::SerializationError(format!(
78 "Failed to serialize invalidation message: {e}"
79 ))
80 })
81 }
82
83 pub fn from_json(json: &str) -> CacheResult<Self> {
89 serde_json::from_str(json).map_err(|e| {
90 crate::error::CacheError::SerializationError(format!(
91 "Failed to deserialize invalidation message: {e}"
92 ))
93 })
94 }
95
96 pub fn ttl(&self) -> Option<Duration> {
98 match self {
99 Self::Update { ttl_secs, .. } => ttl_secs.map(Duration::from_secs),
100 _ => None,
101 }
102 }
103}
104
105mod serde_bytes_wrapper {
107 use bytes::Bytes;
108 use serde::{Deserialize, Deserializer, Serializer};
109
110 pub fn serialize<S>(bytes: &Bytes, serializer: S) -> Result<S::Ok, S::Error>
111 where
112 S: Serializer,
113 {
114 serializer.serialize_bytes(bytes)
117 }
118
119 pub fn deserialize<'de, D>(deserializer: D) -> Result<Bytes, D::Error>
120 where
121 D: Deserializer<'de>,
122 {
123 let v: Vec<u8> = Vec::deserialize(deserializer)?;
124 Ok(Bytes::from(v))
125 }
126}
127
128#[derive(Debug, Clone)]
130pub struct InvalidationConfig {
131 pub channel: String,
133
134 pub auto_broadcast_on_write: bool,
136
137 pub enable_audit_stream: bool,
139
140 pub audit_stream: String,
142
143 pub audit_stream_maxlen: Option<usize>,
145}
146
147impl Default for InvalidationConfig {
148 fn default() -> Self {
149 Self {
150 channel: "cache:invalidate".to_string(),
151 auto_broadcast_on_write: false, enable_audit_stream: false,
153 audit_stream: "cache:invalidations".to_string(),
154 audit_stream_maxlen: Some(10000),
155 }
156 }
157}
158
159pub struct InvalidationPublisher {
161 connection: redis::aio::ConnectionManager,
162 config: InvalidationConfig,
163}
164
165impl InvalidationPublisher {
166 #[must_use]
168 pub fn new(connection: redis::aio::ConnectionManager, config: InvalidationConfig) -> Self {
169 Self { connection, config }
170 }
171
172 pub async fn publish(&mut self, message: &InvalidationMessage) -> CacheResult<()> {
178 let json = message.to_json()?;
179
180 let _: () = self
182 .connection
183 .publish(&self.config.channel, &json)
184 .await
185 .map_err(|e| {
186 crate::error::CacheError::InvalidationError(format!(
187 "Failed to publish invalidation message: {e}"
188 ))
189 })?;
190
191 if self.config.enable_audit_stream
193 && let Err(e) = self.publish_to_audit_stream(message).await
194 {
195 warn!("Failed to publish to audit stream: {}", e);
197 }
198
199 Ok(())
200 }
201
202 async fn publish_to_audit_stream(&mut self, message: &InvalidationMessage) -> CacheResult<()> {
204 let timestamp = SystemTime::now()
205 .duration_since(UNIX_EPOCH)
206 .unwrap_or(Duration::ZERO)
207 .as_secs()
208 .to_string();
209
210 let (type_str, key_str): (&str, &str);
212 let extra_str: String;
213
214 match message {
215 InvalidationMessage::Remove { key } => {
216 type_str = "remove";
217 key_str = key.as_str();
218 extra_str = String::new();
219 }
220 InvalidationMessage::Update { key, .. } => {
221 type_str = "update";
222 key_str = key.as_str();
223 extra_str = String::new();
224 }
225 InvalidationMessage::RemovePattern { pattern } => {
226 type_str = "remove_pattern";
227 key_str = pattern.as_str();
228 extra_str = String::new();
229 }
230 InvalidationMessage::RemoveBulk { keys } => {
231 type_str = "remove_bulk";
232 key_str = "";
233 extra_str = keys.len().to_string();
234 }
235 }
236
237 let mut fields = vec![("type", type_str), ("timestamp", timestamp.as_str())];
238
239 if !key_str.is_empty() {
240 fields.push(("key", key_str));
241 }
242 if !extra_str.is_empty() {
243 fields.push(("count", extra_str.as_str()));
244 }
245
246 let mut cmd = redis::cmd("XADD");
247 cmd.arg(&self.config.audit_stream);
248
249 if let Some(maxlen) = self.config.audit_stream_maxlen {
250 cmd.arg("MAXLEN").arg("~").arg(maxlen);
251 }
252
253 cmd.arg("*"); for (key, value) in fields {
256 cmd.arg(key).arg(value);
257 }
258
259 let _: String = cmd.query_async(&mut self.connection).await.map_err(|e| {
260 crate::error::CacheError::BackendError(format!("Failed to add to audit stream: {e}"))
261 })?;
262
263 Ok(())
264 }
265}
266
267#[derive(Debug, Default, Clone)]
269pub struct InvalidationStats {
270 pub messages_sent: u64,
272
273 pub messages_received: u64,
275
276 pub removes_received: u64,
278
279 pub updates_received: u64,
281
282 pub patterns_received: u64,
284
285 pub bulk_removes_received: u64,
287
288 pub processing_errors: u64,
290}
291
292use std::sync::atomic::{AtomicU64, Ordering};
293
294#[derive(Debug, Default)]
296pub struct AtomicInvalidationStats {
297 pub messages_sent: AtomicU64,
298 pub messages_received: AtomicU64,
299 pub removes_received: AtomicU64,
300 pub updates_received: AtomicU64,
301 pub patterns_received: AtomicU64,
302 pub bulk_removes_received: AtomicU64,
303 pub processing_errors: AtomicU64,
304}
305
306impl AtomicInvalidationStats {
307 pub fn snapshot(&self) -> InvalidationStats {
308 InvalidationStats {
309 messages_sent: self.messages_sent.load(Ordering::Relaxed),
310 messages_received: self.messages_received.load(Ordering::Relaxed),
311 removes_received: self.removes_received.load(Ordering::Relaxed),
312 updates_received: self.updates_received.load(Ordering::Relaxed),
313 patterns_received: self.patterns_received.load(Ordering::Relaxed),
314 bulk_removes_received: self.bulk_removes_received.load(Ordering::Relaxed),
315 processing_errors: self.processing_errors.load(Ordering::Relaxed),
316 }
317 }
318}
319
320use std::sync::Arc;
321
322pub struct InvalidationSubscriber {
327 client: redis::Client,
329 config: InvalidationConfig,
331 stats: Arc<AtomicInvalidationStats>,
333 shutdown_tx: broadcast::Sender<()>,
335}
336
337impl InvalidationSubscriber {
338 pub fn new(redis_url: &str, config: InvalidationConfig) -> CacheResult<Self> {
347 let client = redis::Client::open(redis_url).map_err(|e| {
348 crate::error::CacheError::ConfigError(format!(
349 "Failed to create Redis client for subscriber: {e}"
350 ))
351 })?;
352
353 let (shutdown_tx, _) = broadcast::channel(1);
354
355 Ok(Self {
356 client,
357 config,
358 stats: Arc::new(AtomicInvalidationStats::default()),
359 shutdown_tx,
360 })
361 }
362
363 #[must_use]
365 pub fn stats(&self) -> InvalidationStats {
366 self.stats.snapshot()
367 }
368
369 pub fn start<F, Fut>(&self, handler: F) -> tokio::task::JoinHandle<()>
377 where
378 F: Fn(InvalidationMessage) -> Fut + Send + Sync + 'static,
379 Fut: std::future::Future<Output = CacheResult<()>> + Send + 'static,
380 {
381 let client = self.client.clone();
382 let channel = self.config.channel.clone();
383 let stats = Arc::clone(&self.stats);
384 let mut shutdown_rx = self.shutdown_tx.subscribe();
385
386 tokio::spawn(async move {
387 let handler = Arc::new(handler);
388
389 loop {
390 if shutdown_rx.try_recv().is_ok() {
392 info!("Invalidation subscriber shutting down...");
393 break;
394 }
395
396 match Self::run_subscriber_loop(
398 &client,
399 &channel,
400 Arc::clone(&handler),
401 Arc::clone(&stats),
402 &mut shutdown_rx,
403 )
404 .await
405 {
406 Ok(()) => {
407 info!("Invalidation subscriber loop completed normally");
408 break;
409 }
410 Err(e) => {
411 error!(
412 "Invalidation subscriber error: {}. Reconnecting in 5s...",
413 e
414 );
415 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
416
417 tokio::select! {
419 () = tokio::time::sleep(Duration::from_secs(5)) => {},
420 _ = shutdown_rx.recv() => {
421 info!("Invalidation subscriber shutting down...");
422 break;
423 }
424 }
425 }
426 }
427 }
428 })
429 }
430
431 async fn run_subscriber_loop<F, Fut>(
433 client: &redis::Client,
434 channel: &str,
435 handler: Arc<F>,
436 stats: Arc<AtomicInvalidationStats>,
437 shutdown_rx: &mut broadcast::Receiver<()>,
438 ) -> CacheResult<()>
439 where
440 F: Fn(InvalidationMessage) -> Fut + Send + Sync + 'static,
441 Fut: std::future::Future<Output = CacheResult<()>> + Send + 'static,
442 {
443 let mut pubsub = client.get_async_pubsub().await.map_err(|e| {
444 crate::error::CacheError::BackendError(format!("Failed to get pubsub connection: {e}"))
445 })?;
446
447 pubsub.subscribe(channel).await.map_err(|e| {
449 crate::error::CacheError::InvalidationError(format!(
450 "Failed to subscribe to channel: {e}"
451 ))
452 })?;
453
454 info!("Subscribed to invalidation channel: {}", channel);
455
456 let mut stream = pubsub.on_message();
458
459 loop {
460 tokio::select! {
462 msg_result = stream.next() => {
463 match msg_result {
464 Some(msg) => {
465 let payload: String = match msg.get_payload() {
467 Ok(p) => p,
468 Err(e) => {
469 warn!("Failed to get message payload: {}", e);
470 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
471 continue;
472 }
473 };
474
475 let invalidation_msg = match InvalidationMessage::from_json(&payload) {
477 Ok(m) => m,
478 Err(e) => {
479 warn!("Failed to deserialize invalidation message: {}", e);
480 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
481 continue;
482 }
483 };
484
485 stats.messages_received.fetch_add(1, Ordering::Relaxed);
487 match &invalidation_msg {
488 InvalidationMessage::Remove { .. } => {
489 stats.removes_received.fetch_add(1, Ordering::Relaxed);
490 }
491 InvalidationMessage::Update { .. } => {
492 stats.updates_received.fetch_add(1, Ordering::Relaxed);
493 }
494 InvalidationMessage::RemovePattern { .. } => {
495 stats.patterns_received.fetch_add(1, Ordering::Relaxed);
496 }
497 InvalidationMessage::RemoveBulk { .. } => {
498 stats.bulk_removes_received.fetch_add(1, Ordering::Relaxed);
499 }
500 }
501
502 if let Err(e) = handler(invalidation_msg).await {
504 error!("Invalidation handler error: {}", e);
505 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
506 }
507 }
508 None => {
509 return Err(crate::error::CacheError::InvalidationError("Pub/Sub message stream ended".to_string()));
511 }
512 }
513 }
514 _ = shutdown_rx.recv() => {
515 return Ok(());
516 }
517 }
518 }
519 }
520
521 pub fn shutdown(&self) {
523 let _ = self.shutdown_tx.send(());
524 }
525}
526
527pub struct ReliableStreamSubscriber {
529 client: redis::Client,
530 config: InvalidationConfig,
531 stats: Arc<AtomicInvalidationStats>,
532 shutdown_tx: broadcast::Sender<()>,
533 group_name: String,
534 consumer_name: String,
535}
536
537impl ReliableStreamSubscriber {
538 pub fn new(redis_url: &str, config: InvalidationConfig, group_name: &str) -> CacheResult<Self> {
544 let client = redis::Client::open(redis_url).map_err(|e| {
545 crate::error::CacheError::ConfigError(format!(
546 "Failed to create Redis client for reliable subscriber: {e}"
547 ))
548 })?;
549
550 let (shutdown_tx, _) = broadcast::channel(1);
551 let consumer_name = format!("consumer-{}", Uuid::new_v4());
552
553 Ok(Self {
554 client,
555 config,
556 stats: Arc::new(AtomicInvalidationStats::default()),
557 shutdown_tx,
558 group_name: group_name.to_string(),
559 consumer_name,
560 })
561 }
562
563 pub fn start<F, Fut>(&self, handler: F) -> tokio::task::JoinHandle<()>
564 where
565 F: Fn(InvalidationMessage) -> Fut + Send + Sync + 'static,
566 Fut: std::future::Future<Output = CacheResult<()>> + Send + 'static,
567 {
568 let client = self.client.clone();
569 let stream_key = self.config.channel.clone();
570 let group_name = self.group_name.clone();
571 let consumer_name = self.consumer_name.clone();
572 let handler = Arc::new(handler);
573 let stats = self.stats.clone();
574 let mut shutdown_rx = self.shutdown_tx.subscribe();
575
576 tokio::spawn(async move {
577 info!(
578 stream = %stream_key,
579 group = %group_name,
580 consumer = %consumer_name,
581 "Starting reliable stream subscriber"
582 );
583
584 let redis_backend = crate::redis_streams::RedisStreams::new(
586 client.get_connection_info().addr().to_string().as_str(),
587 )
588 .await;
589 if let Ok(backend) = redis_backend {
590 let _ = backend
591 .stream_create_group(&stream_key, &group_name, "0")
592 .await;
593
594 loop {
595 if shutdown_rx.try_recv().is_ok() {
597 break;
598 }
599
600 if let Err(e) = Self::run_reliable_loop(
601 &backend,
602 &stream_key,
603 &group_name,
604 &consumer_name,
605 handler.clone(),
606 stats.clone(),
607 &mut shutdown_rx,
608 )
609 .await
610 {
611 error!("Reliable subscriber loop error: {}", e);
612
613 tokio::select! {
614 () = tokio::time::sleep(Duration::from_secs(5)) => {},
615 _ = shutdown_rx.recv() => break,
616 }
617 } else {
618 break; }
620 }
621 }
622 })
623 }
624
625 async fn run_reliable_loop<F, Fut>(
626 backend: &dyn crate::traits::StreamingBackend,
627 stream_key: &str,
628 group_name: &str,
629 consumer_name: &str,
630 handler: Arc<F>,
631 stats: Arc<AtomicInvalidationStats>,
632 shutdown_rx: &mut broadcast::Receiver<()>,
633 ) -> CacheResult<()>
634 where
635 F: Fn(InvalidationMessage) -> Fut + Send + Sync + 'static,
636 Fut: std::future::Future<Output = CacheResult<()>> + Send + 'static,
637 {
638 loop {
639 tokio::select! {
640 entries_result = backend.stream_read_group(stream_key, group_name, consumer_name, 10, Some(5000)) => {
641 let entries = entries_result?;
642 if entries.is_empty() { continue; }
643
644 let mut processed_ids = Vec::new();
645 for (id, fields) in entries {
646 let payload = fields.iter().find(|(k, _)| k == "payload")
648 .map(|(_, v)| v.as_str())
649 .or_else(|| fields.first().map(|(_, v)| v.as_str()));
650
651 if let Some(msg) = payload.and_then(|json| InvalidationMessage::from_json(json).ok()) {
652 stats.messages_received.fetch_add(1, Ordering::Relaxed);
653 if let Err(e) = handler(msg).await {
654 error!("Reliable handler error: {}", e);
655 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
656 } else {
657 processed_ids.push(id);
658 }
659 }
660 }
661
662 if !processed_ids.is_empty() {
663 backend.stream_ack(stream_key, group_name, &processed_ids).await?;
664 }
665 }
666 _ = shutdown_rx.recv() => return Ok(()),
667 }
668 }
669 }
670
671 pub fn shutdown(&self) {
673 let _ = self.shutdown_tx.send(()).unwrap_or(0);
674 }
675}
676
677#[cfg(test)]
678mod tests {
679 use super::*;
680
681 #[test]
682 fn test_invalidation_message_serialization() -> CacheResult<()> {
683 let msg = InvalidationMessage::remove("test_key");
685 let json = msg.to_json()?;
686 let parsed = InvalidationMessage::from_json(&json)?;
687 match parsed {
688 InvalidationMessage::Remove { key } => assert_eq!(key, "test_key"),
689 _ => panic!("Wrong message type"),
690 }
691
692 let msg = InvalidationMessage::update(
694 "test_key",
695 Bytes::from("{\"value\": 123}"),
696 Some(Duration::from_secs(3600)),
697 );
698
699 if let InvalidationMessage::Update {
700 key,
701 value,
702 ttl_secs,
703 } = msg
704 {
705 assert_eq!(key, "test_key");
706 assert_eq!(value, Bytes::from("{\"value\": 123}"));
707 assert_eq!(ttl_secs, Some(3600));
708 } else {
709 panic!("Expected Update message");
710 }
711
712 let msg = InvalidationMessage::remove_pattern("user:*");
714 let json = msg.to_json()?;
715 let parsed = InvalidationMessage::from_json(&json)?;
716 match parsed {
717 InvalidationMessage::RemovePattern { pattern } => assert_eq!(pattern, "user:*"),
718 _ => panic!("Wrong message type"),
719 }
720
721 let msg = InvalidationMessage::remove_bulk(vec!["key1".to_string(), "key2".to_string()]);
723 let json = msg.to_json()?;
724 let parsed = InvalidationMessage::from_json(&json)?;
725 match parsed {
726 InvalidationMessage::RemoveBulk { keys } => assert_eq!(keys, vec!["key1", "key2"]),
727 _ => panic!("Wrong message type"),
728 }
729 Ok(())
730 }
731
732 #[test]
733 fn test_invalidation_config_default() {
734 let config = InvalidationConfig::default();
735 assert_eq!(config.channel, "cache:invalidate");
736 assert!(!config.auto_broadcast_on_write);
737 assert!(!config.enable_audit_stream);
738 }
739}