use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::sync::broadcast::Receiver;
use tokio_stream::wrappers::BroadcastStream;
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use tokio_stream::{Stream, StreamExt as _};
use crate::types::{SessionEvent, SessionLifecycleEvent};
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
#[error("subscription lagged behind by {0} events")]
pub struct Lagged(u64);
impl Lagged {
pub fn skipped(&self) -> u64 {
self.0
}
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum RecvError {
#[error("subscription closed")]
Closed,
#[error(transparent)]
Lagged(#[from] Lagged),
}
macro_rules! define_subscription {
(
$(#[$meta:meta])*
$name:ident, $item:ty $(,)?
) => {
$(#[$meta])*
#[must_use = "subscriptions are inert until polled"]
pub struct $name {
inner: BroadcastStream<$item>,
}
impl $name {
pub(crate) fn new(rx: Receiver<$item>) -> Self {
Self {
inner: BroadcastStream::new(rx),
}
}
pub async fn recv(&mut self) -> Result<$item, RecvError> {
match self.inner.next().await {
Some(Ok(event)) => Ok(event),
Some(Err(BroadcastStreamRecvError::Lagged(n))) => {
Err(RecvError::Lagged(Lagged(n)))
}
None => Err(RecvError::Closed),
}
}
}
impl Stream for $name {
type Item = Result<$item, Lagged>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.inner).poll_next(cx) {
Poll::Ready(Some(Ok(event))) => Poll::Ready(Some(Ok(event))),
Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(n)))) => {
Poll::Ready(Some(Err(Lagged(n))))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
};
}
define_subscription! {
EventSubscription, SessionEvent
}
define_subscription! {
LifecycleSubscription, SessionLifecycleEvent
}
#[cfg(test)]
mod tests {
use tokio::sync::broadcast;
use super::*;
fn make_event(id: &str) -> SessionEvent {
SessionEvent {
id: id.into(),
timestamp: "2025-01-01T00:00:00Z".into(),
parent_id: None,
ephemeral: None,
agent_id: None,
debug_cli_received_at_ms: None,
debug_ws_forwarded_at_ms: None,
event_type: "noop".into(),
data: serde_json::json!({}),
}
}
#[tokio::test]
async fn recv_yields_then_closes_on_drop_sender() {
let (tx, rx) = broadcast::channel(8);
let mut sub = EventSubscription::new(rx);
tx.send(make_event("a")).unwrap();
tx.send(make_event("b")).unwrap();
drop(tx);
assert_eq!(sub.recv().await.unwrap().id, "a");
assert_eq!(sub.recv().await.unwrap().id, "b");
assert!(matches!(sub.recv().await, Err(RecvError::Closed)));
}
#[tokio::test]
async fn recv_surfaces_lag() {
let (tx, rx) = broadcast::channel(2);
let mut sub = EventSubscription::new(rx);
for id in ["a", "b", "c", "d"] {
tx.send(make_event(id)).unwrap();
}
match sub.recv().await {
Err(RecvError::Lagged(l)) => assert_eq!(l.skipped(), 2),
other => panic!("expected Lagged, got {other:?}"),
}
assert_eq!(sub.recv().await.unwrap().id, "c");
assert_eq!(sub.recv().await.unwrap().id, "d");
}
#[tokio::test]
async fn stream_impl_matches_recv_semantics() {
let (tx, rx) = broadcast::channel(8);
let mut sub = EventSubscription::new(rx);
tx.send(make_event("a")).unwrap();
drop(tx);
let next = sub.next().await;
assert_eq!(next.unwrap().unwrap().id, "a");
assert!(sub.next().await.is_none());
}
}