1use anyhow::{Context, Result};
7use redis::AsyncCommands;
8use serde::{Deserialize, Serialize};
9use std::time::{Duration, SystemTime, UNIX_EPOCH};
10use futures_util::StreamExt;
11use tracing::{info, warn, error};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15#[serde(tag = "type")]
16pub enum InvalidationMessage {
17 Remove {
19 key: String,
20 },
21
22 Update {
25 key: String,
26 value: serde_json::Value,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 ttl_secs: Option<u64>,
29 },
30
31 RemovePattern {
34 pattern: String,
35 },
36
37 RemoveBulk {
39 keys: Vec<String>,
40 },
41}
42
43impl InvalidationMessage {
44 pub fn remove(key: impl Into<String>) -> Self {
46 Self::Remove { key: key.into() }
47 }
48
49 pub fn update(key: impl Into<String>, value: serde_json::Value, ttl: Option<Duration>) -> Self {
51 Self::Update {
52 key: key.into(),
53 value,
54 ttl_secs: ttl.map(|d| d.as_secs()),
55 }
56 }
57
58 pub fn remove_pattern(pattern: impl Into<String>) -> Self {
60 Self::RemovePattern {
61 pattern: pattern.into(),
62 }
63 }
64
65 pub fn remove_bulk(keys: Vec<String>) -> Self {
67 Self::RemoveBulk { keys }
68 }
69
70 pub fn to_json(&self) -> Result<String> {
72 serde_json::to_string(self).context("Failed to serialize invalidation message")
73 }
74
75 pub fn from_json(json: &str) -> Result<Self> {
77 serde_json::from_str(json).context("Failed to deserialize invalidation message")
78 }
79
80 pub fn ttl(&self) -> Option<Duration> {
82 match self {
83 Self::Update { ttl_secs, .. } => ttl_secs.map(Duration::from_secs),
84 _ => None,
85 }
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct InvalidationConfig {
92 pub channel: String,
94
95 pub auto_broadcast_on_write: bool,
97
98 pub enable_audit_stream: bool,
100
101 pub audit_stream: String,
103
104 pub audit_stream_maxlen: Option<usize>,
106}
107
108impl Default for InvalidationConfig {
109 fn default() -> Self {
110 Self {
111 channel: "cache:invalidate".to_string(),
112 auto_broadcast_on_write: false, enable_audit_stream: false,
114 audit_stream: "cache:invalidations".to_string(),
115 audit_stream_maxlen: Some(10000),
116 }
117 }
118}
119
120pub struct InvalidationPublisher {
122 connection: redis::aio::ConnectionManager,
123 config: InvalidationConfig,
124}
125
126impl InvalidationPublisher {
127 pub fn new(connection: redis::aio::ConnectionManager, config: InvalidationConfig) -> Self {
129 Self { connection, config }
130 }
131
132 pub async fn publish(&mut self, message: &InvalidationMessage) -> Result<()> {
134 let json = message.to_json()?;
135
136 let _: () = self
138 .connection
139 .publish(&self.config.channel, &json)
140 .await
141 .context("Failed to publish invalidation message")?;
142
143 if self.config.enable_audit_stream {
145 if let Err(e) = self.publish_to_audit_stream(message).await {
146 warn!("Failed to publish to audit stream: {}", e);
148 }
149 }
150
151 Ok(())
152 }
153
154 async fn publish_to_audit_stream(&mut self, message: &InvalidationMessage) -> Result<()> {
156 let timestamp = SystemTime::now()
157 .duration_since(UNIX_EPOCH)
158 .unwrap_or(Duration::ZERO)
159 .as_secs()
160 .to_string();
161
162 let (type_str, key_str): (&str, &str);
164 let extra_str: String;
165
166 match message {
167 InvalidationMessage::Remove { key } => {
168 type_str = "remove";
169 key_str = key.as_str();
170 extra_str = String::new();
171 }
172 InvalidationMessage::Update { key, .. } => {
173 type_str = "update";
174 key_str = key.as_str();
175 extra_str = String::new();
176 }
177 InvalidationMessage::RemovePattern { pattern } => {
178 type_str = "remove_pattern";
179 key_str = pattern.as_str();
180 extra_str = String::new();
181 }
182 InvalidationMessage::RemoveBulk { keys } => {
183 type_str = "remove_bulk";
184 key_str = "";
185 extra_str = keys.len().to_string();
186 }
187 }
188
189 let mut fields = vec![
190 ("type", type_str),
191 ("timestamp", timestamp.as_str()),
192 ];
193
194 if !key_str.is_empty() {
195 fields.push(("key", key_str));
196 }
197 if !extra_str.is_empty() {
198 fields.push(("count", extra_str.as_str()));
199 }
200
201 let mut cmd = redis::cmd("XADD");
202 cmd.arg(&self.config.audit_stream);
203
204 if let Some(maxlen) = self.config.audit_stream_maxlen {
205 cmd.arg("MAXLEN").arg("~").arg(maxlen);
206 }
207
208 cmd.arg("*"); for (key, value) in fields {
211 cmd.arg(key).arg(value);
212 }
213
214 let _: String = cmd
215 .query_async(&mut self.connection)
216 .await
217 .context("Failed to add to audit stream")?;
218
219 Ok(())
220 }
221}
222
223#[derive(Debug, Default, Clone)]
225pub struct InvalidationStats {
226 pub messages_sent: u64,
228
229 pub messages_received: u64,
231
232 pub removes_received: u64,
234
235 pub updates_received: u64,
237
238 pub patterns_received: u64,
240
241 pub bulk_removes_received: u64,
243
244 pub processing_errors: u64,
246}
247
248use std::sync::atomic::{AtomicU64, Ordering};
249
250#[derive(Debug, Default)]
252pub struct AtomicInvalidationStats {
253 pub messages_sent: AtomicU64,
254 pub messages_received: AtomicU64,
255 pub removes_received: AtomicU64,
256 pub updates_received: AtomicU64,
257 pub patterns_received: AtomicU64,
258 pub bulk_removes_received: AtomicU64,
259 pub processing_errors: AtomicU64,
260}
261
262impl AtomicInvalidationStats {
263 pub fn snapshot(&self) -> InvalidationStats {
264 InvalidationStats {
265 messages_sent: self.messages_sent.load(Ordering::Relaxed),
266 messages_received: self.messages_received.load(Ordering::Relaxed),
267 removes_received: self.removes_received.load(Ordering::Relaxed),
268 updates_received: self.updates_received.load(Ordering::Relaxed),
269 patterns_received: self.patterns_received.load(Ordering::Relaxed),
270 bulk_removes_received: self.bulk_removes_received.load(Ordering::Relaxed),
271 processing_errors: self.processing_errors.load(Ordering::Relaxed),
272 }
273 }
274}
275
276use std::sync::Arc;
277use tokio::sync::broadcast;
278
279pub struct InvalidationSubscriber {
284 client: redis::Client,
286 config: InvalidationConfig,
288 stats: Arc<AtomicInvalidationStats>,
290 shutdown_tx: broadcast::Sender<()>,
292}
293
294impl InvalidationSubscriber {
295 pub fn new(redis_url: &str, config: InvalidationConfig) -> Result<Self> {
301 let client = redis::Client::open(redis_url)
302 .context("Failed to create Redis client for subscriber")?;
303
304 let (shutdown_tx, _) = broadcast::channel(1);
305
306 Ok(Self {
307 client,
308 config,
309 stats: Arc::new(AtomicInvalidationStats::default()),
310 shutdown_tx,
311 })
312 }
313
314 pub fn stats(&self) -> InvalidationStats {
316 self.stats.snapshot()
317 }
318
319 pub fn start<F, Fut>(
327 &self,
328 handler: F,
329 ) -> tokio::task::JoinHandle<()>
330 where
331 F: Fn(InvalidationMessage) -> Fut + Send + Sync + 'static,
332 Fut: std::future::Future<Output = Result<()>> + Send + 'static,
333 {
334 let client = self.client.clone();
335 let channel = self.config.channel.clone();
336 let stats = Arc::clone(&self.stats);
337 let mut shutdown_rx = self.shutdown_tx.subscribe();
338
339 tokio::spawn(async move {
340 let handler = Arc::new(handler);
341
342 loop {
343 if shutdown_rx.try_recv().is_ok() {
345 info!("Invalidation subscriber shutting down...");
346 break;
347 }
348
349 match Self::run_subscriber_loop(
351 &client,
352 &channel,
353 Arc::clone(&handler),
354 Arc::clone(&stats),
355 &mut shutdown_rx,
356 ).await {
357 Ok(_) => {
358 info!("Invalidation subscriber loop completed normally");
359 break;
360 }
361 Err(e) => {
362 error!("Invalidation subscriber error: {}. Reconnecting in 5s...", e);
363 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
364
365 tokio::select! {
367 _ = tokio::time::sleep(Duration::from_secs(5)) => {},
368 _ = shutdown_rx.recv() => {
369 info!("Invalidation subscriber shutting down...");
370 break;
371 }
372 }
373 }
374 }
375 }
376 })
377 }
378
379 async fn run_subscriber_loop<F, Fut>(
381 client: &redis::Client,
382 channel: &str,
383 handler: Arc<F>,
384 stats: Arc<AtomicInvalidationStats>,
385 shutdown_rx: &mut broadcast::Receiver<()>,
386 ) -> Result<()>
387 where
388 F: Fn(InvalidationMessage) -> Fut + Send + 'static,
389 Fut: std::future::Future<Output = Result<()>> + Send + 'static,
390 {
391 let mut pubsub = client.get_async_pubsub().await
393 .context("Failed to get pubsub connection")?;
394
395 pubsub.subscribe(channel).await
397 .context("Failed to subscribe to channel")?;
398
399 info!("Subscribed to invalidation channel: {}", channel);
400
401 let mut stream = pubsub.on_message();
403
404 loop {
405 tokio::select! {
407 msg_result = stream.next() => {
408 match msg_result {
409 Some(msg) => {
410 let payload: String = match msg.get_payload() {
412 Ok(p) => p,
413 Err(e) => {
414 warn!("Failed to get message payload: {}", e);
415 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
416 continue;
417 }
418 };
419
420 let invalidation_msg = match InvalidationMessage::from_json(&payload) {
422 Ok(m) => m,
423 Err(e) => {
424 warn!("Failed to deserialize invalidation message: {}", e);
425 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
426 continue;
427 }
428 };
429
430 stats.messages_received.fetch_add(1, Ordering::Relaxed);
432 match &invalidation_msg {
433 InvalidationMessage::Remove { .. } => {
434 stats.removes_received.fetch_add(1, Ordering::Relaxed);
435 }
436 InvalidationMessage::Update { .. } => {
437 stats.updates_received.fetch_add(1, Ordering::Relaxed);
438 }
439 InvalidationMessage::RemovePattern { .. } => {
440 stats.patterns_received.fetch_add(1, Ordering::Relaxed);
441 }
442 InvalidationMessage::RemoveBulk { .. } => {
443 stats.bulk_removes_received.fetch_add(1, Ordering::Relaxed);
444 }
445 }
446
447 if let Err(e) = handler(invalidation_msg).await {
449 error!("Invalidation handler error: {}", e);
450 stats.processing_errors.fetch_add(1, Ordering::Relaxed);
451 }
452 }
453 None => {
454 return Err(anyhow::anyhow!("Pub/Sub message stream ended"));
456 }
457 }
458 }
459 _ = shutdown_rx.recv() => {
460 return Ok(());
461 }
462 }
463 }
464 }
465
466 pub fn shutdown(&self) {
468 let _ = self.shutdown_tx.send(());
469 }
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475
476 #[test]
477 fn test_invalidation_message_serialization() {
478 let msg = InvalidationMessage::remove("test_key");
480 let json = msg.to_json().unwrap();
481 let parsed = InvalidationMessage::from_json(&json).unwrap();
482 match parsed {
483 InvalidationMessage::Remove { key } => assert_eq!(key, "test_key"),
484 _ => panic!("Wrong message type"),
485 }
486
487 let msg = InvalidationMessage::update(
489 "test_key",
490 serde_json::json!({"value": 123}),
491 Some(Duration::from_secs(300)),
492 );
493 let json = msg.to_json().unwrap();
494 let parsed = InvalidationMessage::from_json(&json).unwrap();
495 match parsed {
496 InvalidationMessage::Update { key, value, ttl_secs } => {
497 assert_eq!(key, "test_key");
498 assert_eq!(value, serde_json::json!({"value": 123}));
499 assert_eq!(ttl_secs, Some(300));
500 }
501 _ => panic!("Wrong message type"),
502 }
503
504 let msg = InvalidationMessage::remove_pattern("user:*");
506 let json = msg.to_json().unwrap();
507 let parsed = InvalidationMessage::from_json(&json).unwrap();
508 match parsed {
509 InvalidationMessage::RemovePattern { pattern } => assert_eq!(pattern, "user:*"),
510 _ => panic!("Wrong message type"),
511 }
512
513 let msg = InvalidationMessage::remove_bulk(vec!["key1".to_string(), "key2".to_string()]);
515 let json = msg.to_json().unwrap();
516 let parsed = InvalidationMessage::from_json(&json).unwrap();
517 match parsed {
518 InvalidationMessage::RemoveBulk { keys } => assert_eq!(keys, vec!["key1", "key2"]),
519 _ => panic!("Wrong message type"),
520 }
521 }
522
523 #[test]
524 fn test_invalidation_config_default() {
525 let config = InvalidationConfig::default();
526 assert_eq!(config.channel, "cache:invalidate");
527 assert_eq!(config.auto_broadcast_on_write, false);
528 assert_eq!(config.enable_audit_stream, false);
529 }
530}