use anyhow::Result;
use chrono::{DateTime, Utc};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, warn};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InvalidationEvent {
pub event_id: Uuid,
pub event_type: InvalidationEventType,
pub entity_type: String,
pub entity_id: Option<Uuid>,
pub operation: String,
pub timestamp: DateTime<Utc>,
pub affected_caches: Vec<String>,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum InvalidationEventType {
Created,
Updated,
Deleted,
Completed,
BulkOperation,
ManualInvalidation,
Expired,
CascadeInvalidation,
}
impl std::fmt::Display for InvalidationEventType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
InvalidationEventType::Created => write!(f, "Created"),
InvalidationEventType::Updated => write!(f, "Updated"),
InvalidationEventType::Deleted => write!(f, "Deleted"),
InvalidationEventType::Completed => write!(f, "Completed"),
InvalidationEventType::BulkOperation => write!(f, "BulkOperation"),
InvalidationEventType::ManualInvalidation => write!(f, "ManualInvalidation"),
InvalidationEventType::Expired => write!(f, "Expired"),
InvalidationEventType::CascadeInvalidation => write!(f, "CascadeInvalidation"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InvalidationRule {
pub rule_id: Uuid,
pub name: String,
pub description: String,
pub entity_type: String,
pub operations: Vec<String>,
pub affected_cache_types: Vec<String>,
pub invalidation_strategy: InvalidationStrategy,
pub enabled: bool,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum InvalidationStrategy {
InvalidateAll,
InvalidateSpecific(Vec<String>),
InvalidateByEntity,
InvalidateByPattern(String),
CascadeInvalidation,
}
pub struct CacheInvalidationMiddleware {
rules: Arc<RwLock<HashMap<String, InvalidationRule>>>,
events: Arc<RwLock<Vec<InvalidationEvent>>>,
handlers: Arc<RwLock<HashMap<String, Box<dyn CacheInvalidationHandler + Send + Sync>>>>,
config: InvalidationConfig,
stats: Arc<RwLock<InvalidationStats>>,
}
pub trait CacheInvalidationHandler {
fn invalidate(&self, event: &InvalidationEvent) -> Result<()>;
fn cache_type(&self) -> &str;
fn can_handle(&self, event: &InvalidationEvent) -> bool;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InvalidationConfig {
pub max_events: usize,
pub event_retention: Duration,
pub enable_cascade: bool,
pub cascade_depth: u32,
pub enable_batching: bool,
pub batch_size: usize,
pub batch_timeout: Duration,
}
impl Default for InvalidationConfig {
fn default() -> Self {
Self {
max_events: 10000,
event_retention: Duration::from_secs(86400), enable_cascade: true,
cascade_depth: 3,
enable_batching: true,
batch_size: 100,
batch_timeout: Duration::from_secs(5),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct InvalidationStats {
pub total_events: u64,
pub successful_invalidations: u64,
pub failed_invalidations: u64,
pub cascade_invalidations: u64,
pub manual_invalidations: u64,
pub expired_invalidations: u64,
pub average_processing_time_ms: f64,
pub last_invalidation: Option<DateTime<Utc>>,
}
impl CacheInvalidationMiddleware {
#[must_use]
pub fn new(config: InvalidationConfig) -> Self {
Self {
rules: Arc::new(RwLock::new(HashMap::new())),
events: Arc::new(RwLock::new(Vec::new())),
handlers: Arc::new(RwLock::new(HashMap::new())),
config,
stats: Arc::new(RwLock::new(InvalidationStats::default())),
}
}
#[must_use]
pub fn new_default() -> Self {
Self::new(InvalidationConfig::default())
}
pub fn register_handler(&self, handler: Box<dyn CacheInvalidationHandler + Send + Sync>) {
let mut handlers = self.handlers.write();
handlers.insert(handler.cache_type().to_string(), handler);
}
pub fn add_rule(&self, rule: InvalidationRule) {
let mut rules = self.rules.write();
rules.insert(rule.name.clone(), rule);
}
pub async fn process_event(&self, event: InvalidationEvent) -> Result<()> {
let start_time = std::time::Instant::now();
self.store_event(&event);
let applicable_rules = self.find_applicable_rules(&event);
for rule in applicable_rules {
if let Err(e) = self.process_rule(&event, &rule).await {
warn!("Failed to process invalidation rule {}: {}", rule.name, e);
self.record_failed_invalidation();
} else {
self.record_successful_invalidation();
}
}
if self.config.enable_cascade {
self.handle_cascade_invalidation(&event).await?;
}
#[allow(clippy::cast_precision_loss)]
let processing_time = start_time.elapsed().as_millis().min(u128::from(u64::MAX)) as f64;
{
let mut stats = self.stats.write();
stats.total_events += 1;
}
self.update_processing_time(processing_time);
debug!(
"Processed invalidation event: {} for entity: {}:{}",
event.event_type,
event.entity_type,
event
.entity_id
.map_or_else(|| "none".to_string(), |id| id.to_string())
);
Ok(())
}
pub async fn manual_invalidate(
&self,
entity_type: &str,
entity_id: Option<Uuid>,
cache_types: Option<Vec<String>>,
) -> Result<()> {
let event = InvalidationEvent {
event_id: Uuid::new_v4(),
event_type: InvalidationEventType::ManualInvalidation,
entity_type: entity_type.to_string(),
entity_id,
operation: "manual_invalidation".to_string(),
timestamp: Utc::now(),
affected_caches: cache_types.unwrap_or_default(),
metadata: HashMap::new(),
};
self.process_event(event).await?;
self.record_manual_invalidation();
Ok(())
}
#[must_use]
pub fn get_stats(&self) -> InvalidationStats {
self.stats.read().clone()
}
#[must_use]
pub fn get_recent_events(&self, limit: usize) -> Vec<InvalidationEvent> {
let events = self.events.read();
events.iter().rev().take(limit).cloned().collect()
}
#[must_use]
pub fn get_events_by_entity_type(&self, entity_type: &str) -> Vec<InvalidationEvent> {
let events = self.events.read();
events
.iter()
.filter(|event| event.entity_type == entity_type)
.cloned()
.collect()
}
fn store_event(&self, event: &InvalidationEvent) {
let mut events = self.events.write();
events.push(event.clone());
if events.len() > self.config.max_events {
let excess = events.len() - self.config.max_events;
events.drain(0..excess);
}
let cutoff_time = Utc::now()
- chrono::Duration::from_std(self.config.event_retention).unwrap_or_default();
events.retain(|event| event.timestamp > cutoff_time);
}
fn find_applicable_rules(&self, event: &InvalidationEvent) -> Vec<InvalidationRule> {
let rules = self.rules.read();
rules
.values()
.filter(|rule| {
rule.enabled
&& rule.entity_type == event.entity_type
&& (rule.operations.is_empty() || rule.operations.contains(&event.operation))
})
.cloned()
.collect()
}
async fn process_rule(&self, event: &InvalidationEvent, rule: &InvalidationRule) -> Result<()> {
match &rule.invalidation_strategy {
InvalidationStrategy::InvalidateAll => {
let handlers_guard = self.handlers.read();
for handler in handlers_guard.values() {
if handler.can_handle(event) {
handler.invalidate(event)?;
}
}
}
InvalidationStrategy::InvalidateSpecific(cache_types) => {
let handlers_guard = self.handlers.read();
for cache_type in cache_types {
if let Some(handler) = handlers_guard.get(cache_type) {
if handler.can_handle(event) {
handler.invalidate(event)?;
}
}
}
}
InvalidationStrategy::InvalidateByEntity => {
if let Some(_entity_id) = event.entity_id {
let handlers_guard = self.handlers.read();
for handler in handlers_guard.values() {
if handler.can_handle(event) {
handler.invalidate(event)?;
}
}
}
}
InvalidationStrategy::InvalidateByPattern(pattern) => {
let handlers_guard = self.handlers.read();
for handler in handlers_guard.values() {
if handler.can_handle(event) && Self::matches_pattern(event, pattern) {
handler.invalidate(event)?;
}
}
}
InvalidationStrategy::CascadeInvalidation => {
self.handle_cascade_invalidation(event).await?;
}
}
Ok(())
}
async fn handle_cascade_invalidation(&self, event: &InvalidationEvent) -> Result<()> {
let dependent_entities = Self::find_dependent_entities(event);
for dependent_entity in dependent_entities {
let dependent_event = InvalidationEvent {
event_id: Uuid::new_v4(),
event_type: InvalidationEventType::CascadeInvalidation,
entity_type: dependent_entity.entity_type.clone(),
entity_id: dependent_entity.entity_id,
operation: "cascade_invalidation".to_string(),
timestamp: Utc::now(),
affected_caches: dependent_entity.affected_caches.clone(),
metadata: HashMap::new(),
};
Box::pin(self.process_event(dependent_event)).await?;
self.record_cascade_invalidation();
}
Ok(())
}
fn find_dependent_entities(event: &InvalidationEvent) -> Vec<DependentEntity> {
let mut dependent_entities = Vec::new();
match event.entity_type.as_str() {
"task" => {
if let Some(_task_id) = event.entity_id {
dependent_entities.push(DependentEntity {
entity_type: "project".to_string(),
entity_id: None, affected_caches: vec!["l1".to_string(), "l2".to_string()],
});
dependent_entities.push(DependentEntity {
entity_type: "area".to_string(),
entity_id: None, affected_caches: vec!["l1".to_string(), "l2".to_string()],
});
}
}
"project" => {
if let Some(_project_id) = event.entity_id {
dependent_entities.push(DependentEntity {
entity_type: "task".to_string(),
entity_id: None, affected_caches: vec!["l1".to_string(), "l2".to_string()],
});
}
}
"area" => {
if let Some(_area_id) = event.entity_id {
dependent_entities.push(DependentEntity {
entity_type: "project".to_string(),
entity_id: None,
affected_caches: vec!["l1".to_string(), "l2".to_string()],
});
dependent_entities.push(DependentEntity {
entity_type: "task".to_string(),
entity_id: None,
affected_caches: vec!["l1".to_string(), "l2".to_string()],
});
}
}
_ => {
}
}
dependent_entities
}
fn matches_pattern(event: &InvalidationEvent, pattern: &str) -> bool {
event.entity_type.contains(pattern) || event.operation.contains(pattern)
}
fn record_successful_invalidation(&self) {
let mut stats = self.stats.write();
stats.successful_invalidations += 1;
stats.last_invalidation = Some(Utc::now());
}
fn record_failed_invalidation(&self) {
let mut stats = self.stats.write();
stats.failed_invalidations += 1;
}
fn record_cascade_invalidation(&self) {
let mut stats = self.stats.write();
stats.cascade_invalidations += 1;
}
fn record_manual_invalidation(&self) {
let mut stats = self.stats.write();
stats.manual_invalidations += 1;
}
fn update_processing_time(&self, processing_time: f64) {
let mut stats = self.stats.write();
#[allow(clippy::cast_precision_loss)]
let total_events = stats.total_events as f64;
stats.average_processing_time_ms =
(stats.average_processing_time_ms * (total_events - 1.0) + processing_time)
/ total_events;
}
}
#[derive(Debug, Clone)]
struct DependentEntity {
entity_type: String,
entity_id: Option<Uuid>,
affected_caches: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CascadeInvalidationEvent {
InvalidateAll,
InvalidateSpecific(Vec<String>),
InvalidateByLevel(u32),
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
struct MockCacheHandler {
cache_type: String,
invalidated_events: Arc<RwLock<Vec<InvalidationEvent>>>,
}
impl MockCacheHandler {
fn new(cache_type: &str) -> Self {
Self {
cache_type: cache_type.to_string(),
invalidated_events: Arc::new(RwLock::new(Vec::new())),
}
}
fn _get_invalidated_events(&self) -> Vec<InvalidationEvent> {
self.invalidated_events.read().clone()
}
}
impl CacheInvalidationHandler for MockCacheHandler {
fn invalidate(&self, event: &InvalidationEvent) -> Result<()> {
let mut events = self.invalidated_events.write();
events.push(event.clone());
Ok(())
}
fn cache_type(&self) -> &str {
&self.cache_type
}
fn can_handle(&self, event: &InvalidationEvent) -> bool {
event.affected_caches.is_empty() || event.affected_caches.contains(&self.cache_type)
}
}
#[tokio::test]
async fn test_invalidation_middleware_basic() {
let middleware = CacheInvalidationMiddleware::new_default();
let _l1_handler = Arc::new(MockCacheHandler::new("l1"));
let _l2_handler = Arc::new(MockCacheHandler::new("l2"));
middleware.register_handler(Box::new(MockCacheHandler::new("l1")));
middleware.register_handler(Box::new(MockCacheHandler::new("l2")));
let task_rule = InvalidationRule {
rule_id: Uuid::new_v4(),
name: "task_rule".to_string(),
description: "Rule for task invalidation".to_string(),
entity_type: "task".to_string(),
operations: vec!["updated".to_string()],
affected_cache_types: vec!["l1".to_string(), "l2".to_string()],
invalidation_strategy: InvalidationStrategy::InvalidateAll,
enabled: true,
created_at: Utc::now(),
updated_at: Utc::now(),
};
middleware.add_rule(task_rule);
let project_rule = InvalidationRule {
rule_id: Uuid::new_v4(),
name: "project_rule".to_string(),
description: "Rule for project invalidation".to_string(),
entity_type: "project".to_string(),
operations: vec!["cascade_invalidation".to_string()],
affected_cache_types: vec!["l1".to_string(), "l2".to_string()],
invalidation_strategy: InvalidationStrategy::InvalidateAll,
enabled: true,
created_at: Utc::now(),
updated_at: Utc::now(),
};
middleware.add_rule(project_rule);
let area_rule = InvalidationRule {
rule_id: Uuid::new_v4(),
name: "area_rule".to_string(),
description: "Rule for area invalidation".to_string(),
entity_type: "area".to_string(),
operations: vec!["cascade_invalidation".to_string()],
affected_cache_types: vec!["l1".to_string(), "l2".to_string()],
invalidation_strategy: InvalidationStrategy::InvalidateAll,
enabled: true,
created_at: Utc::now(),
updated_at: Utc::now(),
};
middleware.add_rule(area_rule);
let event = InvalidationEvent {
event_id: Uuid::new_v4(),
event_type: InvalidationEventType::Updated,
entity_type: "task".to_string(),
entity_id: Some(Uuid::new_v4()),
operation: "updated".to_string(),
timestamp: Utc::now(),
affected_caches: vec!["l1".to_string(), "l2".to_string()],
metadata: HashMap::new(),
};
middleware.process_event(event).await.unwrap();
let stats = middleware.get_stats();
assert_eq!(stats.total_events, 3); assert_eq!(stats.successful_invalidations, 3);
}
#[tokio::test]
async fn test_manual_invalidation() {
let middleware = CacheInvalidationMiddleware::new_default();
middleware.register_handler(Box::new(MockCacheHandler::new("l1")));
middleware
.manual_invalidate("task", Some(Uuid::new_v4()), None)
.await
.unwrap();
let stats = middleware.get_stats();
assert_eq!(stats.manual_invalidations, 1);
}
#[tokio::test]
async fn test_event_storage() {
let middleware = CacheInvalidationMiddleware::new_default();
let event = InvalidationEvent {
event_id: Uuid::new_v4(),
event_type: InvalidationEventType::Created,
entity_type: "task".to_string(),
entity_id: Some(Uuid::new_v4()),
operation: "created".to_string(),
timestamp: Utc::now(),
affected_caches: vec![],
metadata: HashMap::new(),
};
middleware.store_event(&event);
let recent_events = middleware.get_recent_events(1);
assert_eq!(recent_events.len(), 1);
assert_eq!(recent_events[0].entity_type, "task");
}
#[tokio::test]
async fn test_invalidation_middleware_creation() {
let middleware = CacheInvalidationMiddleware::new_default();
let stats = middleware.get_stats();
assert_eq!(stats.total_events, 0);
assert_eq!(stats.successful_invalidations, 0);
assert_eq!(stats.failed_invalidations, 0);
assert_eq!(stats.manual_invalidations, 0);
}
#[tokio::test]
async fn test_invalidation_middleware_with_config() {
let config = InvalidationConfig {
enable_cascade: true,
max_events: 1000,
event_retention: Duration::from_secs(3600),
batch_size: 10,
batch_timeout: Duration::from_secs(30),
cascade_depth: 3,
enable_batching: true,
};
let middleware = CacheInvalidationMiddleware::new(config);
let stats = middleware.get_stats();
assert_eq!(stats.total_events, 0);
}
#[tokio::test]
async fn test_add_rule() {
let middleware = CacheInvalidationMiddleware::new_default();
let rule = InvalidationRule {
rule_id: Uuid::new_v4(),
name: "test_rule".to_string(),
description: "Test rule".to_string(),
entity_type: "task".to_string(),
operations: vec!["updated".to_string()],
affected_cache_types: vec!["task_cache".to_string()],
invalidation_strategy: InvalidationStrategy::InvalidateAll,
enabled: true,
created_at: Utc::now(),
updated_at: Utc::now(),
};
middleware.add_rule(rule);
}
#[tokio::test]
async fn test_register_handler() {
let middleware = CacheInvalidationMiddleware::new_default();
let handler = Box::new(MockCacheHandler::new("test_cache"));
middleware.register_handler(handler);
}
#[tokio::test]
async fn test_process_event_with_handler() {
let middleware = CacheInvalidationMiddleware::new_default();
let handler = Box::new(MockCacheHandler::new("test_cache"));
let l1_handler = Box::new(MockCacheHandler::new("l1"));
let l2_handler = Box::new(MockCacheHandler::new("l2"));
middleware.register_handler(handler);
middleware.register_handler(l1_handler);
middleware.register_handler(l2_handler);
let rule = InvalidationRule {
rule_id: Uuid::new_v4(),
name: "test_rule".to_string(),
description: "Test rule".to_string(),
entity_type: "task".to_string(),
operations: vec!["created".to_string(), "updated".to_string()],
affected_cache_types: vec!["test_cache".to_string()],
invalidation_strategy: InvalidationStrategy::InvalidateAll,
enabled: true,
created_at: Utc::now(),
updated_at: Utc::now(),
};
middleware.add_rule(rule);
let project_rule = InvalidationRule {
rule_id: Uuid::new_v4(),
name: "project_rule".to_string(),
description: "Rule for project invalidation".to_string(),
entity_type: "project".to_string(),
operations: vec!["cascade_invalidation".to_string()],
affected_cache_types: vec!["l1".to_string(), "l2".to_string()],
invalidation_strategy: InvalidationStrategy::InvalidateAll,
enabled: true,
created_at: Utc::now(),
updated_at: Utc::now(),
};
middleware.add_rule(project_rule);
let area_rule = InvalidationRule {
rule_id: Uuid::new_v4(),
name: "area_rule".to_string(),
description: "Rule for area invalidation".to_string(),
entity_type: "area".to_string(),
operations: vec!["cascade_invalidation".to_string()],
affected_cache_types: vec!["l1".to_string(), "l2".to_string()],
invalidation_strategy: InvalidationStrategy::InvalidateAll,
enabled: true,
created_at: Utc::now(),
updated_at: Utc::now(),
};
middleware.add_rule(area_rule);
let event = InvalidationEvent {
event_id: Uuid::new_v4(),
event_type: InvalidationEventType::Created,
entity_type: "task".to_string(),
entity_id: Some(Uuid::new_v4()),
operation: "created".to_string(),
timestamp: Utc::now(),
affected_caches: vec!["test_cache".to_string()],
metadata: HashMap::new(),
};
let _ = middleware.process_event(event).await;
let stats = middleware.get_stats();
assert_eq!(stats.total_events, 3); assert_eq!(stats.successful_invalidations, 3);
}
#[tokio::test]
async fn test_get_recent_events() {
let middleware = CacheInvalidationMiddleware::new_default();
for i in 0..5 {
let event = InvalidationEvent {
event_id: Uuid::new_v4(),
event_type: InvalidationEventType::Created,
entity_type: format!("task_{i}"),
entity_id: Some(Uuid::new_v4()),
operation: "created".to_string(),
timestamp: Utc::now(),
affected_caches: vec![],
metadata: HashMap::new(),
};
middleware.store_event(&event);
}
let recent_events = middleware.get_recent_events(3);
assert_eq!(recent_events.len(), 3);
let all_events = middleware.get_recent_events(10);
assert_eq!(all_events.len(), 5);
}
}