use super::SubscriptionId;
use std::collections::HashMap;
use tokio::sync::mpsc;
use tracing::warn;
use vibesql_storage::{ChangeEvent, ChangeEventReceiver};
#[derive(Debug, Clone)]
pub struct SubscriptionUpdate {
pub subscription_id: SubscriptionId,
pub table_name: String,
pub event: ChangeEvent,
}
pub struct ChangeRouter {
change_receiver: ChangeEventReceiver,
table_subscriptions: HashMap<String, Vec<SubscriptionId>>,
session_senders: HashMap<String, mpsc::Sender<SubscriptionUpdate>>,
batch_timeout_ms: u64,
}
impl ChangeRouter {
pub fn new(change_receiver: ChangeEventReceiver) -> Self {
Self {
change_receiver,
table_subscriptions: HashMap::new(),
session_senders: HashMap::new(),
batch_timeout_ms: 10,
}
}
pub fn register_subscription(&mut self, table: String, subscription_id: SubscriptionId) {
self.table_subscriptions.entry(table).or_default().push(subscription_id);
}
pub fn unregister_subscription(&mut self, table: &str, subscription_id: SubscriptionId) {
if let Some(subs) = self.table_subscriptions.get_mut(table) {
subs.retain(|s| s != &subscription_id);
}
}
pub fn register_session(
&mut self,
session_id: String,
sender: mpsc::Sender<SubscriptionUpdate>,
) {
self.session_senders.insert(session_id, sender);
}
pub fn unregister_session(&mut self, session_id: &str) {
self.session_senders.remove(session_id);
}
pub async fn run(&mut self) {
loop {
let events = self.change_receiver.recv_all();
for event in events {
self.process_change(&event).await;
}
tokio::time::sleep(tokio::time::Duration::from_millis(self.batch_timeout_ms)).await;
}
}
async fn process_change(&mut self, event: &ChangeEvent) {
let table = event.table_name();
if let Some(subscription_ids) = self.table_subscriptions.get(table) {
for subscription_id in subscription_ids.iter().copied() {
let update = SubscriptionUpdate {
subscription_id,
table_name: table.to_string(),
event: event.clone(),
};
for (session_id, sender) in &self.session_senders {
match sender.try_send(update.clone()) {
Ok(()) => {
}
Err(mpsc::error::TrySendError::Full(_)) => {
warn!(
subscription_id = %subscription_id,
session_id = %session_id,
table = %table,
"Session channel full, dropping update. \
Consider increasing channel buffer size or client is consuming too slowly."
);
}
Err(mpsc::error::TrySendError::Closed(_)) => {
}
}
}
}
}
}
pub fn subscription_count_for_table(&self, table: &str) -> usize {
self.table_subscriptions.get(table).map(|subs| subs.len()).unwrap_or(0)
}
pub fn total_subscription_count(&self) -> usize {
self.table_subscriptions.values().map(|v| v.len()).sum()
}
pub fn set_batch_timeout(&mut self, ms: u64) {
self.batch_timeout_ms = ms;
}
}
#[cfg(test)]
mod tests {
use super::*;
use vibesql_storage::change_event_channel;
#[tokio::test]
async fn test_register_subscription() {
let (_sender, receiver) = change_event_channel(16);
let mut router = ChangeRouter::new(receiver);
let sub_id = SubscriptionId::new();
router.register_subscription("users".to_string(), sub_id);
assert_eq!(router.subscription_count_for_table("users"), 1);
}
#[tokio::test]
async fn test_unregister_subscription() {
let (_sender, receiver) = change_event_channel(16);
let mut router = ChangeRouter::new(receiver);
let sub_id = SubscriptionId::new();
router.register_subscription("users".to_string(), sub_id);
assert_eq!(router.subscription_count_for_table("users"), 1);
router.unregister_subscription("users", sub_id);
assert_eq!(router.subscription_count_for_table("users"), 0);
}
#[tokio::test]
async fn test_multiple_subscriptions_same_table() {
let (_sender, receiver) = change_event_channel(16);
let mut router = ChangeRouter::new(receiver);
let sub1 = SubscriptionId::new();
let sub2 = SubscriptionId::new();
router.register_subscription("users".to_string(), sub1);
router.register_subscription("users".to_string(), sub2);
assert_eq!(router.subscription_count_for_table("users"), 2);
}
#[tokio::test]
async fn test_different_tables() {
let (_sender, receiver) = change_event_channel(16);
let mut router = ChangeRouter::new(receiver);
let sub1 = SubscriptionId::new();
let sub2 = SubscriptionId::new();
router.register_subscription("users".to_string(), sub1);
router.register_subscription("orders".to_string(), sub2);
assert_eq!(router.subscription_count_for_table("users"), 1);
assert_eq!(router.subscription_count_for_table("orders"), 1);
assert_eq!(router.total_subscription_count(), 2);
}
#[tokio::test]
async fn test_session_registration() {
let (_sender, receiver) = change_event_channel(16);
let mut router = ChangeRouter::new(receiver);
let (tx, _rx) = mpsc::channel(16);
router.register_session("session1".to_string(), tx);
assert_eq!(router.session_senders.len(), 1);
}
#[tokio::test]
async fn test_session_unregistration() {
let (_sender, receiver) = change_event_channel(16);
let mut router = ChangeRouter::new(receiver);
let (tx, _rx) = mpsc::channel(16);
router.register_session("session1".to_string(), tx);
assert_eq!(router.session_senders.len(), 1);
router.unregister_session("session1");
assert_eq!(router.session_senders.len(), 0);
}
#[tokio::test]
async fn test_process_change_event() {
let (sender, receiver) = change_event_channel(16);
let mut router = ChangeRouter::new(receiver);
let sub_id = SubscriptionId::new();
router.register_subscription("users".to_string(), sub_id);
let (tx, mut rx) = mpsc::channel(16);
router.register_session("session1".to_string(), tx);
let event = ChangeEvent::Insert { table_name: "users".to_string(), row_index: 0 };
sender.send(event);
let events = router.change_receiver.recv_all();
for evt in events {
router.process_change(&evt).await;
}
let update = rx.try_recv();
assert!(update.is_ok());
if let Ok(upd) = update {
assert_eq!(upd.subscription_id, sub_id);
assert_eq!(upd.table_name, "users");
}
}
}