1use anyhow::{Context, Result};
7use redis::AsyncCommands;
8use serde::{Deserialize, Serialize};
9use std::time::{Duration, SystemTime, UNIX_EPOCH};
10use futures_util::StreamExt;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14#[serde(tag = "type")]
15pub enum InvalidationMessage {
16 Remove {
18 key: String,
19 },
20
21 Update {
24 key: String,
25 value: serde_json::Value,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 ttl_secs: Option<u64>,
28 },
29
30 RemovePattern {
33 pattern: String,
34 },
35
36 RemoveBulk {
38 keys: Vec<String>,
39 },
40}
41
42impl InvalidationMessage {
43 pub fn remove(key: impl Into<String>) -> Self {
45 Self::Remove { key: key.into() }
46 }
47
48 pub fn update(key: impl Into<String>, value: serde_json::Value, ttl: Option<Duration>) -> Self {
50 Self::Update {
51 key: key.into(),
52 value,
53 ttl_secs: ttl.map(|d| d.as_secs()),
54 }
55 }
56
57 pub fn remove_pattern(pattern: impl Into<String>) -> Self {
59 Self::RemovePattern {
60 pattern: pattern.into(),
61 }
62 }
63
64 pub fn remove_bulk(keys: Vec<String>) -> Self {
66 Self::RemoveBulk { keys }
67 }
68
69 pub fn to_json(&self) -> Result<String> {
71 serde_json::to_string(self).context("Failed to serialize invalidation message")
72 }
73
74 pub fn from_json(json: &str) -> Result<Self> {
76 serde_json::from_str(json).context("Failed to deserialize invalidation message")
77 }
78
79 pub fn ttl(&self) -> Option<Duration> {
81 match self {
82 Self::Update { ttl_secs, .. } => ttl_secs.map(Duration::from_secs),
83 _ => None,
84 }
85 }
86}
87
88#[derive(Debug, Clone)]
90pub struct InvalidationConfig {
91 pub channel: String,
93
94 pub auto_broadcast_on_write: bool,
96
97 pub enable_audit_stream: bool,
99
100 pub audit_stream: String,
102
103 pub audit_stream_maxlen: Option<usize>,
105}
106
107impl Default for InvalidationConfig {
108 fn default() -> Self {
109 Self {
110 channel: "cache:invalidate".to_string(),
111 auto_broadcast_on_write: false, enable_audit_stream: false,
113 audit_stream: "cache:invalidations".to_string(),
114 audit_stream_maxlen: Some(10000),
115 }
116 }
117}
118
119pub struct InvalidationPublisher {
121 connection: redis::aio::ConnectionManager,
122 config: InvalidationConfig,
123}
124
125impl InvalidationPublisher {
126 pub fn new(connection: redis::aio::ConnectionManager, config: InvalidationConfig) -> Self {
128 Self { connection, config }
129 }
130
131 pub async fn publish(&mut self, message: &InvalidationMessage) -> Result<()> {
133 let json = message.to_json()?;
134
135 let _: () = self
137 .connection
138 .publish(&self.config.channel, &json)
139 .await
140 .context("Failed to publish invalidation message")?;
141
142 if self.config.enable_audit_stream {
144 if let Err(e) = self.publish_to_audit_stream(message).await {
145 eprintln!("Warning: Failed to publish to audit stream: {}", e);
147 }
148 }
149
150 Ok(())
151 }
152
153 async fn publish_to_audit_stream(&mut self, message: &InvalidationMessage) -> Result<()> {
155 let timestamp = SystemTime::now()
156 .duration_since(UNIX_EPOCH)
157 .unwrap()
158 .as_secs()
159 .to_string();
160
161 let (type_str, key_str, extra_str) = match message {
162 InvalidationMessage::Remove { key } => {
163 ("remove".to_string(), key.clone(), String::new())
164 }
165 InvalidationMessage::Update { key, .. } => {
166 ("update".to_string(), key.clone(), String::new())
167 }
168 InvalidationMessage::RemovePattern { pattern } => {
169 ("remove_pattern".to_string(), pattern.clone(), String::new())
170 }
171 InvalidationMessage::RemoveBulk { keys } => {
172 ("remove_bulk".to_string(), String::new(), keys.len().to_string())
173 }
174 };
175
176 let mut fields = vec![
177 ("type", type_str.as_str()),
178 ("timestamp", timestamp.as_str()),
179 ];
180
181 if !key_str.is_empty() {
182 fields.push(("key", key_str.as_str()));
183 }
184 if !extra_str.is_empty() {
185 fields.push(("count", extra_str.as_str()));
186 }
187
188 let mut cmd = redis::cmd("XADD");
189 cmd.arg(&self.config.audit_stream);
190
191 if let Some(maxlen) = self.config.audit_stream_maxlen {
192 cmd.arg("MAXLEN").arg("~").arg(maxlen);
193 }
194
195 cmd.arg("*"); for (key, value) in fields {
198 cmd.arg(key).arg(value);
199 }
200
201 let _: String = cmd
202 .query_async(&mut self.connection)
203 .await
204 .context("Failed to add to audit stream")?;
205
206 Ok(())
207 }
208}
209
210#[derive(Debug, Default, Clone)]
212pub struct InvalidationStats {
213 pub messages_sent: u64,
215
216 pub messages_received: u64,
218
219 pub removes_received: u64,
221
222 pub updates_received: u64,
224
225 pub patterns_received: u64,
227
228 pub bulk_removes_received: u64,
230
231 pub processing_errors: u64,
233}
234
235use std::sync::atomic::{AtomicU64, Ordering};
236
237#[derive(Debug, Default)]
239pub struct AtomicInvalidationStats {
240 pub messages_sent: AtomicU64,
241 pub messages_received: AtomicU64,
242 pub removes_received: AtomicU64,
243 pub updates_received: AtomicU64,
244 pub patterns_received: AtomicU64,
245 pub bulk_removes_received: AtomicU64,
246 pub processing_errors: AtomicU64,
247}
248
249impl AtomicInvalidationStats {
250 pub fn snapshot(&self) -> InvalidationStats {
251 InvalidationStats {
252 messages_sent: self.messages_sent.load(Ordering::Relaxed),
253 messages_received: self.messages_received.load(Ordering::Relaxed),
254 removes_received: self.removes_received.load(Ordering::Relaxed),
255 updates_received: self.updates_received.load(Ordering::Relaxed),
256 patterns_received: self.patterns_received.load(Ordering::Relaxed),
257 bulk_removes_received: self.bulk_removes_received.load(Ordering::Relaxed),
258 processing_errors: self.processing_errors.load(Ordering::Relaxed),
259 }
260 }
261}
262
263use std::sync::Arc;
264use tokio::sync::broadcast;
265
266pub struct InvalidationSubscriber {
271 client: redis::Client,
273 config: InvalidationConfig,
275 stats: Arc<AtomicInvalidationStats>,
277 shutdown_tx: broadcast::Sender<()>,
279}
280
281impl InvalidationSubscriber {
282 pub fn new(redis_url: &str, config: InvalidationConfig) -> Result<Self> {
288 let client = redis::Client::open(redis_url)
289 .context("Failed to create Redis client for subscriber")?;
290
291 let (shutdown_tx, _) = broadcast::channel(1);
292
293 Ok(Self {
294 client,
295 config,
296 stats: Arc::new(AtomicInvalidationStats::default()),
297 shutdown_tx,
298 })
299 }
300
301 pub fn stats(&self) -> InvalidationStats {
303 self.stats.snapshot()
304 }
305
306 pub fn start<F, Fut>(
314 &self,
315 handler: F,
316 ) -> tokio::task::JoinHandle<()>
317 where
318 F: Fn(InvalidationMessage) -> Fut + Send + Sync + 'static,
319 Fut: std::future::Future<Output = Result<()>> + Send + 'static,
320 {
321 let client = self.client.clone();
322 let channel = self.config.channel.clone();
323 let stats = Arc::clone(&self.stats);
324 let mut shutdown_rx = self.shutdown_tx.subscribe();
325
326 tokio::spawn(async move {
327 let handler = Arc::new(handler);
328
329 loop {
330 if shutdown_rx.try_recv().is_ok() {
332 println!("🛑 Invalidation subscriber shutting down...");
333 break;
334 }
335
336 match Self::run_subscriber_loop(
338 &client,
339 &channel,
340 Arc::clone(&handler),
341 Arc::clone(&stats),
342 &mut shutdown_rx,
343 ).await {
344 Ok(_) => {
345 println!("✅ Invalidation subscriber loop completed normally");
346 break;
347 }
348 Err(e) => {
349 eprintln!("⚠️ Invalidation subscriber error: {}. Reconnecting in 5s...", e);
350 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
351
352 tokio::select! {
354 _ = tokio::time::sleep(Duration::from_secs(5)) => {},
355 _ = shutdown_rx.recv() => {
356 println!("🛑 Invalidation subscriber shutting down...");
357 break;
358 }
359 }
360 }
361 }
362 }
363 })
364 }
365
366 async fn run_subscriber_loop<F, Fut>(
368 client: &redis::Client,
369 channel: &str,
370 handler: Arc<F>,
371 stats: Arc<AtomicInvalidationStats>,
372 shutdown_rx: &mut broadcast::Receiver<()>,
373 ) -> Result<()>
374 where
375 F: Fn(InvalidationMessage) -> Fut + Send + 'static,
376 Fut: std::future::Future<Output = Result<()>> + Send + 'static,
377 {
378 let mut pubsub = client.get_async_pubsub().await
380 .context("Failed to get pubsub connection")?;
381
382 pubsub.subscribe(channel).await
384 .context("Failed to subscribe to channel")?;
385
386 println!("📡 Subscribed to invalidation channel: {}", channel);
387
388 let mut stream = pubsub.on_message();
390
391 loop {
392 tokio::select! {
394 msg_result = stream.next() => {
395 match msg_result {
396 Some(msg) => {
397 let payload: String = match msg.get_payload() {
399 Ok(p) => p,
400 Err(e) => {
401 eprintln!("⚠️ Failed to get message payload: {}", e);
402 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
403 continue;
404 }
405 };
406
407 let invalidation_msg = match InvalidationMessage::from_json(&payload) {
409 Ok(m) => m,
410 Err(e) => {
411 eprintln!("⚠️ Failed to deserialize invalidation message: {}", e);
412 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
413 continue;
414 }
415 };
416
417 stats.messages_received.fetch_add(1, Ordering::Relaxed);
419 match &invalidation_msg {
420 InvalidationMessage::Remove { .. } => {
421 stats.removes_received.fetch_add(1, Ordering::Relaxed);
422 }
423 InvalidationMessage::Update { .. } => {
424 stats.updates_received.fetch_add(1, Ordering::Relaxed);
425 }
426 InvalidationMessage::RemovePattern { .. } => {
427 stats.patterns_received.fetch_add(1, Ordering::Relaxed);
428 }
429 InvalidationMessage::RemoveBulk { .. } => {
430 stats.bulk_removes_received.fetch_add(1, Ordering::Relaxed);
431 }
432 }
433
434 if let Err(e) = handler(invalidation_msg).await {
436 eprintln!("⚠️ Invalidation handler error: {}", e);
437 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
438 }
439 }
440 None => {
441 return Err(anyhow::anyhow!("Pub/Sub message stream ended"));
443 }
444 }
445 }
446 _ = shutdown_rx.recv() => {
447 return Ok(());
448 }
449 }
450 }
451 }
452
453 pub fn shutdown(&self) {
455 let _ = self.shutdown_tx.send(());
456 }
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462
463 #[test]
464 fn test_invalidation_message_serialization() {
465 let msg = InvalidationMessage::remove("test_key");
467 let json = msg.to_json().unwrap();
468 let parsed = InvalidationMessage::from_json(&json).unwrap();
469 match parsed {
470 InvalidationMessage::Remove { key } => assert_eq!(key, "test_key"),
471 _ => panic!("Wrong message type"),
472 }
473
474 let msg = InvalidationMessage::update(
476 "test_key",
477 serde_json::json!({"value": 123}),
478 Some(Duration::from_secs(300)),
479 );
480 let json = msg.to_json().unwrap();
481 let parsed = InvalidationMessage::from_json(&json).unwrap();
482 match parsed {
483 InvalidationMessage::Update { key, value, ttl_secs } => {
484 assert_eq!(key, "test_key");
485 assert_eq!(value, serde_json::json!({"value": 123}));
486 assert_eq!(ttl_secs, Some(300));
487 }
488 _ => panic!("Wrong message type"),
489 }
490
491 let msg = InvalidationMessage::remove_pattern("user:*");
493 let json = msg.to_json().unwrap();
494 let parsed = InvalidationMessage::from_json(&json).unwrap();
495 match parsed {
496 InvalidationMessage::RemovePattern { pattern } => assert_eq!(pattern, "user:*"),
497 _ => panic!("Wrong message type"),
498 }
499
500 let msg = InvalidationMessage::remove_bulk(vec!["key1".to_string(), "key2".to_string()]);
502 let json = msg.to_json().unwrap();
503 let parsed = InvalidationMessage::from_json(&json).unwrap();
504 match parsed {
505 InvalidationMessage::RemoveBulk { keys } => assert_eq!(keys, vec!["key1", "key2"]),
506 _ => panic!("Wrong message type"),
507 }
508 }
509
510 #[test]
511 fn test_invalidation_config_default() {
512 let config = InvalidationConfig::default();
513 assert_eq!(config.channel, "cache:invalidate");
514 assert_eq!(config.auto_broadcast_on_write, false);
515 assert_eq!(config.enable_audit_stream, false);
516 }
517}