1use anyhow::{Context, Result};
7use redis::AsyncCommands;
8use serde::{Deserialize, Serialize};
9use std::time::{Duration, SystemTime, UNIX_EPOCH};
10use futures_util::StreamExt;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14#[serde(tag = "type")]
15pub enum InvalidationMessage {
16 Remove {
18 key: String,
19 },
20
21 Update {
24 key: String,
25 value: serde_json::Value,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 ttl_secs: Option<u64>,
28 },
29
30 RemovePattern {
33 pattern: String,
34 },
35
36 RemoveBulk {
38 keys: Vec<String>,
39 },
40}
41
42impl InvalidationMessage {
43 pub fn remove(key: impl Into<String>) -> Self {
45 Self::Remove { key: key.into() }
46 }
47
48 pub fn update(key: impl Into<String>, value: serde_json::Value, ttl: Option<Duration>) -> Self {
50 Self::Update {
51 key: key.into(),
52 value,
53 ttl_secs: ttl.map(|d| d.as_secs()),
54 }
55 }
56
57 pub fn remove_pattern(pattern: impl Into<String>) -> Self {
59 Self::RemovePattern {
60 pattern: pattern.into(),
61 }
62 }
63
64 pub fn remove_bulk(keys: Vec<String>) -> Self {
66 Self::RemoveBulk { keys }
67 }
68
69 pub fn to_json(&self) -> Result<String> {
71 serde_json::to_string(self).context("Failed to serialize invalidation message")
72 }
73
74 pub fn from_json(json: &str) -> Result<Self> {
76 serde_json::from_str(json).context("Failed to deserialize invalidation message")
77 }
78
79 pub fn ttl(&self) -> Option<Duration> {
81 match self {
82 Self::Update { ttl_secs, .. } => ttl_secs.map(Duration::from_secs),
83 _ => None,
84 }
85 }
86}
87
88#[derive(Debug, Clone)]
90pub struct InvalidationConfig {
91 pub channel: String,
93
94 pub auto_broadcast_on_write: bool,
96
97 pub enable_audit_stream: bool,
99
100 pub audit_stream: String,
102
103 pub audit_stream_maxlen: Option<usize>,
105}
106
107impl Default for InvalidationConfig {
108 fn default() -> Self {
109 Self {
110 channel: "cache:invalidate".to_string(),
111 auto_broadcast_on_write: false, enable_audit_stream: false,
113 audit_stream: "cache:invalidations".to_string(),
114 audit_stream_maxlen: Some(10000),
115 }
116 }
117}
118
119pub struct InvalidationPublisher {
121 connection: redis::aio::ConnectionManager,
122 config: InvalidationConfig,
123}
124
125impl InvalidationPublisher {
126 pub fn new(connection: redis::aio::ConnectionManager, config: InvalidationConfig) -> Self {
128 Self { connection, config }
129 }
130
131 pub async fn publish(&mut self, message: &InvalidationMessage) -> Result<()> {
133 let json = message.to_json()?;
134
135 let _: () = self
137 .connection
138 .publish(&self.config.channel, &json)
139 .await
140 .context("Failed to publish invalidation message")?;
141
142 if self.config.enable_audit_stream {
144 if let Err(e) = self.publish_to_audit_stream(message).await {
145 eprintln!("Warning: Failed to publish to audit stream: {}", e);
147 }
148 }
149
150 Ok(())
151 }
152
153 async fn publish_to_audit_stream(&mut self, message: &InvalidationMessage) -> Result<()> {
155 let timestamp = SystemTime::now()
156 .duration_since(UNIX_EPOCH)
157 .unwrap_or(Duration::ZERO)
158 .as_secs()
159 .to_string();
160
161 let (type_str, key_str): (&str, &str);
163 let extra_str: String;
164
165 match message {
166 InvalidationMessage::Remove { key } => {
167 type_str = "remove";
168 key_str = key.as_str();
169 extra_str = String::new();
170 }
171 InvalidationMessage::Update { key, .. } => {
172 type_str = "update";
173 key_str = key.as_str();
174 extra_str = String::new();
175 }
176 InvalidationMessage::RemovePattern { pattern } => {
177 type_str = "remove_pattern";
178 key_str = pattern.as_str();
179 extra_str = String::new();
180 }
181 InvalidationMessage::RemoveBulk { keys } => {
182 type_str = "remove_bulk";
183 key_str = "";
184 extra_str = keys.len().to_string();
185 }
186 }
187
188 let mut fields = vec![
189 ("type", type_str),
190 ("timestamp", timestamp.as_str()),
191 ];
192
193 if !key_str.is_empty() {
194 fields.push(("key", key_str));
195 }
196 if !extra_str.is_empty() {
197 fields.push(("count", extra_str.as_str()));
198 }
199
200 let mut cmd = redis::cmd("XADD");
201 cmd.arg(&self.config.audit_stream);
202
203 if let Some(maxlen) = self.config.audit_stream_maxlen {
204 cmd.arg("MAXLEN").arg("~").arg(maxlen);
205 }
206
207 cmd.arg("*"); for (key, value) in fields {
210 cmd.arg(key).arg(value);
211 }
212
213 let _: String = cmd
214 .query_async(&mut self.connection)
215 .await
216 .context("Failed to add to audit stream")?;
217
218 Ok(())
219 }
220}
221
222#[derive(Debug, Default, Clone)]
224pub struct InvalidationStats {
225 pub messages_sent: u64,
227
228 pub messages_received: u64,
230
231 pub removes_received: u64,
233
234 pub updates_received: u64,
236
237 pub patterns_received: u64,
239
240 pub bulk_removes_received: u64,
242
243 pub processing_errors: u64,
245}
246
247use std::sync::atomic::{AtomicU64, Ordering};
248
249#[derive(Debug, Default)]
251pub struct AtomicInvalidationStats {
252 pub messages_sent: AtomicU64,
253 pub messages_received: AtomicU64,
254 pub removes_received: AtomicU64,
255 pub updates_received: AtomicU64,
256 pub patterns_received: AtomicU64,
257 pub bulk_removes_received: AtomicU64,
258 pub processing_errors: AtomicU64,
259}
260
261impl AtomicInvalidationStats {
262 pub fn snapshot(&self) -> InvalidationStats {
263 InvalidationStats {
264 messages_sent: self.messages_sent.load(Ordering::Relaxed),
265 messages_received: self.messages_received.load(Ordering::Relaxed),
266 removes_received: self.removes_received.load(Ordering::Relaxed),
267 updates_received: self.updates_received.load(Ordering::Relaxed),
268 patterns_received: self.patterns_received.load(Ordering::Relaxed),
269 bulk_removes_received: self.bulk_removes_received.load(Ordering::Relaxed),
270 processing_errors: self.processing_errors.load(Ordering::Relaxed),
271 }
272 }
273}
274
275use std::sync::Arc;
276use tokio::sync::broadcast;
277
278pub struct InvalidationSubscriber {
283 client: redis::Client,
285 config: InvalidationConfig,
287 stats: Arc<AtomicInvalidationStats>,
289 shutdown_tx: broadcast::Sender<()>,
291}
292
293impl InvalidationSubscriber {
294 pub fn new(redis_url: &str, config: InvalidationConfig) -> Result<Self> {
300 let client = redis::Client::open(redis_url)
301 .context("Failed to create Redis client for subscriber")?;
302
303 let (shutdown_tx, _) = broadcast::channel(1);
304
305 Ok(Self {
306 client,
307 config,
308 stats: Arc::new(AtomicInvalidationStats::default()),
309 shutdown_tx,
310 })
311 }
312
313 pub fn stats(&self) -> InvalidationStats {
315 self.stats.snapshot()
316 }
317
318 pub fn start<F, Fut>(
326 &self,
327 handler: F,
328 ) -> tokio::task::JoinHandle<()>
329 where
330 F: Fn(InvalidationMessage) -> Fut + Send + Sync + 'static,
331 Fut: std::future::Future<Output = Result<()>> + Send + 'static,
332 {
333 let client = self.client.clone();
334 let channel = self.config.channel.clone();
335 let stats = Arc::clone(&self.stats);
336 let mut shutdown_rx = self.shutdown_tx.subscribe();
337
338 tokio::spawn(async move {
339 let handler = Arc::new(handler);
340
341 loop {
342 if shutdown_rx.try_recv().is_ok() {
344 println!("🛑 Invalidation subscriber shutting down...");
345 break;
346 }
347
348 match Self::run_subscriber_loop(
350 &client,
351 &channel,
352 Arc::clone(&handler),
353 Arc::clone(&stats),
354 &mut shutdown_rx,
355 ).await {
356 Ok(_) => {
357 println!("✅ Invalidation subscriber loop completed normally");
358 break;
359 }
360 Err(e) => {
361 eprintln!("⚠️ Invalidation subscriber error: {}. Reconnecting in 5s...", e);
362 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
363
364 tokio::select! {
366 _ = tokio::time::sleep(Duration::from_secs(5)) => {},
367 _ = shutdown_rx.recv() => {
368 println!("🛑 Invalidation subscriber shutting down...");
369 break;
370 }
371 }
372 }
373 }
374 }
375 })
376 }
377
378 async fn run_subscriber_loop<F, Fut>(
380 client: &redis::Client,
381 channel: &str,
382 handler: Arc<F>,
383 stats: Arc<AtomicInvalidationStats>,
384 shutdown_rx: &mut broadcast::Receiver<()>,
385 ) -> Result<()>
386 where
387 F: Fn(InvalidationMessage) -> Fut + Send + 'static,
388 Fut: std::future::Future<Output = Result<()>> + Send + 'static,
389 {
390 let mut pubsub = client.get_async_pubsub().await
392 .context("Failed to get pubsub connection")?;
393
394 pubsub.subscribe(channel).await
396 .context("Failed to subscribe to channel")?;
397
398 println!("📡 Subscribed to invalidation channel: {}", channel);
399
400 let mut stream = pubsub.on_message();
402
403 loop {
404 tokio::select! {
406 msg_result = stream.next() => {
407 match msg_result {
408 Some(msg) => {
409 let payload: String = match msg.get_payload() {
411 Ok(p) => p,
412 Err(e) => {
413 eprintln!("⚠️ Failed to get message payload: {}", e);
414 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
415 continue;
416 }
417 };
418
419 let invalidation_msg = match InvalidationMessage::from_json(&payload) {
421 Ok(m) => m,
422 Err(e) => {
423 eprintln!("⚠️ Failed to deserialize invalidation message: {}", e);
424 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
425 continue;
426 }
427 };
428
429 stats.messages_received.fetch_add(1, Ordering::Relaxed);
431 match &invalidation_msg {
432 InvalidationMessage::Remove { .. } => {
433 stats.removes_received.fetch_add(1, Ordering::Relaxed);
434 }
435 InvalidationMessage::Update { .. } => {
436 stats.updates_received.fetch_add(1, Ordering::Relaxed);
437 }
438 InvalidationMessage::RemovePattern { .. } => {
439 stats.patterns_received.fetch_add(1, Ordering::Relaxed);
440 }
441 InvalidationMessage::RemoveBulk { .. } => {
442 stats.bulk_removes_received.fetch_add(1, Ordering::Relaxed);
443 }
444 }
445
446 if let Err(e) = handler(invalidation_msg).await {
448 eprintln!("⚠️ Invalidation handler error: {}", e);
449 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
450 }
451 }
452 None => {
453 return Err(anyhow::anyhow!("Pub/Sub message stream ended"));
455 }
456 }
457 }
458 _ = shutdown_rx.recv() => {
459 return Ok(());
460 }
461 }
462 }
463 }
464
465 pub fn shutdown(&self) {
467 let _ = self.shutdown_tx.send(());
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474
475 #[test]
476 fn test_invalidation_message_serialization() {
477 let msg = InvalidationMessage::remove("test_key");
479 let json = msg.to_json().unwrap();
480 let parsed = InvalidationMessage::from_json(&json).unwrap();
481 match parsed {
482 InvalidationMessage::Remove { key } => assert_eq!(key, "test_key"),
483 _ => panic!("Wrong message type"),
484 }
485
486 let msg = InvalidationMessage::update(
488 "test_key",
489 serde_json::json!({"value": 123}),
490 Some(Duration::from_secs(300)),
491 );
492 let json = msg.to_json().unwrap();
493 let parsed = InvalidationMessage::from_json(&json).unwrap();
494 match parsed {
495 InvalidationMessage::Update { key, value, ttl_secs } => {
496 assert_eq!(key, "test_key");
497 assert_eq!(value, serde_json::json!({"value": 123}));
498 assert_eq!(ttl_secs, Some(300));
499 }
500 _ => panic!("Wrong message type"),
501 }
502
503 let msg = InvalidationMessage::remove_pattern("user:*");
505 let json = msg.to_json().unwrap();
506 let parsed = InvalidationMessage::from_json(&json).unwrap();
507 match parsed {
508 InvalidationMessage::RemovePattern { pattern } => assert_eq!(pattern, "user:*"),
509 _ => panic!("Wrong message type"),
510 }
511
512 let msg = InvalidationMessage::remove_bulk(vec!["key1".to_string(), "key2".to_string()]);
514 let json = msg.to_json().unwrap();
515 let parsed = InvalidationMessage::from_json(&json).unwrap();
516 match parsed {
517 InvalidationMessage::RemoveBulk { keys } => assert_eq!(keys, vec!["key1", "key2"]),
518 _ => panic!("Wrong message type"),
519 }
520 }
521
522 #[test]
523 fn test_invalidation_config_default() {
524 let config = InvalidationConfig::default();
525 assert_eq!(config.channel, "cache:invalidate");
526 assert_eq!(config.auto_broadcast_on_write, false);
527 assert_eq!(config.enable_audit_stream, false);
528 }
529}