use std::collections::VecDeque;
use tokio::sync::mpsc;
use crate::{
connection::SharedSubscriptionControl,
error::Result,
models::ChangeEvent,
seq_tracking,
subscription::{buffer_event, event_progress},
timeouts::KalamLinkTimeouts,
SeqId,
};
pub struct SubscriptionManager {
subscription_id: String,
event_rx: mpsc::Receiver<Result<ChangeEvent>>,
shared_control: Option<SharedSubscriptionControl>,
generation: u64,
event_queue: VecDeque<ChangeEvent>,
buffered_changes: Vec<ChangeEvent>,
is_loading: bool,
resume_from: Option<SeqId>,
timeouts: KalamLinkTimeouts,
closed: bool,
}
impl SubscriptionManager {
pub(crate) fn from_shared(
subscription_id: String,
event_rx: mpsc::Receiver<Result<ChangeEvent>>,
shared_control: SharedSubscriptionControl,
generation: u64,
resume_from: Option<SeqId>,
timeouts: &KalamLinkTimeouts,
) -> Self {
Self {
subscription_id,
event_rx,
shared_control: Some(shared_control),
generation,
event_queue: VecDeque::new(),
buffered_changes: Vec::new(),
is_loading: true,
resume_from,
timeouts: timeouts.clone(),
closed: false,
}
}
async fn report_shared_progress(&mut self, event: &ChangeEvent) {
let Some(progress) = event_progress(event) else {
return;
};
seq_tracking::advance_seq(&mut self.resume_from, progress.seq_id);
let Some(shared_control) = self.shared_control.as_ref() else {
return;
};
shared_control
.progress(
self.subscription_id.clone(),
self.generation,
progress.seq_id,
progress.advance_resume,
)
.await;
}
fn apply_buffering(&mut self, event: ChangeEvent) {
buffer_event(
&mut self.event_queue,
&mut self.buffered_changes,
&mut self.is_loading,
self.resume_from,
event,
);
}
pub async fn next(&mut self) -> Option<Result<ChangeEvent>> {
loop {
if let Some(event) = self.event_queue.pop_front() {
self.report_shared_progress(&event).await;
return Some(Ok(event));
}
if self.closed {
return None;
}
match self.event_rx.recv().await {
Some(Ok(event)) => {
self.apply_buffering(event);
},
Some(Err(e)) => return Some(Err(e)),
None => {
self.closed = true;
return None;
},
}
}
}
pub fn subscription_id(&self) -> &str {
&self.subscription_id
}
pub fn timeouts(&self) -> &KalamLinkTimeouts {
&self.timeouts
}
pub async fn close(&mut self) -> Result<()> {
if self.closed {
return Ok(());
}
self.closed = true;
if let Some(shared_control) = self.shared_control.take() {
shared_control.unsubscribe(self.subscription_id.clone(), self.generation).await;
}
Ok(())
}
pub fn is_closed(&self) -> bool {
self.closed
}
}
impl Drop for SubscriptionManager {
fn drop(&mut self) {
if let Some(shared_control) = self.shared_control.take() {
shared_control.try_unsubscribe(self.subscription_id.clone(), self.generation);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_sub() -> SubscriptionManager {
let (event_tx, event_rx) = mpsc::channel(1);
drop(event_tx);
let mut subscription = SubscriptionManager::from_shared(
"unit-test-id".to_string(),
event_rx,
SharedSubscriptionControl::test_control(),
0,
None,
&KalamLinkTimeouts::default(),
);
subscription.is_loading = false;
subscription
}
#[tokio::test]
async fn test_is_not_closed_initially() {
let sub = make_test_sub();
assert!(!sub.is_closed(), "subscription should start as open");
}
#[tokio::test]
async fn test_close_marks_subscription_as_closed() {
let mut sub = make_test_sub();
assert!(!sub.is_closed());
sub.close().await.expect("close should succeed on a stream-less sub");
assert!(sub.is_closed(), "subscription should be closed after close()");
}
#[tokio::test]
async fn test_close_is_idempotent() {
let mut sub = make_test_sub();
sub.close().await.expect("first close should succeed");
sub.close().await.expect("second close should also succeed (no-op)");
assert!(sub.is_closed());
}
#[tokio::test]
async fn test_next_returns_none_when_stream_is_none() {
let mut sub = make_test_sub();
let result = tokio::time::timeout(std::time::Duration::from_millis(100), sub.next())
.await
.expect("next() should complete quickly when stream is None");
assert!(result.is_none(), "next() should return None when stream is None");
}
#[tokio::test]
async fn test_next_returns_none_after_close() {
let mut sub = make_test_sub();
sub.close().await.unwrap();
let result = tokio::time::timeout(std::time::Duration::from_millis(100), sub.next())
.await
.expect("next() should complete quickly after close");
assert!(result.is_none());
}
#[test]
fn test_drop_without_runtime_does_not_panic() {
let sub = make_test_sub();
drop(sub);
}
#[tokio::test]
async fn test_consumed_initial_batch_advances_local_replay_filter() {
let mut sub = make_test_sub();
let event = ChangeEvent::InitialDataBatch {
subscription_id: "unit-test-id".to_string(),
rows: vec![{
let mut row = std::collections::HashMap::new();
row.insert("id".to_string(), crate::models::KalamCellValue::text("seed"));
row.insert("_seq".to_string(), crate::models::KalamCellValue::text("10"));
row
}],
batch_control: crate::models::BatchControl {
batch_num: 0,
has_more: true,
status: crate::models::BatchStatus::Loading,
last_seq_id: Some(SeqId::from_i64(10)),
},
};
sub.report_shared_progress(&event).await;
sub.apply_buffering(ChangeEvent::Insert {
subscription_id: "unit-test-id".to_string(),
rows: vec![{
let mut row = std::collections::HashMap::new();
row.insert("id".to_string(), crate::models::KalamCellValue::text("seed"));
row.insert("_seq".to_string(), crate::models::KalamCellValue::text("10"));
row
}],
});
assert!(sub.event_queue.is_empty());
assert!(sub.buffered_changes.is_empty());
}
#[tokio::test]
async fn test_drop_inside_runtime_does_not_panic() {
let sub = make_test_sub();
drop(sub);
tokio::task::yield_now().await;
}
}