use futures::StreamExt;
use redis::{AsyncCommands, Client, aio::PubSub};
use reinhardt_core::exception::{Error, Result};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CacheInvalidationMessage {
InvalidateKey {
key: String,
},
InvalidatePattern {
pattern: String,
},
ClearAll,
}
pub struct CacheInvalidationChannel {
client: Client,
channel_name: String,
}
impl CacheInvalidationChannel {
pub async fn new(redis_url: &str) -> Result<Self> {
let client =
Client::open(redis_url).map_err(|e| Error::Http(format!("Redis error: {}", e)))?;
Ok(Self {
client,
channel_name: "cache:invalidation".to_string(),
})
}
pub async fn with_channel_name(redis_url: &str, channel_name: String) -> Result<Self> {
let client =
Client::open(redis_url).map_err(|e| Error::Http(format!("Redis error: {}", e)))?;
Ok(Self {
client,
channel_name,
})
}
pub async fn invalidate(&self, key: &str) -> Result<()> {
let msg = CacheInvalidationMessage::InvalidateKey {
key: key.to_string(),
};
self.publish(msg).await
}
pub async fn invalidate_pattern(&self, pattern: &str) -> Result<()> {
let msg = CacheInvalidationMessage::InvalidatePattern {
pattern: pattern.to_string(),
};
self.publish(msg).await
}
pub async fn clear_all(&self) -> Result<()> {
let msg = CacheInvalidationMessage::ClearAll;
self.publish(msg).await
}
async fn publish(&self, message: CacheInvalidationMessage) -> Result<()> {
let mut conn = self
.client
.get_multiplexed_async_connection()
.await
.map_err(|e| Error::Http(format!("Redis error: {}", e)))?;
let json = serde_json::to_string(&message)
.map_err(|e| Error::Serialization(format!("Serialization error: {}", e)))?;
let _: () = conn
.publish(&self.channel_name, json)
.await
.map_err(|e| Error::Http(format!("Redis error: {}", e)))?;
Ok(())
}
pub async fn subscribe(&self) -> Result<CacheInvalidationSubscriber> {
let mut pubsub = self
.client
.get_async_pubsub()
.await
.map_err(|e| Error::Http(format!("Redis error: {}", e)))?;
pubsub
.subscribe(&self.channel_name)
.await
.map_err(|e| Error::Http(format!("Redis error: {}", e)))?;
Ok(CacheInvalidationSubscriber {
pubsub: Arc::new(Mutex::new(pubsub)),
})
}
}
pub struct CacheInvalidationSubscriber {
pubsub: Arc<Mutex<PubSub>>,
}
impl CacheInvalidationSubscriber {
pub async fn next_message(&mut self) -> Result<Option<CacheInvalidationMessage>> {
let mut pubsub = self.pubsub.lock().await;
match pubsub.on_message().next().await {
Some(msg) => {
let payload: String = msg
.get_payload()
.map_err(|e| Error::Http(format!("Redis error: {}", e)))?;
let message: CacheInvalidationMessage = serde_json::from_str(&payload)
.map_err(|e| Error::Serialization(format!("Deserialization error: {}", e)))?;
Ok(Some(message))
}
None => Ok(None),
}
}
}