1use anyhow::{Context, Result};
7use bytes::Bytes;
8use futures_util::StreamExt;
9use redis::AsyncCommands;
10use serde::{Deserialize, Serialize};
11use std::time::{Duration, SystemTime, UNIX_EPOCH};
12use tracing::{error, info, warn};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16#[serde(tag = "type")]
17pub enum InvalidationMessage {
18 Remove { key: String },
20
21 Update {
24 key: String,
25 #[serde(with = "serde_bytes_wrapper")]
26 value: Bytes,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 ttl_secs: Option<u64>,
29 },
30
31 RemovePattern { pattern: String },
34
35 RemoveBulk { keys: Vec<String> },
37}
38
39impl InvalidationMessage {
40 pub fn remove(key: impl Into<String>) -> Self {
42 Self::Remove { key: key.into() }
43 }
44
45 pub fn update(key: impl Into<String>, value: Bytes, ttl: Option<Duration>) -> Self {
47 Self::Update {
48 key: key.into(),
49 value,
50 ttl_secs: ttl.map(|d| d.as_secs()),
51 }
52 }
53
54 pub fn remove_pattern(pattern: impl Into<String>) -> Self {
56 Self::RemovePattern {
57 pattern: pattern.into(),
58 }
59 }
60
61 #[must_use]
63 pub fn remove_bulk(keys: Vec<String>) -> Self {
64 Self::RemoveBulk { keys }
65 }
66
67 pub fn to_json(&self) -> Result<String> {
73 serde_json::to_string(self).context("Failed to serialize invalidation message")
74 }
75
76 pub fn from_json(json: &str) -> Result<Self> {
82 serde_json::from_str(json).context("Failed to deserialize invalidation message")
83 }
84
85 pub fn ttl(&self) -> Option<Duration> {
87 match self {
88 Self::Update { ttl_secs, .. } => ttl_secs.map(Duration::from_secs),
89 _ => None,
90 }
91 }
92}
93
94mod serde_bytes_wrapper {
96 use bytes::Bytes;
97 use serde::{Deserialize, Deserializer, Serializer};
98
99 pub fn serialize<S>(bytes: &Bytes, serializer: S) -> Result<S::Ok, S::Error>
100 where
101 S: Serializer,
102 {
103 serializer.serialize_bytes(bytes)
106 }
107
108 pub fn deserialize<'de, D>(deserializer: D) -> Result<Bytes, D::Error>
109 where
110 D: Deserializer<'de>,
111 {
112 let v: Vec<u8> = Vec::deserialize(deserializer)?;
113 Ok(Bytes::from(v))
114 }
115}
116
117#[derive(Debug, Clone)]
119pub struct InvalidationConfig {
120 pub channel: String,
122
123 pub auto_broadcast_on_write: bool,
125
126 pub enable_audit_stream: bool,
128
129 pub audit_stream: String,
131
132 pub audit_stream_maxlen: Option<usize>,
134}
135
136impl Default for InvalidationConfig {
137 fn default() -> Self {
138 Self {
139 channel: "cache:invalidate".to_string(),
140 auto_broadcast_on_write: false, enable_audit_stream: false,
142 audit_stream: "cache:invalidations".to_string(),
143 audit_stream_maxlen: Some(10000),
144 }
145 }
146}
147
148pub struct InvalidationPublisher {
150 connection: redis::aio::ConnectionManager,
151 config: InvalidationConfig,
152}
153
154impl InvalidationPublisher {
155 #[must_use]
157 pub fn new(connection: redis::aio::ConnectionManager, config: InvalidationConfig) -> Self {
158 Self { connection, config }
159 }
160
161 pub async fn publish(&mut self, message: &InvalidationMessage) -> Result<()> {
167 let json = message.to_json()?;
168
169 let _: () = self
171 .connection
172 .publish(&self.config.channel, &json)
173 .await
174 .context("Failed to publish invalidation message")?;
175
176 if self.config.enable_audit_stream
178 && let Err(e) = self.publish_to_audit_stream(message).await
179 {
180 warn!("Failed to publish to audit stream: {}", e);
182 }
183
184 Ok(())
185 }
186
187 async fn publish_to_audit_stream(&mut self, message: &InvalidationMessage) -> Result<()> {
189 let timestamp = SystemTime::now()
190 .duration_since(UNIX_EPOCH)
191 .unwrap_or(Duration::ZERO)
192 .as_secs()
193 .to_string();
194
195 let (type_str, key_str): (&str, &str);
197 let extra_str: String;
198
199 match message {
200 InvalidationMessage::Remove { key } => {
201 type_str = "remove";
202 key_str = key.as_str();
203 extra_str = String::new();
204 }
205 InvalidationMessage::Update { key, .. } => {
206 type_str = "update";
207 key_str = key.as_str();
208 extra_str = String::new();
209 }
210 InvalidationMessage::RemovePattern { pattern } => {
211 type_str = "remove_pattern";
212 key_str = pattern.as_str();
213 extra_str = String::new();
214 }
215 InvalidationMessage::RemoveBulk { keys } => {
216 type_str = "remove_bulk";
217 key_str = "";
218 extra_str = keys.len().to_string();
219 }
220 }
221
222 let mut fields = vec![("type", type_str), ("timestamp", timestamp.as_str())];
223
224 if !key_str.is_empty() {
225 fields.push(("key", key_str));
226 }
227 if !extra_str.is_empty() {
228 fields.push(("count", extra_str.as_str()));
229 }
230
231 let mut cmd = redis::cmd("XADD");
232 cmd.arg(&self.config.audit_stream);
233
234 if let Some(maxlen) = self.config.audit_stream_maxlen {
235 cmd.arg("MAXLEN").arg("~").arg(maxlen);
236 }
237
238 cmd.arg("*"); for (key, value) in fields {
241 cmd.arg(key).arg(value);
242 }
243
244 let _: String = cmd
245 .query_async(&mut self.connection)
246 .await
247 .context("Failed to add to audit stream")?;
248
249 Ok(())
250 }
251}
252
253#[derive(Debug, Default, Clone)]
255pub struct InvalidationStats {
256 pub messages_sent: u64,
258
259 pub messages_received: u64,
261
262 pub removes_received: u64,
264
265 pub updates_received: u64,
267
268 pub patterns_received: u64,
270
271 pub bulk_removes_received: u64,
273
274 pub processing_errors: u64,
276}
277
278use std::sync::atomic::{AtomicU64, Ordering};
279
280#[derive(Debug, Default)]
282pub struct AtomicInvalidationStats {
283 pub messages_sent: AtomicU64,
284 pub messages_received: AtomicU64,
285 pub removes_received: AtomicU64,
286 pub updates_received: AtomicU64,
287 pub patterns_received: AtomicU64,
288 pub bulk_removes_received: AtomicU64,
289 pub processing_errors: AtomicU64,
290}
291
292impl AtomicInvalidationStats {
293 pub fn snapshot(&self) -> InvalidationStats {
294 InvalidationStats {
295 messages_sent: self.messages_sent.load(Ordering::Relaxed),
296 messages_received: self.messages_received.load(Ordering::Relaxed),
297 removes_received: self.removes_received.load(Ordering::Relaxed),
298 updates_received: self.updates_received.load(Ordering::Relaxed),
299 patterns_received: self.patterns_received.load(Ordering::Relaxed),
300 bulk_removes_received: self.bulk_removes_received.load(Ordering::Relaxed),
301 processing_errors: self.processing_errors.load(Ordering::Relaxed),
302 }
303 }
304}
305
306use std::sync::Arc;
307use tokio::sync::broadcast;
308
309pub struct InvalidationSubscriber {
314 client: redis::Client,
316 config: InvalidationConfig,
318 stats: Arc<AtomicInvalidationStats>,
320 shutdown_tx: broadcast::Sender<()>,
322}
323
324impl InvalidationSubscriber {
325 pub fn new(redis_url: &str, config: InvalidationConfig) -> Result<Self> {
334 let client = redis::Client::open(redis_url)
335 .context("Failed to create Redis client for subscriber")?;
336
337 let (shutdown_tx, _) = broadcast::channel(1);
338
339 Ok(Self {
340 client,
341 config,
342 stats: Arc::new(AtomicInvalidationStats::default()),
343 shutdown_tx,
344 })
345 }
346
347 #[must_use]
349 pub fn stats(&self) -> InvalidationStats {
350 self.stats.snapshot()
351 }
352
353 pub fn start<F, Fut>(&self, handler: F) -> tokio::task::JoinHandle<()>
361 where
362 F: Fn(InvalidationMessage) -> Fut + Send + Sync + 'static,
363 Fut: std::future::Future<Output = Result<()>> + Send + 'static,
364 {
365 let client = self.client.clone();
366 let channel = self.config.channel.clone();
367 let stats = Arc::clone(&self.stats);
368 let mut shutdown_rx = self.shutdown_tx.subscribe();
369
370 tokio::spawn(async move {
371 let handler = Arc::new(handler);
372
373 loop {
374 if shutdown_rx.try_recv().is_ok() {
376 info!("Invalidation subscriber shutting down...");
377 break;
378 }
379
380 match Self::run_subscriber_loop(
382 &client,
383 &channel,
384 Arc::clone(&handler),
385 Arc::clone(&stats),
386 &mut shutdown_rx,
387 )
388 .await
389 {
390 Ok(()) => {
391 info!("Invalidation subscriber loop completed normally");
392 break;
393 }
394 Err(e) => {
395 error!(
396 "Invalidation subscriber error: {}. Reconnecting in 5s...",
397 e
398 );
399 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
400
401 tokio::select! {
403 () = tokio::time::sleep(Duration::from_secs(5)) => {},
404 _ = shutdown_rx.recv() => {
405 info!("Invalidation subscriber shutting down...");
406 break;
407 }
408 }
409 }
410 }
411 }
412 })
413 }
414
415 async fn run_subscriber_loop<F, Fut>(
417 client: &redis::Client,
418 channel: &str,
419 handler: Arc<F>,
420 stats: Arc<AtomicInvalidationStats>,
421 shutdown_rx: &mut broadcast::Receiver<()>,
422 ) -> Result<()>
423 where
424 F: Fn(InvalidationMessage) -> Fut + Send + 'static,
425 Fut: std::future::Future<Output = Result<()>> + Send + 'static,
426 {
427 let mut pubsub = client
429 .get_async_pubsub()
430 .await
431 .context("Failed to get pubsub connection")?;
432
433 pubsub
435 .subscribe(channel)
436 .await
437 .context("Failed to subscribe to channel")?;
438
439 info!("Subscribed to invalidation channel: {}", channel);
440
441 let mut stream = pubsub.on_message();
443
444 loop {
445 tokio::select! {
447 msg_result = stream.next() => {
448 match msg_result {
449 Some(msg) => {
450 let payload: String = match msg.get_payload() {
452 Ok(p) => p,
453 Err(e) => {
454 warn!("Failed to get message payload: {}", e);
455 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
456 continue;
457 }
458 };
459
460 let invalidation_msg = match InvalidationMessage::from_json(&payload) {
462 Ok(m) => m,
463 Err(e) => {
464 warn!("Failed to deserialize invalidation message: {}", e);
465 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
466 continue;
467 }
468 };
469
470 stats.messages_received.fetch_add(1, Ordering::Relaxed);
472 match &invalidation_msg {
473 InvalidationMessage::Remove { .. } => {
474 stats.removes_received.fetch_add(1, Ordering::Relaxed);
475 }
476 InvalidationMessage::Update { .. } => {
477 stats.updates_received.fetch_add(1, Ordering::Relaxed);
478 }
479 InvalidationMessage::RemovePattern { .. } => {
480 stats.patterns_received.fetch_add(1, Ordering::Relaxed);
481 }
482 InvalidationMessage::RemoveBulk { .. } => {
483 stats.bulk_removes_received.fetch_add(1, Ordering::Relaxed);
484 }
485 }
486
487 if let Err(e) = handler(invalidation_msg).await {
489 error!("Invalidation handler error: {}", e);
490 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
491 }
492 }
493 None => {
494 return Err(anyhow::anyhow!("Pub/Sub message stream ended"));
496 }
497 }
498 }
499 _ = shutdown_rx.recv() => {
500 return Ok(());
501 }
502 }
503 }
504 }
505
506 pub fn shutdown(&self) {
508 let _ = self.shutdown_tx.send(());
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515
516 #[test]
517 fn test_invalidation_message_serialization() -> Result<()> {
518 let msg = InvalidationMessage::remove("test_key");
520 let json = msg.to_json()?;
521 let parsed = InvalidationMessage::from_json(&json)?;
522 match parsed {
523 InvalidationMessage::Remove { key } => assert_eq!(key, "test_key"),
524 _ => panic!("Wrong message type"),
525 }
526
527 let msg = InvalidationMessage::update(
529 "test_key",
530 Bytes::from("{\"value\": 123}"),
531 Some(Duration::from_secs(3600)),
532 );
533
534 if let InvalidationMessage::Update {
535 key,
536 value,
537 ttl_secs,
538 } = msg
539 {
540 assert_eq!(key, "test_key");
541 assert_eq!(value, Bytes::from("{\"value\": 123}"));
542 assert_eq!(ttl_secs, Some(3600));
543 } else {
544 panic!("Expected Update message");
545 }
546
547 let msg = InvalidationMessage::remove_pattern("user:*");
549 let json = msg.to_json()?;
550 let parsed = InvalidationMessage::from_json(&json)?;
551 match parsed {
552 InvalidationMessage::RemovePattern { pattern } => assert_eq!(pattern, "user:*"),
553 _ => panic!("Wrong message type"),
554 }
555
556 let msg = InvalidationMessage::remove_bulk(vec!["key1".to_string(), "key2".to_string()]);
558 let json = msg.to_json()?;
559 let parsed = InvalidationMessage::from_json(&json)?;
560 match parsed {
561 InvalidationMessage::RemoveBulk { keys } => assert_eq!(keys, vec!["key1", "key2"]),
562 _ => panic!("Wrong message type"),
563 }
564 Ok(())
565 }
566
567 #[test]
568 fn test_invalidation_config_default() {
569 let config = InvalidationConfig::default();
570 assert_eq!(config.channel, "cache:invalidate");
571 assert!(!config.auto_broadcast_on_write);
572 assert!(!config.enable_audit_stream);
573 }
574}