1use std::collections::{HashMap, HashSet};
30use std::sync::Arc;
31use std::time::{SystemTime, UNIX_EPOCH};
32use thiserror::Error;
33use tokio::sync::{RwLock, broadcast};
34
35const INVALIDATION_BUFFER_SIZE: usize = 1024;
37
38#[derive(Debug, Error)]
40pub enum InvalidationError {
41 #[error("Failed to send invalidation notification: {0}")]
42 SendError(String),
43
44 #[error("Invalid invalidation pattern: {0}")]
45 InvalidPattern(String),
46
47 #[error("Receiver disconnected")]
48 ReceiverDisconnected,
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
53pub enum InvalidationReason {
54 Updated,
56 Deleted,
58 Expired,
60 Manual,
62 Error,
64 MemoryPressure,
66}
67
68#[derive(Debug, Clone)]
70pub struct InvalidationEvent {
71 pub key: String,
73 pub reason: InvalidationReason,
75 pub timestamp: u64,
77 pub metadata: HashMap<String, String>,
79 pub origin_node_id: Option<String>,
81}
82
83impl InvalidationEvent {
84 #[must_use]
86 #[inline]
87 pub fn new(key: String, reason: InvalidationReason) -> Self {
88 Self {
89 key,
90 reason,
91 timestamp: current_timestamp(),
92 metadata: HashMap::new(),
93 origin_node_id: None,
94 }
95 }
96
97 #[must_use]
99 #[inline]
100 pub fn with_metadata(mut self, key: String, value: String) -> Self {
101 self.metadata.insert(key, value);
102 self
103 }
104
105 #[must_use]
107 #[inline]
108 pub fn with_origin(mut self, node_id: String) -> Self {
109 self.origin_node_id = Some(node_id);
110 self
111 }
112}
113
114#[derive(Debug, Clone)]
116pub enum InvalidationPattern {
117 Exact(String),
119 Prefix(String),
121 Suffix(String),
123 Contains(String),
125 Tags(HashSet<String>),
127}
128
129impl InvalidationPattern {
130 #[inline]
131 pub fn matches(&self, key: &str) -> bool {
133 match self {
134 Self::Exact(exact) => key == exact,
135 Self::Prefix(prefix) => key.starts_with(prefix),
136 Self::Suffix(suffix) => key.ends_with(suffix),
137 Self::Contains(substring) => key.contains(substring),
138 Self::Tags(_) => false, }
140 }
141}
142
143#[derive(Debug, Clone, Default)]
145pub struct InvalidationStats {
146 pub total_invalidations: u64,
148 pub by_reason: HashMap<String, u64>,
150 pub active_subscribers: usize,
152 pub failed_sends: u64,
154}
155
156pub struct InvalidationNotifier {
158 sender: broadcast::Sender<InvalidationEvent>,
160 stats: Arc<RwLock<InvalidationStats>>,
162 tag_index: Arc<RwLock<HashMap<String, HashSet<String>>>>,
164}
165
166impl InvalidationNotifier {
167 #[must_use]
169 #[inline]
170 pub fn new() -> Self {
171 let (sender, _) = broadcast::channel(INVALIDATION_BUFFER_SIZE);
172 Self {
173 sender,
174 stats: Arc::new(RwLock::new(InvalidationStats::default())),
175 tag_index: Arc::new(RwLock::new(HashMap::new())),
176 }
177 }
178
179 #[must_use]
181 #[inline]
182 pub fn subscribe(&self) -> InvalidationReceiver {
183 let receiver = self.sender.subscribe();
184 InvalidationReceiver { receiver }
185 }
186
187 pub async fn invalidate_key(&self, key: &str, reason: InvalidationReason) {
189 let event = InvalidationEvent::new(key.to_string(), reason.clone());
190 self.send_event(event).await;
191 }
192
193 pub async fn invalidate_keys(&self, keys: &[String], reason: InvalidationReason) {
195 for key in keys {
196 let event = InvalidationEvent::new(key.clone(), reason.clone());
197 self.send_event(event).await;
198 }
199 }
200
201 pub async fn invalidate_pattern(
203 &self,
204 pattern: InvalidationPattern,
205 reason: InvalidationReason,
206 known_keys: &[String],
207 ) {
208 for key in known_keys {
209 if pattern.matches(key) {
210 let event = InvalidationEvent::new(key.clone(), reason.clone());
211 self.send_event(event).await;
212 }
213 }
214 }
215
216 pub async fn invalidate_tag(&self, tag: &str, reason: InvalidationReason) {
218 let keys = {
219 let index = self.tag_index.read().await;
220 index.get(tag).cloned().unwrap_or_default()
221 };
222
223 for key in keys {
224 let event = InvalidationEvent::new(key, reason.clone());
225 self.send_event(event).await;
226 }
227 }
228
229 pub async fn tag_key(&self, key: String, tags: Vec<String>) {
231 let mut index = self.tag_index.write().await;
232 for tag in tags {
233 index
234 .entry(tag)
235 .or_insert_with(HashSet::new)
236 .insert(key.clone());
237 }
238 }
239
240 pub async fn untag_key(&self, key: &str) {
242 let mut index = self.tag_index.write().await;
243 for keys in index.values_mut() {
244 keys.remove(key);
245 }
246 }
247
248 async fn send_event(&self, event: InvalidationEvent) {
250 let reason_key = format!("{:?}", event.reason);
251
252 {
254 let mut stats = self.stats.write().await;
255 stats.total_invalidations += 1;
256 *stats.by_reason.entry(reason_key).or_insert(0) += 1;
257 stats.active_subscribers = self.sender.receiver_count();
258 }
259
260 if self.sender.send(event).is_err() {
262 let mut stats = self.stats.write().await;
263 stats.failed_sends += 1;
264 }
265 }
266
267 #[must_use]
269 #[inline]
270 pub async fn stats(&self) -> InvalidationStats {
271 self.stats.read().await.clone()
272 }
273
274 #[must_use]
276 #[inline]
277 pub fn subscriber_count(&self) -> usize {
278 self.sender.receiver_count()
279 }
280}
281
282impl Default for InvalidationNotifier {
283 #[inline]
284 fn default() -> Self {
285 Self::new()
286 }
287}
288
289pub struct InvalidationReceiver {
291 receiver: broadcast::Receiver<InvalidationEvent>,
292}
293
294impl InvalidationReceiver {
295 pub async fn recv(&mut self) -> Option<InvalidationEvent> {
297 loop {
298 match self.receiver.recv().await {
299 Ok(event) => return Some(event),
300 Err(broadcast::error::RecvError::Lagged(skipped)) => {
301 eprintln!(
302 "Warning: Invalidation receiver lagged, skipped {} events",
303 skipped
304 );
305 continue;
307 }
308 Err(broadcast::error::RecvError::Closed) => return None,
309 }
310 }
311 }
312
313 pub fn try_recv(&mut self) -> Result<InvalidationEvent, InvalidationError> {
315 self.receiver.try_recv().map_err(|e| match e {
316 broadcast::error::TryRecvError::Empty => InvalidationError::ReceiverDisconnected,
317 broadcast::error::TryRecvError::Lagged(_) => InvalidationError::ReceiverDisconnected,
318 broadcast::error::TryRecvError::Closed => InvalidationError::ReceiverDisconnected,
319 })
320 }
321}
322
323pub struct BatchInvalidation {
325 notifier: Arc<InvalidationNotifier>,
326 batch: Vec<(String, InvalidationReason)>,
327 max_batch_size: usize,
328}
329
330impl BatchInvalidation {
331 #[must_use]
333 #[inline]
334 pub fn new(notifier: Arc<InvalidationNotifier>, max_batch_size: usize) -> Self {
335 Self {
336 notifier,
337 batch: Vec::with_capacity(max_batch_size),
338 max_batch_size,
339 }
340 }
341
342 #[inline]
344 pub fn add(&mut self, key: String, reason: InvalidationReason) {
345 self.batch.push((key, reason));
346 if self.batch.len() >= self.max_batch_size {
347 }
350 }
351
352 pub async fn flush(&mut self) {
354 for (key, reason) in self.batch.drain(..) {
355 self.notifier.invalidate_key(&key, reason).await;
356 }
357 }
358
359 #[must_use]
361 #[inline]
362 pub fn len(&self) -> usize {
363 self.batch.len()
364 }
365
366 #[must_use]
368 #[inline]
369 pub fn is_empty(&self) -> bool {
370 self.batch.is_empty()
371 }
372}
373
374#[inline]
376fn current_timestamp() -> u64 {
377 SystemTime::now()
378 .duration_since(UNIX_EPOCH)
379 .unwrap_or_default()
380 .as_secs()
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386
387 #[tokio::test]
388 async fn test_single_invalidation() {
389 let notifier = InvalidationNotifier::new();
390 let mut receiver = notifier.subscribe();
391
392 notifier
393 .invalidate_key("test:key", InvalidationReason::Updated)
394 .await;
395
396 let event = receiver.recv().await.unwrap();
397 assert_eq!(event.key, "test:key");
398 assert_eq!(event.reason, InvalidationReason::Updated);
399 }
400
401 #[tokio::test]
402 async fn test_multiple_subscribers() {
403 let notifier = InvalidationNotifier::new();
404 let mut receiver1 = notifier.subscribe();
405 let mut receiver2 = notifier.subscribe();
406
407 notifier
408 .invalidate_key("test:key", InvalidationReason::Deleted)
409 .await;
410
411 let event1 = receiver1.recv().await.unwrap();
412 let event2 = receiver2.recv().await.unwrap();
413
414 assert_eq!(event1.key, event2.key);
415 assert_eq!(event1.reason, event2.reason);
416 }
417
418 #[tokio::test]
419 async fn test_batch_invalidation() {
420 let notifier = InvalidationNotifier::new();
421 let mut receiver = notifier.subscribe();
422
423 let keys = vec!["key1".to_string(), "key2".to_string(), "key3".to_string()];
424 notifier
425 .invalidate_keys(&keys, InvalidationReason::Expired)
426 .await;
427
428 for _ in 0..3 {
429 let event = receiver.recv().await.unwrap();
430 assert!(keys.contains(&event.key));
431 assert_eq!(event.reason, InvalidationReason::Expired);
432 }
433 }
434
435 #[tokio::test]
436 async fn test_pattern_prefix() {
437 let pattern = InvalidationPattern::Prefix("content:".to_string());
438 assert!(pattern.matches("content:abc123"));
439 assert!(!pattern.matches("metadata:abc123"));
440 }
441
442 #[tokio::test]
443 async fn test_pattern_suffix() {
444 let pattern = InvalidationPattern::Suffix(":metadata".to_string());
445 assert!(pattern.matches("content:metadata"));
446 assert!(!pattern.matches("content:data"));
447 }
448
449 #[tokio::test]
450 async fn test_pattern_contains() {
451 let pattern = InvalidationPattern::Contains("temp".to_string());
452 assert!(pattern.matches("cache:temp:data"));
453 assert!(!pattern.matches("cache:perm:data"));
454 }
455
456 #[tokio::test]
457 async fn test_tag_based_invalidation() {
458 let notifier = InvalidationNotifier::new();
459 let mut receiver = notifier.subscribe();
460
461 notifier
463 .tag_key("key1".to_string(), vec!["user:123".to_string()])
464 .await;
465 notifier
466 .tag_key("key2".to_string(), vec!["user:123".to_string()])
467 .await;
468 notifier
469 .tag_key("key3".to_string(), vec!["user:456".to_string()])
470 .await;
471
472 notifier
474 .invalidate_tag("user:123", InvalidationReason::Updated)
475 .await;
476
477 let mut received_keys = HashSet::new();
479 for _ in 0..2 {
480 if let Some(event) = receiver.recv().await {
481 received_keys.insert(event.key.clone());
482 }
483 }
484
485 assert!(received_keys.contains("key1"));
486 assert!(received_keys.contains("key2"));
487 assert!(!received_keys.contains("key3"));
488 }
489
490 #[tokio::test]
491 async fn test_invalidation_stats() {
492 let notifier = InvalidationNotifier::new();
493
494 notifier
495 .invalidate_key("key1", InvalidationReason::Updated)
496 .await;
497 notifier
498 .invalidate_key("key2", InvalidationReason::Deleted)
499 .await;
500 notifier
501 .invalidate_key("key3", InvalidationReason::Updated)
502 .await;
503
504 let stats = notifier.stats().await;
505 assert_eq!(stats.total_invalidations, 3);
506 assert_eq!(*stats.by_reason.get("Updated").unwrap_or(&0), 2);
507 assert_eq!(*stats.by_reason.get("Deleted").unwrap_or(&0), 1);
508 }
509
510 #[tokio::test]
511 async fn test_untag_key() {
512 let notifier = InvalidationNotifier::new();
513
514 notifier
515 .tag_key("key1".to_string(), vec!["tag1".to_string()])
516 .await;
517 notifier.untag_key("key1").await;
518
519 let mut receiver = notifier.subscribe();
520 notifier
521 .invalidate_tag("tag1", InvalidationReason::Manual)
522 .await;
523
524 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
526 assert!(receiver.try_recv().is_err());
527 }
528
529 #[tokio::test]
530 async fn test_batch_invalidation_manager() {
531 let notifier = Arc::new(InvalidationNotifier::new());
532 let mut receiver = notifier.subscribe();
533
534 let mut batch = BatchInvalidation::new(notifier.clone(), 10);
535 batch.add("key1".to_string(), InvalidationReason::Manual);
536 batch.add("key2".to_string(), InvalidationReason::Manual);
537
538 assert_eq!(batch.len(), 2);
539 assert!(!batch.is_empty());
540
541 batch.flush().await;
542 assert_eq!(batch.len(), 0);
543 assert!(batch.is_empty());
544
545 for _ in 0..2 {
547 assert!(receiver.recv().await.is_some());
548 }
549 }
550
551 #[tokio::test]
552 async fn test_subscriber_count() {
553 let notifier = InvalidationNotifier::new();
554 assert_eq!(notifier.subscriber_count(), 0);
555
556 let _receiver1 = notifier.subscribe();
557 assert_eq!(notifier.subscriber_count(), 1);
558
559 let _receiver2 = notifier.subscribe();
560 assert_eq!(notifier.subscriber_count(), 2);
561 }
562
563 #[test]
564 fn test_invalidation_event_builder() {
565 let event = InvalidationEvent::new("test:key".to_string(), InvalidationReason::Updated)
566 .with_metadata("version".to_string(), "2".to_string())
567 .with_origin("node123".to_string());
568
569 assert_eq!(event.key, "test:key");
570 assert_eq!(event.metadata.get("version").unwrap(), "2");
571 assert_eq!(event.origin_node_id.as_ref().unwrap(), "node123");
572 }
573}