Skip to main content

systemprompt_events/services/
broadcaster.rs

1use async_trait::async_trait;
2use axum::response::sse::{Event, KeepAlive};
3use std::collections::HashMap;
4use std::marker::PhantomData;
5use std::sync::Arc;
6use std::time::Duration;
7use systemprompt_identifiers::UserId;
8use tokio::sync::RwLock;
9
10use crate::{Broadcaster, EventSender, ToSse};
11
12pub const HEARTBEAT_JSON: &str = r#"{"type":"heartbeat"}"#;
13pub const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(15);
14
15pub fn standard_keep_alive() -> KeepAlive {
16    KeepAlive::new()
17        .interval(HEARTBEAT_INTERVAL)
18        .event(Event::default().event("heartbeat").data(HEARTBEAT_JSON))
19}
20
21pub struct GenericBroadcaster<E: ToSse + Clone + Send + Sync> {
22    connections: Arc<RwLock<HashMap<String, HashMap<String, EventSender>>>>,
23    _phantom: PhantomData<E>,
24}
25
26impl<E: ToSse + Clone + Send + Sync> std::fmt::Debug for GenericBroadcaster<E> {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        f.debug_struct("GenericBroadcaster")
29            .field("connections", &"<RwLock<HashMap>>")
30            .finish()
31    }
32}
33
34impl<E: ToSse + Clone + Send + Sync> GenericBroadcaster<E> {
35    pub fn new() -> Self {
36        Self {
37            connections: Arc::new(RwLock::new(HashMap::new())),
38            _phantom: PhantomData,
39        }
40    }
41
42    pub async fn connected_users(&self) -> Vec<String> {
43        let connections = self.connections.read().await;
44        connections.keys().cloned().collect()
45    }
46
47    pub async fn connection_info(&self) -> (usize, usize) {
48        let (user_count, conn_count) = {
49            let connections = self.connections.read().await;
50            (
51                connections.len(),
52                connections.values().map(HashMap::len).sum(),
53            )
54        };
55        (user_count, conn_count)
56    }
57}
58
59impl<E: ToSse + Clone + Send + Sync> Default for GenericBroadcaster<E> {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65#[async_trait]
66impl<E: ToSse + Clone + Send + Sync + 'static> Broadcaster for GenericBroadcaster<E> {
67    type Event = E;
68
69    #[allow(clippy::significant_drop_tightening)]
70    async fn register(&self, user_id: &UserId, connection_id: &str, sender: EventSender) {
71        let mut connections = self.connections.write().await;
72        let user_connections = connections.entry(user_id.to_string()).or_default();
73        user_connections.insert(connection_id.to_string(), sender);
74    }
75
76    async fn unregister(&self, user_id: &UserId, connection_id: &str) {
77        let mut connections = self.connections.write().await;
78        if let Some(user_connections) = connections.get_mut(user_id.as_str()) {
79            user_connections.remove(connection_id);
80            if user_connections.is_empty() {
81                connections.remove(user_id.as_str());
82            }
83        }
84    }
85
86    async fn broadcast(&self, user_id: &UserId, event: Self::Event) -> usize {
87        let sse_event: Event = match event.to_sse() {
88            Ok(e) => e,
89            Err(e) => {
90                tracing::error!(error = %e, event_type = ?std::any::type_name_of_val(&event), "Failed to serialize SSE event");
91                return 0;
92            },
93        };
94
95        let senders: Vec<(String, EventSender)> = {
96            let connections = self.connections.read().await;
97            match connections.get(user_id.as_str()) {
98                Some(user_connections) => user_connections
99                    .iter()
100                    .map(|(id, sender)| (id.clone(), sender.clone()))
101                    .collect(),
102                None => return 0,
103            }
104        };
105
106        let mut successful = 0;
107        let mut failed_ids = Vec::new();
108
109        for (conn_id, sender) in senders {
110            if sender.send(Ok(sse_event.clone())).is_ok() {
111                successful += 1;
112            } else {
113                failed_ids.push(conn_id);
114            }
115        }
116
117        if !failed_ids.is_empty() {
118            let mut connections = self.connections.write().await;
119            if let Some(user_connections) = connections.get_mut(user_id.as_str()) {
120                for conn_id in &failed_ids {
121                    user_connections.remove(conn_id);
122                }
123                if user_connections.is_empty() {
124                    connections.remove(user_id.as_str());
125                }
126            }
127        }
128
129        successful
130    }
131
132    async fn connection_count(&self, user_id: &UserId) -> usize {
133        let connections = self.connections.read().await;
134        connections.get(user_id.as_str()).map_or(0, HashMap::len)
135    }
136
137    async fn total_connections(&self) -> usize {
138        let connections = self.connections.read().await;
139        connections.values().map(HashMap::len).sum()
140    }
141}
142
143use systemprompt_models::{A2AEvent, AgUiEvent, AnalyticsEvent, ContextEvent};
144
145pub type AgUiBroadcaster = GenericBroadcaster<AgUiEvent>;
146pub type A2ABroadcaster = GenericBroadcaster<A2AEvent>;
147pub type ContextBroadcaster = GenericBroadcaster<ContextEvent>;
148pub type AnalyticsBroadcaster = GenericBroadcaster<AnalyticsEvent>;
149
150pub struct ConnectionGuard<E: ToSse + Clone + Send + Sync + 'static> {
151    broadcaster: &'static std::sync::LazyLock<GenericBroadcaster<E>>,
152    user_id: UserId,
153    connection_id: String,
154}
155
156#[allow(clippy::missing_fields_in_debug)]
157impl<E: ToSse + Clone + Send + Sync + 'static> std::fmt::Debug for ConnectionGuard<E> {
158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159        f.debug_struct("ConnectionGuard")
160            .field("user_id", &self.user_id)
161            .field("connection_id", &self.connection_id)
162            .finish_non_exhaustive()
163    }
164}
165
166impl<E: ToSse + Clone + Send + Sync + 'static> ConnectionGuard<E> {
167    pub fn new(
168        broadcaster: &'static std::sync::LazyLock<GenericBroadcaster<E>>,
169        user_id: UserId,
170        connection_id: String,
171    ) -> Self {
172        Self {
173            broadcaster,
174            user_id,
175            connection_id,
176        }
177    }
178}
179
180impl<E: ToSse + Clone + Send + Sync + 'static> Drop for ConnectionGuard<E> {
181    fn drop(&mut self) {
182        let broadcaster = self.broadcaster;
183        let user_id = self.user_id.clone();
184        let conn_id = self.connection_id.clone();
185
186        tokio::spawn(async move {
187            broadcaster.unregister(&user_id, &conn_id).await;
188        });
189    }
190}