use std::collections::HashSet;
use std::sync::Arc;
use dashmap::DashMap;
use tokio::sync::mpsc;
use tracing::{debug, trace, warn};
use vibesql_storage::Database;
use vibesql_storage::change_events::RecvError;
use super::{
extract_table_refs, hash_rows, Subscription, SubscriptionError, SubscriptionId,
SubscriptionUpdate,
};
pub struct SubscriptionManager {
subscriptions: DashMap<SubscriptionId, Subscription>,
table_index: DashMap<String, HashSet<SubscriptionId>>,
}
impl SubscriptionManager {
pub fn new() -> Self {
Self {
subscriptions: DashMap::new(),
table_index: DashMap::new(),
}
}
pub fn subscribe(
&self,
query: String,
notify_tx: mpsc::Sender<SubscriptionUpdate>,
) -> Result<SubscriptionId, SubscriptionError> {
let tables = self.extract_tables(&query)?;
if tables.is_empty() {
return Err(SubscriptionError::ParseError(
"Query must reference at least one table".to_string(),
));
}
let subscription = Subscription::new(query.clone(), tables.clone(), notify_tx);
let id = subscription.id;
debug!(
subscription_id = %id,
tables = ?tables,
"Creating new subscription"
);
self.subscriptions.insert(id, subscription);
for table in tables {
self.table_index
.entry(table)
.or_default()
.insert(id);
}
Ok(id)
}
pub fn unsubscribe(&self, id: SubscriptionId) {
if let Some((_, subscription)) = self.subscriptions.remove(&id) {
debug!(subscription_id = %id, "Removing subscription");
for table in &subscription.tables {
if let Some(mut ids) = self.table_index.get_mut(table) {
ids.remove(&id);
}
}
}
}
pub fn subscription_count(&self) -> usize {
self.subscriptions.len()
}
pub fn watched_tables(&self) -> Vec<(String, usize)> {
self.table_index
.iter()
.map(|entry| (entry.key().clone(), entry.value().len()))
.collect()
}
pub fn find_affected_subscriptions(&self, table_name: &str) -> Vec<SubscriptionId> {
let table = table_name.to_lowercase();
self.table_index
.get(&table)
.map(|ids| ids.iter().copied().collect())
.unwrap_or_default()
}
pub async fn handle_change(&self, event: vibesql_storage::ChangeEvent, db: &Database) {
let table = event.table_name();
trace!(
table = %table,
event = ?event,
"Processing change event from storage"
);
let affected_ids = self.find_affected_subscriptions(table);
if affected_ids.is_empty() {
trace!(table = %table, "No subscriptions affected");
return;
}
debug!(
table = %table,
affected_count = affected_ids.len(),
"Found affected subscriptions"
);
for id in affected_ids {
self.check_and_notify(id, db).await;
}
}
async fn check_and_notify(&self, id: SubscriptionId, db: &Database) {
let mut sub_ref = match self.subscriptions.get_mut(&id) {
Some(sub) => sub,
None => {
trace!(subscription_id = %id, "Subscription not found (may have been removed)");
return;
}
};
let subscription = sub_ref.value_mut();
let executor = vibesql_executor::SelectExecutor::new(db);
let result = match vibesql_parser::Parser::parse_sql(&subscription.query) {
Ok(vibesql_ast::Statement::Select(select)) => executor.execute(&select),
Ok(_) => {
warn!(
subscription_id = %id,
"Subscription query is not a SELECT"
);
return;
}
Err(e) => {
let _ = subscription
.notify_tx
.send(SubscriptionUpdate::Error {
message: format!("Failed to parse query: {}", e),
})
.await;
return;
}
};
match result {
Ok(rows) => {
let result_rows: Vec<crate::Row> = rows
.iter()
.map(|r| crate::Row {
values: r.values.clone(),
})
.collect();
let new_hash = hash_rows(&result_rows);
if new_hash != subscription.last_result_hash {
debug!(
subscription_id = %id,
old_hash = subscription.last_result_hash,
new_hash = new_hash,
row_count = result_rows.len(),
"Results changed, notifying subscriber"
);
subscription.last_result_hash = new_hash;
if subscription
.notify_tx
.send(SubscriptionUpdate::Full { rows: result_rows })
.await
.is_err()
{
trace!(
subscription_id = %id,
"Notification channel closed, subscription will be cleaned up"
);
}
} else {
trace!(
subscription_id = %id,
"Results unchanged, no notification needed"
);
}
}
Err(e) => {
let _ = subscription
.notify_tx
.send(SubscriptionUpdate::Error {
message: format!("Query execution failed: {}", e),
})
.await;
}
}
}
pub async fn run_event_loop(&self, mut change_rx: vibesql_storage::ChangeEventReceiver, db: Arc<Database>) {
loop {
match change_rx.try_recv() {
Ok(event) => {
self.handle_change(event, &db).await;
}
Err(RecvError::Lagged(n)) => {
warn!(
lagged_count = n,
"SubscriptionManager lagged behind change events"
);
}
Err(RecvError::Closed) => {
debug!("Change event channel closed, stopping subscription manager");
break;
}
Err(RecvError::Empty) => {
tokio::task::yield_now().await;
}
}
}
}
fn extract_tables(&self, query: &str) -> Result<HashSet<String>, SubscriptionError> {
let stmt = vibesql_parser::Parser::parse_sql(query)
.map_err(|e| SubscriptionError::ParseError(e.to_string()))?;
Ok(extract_table_refs(&stmt))
}
pub async fn send_initial_results(
&self,
id: SubscriptionId,
db: &Database,
) -> Result<(), SubscriptionError> {
let mut sub_ref = self
.subscriptions
.get_mut(&id)
.ok_or(SubscriptionError::NotFound(id))?;
let subscription = sub_ref.value_mut();
let executor = vibesql_executor::SelectExecutor::new(db);
let stmt = vibesql_parser::Parser::parse_sql(&subscription.query)
.map_err(|e| SubscriptionError::ParseError(e.to_string()))?;
let rows = match stmt {
vibesql_ast::Statement::Select(select) => executor
.execute(&select)
.map_err(|e| SubscriptionError::ParseError(e.to_string()))?,
_ => return Err(SubscriptionError::ParseError("Not a SELECT query".to_string())),
};
let result_rows: Vec<crate::Row> = rows
.iter()
.map(|r| crate::Row {
values: r.values.clone(),
})
.collect();
subscription.last_result_hash = hash_rows(&result_rows);
subscription
.notify_tx
.send(SubscriptionUpdate::Full { rows: result_rows })
.await
.map_err(|_| SubscriptionError::ChannelClosed)?;
Ok(())
}
}
impl Default for SubscriptionManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use vibesql_types::SqlValue;
fn setup_test_db() -> Database {
let mut db = Database::new();
let create_users = vibesql_parser::Parser::parse_sql(
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100), active BOOLEAN)",
)
.unwrap();
if let vibesql_ast::Statement::CreateTable(stmt) = create_users {
vibesql_executor::CreateTableExecutor::execute(&stmt, &mut db).unwrap();
}
let create_orders = vibesql_parser::Parser::parse_sql(
"CREATE TABLE orders (id INT PRIMARY KEY, user_id INT, amount INT)",
)
.unwrap();
if let vibesql_ast::Statement::CreateTable(stmt) = create_orders {
vibesql_executor::CreateTableExecutor::execute(&stmt, &mut db).unwrap();
}
db
}
#[test]
fn test_subscribe_simple() {
let manager = SubscriptionManager::new();
let (tx, _rx) = mpsc::channel(16);
let result = manager.subscribe("SELECT * FROM users".to_string(), tx);
assert!(result.is_ok());
let _id = result.unwrap();
assert_eq!(manager.subscription_count(), 1);
let watched = manager.watched_tables();
assert_eq!(watched.len(), 1);
assert!(watched.iter().any(|(t, c)| t == "users" && *c == 1));
}
#[test]
fn test_subscribe_with_join() {
let manager = SubscriptionManager::new();
let (tx, _rx) = mpsc::channel(16);
let result = manager.subscribe(
"SELECT * FROM users u JOIN orders o ON u.id = o.user_id".to_string(),
tx,
);
assert!(result.is_ok());
let watched = manager.watched_tables();
assert_eq!(watched.len(), 2);
assert!(watched.iter().any(|(t, _)| t == "users"));
assert!(watched.iter().any(|(t, _)| t == "orders"));
}
#[test]
fn test_unsubscribe() {
let manager = SubscriptionManager::new();
let (tx, _rx) = mpsc::channel(16);
let id = manager
.subscribe("SELECT * FROM users".to_string(), tx)
.unwrap();
assert_eq!(manager.subscription_count(), 1);
manager.unsubscribe(id);
assert_eq!(manager.subscription_count(), 0);
let watched = manager.watched_tables();
assert!(watched.iter().all(|(_, c)| *c == 0));
}
#[test]
fn test_invalid_query() {
let manager = SubscriptionManager::new();
let (tx, _rx) = mpsc::channel(16);
let result = manager.subscribe("SELECT * FROM".to_string(), tx);
assert!(result.is_err());
assert!(matches!(result, Err(SubscriptionError::ParseError(_))));
}
#[test]
fn test_query_without_tables() {
let manager = SubscriptionManager::new();
let (tx, _rx) = mpsc::channel(16);
let result = manager.subscribe("SELECT 1 + 1".to_string(), tx);
assert!(result.is_err());
}
#[tokio::test]
async fn test_handle_change_notifies_subscribers() {
let manager = SubscriptionManager::new();
let (tx, mut rx) = mpsc::channel(16);
let db = setup_test_db();
let _id = manager
.subscribe("SELECT * FROM users".to_string(), tx)
.unwrap();
manager
.handle_change(
vibesql_storage::ChangeEvent::Insert {
table_name: "users".to_string(),
row_index: 0,
},
&db,
)
.await;
let update = rx.try_recv();
assert!(update.is_ok());
match update.unwrap() {
SubscriptionUpdate::Full { rows } => {
assert!(rows.is_empty());
}
_ => panic!("Expected Full update"),
}
}
#[tokio::test]
async fn test_handle_change_ignores_unrelated_tables() {
let manager = SubscriptionManager::new();
let (tx, mut rx) = mpsc::channel(16);
let db = setup_test_db();
let _id = manager
.subscribe("SELECT * FROM users".to_string(), tx)
.unwrap();
manager
.handle_change(
vibesql_storage::ChangeEvent::Insert {
table_name: "orders".to_string(),
row_index: 0,
},
&db,
)
.await;
let update = rx.try_recv();
assert!(update.is_err()); }
#[tokio::test]
async fn test_send_initial_results() {
let manager = SubscriptionManager::new();
let (tx, mut rx) = mpsc::channel(16);
let mut db = setup_test_db();
let insert = vibesql_parser::Parser::parse_sql("INSERT INTO users VALUES (1, 'Alice', TRUE)")
.unwrap();
if let vibesql_ast::Statement::Insert(stmt) = insert {
vibesql_executor::InsertExecutor::execute(&mut db, &stmt).unwrap();
}
let id = manager
.subscribe("SELECT * FROM users".to_string(), tx)
.unwrap();
manager.send_initial_results(id, &db).await.unwrap();
let update = rx.recv().await.unwrap();
match update {
SubscriptionUpdate::Full { rows } => {
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values[0], SqlValue::Integer(1));
}
_ => panic!("Expected Full update"),
}
}
#[tokio::test]
async fn test_results_changed_detection() {
let manager = SubscriptionManager::new();
let (tx, mut rx) = mpsc::channel(16);
let mut db = setup_test_db();
let id = manager
.subscribe("SELECT * FROM users".to_string(), tx)
.unwrap();
manager.send_initial_results(id, &db).await.unwrap();
let _ = rx.recv().await;
let insert = vibesql_parser::Parser::parse_sql("INSERT INTO users VALUES (1, 'Alice', TRUE)")
.unwrap();
if let vibesql_ast::Statement::Insert(stmt) = insert {
vibesql_executor::InsertExecutor::execute(&mut db, &stmt).unwrap();
}
manager
.handle_change(
vibesql_storage::ChangeEvent::Insert {
table_name: "users".to_string(),
row_index: 0,
},
&db,
)
.await;
let update = rx.recv().await.unwrap();
match update {
SubscriptionUpdate::Full { rows } => {
assert_eq!(rows.len(), 1);
}
_ => panic!("Expected Full update"),
}
}
#[tokio::test]
async fn test_no_notification_when_unchanged() {
let manager = SubscriptionManager::new();
let (tx, mut rx) = mpsc::channel(16);
let db = setup_test_db();
let id = manager
.subscribe("SELECT * FROM users".to_string(), tx)
.unwrap();
manager.send_initial_results(id, &db).await.unwrap();
let _ = rx.recv().await;
manager
.handle_change(
vibesql_storage::ChangeEvent::Insert {
table_name: "users".to_string(),
row_index: 0,
},
&db,
)
.await;
let update = rx.try_recv();
assert!(update.is_err()); }
#[test]
fn test_multiple_subscriptions_same_table() {
let manager = SubscriptionManager::new();
let (tx1, _rx1) = mpsc::channel(16);
let (tx2, _rx2) = mpsc::channel(16);
let _id1 = manager
.subscribe("SELECT * FROM users".to_string(), tx1)
.unwrap();
let _id2 = manager
.subscribe("SELECT * FROM users WHERE active = TRUE".to_string(), tx2)
.unwrap();
assert_eq!(manager.subscription_count(), 2);
let watched = manager.watched_tables();
let users_entry = watched.iter().find(|(t, _)| t == "users").unwrap();
assert_eq!(users_entry.1, 2);
}
}