use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use tracing::{debug, warn};
use crate::types::TenantId;
pub const DEFAULT_QUEUE_CAP: usize = 1024;
#[derive(Debug, Clone)]
pub struct Notification {
pub channel: String,
pub payload: String,
pub pid: i32,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct BusKey {
tenant_id: u64,
channel: String,
}
struct SessionSink {
tx: tokio::sync::mpsc::Sender<Notification>,
cap: usize,
}
pub struct NotifyBus {
subscribers: RwLock<HashMap<BusKey, Vec<(u64 /* session_id */, SessionSink)>>>,
next_session_id: AtomicU64,
pub dropped: AtomicU64,
queue_cap: usize,
}
impl Default for NotifyBus {
fn default() -> Self {
Self::new(DEFAULT_QUEUE_CAP)
}
}
impl NotifyBus {
pub fn new(queue_cap: usize) -> Self {
Self {
subscribers: RwLock::new(HashMap::new()),
next_session_id: AtomicU64::new(1),
dropped: AtomicU64::new(0),
queue_cap,
}
}
pub fn listen(
&self,
tenant_id: TenantId,
channel: &str,
) -> (u64, tokio::sync::mpsc::Receiver<Notification>) {
let key = BusKey {
tenant_id: tenant_id.as_u64(),
channel: normalize_channel(channel),
};
let session_id = self.next_session_id.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = tokio::sync::mpsc::channel(self.queue_cap);
let sink = SessionSink {
tx,
cap: self.queue_cap,
};
let mut map = self.subscribers.write().unwrap_or_else(|p| p.into_inner());
map.entry(key.clone()).or_default().push((session_id, sink));
debug!(
session_id,
tenant = tenant_id.as_u64(),
channel = key.channel.as_str(),
"LISTEN registered"
);
(session_id, rx)
}
pub fn unlisten(&self, tenant_id: TenantId, channel: &str, session_id: u64) {
let key = BusKey {
tenant_id: tenant_id.as_u64(),
channel: normalize_channel(channel),
};
let mut map = self.subscribers.write().unwrap_or_else(|p| p.into_inner());
if let Some(sinks) = map.get_mut(&key) {
sinks.retain(|(id, _)| *id != session_id);
if sinks.is_empty() {
map.remove(&key);
}
}
debug!(
session_id,
tenant = tenant_id.as_u64(),
channel = key.channel.as_str(),
"UNLISTEN"
);
}
pub fn unlisten_all(&self, tenant_id: TenantId, session_ids: &[(String, u64)]) {
if session_ids.is_empty() {
return;
}
let mut map = self.subscribers.write().unwrap_or_else(|p| p.into_inner());
for (channel, session_id) in session_ids {
let key = BusKey {
tenant_id: tenant_id.as_u64(),
channel: normalize_channel(channel),
};
if let Some(sinks) = map.get_mut(&key) {
sinks.retain(|(id, _)| id != session_id);
if sinks.is_empty() {
map.remove(&key);
}
}
}
debug!(
tenant = tenant_id.as_u64(),
count = session_ids.len(),
"UNLISTEN * (session disconnect)"
);
}
pub fn notify(&self, tenant_id: TenantId, channel: &str, payload: &str) {
let key = BusKey {
tenant_id: tenant_id.as_u64(),
channel: normalize_channel(channel),
};
let notification = Notification {
channel: key.channel.clone(),
payload: payload.to_string(),
pid: 0,
};
let map = self.subscribers.read().unwrap_or_else(|p| p.into_inner());
let sinks = match map.get(&key) {
Some(s) => s,
None => return, };
let mut dead = Vec::new();
for (session_id, sink) in sinks {
match sink.tx.try_send(notification.clone()) {
Ok(()) => {}
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
self.dropped.fetch_add(1, Ordering::Relaxed);
warn!(
session_id,
channel = key.channel.as_str(),
cap = sink.cap,
"NOTIFY queue full — dropping oldest (metric incremented)"
);
}
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
dead.push(*session_id);
}
}
}
drop(map);
if !dead.is_empty() {
let mut map = self.subscribers.write().unwrap_or_else(|p| p.into_inner());
if let Some(sinks) = map.get_mut(&key) {
sinks.retain(|(id, _)| !dead.contains(id));
if sinks.is_empty() {
map.remove(&key);
}
}
}
}
pub fn total_dropped(&self) -> u64 {
self.dropped.load(Ordering::Relaxed)
}
pub fn subscription_count(&self) -> usize {
let map = self.subscribers.read().unwrap_or_else(|p| p.into_inner());
map.values().map(|v| v.len()).sum()
}
}
pub fn normalize_channel(channel: &str) -> String {
channel.to_lowercase()
}
pub struct ListenHandle {
pub channel: String,
pub session_id: u64,
pub rx: tokio::sync::mpsc::Receiver<Notification>,
}
pub type NotifyBusHandle = Arc<NotifyBus>;
#[cfg(test)]
mod tests {
use super::*;
fn tenant(n: u64) -> TenantId {
TenantId::new(n)
}
#[tokio::test]
async fn basic_listen_notify() {
let bus = NotifyBus::new(64);
let t = tenant(1);
let (_, mut rx) = bus.listen(t, "orders");
bus.notify(t, "orders", "hello");
let n = rx.try_recv().unwrap();
assert_eq!(n.channel, "orders");
assert_eq!(n.payload, "hello");
}
#[tokio::test]
async fn unlisten_stops_delivery() {
let bus = NotifyBus::new(64);
let t = tenant(1);
let (sid, mut rx) = bus.listen(t, "orders");
bus.notify(t, "orders", "first");
assert!(rx.try_recv().is_ok());
bus.unlisten(t, "orders", sid);
bus.notify(t, "orders", "second");
assert!(rx.try_recv().is_err());
}
#[tokio::test]
async fn tenant_isolation() {
let bus = NotifyBus::new(64);
let t1 = tenant(1);
let t2 = tenant(2);
let (_, mut rx1) = bus.listen(t1, "ch");
let (_, mut rx2) = bus.listen(t2, "ch");
bus.notify(t1, "ch", "for-t1");
assert!(rx1.try_recv().is_ok());
assert!(rx2.try_recv().is_err());
}
#[tokio::test]
async fn queue_full_increments_dropped() {
let bus = NotifyBus::new(2); let t = tenant(1);
let (_, _rx) = bus.listen(t, "ch"); bus.notify(t, "ch", "a");
bus.notify(t, "ch", "b"); bus.notify(t, "ch", "c"); assert_eq!(bus.total_dropped(), 1);
}
#[tokio::test]
async fn unlisten_all() {
let bus = NotifyBus::new(64);
let t = tenant(1);
let (sid1, mut rx1) = bus.listen(t, "ch1");
let (sid2, mut rx2) = bus.listen(t, "ch2");
bus.unlisten_all(t, &[("ch1".to_string(), sid1), ("ch2".to_string(), sid2)]);
bus.notify(t, "ch1", "msg");
bus.notify(t, "ch2", "msg");
assert!(rx1.try_recv().is_err());
assert!(rx2.try_recv().is_err());
}
#[test]
fn channel_normalize() {
assert_eq!(normalize_channel("Orders"), "orders");
assert_eq!(normalize_channel("my_channel"), "my_channel");
}
}