1use actix_web::{web, Error, HttpRequest, HttpResponse};
13use futures_util::StreamExt;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::sync::RwLock;
17use std::time::{Duration, Instant};
18use tokio::sync::broadcast;
19use uuid::Uuid;
20
21const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
26const CLIENT_TIMEOUT: Duration = Duration::from_secs(60);
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
34#[serde(tag = "type", rename_all = "snake_case")]
35pub enum WsClientMessage {
36 Subscribe { channel: String },
38 Unsubscribe { channel: String },
40
41 StreamStart { session_id: String, model: String },
43 StreamCancel { session_id: String },
45 StreamInput { session_id: String, content: String },
47
48 AgentCommand {
50 agent_id: String,
51 command: String,
52 params: Option<serde_json::Value>,
53 },
54
55 SyncRequest { from_version: u64 },
57
58 Ping { timestamp: i64 },
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64#[serde(tag = "type", rename_all = "snake_case")]
65pub enum WsServerMessage {
66 Connected { client_id: String, version: u64 },
68 Error { code: String, message: String },
70
71 Subscribed { channel: String },
73 Unsubscribed { channel: String },
75
76 StreamToken { session_id: String, token: String },
78 StreamComplete {
80 session_id: String,
81 message_id: String,
82 },
83 StreamError { session_id: String, error: String },
85
86 AgentEvent {
88 agent_id: String,
89 event: String,
90 data: Option<serde_json::Value>,
91 },
92
93 SyncEvent {
95 entity_type: String,
96 entity_id: String,
97 operation: String,
98 data: Option<serde_json::Value>,
99 version: u64,
100 },
101
102 Pong { timestamp: i64 },
104}
105
106#[derive(Debug, Clone)]
112pub struct ClientInfo {
113 pub id: String,
114 pub connected_at: Instant,
115 pub last_heartbeat: Instant,
116 pub subscriptions: Vec<String>,
117}
118
119pub struct WebSocketState {
121 pub broadcast_tx: broadcast::Sender<WsServerMessage>,
123 pub channel_senders: RwLock<HashMap<String, broadcast::Sender<WsServerMessage>>>,
125 pub clients: RwLock<HashMap<String, ClientInfo>>,
127 pub version: std::sync::atomic::AtomicU64,
129}
130
131impl WebSocketState {
132 pub fn new() -> Self {
133 let (broadcast_tx, _) = broadcast::channel(1024);
134 Self {
135 broadcast_tx,
136 channel_senders: RwLock::new(HashMap::new()),
137 clients: RwLock::new(HashMap::new()),
138 version: std::sync::atomic::AtomicU64::new(1),
139 }
140 }
141
142 pub fn get_channel_sender(&self, channel: &str) -> broadcast::Sender<WsServerMessage> {
144 {
145 let channels = self.channel_senders.read().unwrap();
146 if let Some(sender) = channels.get(channel) {
147 return sender.clone();
148 }
149 }
150
151 let mut channels = self.channel_senders.write().unwrap();
152 let entry = channels
153 .entry(channel.to_string())
154 .or_insert_with(|| broadcast::channel(256).0);
155 entry.clone()
156 }
157
158 pub fn broadcast(&self, msg: WsServerMessage) {
160 let _ = self.broadcast_tx.send(msg);
161 }
162
163 pub fn broadcast_to_channel(&self, channel: &str, msg: WsServerMessage) {
165 let channels = self.channel_senders.read().unwrap();
166 if let Some(sender) = channels.get(channel) {
167 let _ = sender.send(msg);
168 }
169 }
170
171 pub fn increment_version(&self) -> u64 {
173 self.version
174 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
175 + 1
176 }
177
178 pub fn current_version(&self) -> u64 {
180 self.version.load(std::sync::atomic::Ordering::Relaxed)
181 }
182
183 pub fn register_client(&self, id: &str) {
185 let mut clients = self.clients.write().unwrap();
186 clients.insert(
187 id.to_string(),
188 ClientInfo {
189 id: id.to_string(),
190 connected_at: Instant::now(),
191 last_heartbeat: Instant::now(),
192 subscriptions: Vec::new(),
193 },
194 );
195 }
196
197 pub fn unregister_client(&self, id: &str) {
199 let mut clients = self.clients.write().unwrap();
200 clients.remove(id);
201 }
202
203 pub fn client_count(&self) -> usize {
205 self.clients.read().unwrap().len()
206 }
207}
208
209impl Default for WebSocketState {
210 fn default() -> Self {
211 Self::new()
212 }
213}
214
215fn handle_client_message(
221 client_id: &str,
222 msg: WsClientMessage,
223 state: &WebSocketState,
224) -> Option<WsServerMessage> {
225 match msg {
226 WsClientMessage::Subscribe { channel } => {
227 if let Ok(mut clients) = state.clients.write() {
229 if let Some(client) = clients.get_mut(client_id) {
230 if !client.subscriptions.contains(&channel) {
231 client.subscriptions.push(channel.clone());
232 }
233 }
234 }
235 Some(WsServerMessage::Subscribed { channel })
236 }
237
238 WsClientMessage::Unsubscribe { channel } => {
239 if let Ok(mut clients) = state.clients.write() {
241 if let Some(client) = clients.get_mut(client_id) {
242 client.subscriptions.retain(|c| c != &channel);
243 }
244 }
245 Some(WsServerMessage::Unsubscribed { channel })
246 }
247
248 WsClientMessage::Ping { timestamp } => Some(WsServerMessage::Pong { timestamp }),
249
250 WsClientMessage::StreamStart { session_id, model } => {
251 log::info!(
252 "Client {} requested stream start for {} with model {}",
253 client_id,
254 session_id,
255 model
256 );
257 None
259 }
260
261 WsClientMessage::StreamCancel { session_id } => {
262 log::info!(
263 "Client {} requested stream cancel for {}",
264 client_id,
265 session_id
266 );
267 None
269 }
270
271 WsClientMessage::StreamInput {
272 session_id,
273 content,
274 } => {
275 log::info!(
276 "Client {} sent input for {}: {} bytes",
277 client_id,
278 session_id,
279 content.len()
280 );
281 None
283 }
284
285 WsClientMessage::AgentCommand {
286 agent_id,
287 command,
288 params,
289 } => {
290 log::info!(
291 "Client {} sent agent command {} to {}: {:?}",
292 client_id,
293 command,
294 agent_id,
295 params
296 );
297 None
299 }
300
301 WsClientMessage::SyncRequest { from_version } => {
302 log::info!(
303 "Client {} requested sync from version {}",
304 client_id,
305 from_version
306 );
307 None
309 }
310 }
311}
312
313pub async fn ws_handler(
315 req: HttpRequest,
316 body: web::Payload,
317 state: web::Data<WebSocketState>,
318) -> Result<HttpResponse, Error> {
319 let (response, mut session, mut msg_stream) = actix_ws::handle(&req, body)?;
321
322 let client_id = Uuid::new_v4().to_string();
323 let state_clone = state.clone();
324
325 state.register_client(&client_id);
327
328 let connected_msg = WsServerMessage::Connected {
330 client_id: client_id.clone(),
331 version: state.current_version(),
332 };
333 if let Ok(json) = serde_json::to_string(&connected_msg) {
334 let _ = session.text(json).await;
335 }
336
337 log::info!("WebSocket client {} connected", client_id);
338
339 let mut broadcast_rx = state.broadcast_tx.subscribe();
341
342 let client_id_clone = client_id.clone();
344 actix_web::rt::spawn(async move {
345 let mut heartbeat_interval = tokio::time::interval(HEARTBEAT_INTERVAL);
346 let mut last_heartbeat = Instant::now();
347
348 loop {
349 tokio::select! {
350 Some(msg_result) = msg_stream.next() => {
352 match msg_result {
353 Ok(actix_ws::Message::Text(text)) => {
354 last_heartbeat = Instant::now();
355 if let Ok(client_msg) = serde_json::from_str::<WsClientMessage>(&text) {
356 if let Some(response) = handle_client_message(
357 &client_id_clone,
358 client_msg,
359 &state_clone,
360 ) {
361 if let Ok(json) = serde_json::to_string(&response) {
362 let _ = session.text(json).await;
363 }
364 }
365 } else {
366 let error_msg = WsServerMessage::Error {
367 code: "invalid_message".to_string(),
368 message: "Failed to parse message".to_string(),
369 };
370 if let Ok(json) = serde_json::to_string(&error_msg) {
371 let _ = session.text(json).await;
372 }
373 }
374 }
375 Ok(actix_ws::Message::Ping(data)) => {
376 last_heartbeat = Instant::now();
377 let _ = session.pong(&data).await;
378 }
379 Ok(actix_ws::Message::Pong(_)) => {
380 last_heartbeat = Instant::now();
381 }
382 Ok(actix_ws::Message::Close(_)) => {
383 log::info!("WebSocket client {} requested close", client_id_clone);
384 break;
385 }
386 _ => {}
387 }
388 }
389
390 Ok(msg) = broadcast_rx.recv() => {
392 if let Ok(json) = serde_json::to_string(&msg) {
393 let _ = session.text(json).await;
394 }
395 }
396
397 _ = heartbeat_interval.tick() => {
399 if Instant::now().duration_since(last_heartbeat) > CLIENT_TIMEOUT {
400 log::warn!("WebSocket client {} timed out", client_id_clone);
401 break;
402 }
403 let _ = session.ping(b"").await;
404 }
405 }
406 }
407
408 state_clone.unregister_client(&client_id_clone);
410 let _ = session.close(None).await;
411 log::info!("WebSocket client {} disconnected", client_id_clone);
412 });
413
414 Ok(response)
415}
416
417pub fn configure_websocket_routes(cfg: &mut web::ServiceConfig, state: web::Data<WebSocketState>) {
419 cfg.app_data(state).route("/ws", web::get().to(ws_handler));
420}
421
422pub fn broadcast_sync_event(
428 state: &WebSocketState,
429 entity_type: &str,
430 entity_id: &str,
431 operation: &str,
432 data: Option<serde_json::Value>,
433) {
434 let version = state.increment_version();
435 let msg = WsServerMessage::SyncEvent {
436 entity_type: entity_type.to_string(),
437 entity_id: entity_id.to_string(),
438 operation: operation.to_string(),
439 data,
440 version,
441 };
442 state.broadcast(msg);
443}
444
445pub fn broadcast_stream_token(state: &WebSocketState, session_id: &str, token: &str) {
447 let msg = WsServerMessage::StreamToken {
448 session_id: session_id.to_string(),
449 token: token.to_string(),
450 };
451 state.broadcast_to_channel(&format!("session:{}", session_id), msg);
452}
453
454pub fn broadcast_stream_complete(state: &WebSocketState, session_id: &str, message_id: &str) {
456 let msg = WsServerMessage::StreamComplete {
457 session_id: session_id.to_string(),
458 message_id: message_id.to_string(),
459 };
460 state.broadcast_to_channel(&format!("session:{}", session_id), msg);
461}
462
463pub fn broadcast_agent_event(
465 state: &WebSocketState,
466 agent_id: &str,
467 event: &str,
468 data: Option<serde_json::Value>,
469) {
470 let msg = WsServerMessage::AgentEvent {
471 agent_id: agent_id.to_string(),
472 event: event.to_string(),
473 data,
474 };
475 state.broadcast_to_channel(&format!("agent:{}", agent_id), msg);
476}