systemprompt_events/services/
broadcaster.rs1use async_trait::async_trait;
2use axum::response::sse::{Event, KeepAlive};
3use std::collections::HashMap;
4use std::marker::PhantomData;
5use std::sync::Arc;
6use std::time::Duration;
7use systemprompt_identifiers::UserId;
8use tokio::sync::RwLock;
9
10use crate::{Broadcaster, EventSender, ToSse};
11
12pub const HEARTBEAT_JSON: &str = r#"{"type":"heartbeat"}"#;
13pub const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(15);
14
15pub fn standard_keep_alive() -> KeepAlive {
16 KeepAlive::new()
17 .interval(HEARTBEAT_INTERVAL)
18 .event(Event::default().event("heartbeat").data(HEARTBEAT_JSON))
19}
20
21pub struct GenericBroadcaster<E: ToSse + Clone + Send + Sync> {
22 connections: Arc<RwLock<HashMap<String, HashMap<String, EventSender>>>>,
23 _phantom: PhantomData<E>,
24}
25
26impl<E: ToSse + Clone + Send + Sync> std::fmt::Debug for GenericBroadcaster<E> {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 f.debug_struct("GenericBroadcaster")
29 .field("connections", &"<RwLock<HashMap>>")
30 .finish()
31 }
32}
33
34impl<E: ToSse + Clone + Send + Sync> GenericBroadcaster<E> {
35 pub fn new() -> Self {
36 Self {
37 connections: Arc::new(RwLock::new(HashMap::new())),
38 _phantom: PhantomData,
39 }
40 }
41
42 pub async fn connected_users(&self) -> Vec<String> {
43 let connections = self.connections.read().await;
44 connections.keys().cloned().collect()
45 }
46
47 pub async fn connection_info(&self) -> (usize, usize) {
48 let (user_count, conn_count) = {
49 let connections = self.connections.read().await;
50 (
51 connections.len(),
52 connections.values().map(HashMap::len).sum(),
53 )
54 };
55 (user_count, conn_count)
56 }
57}
58
59impl<E: ToSse + Clone + Send + Sync> Default for GenericBroadcaster<E> {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65#[async_trait]
66impl<E: ToSse + Clone + Send + Sync + 'static> Broadcaster for GenericBroadcaster<E> {
67 type Event = E;
68
69 #[allow(clippy::significant_drop_tightening)]
70 async fn register(&self, user_id: &UserId, connection_id: &str, sender: EventSender) {
71 let mut connections = self.connections.write().await;
72 let user_connections = connections.entry(user_id.to_string()).or_default();
73 user_connections.insert(connection_id.to_string(), sender);
74 }
75
76 async fn unregister(&self, user_id: &UserId, connection_id: &str) {
77 let mut connections = self.connections.write().await;
78 if let Some(user_connections) = connections.get_mut(user_id.as_str()) {
79 user_connections.remove(connection_id);
80 if user_connections.is_empty() {
81 connections.remove(user_id.as_str());
82 }
83 }
84 }
85
86 async fn broadcast(&self, user_id: &UserId, event: Self::Event) -> usize {
87 let sse_event: Event = match event.to_sse() {
88 Ok(e) => e,
89 Err(e) => {
90 tracing::error!(error = %e, event_type = ?std::any::type_name_of_val(&event), "Failed to serialize SSE event");
91 return 0;
92 },
93 };
94
95 let senders: Vec<(String, EventSender)> = {
96 let connections = self.connections.read().await;
97 match connections.get(user_id.as_str()) {
98 Some(user_connections) => user_connections
99 .iter()
100 .map(|(id, sender)| (id.clone(), sender.clone()))
101 .collect(),
102 None => return 0,
103 }
104 };
105
106 let mut successful = 0;
107 let mut failed_ids = Vec::new();
108
109 for (conn_id, sender) in senders {
110 if sender.send(Ok(sse_event.clone())).is_ok() {
111 successful += 1;
112 } else {
113 failed_ids.push(conn_id);
114 }
115 }
116
117 if !failed_ids.is_empty() {
118 let mut connections = self.connections.write().await;
119 if let Some(user_connections) = connections.get_mut(user_id.as_str()) {
120 for conn_id in &failed_ids {
121 user_connections.remove(conn_id);
122 }
123 if user_connections.is_empty() {
124 connections.remove(user_id.as_str());
125 }
126 }
127 }
128
129 successful
130 }
131
132 async fn connection_count(&self, user_id: &UserId) -> usize {
133 let connections = self.connections.read().await;
134 connections.get(user_id.as_str()).map_or(0, HashMap::len)
135 }
136
137 async fn total_connections(&self) -> usize {
138 let connections = self.connections.read().await;
139 connections.values().map(HashMap::len).sum()
140 }
141}
142
143use systemprompt_models::{A2AEvent, AgUiEvent, AnalyticsEvent, ContextEvent};
144
145pub type AgUiBroadcaster = GenericBroadcaster<AgUiEvent>;
146pub type A2ABroadcaster = GenericBroadcaster<A2AEvent>;
147pub type ContextBroadcaster = GenericBroadcaster<ContextEvent>;
148pub type AnalyticsBroadcaster = GenericBroadcaster<AnalyticsEvent>;
149
150pub struct ConnectionGuard<E: ToSse + Clone + Send + Sync + 'static> {
151 broadcaster: &'static std::sync::LazyLock<GenericBroadcaster<E>>,
152 user_id: UserId,
153 connection_id: String,
154}
155
156#[allow(clippy::missing_fields_in_debug)]
157impl<E: ToSse + Clone + Send + Sync + 'static> std::fmt::Debug for ConnectionGuard<E> {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 f.debug_struct("ConnectionGuard")
160 .field("user_id", &self.user_id)
161 .field("connection_id", &self.connection_id)
162 .finish_non_exhaustive()
163 }
164}
165
166impl<E: ToSse + Clone + Send + Sync + 'static> ConnectionGuard<E> {
167 pub fn new(
168 broadcaster: &'static std::sync::LazyLock<GenericBroadcaster<E>>,
169 user_id: UserId,
170 connection_id: String,
171 ) -> Self {
172 Self {
173 broadcaster,
174 user_id,
175 connection_id,
176 }
177 }
178}
179
180impl<E: ToSse + Clone + Send + Sync + 'static> Drop for ConnectionGuard<E> {
181 fn drop(&mut self) {
182 let broadcaster = self.broadcaster;
183 let user_id = self.user_id.clone();
184 let conn_id = self.connection_id.clone();
185
186 tokio::spawn(async move {
187 broadcaster.unregister(&user_id, &conn_id).await;
188 });
189 }
190}