use std::sync::Arc;
use tokio::sync::broadcast;
pub const INVALIDATION_CHANNEL: &str = "skp_cache:invalidate";
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum InvalidationEvent {
Key(String),
Pattern(String),
Tag(String),
Clear,
}
impl InvalidationEvent {
pub fn to_message(&self) -> String {
match self {
InvalidationEvent::Key(k) => format!("key:{}", k),
InvalidationEvent::Pattern(p) => format!("pattern:{}", p),
InvalidationEvent::Tag(t) => format!("tag:{}", t),
InvalidationEvent::Clear => "clear".to_string(),
}
}
pub fn from_message(msg: &str) -> Option<Self> {
if msg == "clear" {
return Some(InvalidationEvent::Clear);
}
let (prefix, value) = msg.split_once(':')?;
let value = value.to_string();
match prefix {
"key" => Some(InvalidationEvent::Key(value)),
"pattern" => Some(InvalidationEvent::Pattern(value)),
"tag" => Some(InvalidationEvent::Tag(value)),
_ => None,
}
}
}
#[derive(Clone)]
pub struct InvalidationPublisher {
tx: broadcast::Sender<InvalidationEvent>,
}
impl InvalidationPublisher {
pub fn new(capacity: usize) -> (Self, InvalidationSubscriber) {
let (tx, rx) = broadcast::channel(capacity);
(
Self { tx },
InvalidationSubscriber { rx },
)
}
pub fn publish(&self, event: InvalidationEvent) -> Result<usize, PublishError> {
self.tx.send(event).map_err(|_| PublishError::NoSubscribers)
}
pub fn subscribe(&self) -> InvalidationSubscriber {
InvalidationSubscriber {
rx: self.tx.subscribe(),
}
}
}
pub struct InvalidationSubscriber {
rx: broadcast::Receiver<InvalidationEvent>,
}
impl InvalidationSubscriber {
pub async fn recv(&mut self) -> Result<InvalidationEvent, SubscribeError> {
self.rx.recv().await.map_err(|e| match e {
broadcast::error::RecvError::Closed => SubscribeError::Closed,
broadcast::error::RecvError::Lagged(n) => SubscribeError::Lagged(n),
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PublishError {
NoSubscribers,
}
impl std::fmt::Display for PublishError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PublishError::NoSubscribers => write!(f, "no subscribers listening"),
}
}
}
impl std::error::Error for PublishError {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SubscribeError {
Closed,
Lagged(u64),
}
impl std::fmt::Display for SubscribeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SubscribeError::Closed => write!(f, "channel closed"),
SubscribeError::Lagged(n) => write!(f, "lagged behind by {} messages", n),
}
}
}
impl std::error::Error for SubscribeError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_event_serialization() {
let events = vec![
(InvalidationEvent::Key("foo".into()), "key:foo"),
(InvalidationEvent::Pattern("user:*".into()), "pattern:user:*"),
(InvalidationEvent::Tag("users".into()), "tag:users"),
(InvalidationEvent::Clear, "clear"),
];
for (event, expected) in events {
let msg = event.to_message();
assert_eq!(msg, expected);
let parsed = InvalidationEvent::from_message(&msg);
assert_eq!(parsed, Some(event));
}
}
#[tokio::test]
async fn test_pubsub() {
let (publisher, mut subscriber) = InvalidationPublisher::new(16);
publisher.publish(InvalidationEvent::Key("test".into())).unwrap();
let event = subscriber.recv().await.unwrap();
assert_eq!(event, InvalidationEvent::Key("test".into()));
}
#[tokio::test]
async fn test_multiple_subscribers() {
let (publisher, mut sub1) = InvalidationPublisher::new(16);
let mut sub2 = publisher.subscribe();
publisher.publish(InvalidationEvent::Clear).unwrap();
let e1 = sub1.recv().await.unwrap();
let e2 = sub2.recv().await.unwrap();
assert_eq!(e1, InvalidationEvent::Clear);
assert_eq!(e2, InvalidationEvent::Clear);
}
}