1use futures::StreamExt;
9use parking_lot::RwLock;
10use redis::Client;
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13use std::sync::Arc;
14use tokio::sync::mpsc;
15use tracing::{debug, error, info, warn};
16
17use crate::cache::RedisCache;
18use crate::error::{DbError, Result};
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct InvalidationEvent {
23 pub event_type: InvalidationType,
25 pub keys: Vec<String>,
27 pub tags: Vec<String>,
29 pub timestamp: chrono::DateTime<chrono::Utc>,
31 pub source_instance: String,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
37pub enum InvalidationType {
38 Keys,
40 Tags,
42 Cascade,
44 Pattern,
46}
47
48#[derive(Debug, Clone)]
50pub struct InvalidationConfig {
51 pub pubsub_channel: String,
53 pub instance_id: String,
55 pub enable_cascade: bool,
57 pub max_cascade_depth: usize,
59}
60
61impl Default for InvalidationConfig {
62 fn default() -> Self {
63 Self {
64 pubsub_channel: "cache:invalidation".to_string(),
65 instance_id: uuid::Uuid::new_v4().to_string(),
66 enable_cascade: true,
67 max_cascade_depth: 5,
68 }
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct TagRegistry {
75 tags: Arc<RwLock<HashMap<String, HashSet<String>>>>,
77 keys: Arc<RwLock<HashMap<String, HashSet<String>>>>,
79}
80
81impl Default for TagRegistry {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87impl TagRegistry {
88 pub fn new() -> Self {
90 Self {
91 tags: Arc::new(RwLock::new(HashMap::new())),
92 keys: Arc::new(RwLock::new(HashMap::new())),
93 }
94 }
95
96 pub fn register(&self, key: String, tags: Vec<String>) {
98 let mut tag_map = self.tags.write();
99 let mut key_map = self.keys.write();
100
101 for tag in &tags {
102 tag_map.entry(tag.clone()).or_default().insert(key.clone());
103 }
104
105 key_map.insert(key, tags.into_iter().collect());
106 }
107
108 pub fn get_keys_for_tag(&self, tag: &str) -> Vec<String> {
110 self.tags
111 .read()
112 .get(tag)
113 .map(|keys| keys.iter().cloned().collect())
114 .unwrap_or_default()
115 }
116
117 pub fn get_tags_for_key(&self, key: &str) -> Vec<String> {
119 self.keys
120 .read()
121 .get(key)
122 .map(|tags| tags.iter().cloned().collect())
123 .unwrap_or_default()
124 }
125
126 pub fn unregister(&self, key: &str) {
128 let mut key_map = self.keys.write();
129 if let Some(tags) = key_map.remove(key) {
130 let mut tag_map = self.tags.write();
131 for tag in tags {
132 if let Some(keys) = tag_map.get_mut(&tag) {
133 keys.remove(key);
134 }
135 }
136 }
137 }
138}
139
140#[derive(Debug, Clone)]
142pub struct CascadeRule {
143 pub source_tag: String,
145 pub target_tags: Vec<String>,
147}
148
149pub struct InvalidationManager {
151 cache: Arc<RedisCache>,
152 config: InvalidationConfig,
153 tag_registry: TagRegistry,
154 cascade_rules: Arc<RwLock<Vec<CascadeRule>>>,
155 pubsub_tx: mpsc::UnboundedSender<InvalidationEvent>,
156}
157
158impl InvalidationManager {
159 pub fn new(
161 cache: Arc<RedisCache>,
162 config: InvalidationConfig,
163 ) -> (Self, mpsc::UnboundedReceiver<InvalidationEvent>) {
164 let (tx, rx) = mpsc::unbounded_channel();
165
166 let manager = Self {
167 cache,
168 config,
169 tag_registry: TagRegistry::new(),
170 cascade_rules: Arc::new(RwLock::new(Vec::new())),
171 pubsub_tx: tx,
172 };
173
174 (manager, rx)
175 }
176
177 pub fn add_cascade_rule(&self, rule: CascadeRule) {
179 info!(
180 source = %rule.source_tag,
181 targets = ?rule.target_tags,
182 "Added cascade invalidation rule"
183 );
184 self.cascade_rules.write().push(rule);
185 }
186
187 pub fn register_key(&self, key: String, tags: Vec<String>) {
189 self.tag_registry.register(key, tags);
190 }
191
192 pub async fn invalidate_keys(&self, keys: Vec<String>) -> Result<()> {
194 debug!(count = keys.len(), "Invalidating keys");
195
196 for key in &keys {
197 if let Err(e) = self.cache.delete(key).await {
198 error!(key = %key, error = %e, "Failed to invalidate key");
199 }
200 }
201
202 let event = InvalidationEvent {
204 event_type: InvalidationType::Keys,
205 keys,
206 tags: Vec::new(),
207 timestamp: chrono::Utc::now(),
208 source_instance: self.config.instance_id.clone(),
209 };
210
211 self.publish_event(event).await?;
212
213 Ok(())
214 }
215
216 pub async fn invalidate_tag(&self, tag: String) -> Result<()> {
218 let keys = self.tag_registry.get_keys_for_tag(&tag);
219
220 debug!(tag = %tag, key_count = keys.len(), "Invalidating tag");
221
222 for key in &keys {
223 if let Err(e) = self.cache.delete(key).await {
224 error!(key = %key, error = %e, "Failed to invalidate key");
225 }
226 }
227
228 if self.config.enable_cascade {
230 self.apply_cascade_rules(&tag, 0).await?;
231 }
232
233 let event = InvalidationEvent {
235 event_type: InvalidationType::Tags,
236 keys: Vec::new(),
237 tags: vec![tag],
238 timestamp: chrono::Utc::now(),
239 source_instance: self.config.instance_id.clone(),
240 };
241
242 self.publish_event(event).await?;
243
244 Ok(())
245 }
246
247 pub async fn invalidate_pattern(&self, pattern: String) -> Result<()> {
249 debug!(pattern = %pattern, "Invalidating pattern");
250
251 let deleted = self.cache.delete_pattern(&pattern).await?;
252
253 info!(pattern = %pattern, deleted = deleted, "Pattern invalidation completed");
254
255 let event = InvalidationEvent {
257 event_type: InvalidationType::Pattern,
258 keys: vec![pattern],
259 tags: Vec::new(),
260 timestamp: chrono::Utc::now(),
261 source_instance: self.config.instance_id.clone(),
262 };
263
264 self.publish_event(event).await?;
265
266 Ok(())
267 }
268
269 fn apply_cascade_rules<'a>(
271 &'a self,
272 tag: &'a str,
273 depth: usize,
274 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> {
275 Box::pin(async move {
276 if depth >= self.config.max_cascade_depth {
277 warn!(tag = %tag, depth = depth, "Max cascade depth reached");
278 return Ok(());
279 }
280
281 let matching_rules: Vec<_> = {
282 let rules = self.cascade_rules.read();
283 rules
284 .iter()
285 .filter(|rule| rule.source_tag == tag)
286 .cloned()
287 .collect()
288 };
289
290 for rule in matching_rules {
291 debug!(
292 source = %rule.source_tag,
293 targets = ?rule.target_tags,
294 depth = depth,
295 "Applying cascade rule"
296 );
297
298 for target_tag in rule.target_tags {
299 let keys = self.tag_registry.get_keys_for_tag(&target_tag);
300
301 for key in keys {
302 if let Err(e) = self.cache.delete(&key).await {
303 error!(key = %key, error = %e, "Failed to cascade invalidate key");
304 }
305 }
306
307 self.apply_cascade_rules(&target_tag, depth + 1).await?;
309 }
310 }
311
312 Ok(())
313 })
314 }
315
316 async fn publish_event(&self, event: InvalidationEvent) -> Result<()> {
318 let _json = serde_json::to_string(&event)
319 .map_err(|e| DbError::Cache(format!("Serialization error: {}", e)))?;
320
321 if let Err(e) = self.pubsub_tx.send(event) {
322 error!(error = %e, "Failed to send invalidation event to channel");
323 }
324
325 debug!("Published invalidation event");
326
327 Ok(())
328 }
329
330 pub async fn start_subscriber(self: Arc<Self>, redis_url: String) -> Result<()> {
332 let client = Client::open(redis_url.as_str())
333 .map_err(|e| DbError::Connection(format!("Redis client error: {}", e)))?;
334
335 let mut pubsub = client
336 .get_async_pubsub()
337 .await
338 .map_err(|e| DbError::Connection(format!("Redis pubsub error: {}", e)))?;
339 pubsub
340 .subscribe(&self.config.pubsub_channel)
341 .await
342 .map_err(|e| DbError::Cache(format!("Subscribe error: {}", e)))?;
343
344 info!(channel = %self.config.pubsub_channel, "Started invalidation subscriber");
345
346 tokio::spawn(async move {
347 loop {
348 match pubsub.on_message().next().await {
349 Some(msg) => {
350 let payload: String = match msg.get_payload() {
351 Ok(p) => p,
352 Err(e) => {
353 error!(error = %e, "Failed to get message payload");
354 continue;
355 }
356 };
357
358 let event: InvalidationEvent = match serde_json::from_str(&payload) {
359 Ok(e) => e,
360 Err(e) => {
361 error!(error = %e, "Failed to deserialize event");
362 continue;
363 }
364 };
365
366 if event.source_instance == self.config.instance_id {
368 continue;
369 }
370
371 debug!(
372 event_type = ?event.event_type,
373 source = %event.source_instance,
374 "Received invalidation event"
375 );
376
377 match event.event_type {
379 InvalidationType::Keys => {
380 for key in &event.keys {
381 if let Err(e) = self.cache.delete(key).await {
382 error!(key = %key, error = %e, "Failed to invalidate key");
383 }
384 }
385 }
386 InvalidationType::Tags => {
387 for tag in &event.tags {
388 let keys = self.tag_registry.get_keys_for_tag(tag);
389 for key in keys {
390 if let Err(e) = self.cache.delete(&key).await {
391 error!(key = %key, error = %e, "Failed to invalidate key");
392 }
393 }
394 }
395 }
396 InvalidationType::Pattern => {
397 for pattern in &event.keys {
398 if let Err(e) = self.cache.delete_pattern(pattern).await {
399 error!(pattern = %pattern, error = %e, "Failed to invalidate pattern");
400 }
401 }
402 }
403 InvalidationType::Cascade => {
404 }
406 }
407 }
408 None => {
409 warn!("Pub/sub connection closed, reconnecting...");
410 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
411 }
412 }
413 }
414 });
415
416 Ok(())
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423
424 #[test]
425 fn test_invalidation_config_default() {
426 let config = InvalidationConfig::default();
427 assert_eq!(config.pubsub_channel, "cache:invalidation");
428 assert!(config.enable_cascade);
429 assert_eq!(config.max_cascade_depth, 5);
430 }
431
432 #[test]
433 fn test_tag_registry_register() {
434 let registry = TagRegistry::new();
435
436 registry.register(
437 "key1".to_string(),
438 vec!["tag1".to_string(), "tag2".to_string()],
439 );
440
441 let keys = registry.get_keys_for_tag("tag1");
442 assert_eq!(keys.len(), 1);
443 assert!(keys.contains(&"key1".to_string()));
444 }
445
446 #[test]
447 fn test_tag_registry_get_tags() {
448 let registry = TagRegistry::new();
449
450 registry.register(
451 "key1".to_string(),
452 vec!["tag1".to_string(), "tag2".to_string()],
453 );
454
455 let tags = registry.get_tags_for_key("key1");
456 assert_eq!(tags.len(), 2);
457 assert!(tags.contains(&"tag1".to_string()));
458 assert!(tags.contains(&"tag2".to_string()));
459 }
460
461 #[test]
462 fn test_tag_registry_unregister() {
463 let registry = TagRegistry::new();
464
465 registry.register("key1".to_string(), vec!["tag1".to_string()]);
466 registry.unregister("key1");
467
468 let keys = registry.get_keys_for_tag("tag1");
469 assert_eq!(keys.len(), 0);
470 }
471
472 #[test]
473 fn test_cascade_rule_creation() {
474 let rule = CascadeRule {
475 source_tag: "user".to_string(),
476 target_tags: vec!["user_profile".to_string(), "user_orders".to_string()],
477 };
478
479 assert_eq!(rule.source_tag, "user");
480 assert_eq!(rule.target_tags.len(), 2);
481 }
482
483 #[test]
484 fn test_invalidation_event_serialization() {
485 let event = InvalidationEvent {
486 event_type: InvalidationType::Keys,
487 keys: vec!["key1".to_string()],
488 tags: vec![],
489 timestamp: chrono::Utc::now(),
490 source_instance: "instance1".to_string(),
491 };
492
493 let json = serde_json::to_string(&event).unwrap();
494 let deserialized: InvalidationEvent = serde_json::from_str(&json).unwrap();
495
496 assert_eq!(deserialized.event_type, InvalidationType::Keys);
497 assert_eq!(deserialized.keys.len(), 1);
498 }
499}