1use anyhow::{Context, Result};
7use futures_util::StreamExt;
8use redis::AsyncCommands;
9use serde::{Deserialize, Serialize};
10use std::time::{Duration, SystemTime, UNIX_EPOCH};
11use tracing::{error, info, warn};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15#[serde(tag = "type")]
16pub enum InvalidationMessage {
17 Remove { key: String },
19
20 Update {
23 key: String,
24 value: serde_json::Value,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 ttl_secs: Option<u64>,
27 },
28
29 RemovePattern { pattern: String },
32
33 RemoveBulk { keys: Vec<String> },
35}
36
37impl InvalidationMessage {
38 pub fn remove(key: impl Into<String>) -> Self {
40 Self::Remove { key: key.into() }
41 }
42
43 pub fn update(key: impl Into<String>, value: serde_json::Value, ttl: Option<Duration>) -> Self {
45 Self::Update {
46 key: key.into(),
47 value,
48 ttl_secs: ttl.map(|d| d.as_secs()),
49 }
50 }
51
52 pub fn remove_pattern(pattern: impl Into<String>) -> Self {
54 Self::RemovePattern {
55 pattern: pattern.into(),
56 }
57 }
58
59 #[must_use]
61 pub fn remove_bulk(keys: Vec<String>) -> Self {
62 Self::RemoveBulk { keys }
63 }
64
65 pub fn to_json(&self) -> Result<String> {
71 serde_json::to_string(self).context("Failed to serialize invalidation message")
72 }
73
74 pub fn from_json(json: &str) -> Result<Self> {
80 serde_json::from_str(json).context("Failed to deserialize invalidation message")
81 }
82
83 pub fn ttl(&self) -> Option<Duration> {
85 match self {
86 Self::Update { ttl_secs, .. } => ttl_secs.map(Duration::from_secs),
87 _ => None,
88 }
89 }
90}
91
92#[derive(Debug, Clone)]
94pub struct InvalidationConfig {
95 pub channel: String,
97
98 pub auto_broadcast_on_write: bool,
100
101 pub enable_audit_stream: bool,
103
104 pub audit_stream: String,
106
107 pub audit_stream_maxlen: Option<usize>,
109}
110
111impl Default for InvalidationConfig {
112 fn default() -> Self {
113 Self {
114 channel: "cache:invalidate".to_string(),
115 auto_broadcast_on_write: false, enable_audit_stream: false,
117 audit_stream: "cache:invalidations".to_string(),
118 audit_stream_maxlen: Some(10000),
119 }
120 }
121}
122
123pub struct InvalidationPublisher {
125 connection: redis::aio::ConnectionManager,
126 config: InvalidationConfig,
127}
128
129impl InvalidationPublisher {
130 #[must_use]
132 pub fn new(connection: redis::aio::ConnectionManager, config: InvalidationConfig) -> Self {
133 Self { connection, config }
134 }
135
136 pub async fn publish(&mut self, message: &InvalidationMessage) -> Result<()> {
142 let json = message.to_json()?;
143
144 let _: () = self
146 .connection
147 .publish(&self.config.channel, &json)
148 .await
149 .context("Failed to publish invalidation message")?;
150
151 if self.config.enable_audit_stream {
153 if let Err(e) = self.publish_to_audit_stream(message).await {
154 warn!("Failed to publish to audit stream: {}", e);
156 }
157 }
158
159 Ok(())
160 }
161
162 async fn publish_to_audit_stream(&mut self, message: &InvalidationMessage) -> Result<()> {
164 let timestamp = SystemTime::now()
165 .duration_since(UNIX_EPOCH)
166 .unwrap_or(Duration::ZERO)
167 .as_secs()
168 .to_string();
169
170 let (type_str, key_str): (&str, &str);
172 let extra_str: String;
173
174 match message {
175 InvalidationMessage::Remove { key } => {
176 type_str = "remove";
177 key_str = key.as_str();
178 extra_str = String::new();
179 }
180 InvalidationMessage::Update { key, .. } => {
181 type_str = "update";
182 key_str = key.as_str();
183 extra_str = String::new();
184 }
185 InvalidationMessage::RemovePattern { pattern } => {
186 type_str = "remove_pattern";
187 key_str = pattern.as_str();
188 extra_str = String::new();
189 }
190 InvalidationMessage::RemoveBulk { keys } => {
191 type_str = "remove_bulk";
192 key_str = "";
193 extra_str = keys.len().to_string();
194 }
195 }
196
197 let mut fields = vec![("type", type_str), ("timestamp", timestamp.as_str())];
198
199 if !key_str.is_empty() {
200 fields.push(("key", key_str));
201 }
202 if !extra_str.is_empty() {
203 fields.push(("count", extra_str.as_str()));
204 }
205
206 let mut cmd = redis::cmd("XADD");
207 cmd.arg(&self.config.audit_stream);
208
209 if let Some(maxlen) = self.config.audit_stream_maxlen {
210 cmd.arg("MAXLEN").arg("~").arg(maxlen);
211 }
212
213 cmd.arg("*"); for (key, value) in fields {
216 cmd.arg(key).arg(value);
217 }
218
219 let _: String = cmd
220 .query_async(&mut self.connection)
221 .await
222 .context("Failed to add to audit stream")?;
223
224 Ok(())
225 }
226}
227
228#[derive(Debug, Default, Clone)]
230pub struct InvalidationStats {
231 pub messages_sent: u64,
233
234 pub messages_received: u64,
236
237 pub removes_received: u64,
239
240 pub updates_received: u64,
242
243 pub patterns_received: u64,
245
246 pub bulk_removes_received: u64,
248
249 pub processing_errors: u64,
251}
252
253use std::sync::atomic::{AtomicU64, Ordering};
254
255#[derive(Debug, Default)]
257pub struct AtomicInvalidationStats {
258 pub messages_sent: AtomicU64,
259 pub messages_received: AtomicU64,
260 pub removes_received: AtomicU64,
261 pub updates_received: AtomicU64,
262 pub patterns_received: AtomicU64,
263 pub bulk_removes_received: AtomicU64,
264 pub processing_errors: AtomicU64,
265}
266
267impl AtomicInvalidationStats {
268 pub fn snapshot(&self) -> InvalidationStats {
269 InvalidationStats {
270 messages_sent: self.messages_sent.load(Ordering::Relaxed),
271 messages_received: self.messages_received.load(Ordering::Relaxed),
272 removes_received: self.removes_received.load(Ordering::Relaxed),
273 updates_received: self.updates_received.load(Ordering::Relaxed),
274 patterns_received: self.patterns_received.load(Ordering::Relaxed),
275 bulk_removes_received: self.bulk_removes_received.load(Ordering::Relaxed),
276 processing_errors: self.processing_errors.load(Ordering::Relaxed),
277 }
278 }
279}
280
281use std::sync::Arc;
282use tokio::sync::broadcast;
283
284pub struct InvalidationSubscriber {
289 client: redis::Client,
291 config: InvalidationConfig,
293 stats: Arc<AtomicInvalidationStats>,
295 shutdown_tx: broadcast::Sender<()>,
297}
298
299impl InvalidationSubscriber {
300 pub fn new(redis_url: &str, config: InvalidationConfig) -> Result<Self> {
309 let client = redis::Client::open(redis_url)
310 .context("Failed to create Redis client for subscriber")?;
311
312 let (shutdown_tx, _) = broadcast::channel(1);
313
314 Ok(Self {
315 client,
316 config,
317 stats: Arc::new(AtomicInvalidationStats::default()),
318 shutdown_tx,
319 })
320 }
321
322 #[must_use]
324 pub fn stats(&self) -> InvalidationStats {
325 self.stats.snapshot()
326 }
327
328 pub fn start<F, Fut>(&self, handler: F) -> tokio::task::JoinHandle<()>
336 where
337 F: Fn(InvalidationMessage) -> Fut + Send + Sync + 'static,
338 Fut: std::future::Future<Output = Result<()>> + Send + 'static,
339 {
340 let client = self.client.clone();
341 let channel = self.config.channel.clone();
342 let stats = Arc::clone(&self.stats);
343 let mut shutdown_rx = self.shutdown_tx.subscribe();
344
345 tokio::spawn(async move {
346 let handler = Arc::new(handler);
347
348 loop {
349 if shutdown_rx.try_recv().is_ok() {
351 info!("Invalidation subscriber shutting down...");
352 break;
353 }
354
355 match Self::run_subscriber_loop(
357 &client,
358 &channel,
359 Arc::clone(&handler),
360 Arc::clone(&stats),
361 &mut shutdown_rx,
362 )
363 .await
364 {
365 Ok(()) => {
366 info!("Invalidation subscriber loop completed normally");
367 break;
368 }
369 Err(e) => {
370 error!(
371 "Invalidation subscriber error: {}. Reconnecting in 5s...",
372 e
373 );
374 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
375
376 tokio::select! {
378 () = tokio::time::sleep(Duration::from_secs(5)) => {},
379 _ = shutdown_rx.recv() => {
380 info!("Invalidation subscriber shutting down...");
381 break;
382 }
383 }
384 }
385 }
386 }
387 })
388 }
389
390 async fn run_subscriber_loop<F, Fut>(
392 client: &redis::Client,
393 channel: &str,
394 handler: Arc<F>,
395 stats: Arc<AtomicInvalidationStats>,
396 shutdown_rx: &mut broadcast::Receiver<()>,
397 ) -> Result<()>
398 where
399 F: Fn(InvalidationMessage) -> Fut + Send + 'static,
400 Fut: std::future::Future<Output = Result<()>> + Send + 'static,
401 {
402 let mut pubsub = client
404 .get_async_pubsub()
405 .await
406 .context("Failed to get pubsub connection")?;
407
408 pubsub
410 .subscribe(channel)
411 .await
412 .context("Failed to subscribe to channel")?;
413
414 info!("Subscribed to invalidation channel: {}", channel);
415
416 let mut stream = pubsub.on_message();
418
419 loop {
420 tokio::select! {
422 msg_result = stream.next() => {
423 match msg_result {
424 Some(msg) => {
425 let payload: String = match msg.get_payload() {
427 Ok(p) => p,
428 Err(e) => {
429 warn!("Failed to get message payload: {}", e);
430 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
431 continue;
432 }
433 };
434
435 let invalidation_msg = match InvalidationMessage::from_json(&payload) {
437 Ok(m) => m,
438 Err(e) => {
439 warn!("Failed to deserialize invalidation message: {}", e);
440 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
441 continue;
442 }
443 };
444
445 stats.messages_received.fetch_add(1, Ordering::Relaxed);
447 match &invalidation_msg {
448 InvalidationMessage::Remove { .. } => {
449 stats.removes_received.fetch_add(1, Ordering::Relaxed);
450 }
451 InvalidationMessage::Update { .. } => {
452 stats.updates_received.fetch_add(1, Ordering::Relaxed);
453 }
454 InvalidationMessage::RemovePattern { .. } => {
455 stats.patterns_received.fetch_add(1, Ordering::Relaxed);
456 }
457 InvalidationMessage::RemoveBulk { .. } => {
458 stats.bulk_removes_received.fetch_add(1, Ordering::Relaxed);
459 }
460 }
461
462 if let Err(e) = handler(invalidation_msg).await {
464 error!("Invalidation handler error: {}", e);
465 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
466 }
467 }
468 None => {
469 return Err(anyhow::anyhow!("Pub/Sub message stream ended"));
471 }
472 }
473 }
474 _ = shutdown_rx.recv() => {
475 return Ok(());
476 }
477 }
478 }
479 }
480
481 pub fn shutdown(&self) {
483 let _ = self.shutdown_tx.send(());
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490
491 #[test]
492 fn test_invalidation_message_serialization() -> Result<()> {
493 let msg = InvalidationMessage::remove("test_key");
495 let json = msg.to_json()?;
496 let parsed = InvalidationMessage::from_json(&json)?;
497 match parsed {
498 InvalidationMessage::Remove { key } => assert_eq!(key, "test_key"),
499 _ => panic!("Wrong message type"),
500 }
501
502 let msg = InvalidationMessage::update(
504 "test_key",
505 serde_json::json!({"value": 123}),
506 Some(Duration::from_secs(300)),
507 );
508 let json = msg.to_json()?;
509 let parsed = InvalidationMessage::from_json(&json)?;
510 match parsed {
511 InvalidationMessage::Update {
512 key,
513 value,
514 ttl_secs,
515 } => {
516 assert_eq!(key, "test_key");
517 assert_eq!(value, serde_json::json!({"value": 123}));
518 assert_eq!(ttl_secs, Some(300));
519 }
520 _ => panic!("Wrong message type"),
521 }
522
523 let msg = InvalidationMessage::remove_pattern("user:*");
525 let json = msg.to_json()?;
526 let parsed = InvalidationMessage::from_json(&json)?;
527 match parsed {
528 InvalidationMessage::RemovePattern { pattern } => assert_eq!(pattern, "user:*"),
529 _ => panic!("Wrong message type"),
530 }
531
532 let msg = InvalidationMessage::remove_bulk(vec!["key1".to_string(), "key2".to_string()]);
534 let json = msg.to_json()?;
535 let parsed = InvalidationMessage::from_json(&json)?;
536 match parsed {
537 InvalidationMessage::RemoveBulk { keys } => assert_eq!(keys, vec!["key1", "key2"]),
538 _ => panic!("Wrong message type"),
539 }
540 Ok(())
541 }
542
543 #[test]
544 fn test_invalidation_config_default() {
545 let config = InvalidationConfig::default();
546 assert_eq!(config.channel, "cache:invalidate");
547 assert!(!config.auto_broadcast_on_write);
548 assert!(!config.enable_audit_stream);
549 }
550}