use std::sync::Arc;
use nodedb::control::change_stream::{ChangeStream, LiveSubscriptionSet, Subscription};
#[tokio::test]
async fn dropping_live_set_aborts_spawned_tasks_and_drops_subscriptions() {
let cs = Arc::new(ChangeStream::new(64));
assert_eq!(cs.subscriber_count(), 0);
{
let mut set = LiveSubscriptionSet::new();
for _ in 0..10 {
let sub: Subscription = cs.subscribe(Some("orders".into()), None);
set.spawn_forwarder(sub, |_event| { });
}
assert_eq!(
cs.subscriber_count(),
10,
"10 LIVE subscriptions registered while the connection is alive"
);
}
tokio::task::yield_now().await;
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
assert_eq!(
cs.subscriber_count(),
0,
"dropping the per-connection LiveSubscriptionSet must abort every \
spawned forwarder task so the Subscription's Drop runs and \
active_subscriptions returns to 0"
);
}
#[tokio::test]
async fn spawned_forwarder_observes_shutdown_abort() {
let cs = Arc::new(ChangeStream::new(64));
let mut set = LiveSubscriptionSet::new();
for _ in 0..5 {
let sub: Subscription = cs.subscribe(None, None);
set.spawn_forwarder(sub, |_| {});
}
assert_eq!(cs.subscriber_count(), 5);
set.abort_all();
tokio::task::yield_now().await;
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
assert_eq!(
cs.subscriber_count(),
0,
"abort_all must tear down every forwarder so no Subscription outlives shutdown"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn ws_disconnect_drops_active_subscriptions() {
use futures::{SinkExt, StreamExt};
use nodedb::bridge::dispatch::Dispatcher;
use nodedb::config::auth::AuthMode;
use nodedb::control::state::SharedState;
use nodedb::wal::WalManager;
use tokio_tungstenite::tungstenite::Message;
let dir = tempfile::tempdir().unwrap();
let wal = Arc::new(WalManager::open_for_testing(&dir.path().join("ws.wal")).unwrap());
let (dispatcher, _data_sides) = Dispatcher::new(1, 64);
let shared = SharedState::new(dispatcher, wal);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let (bus, _) = nodedb::control::shutdown::ShutdownBus::new(Arc::clone(&shared.shutdown));
let shared_http = Arc::clone(&shared);
let server_handle = tokio::spawn(async move {
nodedb::control::server::http::server::run_with_listener(
listener,
shared_http,
AuthMode::Trust,
None,
bus,
)
.await
.ok();
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert_eq!(shared.change_stream.subscriber_count(), 0);
let url = format!("ws://{local_addr}/v1/ws");
let (mut ws, _) = tokio_tungstenite::connect_async(&url).await.unwrap();
for i in 0..5 {
let sql = format!(r#"LIVE SELECT * FROM orders_{i}"#);
let req = serde_json::json!({
"id": i,
"method": "live",
"params": {"sql": sql},
})
.to_string();
ws.send(Message::Text(req.into())).await.unwrap();
let _ack = ws.next().await.unwrap().unwrap();
}
assert_eq!(
shared.change_stream.subscriber_count(),
5,
"server must have 5 active subscriptions while the WS client is connected"
);
drop(ws);
for _ in 0..20 {
if shared.change_stream.subscriber_count() == 0 {
break;
}
tokio::time::sleep(std::time::Duration::from_millis(25)).await;
}
assert_eq!(
shared.change_stream.subscriber_count(),
0,
"disconnect must abort every LIVE forwarder and return active_subscriptions to 0 — \
detached spawn would leak 5 subscriptions here"
);
server_handle.abort();
}