use serde::{Deserialize, Serialize};
use tokio::sync::broadcast;
use crate::commit::TenantId;
pub const DEFAULT_BUS_CAPACITY: usize = 4096;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "kind")]
pub enum InvalidationEvent {
Tombstoned { tenant_id: TenantId, rid: String },
Updated { tenant_id: TenantId, rid: String },
EdgeChanged {
tenant_id: TenantId,
src: String,
dst: String,
},
TenantConfigChanged { tenant_id: TenantId, key: String },
}
impl InvalidationEvent {
pub fn tenant_id(&self) -> TenantId {
match self {
InvalidationEvent::Tombstoned { tenant_id, .. }
| InvalidationEvent::Updated { tenant_id, .. }
| InvalidationEvent::EdgeChanged { tenant_id, .. }
| InvalidationEvent::TenantConfigChanged { tenant_id, .. } => *tenant_id,
}
}
pub fn variant_name(&self) -> &'static str {
match self {
InvalidationEvent::Tombstoned { .. } => "Tombstoned",
InvalidationEvent::Updated { .. } => "Updated",
InvalidationEvent::EdgeChanged { .. } => "EdgeChanged",
InvalidationEvent::TenantConfigChanged { .. } => "TenantConfigChanged",
}
}
}
#[derive(Clone)]
pub struct InvalidationBus {
sender: broadcast::Sender<InvalidationEvent>,
}
impl InvalidationBus {
pub fn new() -> Self {
Self::with_capacity(DEFAULT_BUS_CAPACITY)
}
pub fn with_capacity(cap: usize) -> Self {
let (sender, _rx) = broadcast::channel(cap);
Self { sender }
}
pub fn publish(&self, event: InvalidationEvent) -> usize {
self.sender.send(event).unwrap_or(0)
}
pub fn subscribe(&self) -> broadcast::Receiver<InvalidationEvent> {
self.sender.subscribe()
}
pub fn subscriber_count(&self) -> usize {
self.sender.receiver_count()
}
}
impl Default for InvalidationBus {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn empty_bus_starts_with_zero_subscribers() {
let bus = InvalidationBus::new();
assert_eq!(bus.subscriber_count(), 0);
}
#[tokio::test]
async fn publish_with_no_subscribers_returns_zero_not_error() {
let bus = InvalidationBus::new();
let n = bus.publish(InvalidationEvent::Tombstoned {
tenant_id: TenantId::new(1),
rid: "x".into(),
});
assert_eq!(n, 0);
}
#[tokio::test]
async fn subscribe_receives_published_event() {
let bus = InvalidationBus::new();
let mut rx = bus.subscribe();
bus.publish(InvalidationEvent::Tombstoned {
tenant_id: TenantId::new(1),
rid: "mem_a".into(),
});
let evt = rx.recv().await.expect("receive");
match evt {
InvalidationEvent::Tombstoned { tenant_id, rid } => {
assert_eq!(tenant_id, TenantId::new(1));
assert_eq!(rid, "mem_a");
}
other => panic!("wrong variant: {other:?}"),
}
}
#[tokio::test]
async fn multi_subscriber_fanout() {
let bus = InvalidationBus::new();
let mut rx1 = bus.subscribe();
let mut rx2 = bus.subscribe();
let mut rx3 = bus.subscribe();
let n = bus.publish(InvalidationEvent::Updated {
tenant_id: TenantId::new(1),
rid: "mem_a".into(),
});
assert_eq!(n, 3);
for rx in [&mut rx1, &mut rx2, &mut rx3] {
let evt = rx.recv().await.expect("receive");
assert_eq!(evt.variant_name(), "Updated");
}
}
#[tokio::test]
async fn slow_subscriber_lags_but_does_not_block_publisher() {
let bus = InvalidationBus::with_capacity(2);
let mut rx = bus.subscribe();
for i in 0..5 {
bus.publish(InvalidationEvent::Tombstoned {
tenant_id: TenantId::new(1),
rid: format!("mem_{i}"),
});
}
let result = rx.recv().await;
match result {
Err(broadcast::error::RecvError::Lagged(n)) => {
assert!(n >= 1, "lagged by at least 1 message");
}
other => panic!("expected Lagged, got {other:?}"),
}
}
#[tokio::test]
async fn subscriber_count_tracks_active_receivers() {
let bus = InvalidationBus::new();
assert_eq!(bus.subscriber_count(), 0);
let _rx1 = bus.subscribe();
assert_eq!(bus.subscriber_count(), 1);
{
let _rx2 = bus.subscribe();
assert_eq!(bus.subscriber_count(), 2);
}
assert_eq!(bus.subscriber_count(), 1);
}
#[test]
fn variant_names_are_stable() {
assert_eq!(
InvalidationEvent::Tombstoned {
tenant_id: TenantId::new(1),
rid: "x".into(),
}
.variant_name(),
"Tombstoned"
);
assert_eq!(
InvalidationEvent::Updated {
tenant_id: TenantId::new(1),
rid: "x".into(),
}
.variant_name(),
"Updated"
);
assert_eq!(
InvalidationEvent::EdgeChanged {
tenant_id: TenantId::new(1),
src: "a".into(),
dst: "b".into(),
}
.variant_name(),
"EdgeChanged"
);
assert_eq!(
InvalidationEvent::TenantConfigChanged {
tenant_id: TenantId::new(1),
key: "k".into(),
}
.variant_name(),
"TenantConfigChanged"
);
}
#[test]
fn tenant_id_extraction_works_for_all_variants() {
let t = TenantId::new(42);
let cases = vec![
InvalidationEvent::Tombstoned {
tenant_id: t,
rid: "x".into(),
},
InvalidationEvent::Updated {
tenant_id: t,
rid: "x".into(),
},
InvalidationEvent::EdgeChanged {
tenant_id: t,
src: "a".into(),
dst: "b".into(),
},
InvalidationEvent::TenantConfigChanged {
tenant_id: t,
key: "k".into(),
},
];
for evt in cases {
assert_eq!(evt.tenant_id(), t);
}
}
}