Skip to main content

chasm/api/
websocket.rs

1// Copyright (c) 2024-2026 Nervosys LLC
2// SPDX-License-Identifier: AGPL-3.0-only
3//! WebSocket Handler for bidirectional real-time communication
4//!
5//! This module provides WebSocket-based communication for scenarios requiring
6//! bidirectional messaging, such as live chat streaming, agent control, and
7//! collaborative editing.
8//!
9//! Note: For simpler use cases, the SSE-based sync (in sync.rs) may be preferred
10//! as it has better HTTP/2 compatibility and doesn't require connection upgrades.
11
12use actix_web::{web, Error, HttpRequest, HttpResponse};
13use futures_util::StreamExt;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::sync::RwLock;
17use std::time::{Duration, Instant};
18use tokio::sync::broadcast;
19use uuid::Uuid;
20
21// =============================================================================
22// WebSocket Configuration
23// =============================================================================
24
25const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
26const CLIENT_TIMEOUT: Duration = Duration::from_secs(60);
27
28// =============================================================================
29// WebSocket Message Types
30// =============================================================================
31
32/// Messages sent from client to server
33#[derive(Debug, Clone, Serialize, Deserialize)]
34#[serde(tag = "type", rename_all = "snake_case")]
35pub enum WsClientMessage {
36    /// Subscribe to a channel
37    Subscribe { channel: String },
38    /// Unsubscribe from a channel
39    Unsubscribe { channel: String },
40
41    /// Start streaming response for a session
42    StreamStart { session_id: String, model: String },
43    /// Cancel streaming for a session
44    StreamCancel { session_id: String },
45    /// Send input during streaming
46    StreamInput { session_id: String, content: String },
47
48    /// Send command to an agent
49    AgentCommand {
50        agent_id: String,
51        command: String,
52        params: Option<serde_json::Value>,
53    },
54
55    /// Request sync delta from version
56    SyncRequest { from_version: u64 },
57
58    /// Ping message for keepalive
59    Ping { timestamp: i64 },
60}
61
62/// Messages sent from server to client
63#[derive(Debug, Clone, Serialize, Deserialize)]
64#[serde(tag = "type", rename_all = "snake_case")]
65pub enum WsServerMessage {
66    /// Connection established
67    Connected { client_id: String, version: u64 },
68    /// Error occurred
69    Error { code: String, message: String },
70
71    /// Successfully subscribed to channel
72    Subscribed { channel: String },
73    /// Successfully unsubscribed from channel
74    Unsubscribed { channel: String },
75
76    /// Streaming token received
77    StreamToken { session_id: String, token: String },
78    /// Streaming completed
79    StreamComplete {
80        session_id: String,
81        message_id: String,
82    },
83    /// Streaming error occurred
84    StreamError { session_id: String, error: String },
85
86    /// Agent event received
87    AgentEvent {
88        agent_id: String,
89        event: String,
90        data: Option<serde_json::Value>,
91    },
92
93    /// Sync event for real-time updates
94    SyncEvent {
95        entity_type: String,
96        entity_id: String,
97        operation: String,
98        data: Option<serde_json::Value>,
99        version: u64,
100    },
101
102    /// Pong response to ping
103    Pong { timestamp: i64 },
104}
105
106// =============================================================================
107// WebSocket State Management
108// =============================================================================
109
110/// Information about a connected client
111#[derive(Debug, Clone)]
112pub struct ClientInfo {
113    pub id: String,
114    pub connected_at: Instant,
115    pub last_heartbeat: Instant,
116    pub subscriptions: Vec<String>,
117}
118
119/// Global WebSocket state shared across connections
120pub struct WebSocketState {
121    /// Broadcast channel for server-wide messages
122    pub broadcast_tx: broadcast::Sender<WsServerMessage>,
123    /// Per-channel broadcast senders
124    pub channel_senders: RwLock<HashMap<String, broadcast::Sender<WsServerMessage>>>,
125    /// Connected clients info
126    pub clients: RwLock<HashMap<String, ClientInfo>>,
127    /// Current sync version
128    pub version: std::sync::atomic::AtomicU64,
129}
130
131impl WebSocketState {
132    pub fn new() -> Self {
133        let (broadcast_tx, _) = broadcast::channel(1024);
134        Self {
135            broadcast_tx,
136            channel_senders: RwLock::new(HashMap::new()),
137            clients: RwLock::new(HashMap::new()),
138            version: std::sync::atomic::AtomicU64::new(1),
139        }
140    }
141
142    /// Get or create a channel sender
143    pub fn get_channel_sender(&self, channel: &str) -> broadcast::Sender<WsServerMessage> {
144        {
145            let channels = self.channel_senders.read().unwrap();
146            if let Some(sender) = channels.get(channel) {
147                return sender.clone();
148            }
149        }
150
151        let mut channels = self.channel_senders.write().unwrap();
152        let entry = channels
153            .entry(channel.to_string())
154            .or_insert_with(|| broadcast::channel(256).0);
155        entry.clone()
156    }
157
158    /// Broadcast to all clients
159    pub fn broadcast(&self, msg: WsServerMessage) {
160        let _ = self.broadcast_tx.send(msg);
161    }
162
163    /// Broadcast to a specific channel
164    pub fn broadcast_to_channel(&self, channel: &str, msg: WsServerMessage) {
165        let channels = self.channel_senders.read().unwrap();
166        if let Some(sender) = channels.get(channel) {
167            let _ = sender.send(msg);
168        }
169    }
170
171    /// Increment version and return new value
172    pub fn increment_version(&self) -> u64 {
173        self.version
174            .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
175            + 1
176    }
177
178    /// Get current version
179    pub fn current_version(&self) -> u64 {
180        self.version.load(std::sync::atomic::Ordering::Relaxed)
181    }
182
183    /// Register a new client
184    pub fn register_client(&self, id: &str) {
185        let mut clients = self.clients.write().unwrap();
186        clients.insert(
187            id.to_string(),
188            ClientInfo {
189                id: id.to_string(),
190                connected_at: Instant::now(),
191                last_heartbeat: Instant::now(),
192                subscriptions: Vec::new(),
193            },
194        );
195    }
196
197    /// Unregister a client
198    pub fn unregister_client(&self, id: &str) {
199        let mut clients = self.clients.write().unwrap();
200        clients.remove(id);
201    }
202
203    /// Get client count
204    pub fn client_count(&self) -> usize {
205        self.clients.read().unwrap().len()
206    }
207}
208
209impl Default for WebSocketState {
210    fn default() -> Self {
211        Self::new()
212    }
213}
214
215// =============================================================================
216// WebSocket Handler
217// =============================================================================
218
219/// Handle incoming WebSocket message
220fn handle_client_message(
221    client_id: &str,
222    msg: WsClientMessage,
223    state: &WebSocketState,
224) -> Option<WsServerMessage> {
225    match msg {
226        WsClientMessage::Subscribe { channel } => {
227            // Update client subscriptions
228            if let Ok(mut clients) = state.clients.write() {
229                if let Some(client) = clients.get_mut(client_id) {
230                    if !client.subscriptions.contains(&channel) {
231                        client.subscriptions.push(channel.clone());
232                    }
233                }
234            }
235            Some(WsServerMessage::Subscribed { channel })
236        }
237
238        WsClientMessage::Unsubscribe { channel } => {
239            // Update client subscriptions
240            if let Ok(mut clients) = state.clients.write() {
241                if let Some(client) = clients.get_mut(client_id) {
242                    client.subscriptions.retain(|c| c != &channel);
243                }
244            }
245            Some(WsServerMessage::Unsubscribed { channel })
246        }
247
248        WsClientMessage::Ping { timestamp } => Some(WsServerMessage::Pong { timestamp }),
249
250        WsClientMessage::StreamStart { session_id, model } => {
251            log::info!(
252                "Client {} requested stream start for {} with model {}",
253                client_id,
254                session_id,
255                model
256            );
257            // TODO: Implement streaming start
258            None
259        }
260
261        WsClientMessage::StreamCancel { session_id } => {
262            log::info!(
263                "Client {} requested stream cancel for {}",
264                client_id,
265                session_id
266            );
267            // TODO: Implement streaming cancel
268            None
269        }
270
271        WsClientMessage::StreamInput {
272            session_id,
273            content,
274        } => {
275            log::info!(
276                "Client {} sent input for {}: {} bytes",
277                client_id,
278                session_id,
279                content.len()
280            );
281            // TODO: Implement streaming input
282            None
283        }
284
285        WsClientMessage::AgentCommand {
286            agent_id,
287            command,
288            params,
289        } => {
290            log::info!(
291                "Client {} sent agent command {} to {}: {:?}",
292                client_id,
293                command,
294                agent_id,
295                params
296            );
297            // TODO: Implement agent commands
298            None
299        }
300
301        WsClientMessage::SyncRequest { from_version } => {
302            log::info!(
303                "Client {} requested sync from version {}",
304                client_id,
305                from_version
306            );
307            // TODO: Implement sync delta response
308            None
309        }
310    }
311}
312
313/// WebSocket endpoint handler using actix-ws
314pub async fn ws_handler(
315    req: HttpRequest,
316    body: web::Payload,
317    state: web::Data<WebSocketState>,
318) -> Result<HttpResponse, Error> {
319    // Perform WebSocket handshake
320    let (response, mut session, mut msg_stream) = actix_ws::handle(&req, body)?;
321
322    let client_id = Uuid::new_v4().to_string();
323    let state_clone = state.clone();
324
325    // Register client
326    state.register_client(&client_id);
327
328    // Send connected message
329    let connected_msg = WsServerMessage::Connected {
330        client_id: client_id.clone(),
331        version: state.current_version(),
332    };
333    if let Ok(json) = serde_json::to_string(&connected_msg) {
334        let _ = session.text(json).await;
335    }
336
337    log::info!("WebSocket client {} connected", client_id);
338
339    // Subscribe to broadcast channel
340    let mut broadcast_rx = state.broadcast_tx.subscribe();
341
342    // Spawn handler task
343    let client_id_clone = client_id.clone();
344    actix_web::rt::spawn(async move {
345        let mut heartbeat_interval = tokio::time::interval(HEARTBEAT_INTERVAL);
346        let mut last_heartbeat = Instant::now();
347
348        loop {
349            tokio::select! {
350                // Handle incoming messages
351                Some(msg_result) = msg_stream.next() => {
352                    match msg_result {
353                        Ok(actix_ws::Message::Text(text)) => {
354                            last_heartbeat = Instant::now();
355                            if let Ok(client_msg) = serde_json::from_str::<WsClientMessage>(&text) {
356                                if let Some(response) = handle_client_message(
357                                    &client_id_clone,
358                                    client_msg,
359                                    &state_clone,
360                                ) {
361                                    if let Ok(json) = serde_json::to_string(&response) {
362                                        let _ = session.text(json).await;
363                                    }
364                                }
365                            } else {
366                                let error_msg = WsServerMessage::Error {
367                                    code: "invalid_message".to_string(),
368                                    message: "Failed to parse message".to_string(),
369                                };
370                                if let Ok(json) = serde_json::to_string(&error_msg) {
371                                    let _ = session.text(json).await;
372                                }
373                            }
374                        }
375                        Ok(actix_ws::Message::Ping(data)) => {
376                            last_heartbeat = Instant::now();
377                            let _ = session.pong(&data).await;
378                        }
379                        Ok(actix_ws::Message::Pong(_)) => {
380                            last_heartbeat = Instant::now();
381                        }
382                        Ok(actix_ws::Message::Close(_)) => {
383                            log::info!("WebSocket client {} requested close", client_id_clone);
384                            break;
385                        }
386                        _ => {}
387                    }
388                }
389
390                // Handle broadcast messages
391                Ok(msg) = broadcast_rx.recv() => {
392                    if let Ok(json) = serde_json::to_string(&msg) {
393                        let _ = session.text(json).await;
394                    }
395                }
396
397                // Heartbeat check
398                _ = heartbeat_interval.tick() => {
399                    if Instant::now().duration_since(last_heartbeat) > CLIENT_TIMEOUT {
400                        log::warn!("WebSocket client {} timed out", client_id_clone);
401                        break;
402                    }
403                    let _ = session.ping(b"").await;
404                }
405            }
406        }
407
408        // Cleanup
409        state_clone.unregister_client(&client_id_clone);
410        let _ = session.close(None).await;
411        log::info!("WebSocket client {} disconnected", client_id_clone);
412    });
413
414    Ok(response)
415}
416
417/// Configure WebSocket routes
418pub fn configure_websocket_routes(cfg: &mut web::ServiceConfig, state: web::Data<WebSocketState>) {
419    cfg.app_data(state).route("/ws", web::get().to(ws_handler));
420}
421
422// =============================================================================
423// Helper Functions for Broadcasting
424// =============================================================================
425
426/// Broadcast a sync event to all clients
427pub fn broadcast_sync_event(
428    state: &WebSocketState,
429    entity_type: &str,
430    entity_id: &str,
431    operation: &str,
432    data: Option<serde_json::Value>,
433) {
434    let version = state.increment_version();
435    let msg = WsServerMessage::SyncEvent {
436        entity_type: entity_type.to_string(),
437        entity_id: entity_id.to_string(),
438        operation: operation.to_string(),
439        data,
440        version,
441    };
442    state.broadcast(msg);
443}
444
445/// Broadcast a stream token to a specific session channel
446pub fn broadcast_stream_token(state: &WebSocketState, session_id: &str, token: &str) {
447    let msg = WsServerMessage::StreamToken {
448        session_id: session_id.to_string(),
449        token: token.to_string(),
450    };
451    state.broadcast_to_channel(&format!("session:{}", session_id), msg);
452}
453
454/// Broadcast stream completion
455pub fn broadcast_stream_complete(state: &WebSocketState, session_id: &str, message_id: &str) {
456    let msg = WsServerMessage::StreamComplete {
457        session_id: session_id.to_string(),
458        message_id: message_id.to_string(),
459    };
460    state.broadcast_to_channel(&format!("session:{}", session_id), msg);
461}
462
463/// Broadcast an agent event
464pub fn broadcast_agent_event(
465    state: &WebSocketState,
466    agent_id: &str,
467    event: &str,
468    data: Option<serde_json::Value>,
469) {
470    let msg = WsServerMessage::AgentEvent {
471        agent_id: agent_id.to_string(),
472        event: event.to_string(),
473        data,
474    };
475    state.broadcast_to_channel(&format!("agent:{}", agent_id), msg);
476}