use std::convert::Infallible;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use axum::response::sse::{Event, KeepAlive, Sse};
use futures::Stream;
use tokio::sync::broadcast;
use tokio_stream::StreamExt;
use tokio_stream::wrappers::BroadcastStream;
use crate::channels::web::types::SseEvent;
const MAX_CONNECTIONS: u64 = 100;
pub struct SseManager {
tx: broadcast::Sender<SseEvent>,
connection_count: Arc<AtomicU64>,
max_connections: u64,
}
impl SseManager {
pub fn new() -> Self {
let (tx, _) = broadcast::channel(256);
Self {
tx,
connection_count: Arc::new(AtomicU64::new(0)),
max_connections: MAX_CONNECTIONS,
}
}
pub fn broadcast(&self, event: SseEvent) {
let _ = self.tx.send(event);
}
pub fn connection_count(&self) -> u64 {
self.connection_count.load(Ordering::Relaxed)
}
pub fn subscribe_raw(&self) -> Option<impl Stream<Item = SseEvent> + Send + 'static + use<>> {
let counter = Arc::clone(&self.connection_count);
let max = self.max_connections;
counter
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
if current < max {
Some(current + 1)
} else {
None
}
})
.ok()?;
let rx = self.tx.subscribe();
let stream = BroadcastStream::new(rx).filter_map(|result| result.ok());
Some(CountedStream {
inner: stream,
counter,
})
}
pub fn subscribe(
&self,
) -> Option<Sse<impl Stream<Item = Result<Event, Infallible>> + Send + 'static + use<>>> {
let counter = Arc::clone(&self.connection_count);
let max = self.max_connections;
counter
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
if current < max {
Some(current + 1)
} else {
None
}
})
.ok()?;
let rx = self.tx.subscribe();
let stream = BroadcastStream::new(rx)
.filter_map(|result| result.ok())
.map(|event| {
let data = serde_json::to_string(&event).unwrap_or_default();
let event_type = match &event {
SseEvent::Response { .. } => "response",
SseEvent::Thinking { .. } => "thinking",
SseEvent::ToolStarted { .. } => "tool_started",
SseEvent::ToolCompleted { .. } => "tool_completed",
SseEvent::ToolResult { .. } => "tool_result",
SseEvent::StreamChunk { .. } => "stream_chunk",
SseEvent::Status { .. } => "status",
SseEvent::ApprovalNeeded { .. } => "approval_needed",
SseEvent::AuthRequired { .. } => "auth_required",
SseEvent::AuthCompleted { .. } => "auth_completed",
SseEvent::Error { .. } => "error",
SseEvent::JobStarted { .. } => "job_started",
SseEvent::JobMessage { .. } => "job_message",
SseEvent::JobToolUse { .. } => "job_tool_use",
SseEvent::JobToolResult { .. } => "job_tool_result",
SseEvent::JobStatus { .. } => "job_status",
SseEvent::JobResult { .. } => "job_result",
SseEvent::Heartbeat => "heartbeat",
};
Ok(Event::default().event(event_type).data(data))
});
let counted_stream = CountedStream {
inner: stream,
counter,
};
Some(
Sse::new(counted_stream)
.keep_alive(KeepAlive::new().interval(Duration::from_secs(30)).text("")),
)
}
}
impl Default for SseManager {
fn default() -> Self {
Self::new()
}
}
struct CountedStream<S> {
inner: S,
counter: Arc<AtomicU64>,
}
impl<S: Stream + Unpin> Stream for CountedStream<S> {
type Item = S::Item;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
std::pin::Pin::new(&mut self.inner).poll_next(cx)
}
}
impl<S> Drop for CountedStream<S> {
fn drop(&mut self) {
self.counter.fetch_sub(1, Ordering::Relaxed);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sse_manager_creation() {
let manager = SseManager::new();
assert_eq!(manager.connection_count(), 0);
}
#[test]
fn test_broadcast_without_receivers() {
let manager = SseManager::new();
manager.broadcast(SseEvent::Heartbeat);
}
#[tokio::test]
async fn test_broadcast_to_receiver() {
let manager = SseManager::new();
let mut rx = BroadcastStream::new(manager.tx.subscribe());
manager.broadcast(SseEvent::Status {
message: "test".to_string(),
thread_id: None,
});
let event = rx.next().await;
assert!(event.is_some());
let event = event.unwrap().unwrap();
match event {
SseEvent::Status { message, .. } => assert_eq!(message, "test"),
_ => panic!("unexpected event type"),
}
}
#[tokio::test]
async fn test_subscribe_raw_receives_events() {
let manager = SseManager::new();
let mut stream = Box::pin(manager.subscribe_raw().expect("should subscribe"));
assert_eq!(manager.connection_count(), 1);
manager.broadcast(SseEvent::Thinking {
message: "working".to_string(),
thread_id: None,
});
let event = stream.next().await.unwrap();
match event {
SseEvent::Thinking { message, .. } => assert_eq!(message, "working"),
_ => panic!("Expected Thinking event"),
}
}
#[tokio::test]
async fn test_subscribe_raw_decrements_on_drop() {
let manager = SseManager::new();
{
let _stream = Box::pin(manager.subscribe_raw().expect("should subscribe"));
assert_eq!(manager.connection_count(), 1);
}
assert_eq!(manager.connection_count(), 0);
}
#[tokio::test]
async fn test_subscribe_raw_multiple_subscribers() {
let manager = SseManager::new();
let mut s1 = Box::pin(manager.subscribe_raw().expect("should subscribe"));
let mut s2 = Box::pin(manager.subscribe_raw().expect("should subscribe"));
assert_eq!(manager.connection_count(), 2);
manager.broadcast(SseEvent::Heartbeat);
let e1 = s1.next().await.unwrap();
let e2 = s2.next().await.unwrap();
assert!(matches!(e1, SseEvent::Heartbeat));
assert!(matches!(e2, SseEvent::Heartbeat));
drop(s1);
assert_eq!(manager.connection_count(), 1);
drop(s2);
assert_eq!(manager.connection_count(), 0);
}
#[tokio::test]
async fn test_subscribe_raw_rejects_over_limit() {
let mut manager = SseManager::new();
manager.max_connections = 2;
let _s1 = Box::pin(manager.subscribe_raw().expect("first should succeed"));
let _s2 = Box::pin(manager.subscribe_raw().expect("second should succeed"));
assert_eq!(manager.connection_count(), 2);
assert!(manager.subscribe_raw().is_none());
assert!(manager.subscribe().is_none());
}
}