use std::convert::Infallible;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use axum::response::sse::{Event, KeepAlive, Sse};
use futures::Stream;
use tokio::sync::broadcast;
use tokio_stream::StreamExt;
use tokio_stream::wrappers::BroadcastStream;
use crate::channels::web::types::AppEvent;
const MAX_CONNECTIONS: u64 = 100;
#[derive(Debug, Clone)]
pub(crate) struct ScopedEvent {
pub(crate) user_id: Option<String>,
pub(crate) event: AppEvent,
}
pub struct SseManager {
tx: broadcast::Sender<ScopedEvent>,
connection_count: Arc<AtomicU64>,
max_connections: u64,
}
impl SseManager {
pub fn new() -> Self {
let (tx, _) = broadcast::channel(256);
Self {
tx,
connection_count: Arc::new(AtomicU64::new(0)),
max_connections: MAX_CONNECTIONS,
}
}
pub(crate) fn from_sender(tx: broadcast::Sender<ScopedEvent>) -> Self {
Self {
tx,
connection_count: Arc::new(AtomicU64::new(0)),
max_connections: MAX_CONNECTIONS,
}
}
pub(crate) fn sender(&self) -> broadcast::Sender<ScopedEvent> {
self.tx.clone()
}
pub fn broadcast(&self, event: AppEvent) {
let _ = self.tx.send(ScopedEvent {
user_id: None,
event,
});
}
pub fn broadcast_for_user(&self, user_id: &str, event: AppEvent) {
let _ = self.tx.send(ScopedEvent {
user_id: Some(user_id.to_string()),
event,
});
}
pub fn connection_count(&self) -> u64 {
self.connection_count.load(Ordering::Relaxed)
}
pub fn subscribe_raw(
&self,
user_id: Option<String>,
) -> Option<impl Stream<Item = AppEvent> + Send + 'static + use<>> {
let counter = Arc::clone(&self.connection_count);
let max = self.max_connections;
counter
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
if current < max {
Some(current + 1)
} else {
None
}
})
.ok()?;
let rx = self.tx.subscribe();
let stream = BroadcastStream::new(rx).filter_map(move |result| match result {
Ok(scoped) => {
match (&user_id, &scoped.user_id) {
(_, None) => Some(scoped.event), (None, _) => Some(scoped.event), (Some(sub), Some(ev)) if sub == ev => Some(scoped.event), _ => None, }
}
Err(_) => None,
});
Some(CountedStream {
inner: stream,
counter,
})
}
pub fn subscribe(
&self,
user_id: Option<String>,
) -> Option<Sse<impl Stream<Item = Result<Event, Infallible>> + Send + 'static + use<>>> {
let counter = Arc::clone(&self.connection_count);
let max = self.max_connections;
counter
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
if current < max {
Some(current + 1)
} else {
None
}
})
.ok()?;
let rx = self.tx.subscribe();
let stream = BroadcastStream::new(rx)
.filter_map(move |result| match result {
Ok(scoped) => match (&user_id, &scoped.user_id) {
(_, None) => Some(scoped.event),
(None, _) => Some(scoped.event),
(Some(sub), Some(ev)) if sub == ev => Some(scoped.event),
_ => None,
},
Err(_) => None,
})
.filter_map(|event| {
let data = match serde_json::to_string(&event) {
Ok(s) => s,
Err(e) => {
tracing::warn!("Failed to serialize SSE event: {}", e);
return None;
}
};
let event_type = event.event_type();
Some(Ok(Event::default().event(event_type).data(data)))
});
let counted_stream = CountedStream {
inner: stream,
counter,
};
Some(
Sse::new(counted_stream)
.keep_alive(KeepAlive::new().interval(Duration::from_secs(30)).text("")),
)
}
}
impl Default for SseManager {
fn default() -> Self {
Self::new()
}
}
struct CountedStream<S> {
inner: S,
counter: Arc<AtomicU64>,
}
impl<S: Stream + Unpin> Stream for CountedStream<S> {
type Item = S::Item;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
std::pin::Pin::new(&mut self.inner).poll_next(cx)
}
}
impl<S> Drop for CountedStream<S> {
fn drop(&mut self) {
self.counter.fetch_sub(1, Ordering::Relaxed);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sse_manager_creation() {
let manager = SseManager::new();
assert_eq!(manager.connection_count(), 0);
}
#[test]
fn test_broadcast_without_receivers() {
let manager = SseManager::new();
manager.broadcast(AppEvent::Heartbeat);
}
#[tokio::test]
async fn test_broadcast_to_receiver() {
let manager = SseManager::new();
let mut stream = Box::pin(manager.subscribe_raw(None).expect("should subscribe"));
manager.broadcast(AppEvent::Status {
message: "test".to_string(),
thread_id: None,
});
let event = stream.next().await.unwrap();
match event {
AppEvent::Status { message, .. } => assert_eq!(message, "test"),
_ => panic!("unexpected event type"),
}
}
#[tokio::test]
async fn test_subscribe_raw_receives_events() {
let manager = SseManager::new();
let mut stream = Box::pin(manager.subscribe_raw(None).expect("should subscribe"));
assert_eq!(manager.connection_count(), 1);
manager.broadcast(AppEvent::Thinking {
message: "working".to_string(),
thread_id: None,
});
let event = stream.next().await.unwrap();
match event {
AppEvent::Thinking { message, .. } => assert_eq!(message, "working"),
_ => panic!("Expected Thinking event"),
}
}
#[tokio::test]
async fn test_subscribe_raw_decrements_on_drop() {
let manager = SseManager::new();
{
let _stream = Box::pin(manager.subscribe_raw(None).expect("should subscribe"));
assert_eq!(manager.connection_count(), 1);
}
assert_eq!(manager.connection_count(), 0);
}
#[tokio::test]
async fn test_subscribe_raw_multiple_subscribers() {
let manager = SseManager::new();
let mut s1 = Box::pin(manager.subscribe_raw(None).expect("should subscribe"));
let mut s2 = Box::pin(manager.subscribe_raw(None).expect("should subscribe"));
assert_eq!(manager.connection_count(), 2);
manager.broadcast(AppEvent::Heartbeat);
let e1 = s1.next().await.unwrap();
let e2 = s2.next().await.unwrap();
assert!(matches!(e1, AppEvent::Heartbeat));
assert!(matches!(e2, AppEvent::Heartbeat));
drop(s1);
assert_eq!(manager.connection_count(), 1);
drop(s2);
assert_eq!(manager.connection_count(), 0);
}
#[tokio::test]
async fn test_subscribe_raw_rejects_over_limit() {
let mut manager = SseManager::new();
manager.max_connections = 2;
let _s1 = Box::pin(manager.subscribe_raw(None).expect("first should succeed"));
let _s2 = Box::pin(manager.subscribe_raw(None).expect("second should succeed"));
assert_eq!(manager.connection_count(), 2);
assert!(manager.subscribe_raw(None).is_none());
assert!(manager.subscribe(None).is_none());
}
#[tokio::test]
async fn test_scoped_events_filtered_by_user() {
let manager = SseManager::new();
let mut alice = Box::pin(
manager
.subscribe_raw(Some("alice".to_string()))
.expect("subscribe"),
);
let mut bob = Box::pin(
manager
.subscribe_raw(Some("bob".to_string()))
.expect("subscribe"),
);
manager.broadcast_for_user(
"alice",
AppEvent::Status {
message: "alice only".to_string(),
thread_id: None,
},
);
manager.broadcast(AppEvent::Heartbeat);
let e = alice.next().await.unwrap();
assert!(matches!(e, AppEvent::Status { .. }));
let e = alice.next().await.unwrap();
assert!(matches!(e, AppEvent::Heartbeat));
let e = bob.next().await.unwrap(); assert!(matches!(e, AppEvent::Heartbeat)); }
}