Skip to main content

forge_reasoning/
websocket.rs

1//! WebSocket API server for checkpointing
2//!
3//! Provides real-time remote access to checkpoint operations
4
5use std::collections::HashMap;
6use std::net::SocketAddr;
7use std::sync::Arc;
8
9use futures_util::{SinkExt, StreamExt};
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::{broadcast, mpsc, RwLock};
12use tokio_tungstenite::{accept_async, tungstenite::Message};
13
14use crate::errors::{ReasoningError, Result};
15use crate::service::CheckpointEvent;
16use crate::service::CheckpointService;
17use crate::SessionId;
18
19/// WebSocket server configuration
20#[derive(Clone, Debug)]
21pub struct WebSocketConfig {
22    pub require_auth: bool,
23    pub auth_token: Option<String>,
24    pub max_connections: usize,
25}
26
27impl Default for WebSocketConfig {
28    fn default() -> Self {
29        Self {
30            require_auth: false,
31            auth_token: None,
32            max_connections: 100,
33        }
34    }
35}
36
37/// WebSocket command from client
38#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
39pub struct WebSocketCommand {
40    pub id: String,
41    pub method: String,
42    pub params: serde_json::Value,
43}
44
45/// WebSocket response to client
46#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
47pub struct WebSocketResponse {
48    pub id: String,
49    pub success: bool,
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub result: Option<serde_json::Value>,
52    #[serde(skip_serializing_if = "Option::is_none")]
53    pub error: Option<String>,
54}
55
56impl WebSocketResponse {
57    pub fn success(id: String, result: impl serde::Serialize) -> Self {
58        Self {
59            id,
60            success: true,
61            result: serde_json::to_value(result).ok(),
62            error: None,
63        }
64    }
65
66    pub fn error(id: String, message: impl Into<String>) -> Self {
67        Self {
68            id,
69            success: false,
70            result: None,
71            error: Some(message.into()),
72        }
73    }
74}
75
76/// WebSocket event broadcast
77#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
78pub struct WebSocketEvent {
79    pub event_type: String,
80    pub data: serde_json::Value,
81}
82
83impl WebSocketEvent {
84    pub fn checkpoint_created(checkpoint_id: impl ToString, session_id: impl ToString) -> Self {
85        Self {
86            event_type: "checkpoint_created".to_string(),
87            data: serde_json::json!({
88                "checkpoint_id": checkpoint_id.to_string(),
89                "session_id": session_id.to_string(),
90                "timestamp": chrono::Utc::now().to_rfc3339(),
91            }),
92        }
93    }
94
95    pub fn checkpoint_restored(checkpoint_id: impl ToString, session_id: impl ToString) -> Self {
96        Self {
97            event_type: "checkpoint_restored".to_string(),
98            data: serde_json::json!({
99                "checkpoint_id": checkpoint_id.to_string(),
100                "session_id": session_id.to_string(),
101                "timestamp": chrono::Utc::now().to_rfc3339(),
102            }),
103        }
104    }
105
106    pub fn checkpoint_deleted(checkpoint_id: impl ToString, session_id: impl ToString) -> Self {
107        Self {
108            event_type: "checkpoint_deleted".to_string(),
109            data: serde_json::json!({
110                "checkpoint_id": checkpoint_id.to_string(),
111                "session_id": session_id.to_string(),
112                "timestamp": chrono::Utc::now().to_rfc3339(),
113            }),
114        }
115    }
116
117    pub fn checkpoints_compacted(session_id: impl ToString, remaining: usize) -> Self {
118        Self {
119            event_type: "checkpoints_compacted".to_string(),
120            data: serde_json::json!({
121                "session_id": session_id.to_string(),
122                "remaining": remaining,
123                "timestamp": chrono::Utc::now().to_rfc3339(),
124            }),
125        }
126    }
127
128    /// Convert from CheckpointEvent to WebSocketEvent
129    pub fn from_checkpoint_event(event: &CheckpointEvent) -> Self {
130        match event {
131            CheckpointEvent::Created { checkpoint_id, session_id, .. } => {
132                Self::checkpoint_created(checkpoint_id.to_string(), session_id.to_string())
133            }
134            CheckpointEvent::Restored { checkpoint_id, session_id } => {
135                Self::checkpoint_restored(checkpoint_id.to_string(), session_id.to_string())
136            }
137            CheckpointEvent::Deleted { checkpoint_id, session_id } => {
138                Self::checkpoint_deleted(checkpoint_id.to_string(), session_id.to_string())
139            }
140            CheckpointEvent::Compacted { session_id, remaining } => {
141                Self::checkpoints_compacted(session_id.to_string(), *remaining)
142            }
143        }
144    }
145}
146
147/// Client connection state
148#[derive(Debug, Clone)]
149struct ClientState {
150    _id: String,
151    authenticated: bool,
152    subscriptions: Vec<SessionId>,
153}
154
155/// WebSocket server for checkpointing
156pub struct CheckpointWebSocketServer {
157    bind_addr: String,
158    service: Arc<CheckpointService>,
159    config: WebSocketConfig,
160    shutdown_tx: Option<broadcast::Sender<()>>,
161    clients: Arc<RwLock<HashMap<String, mpsc::UnboundedSender<Message>>>>,
162}
163
164impl CheckpointWebSocketServer {
165    /// Create a new WebSocket server with default config
166    pub fn new(bind_addr: impl Into<String>, service: Arc<CheckpointService>) -> Self {
167        Self::with_config(bind_addr, service, WebSocketConfig::default())
168    }
169
170    /// Create a new WebSocket server with custom config
171    pub fn with_config(
172        bind_addr: impl Into<String>,
173        service: Arc<CheckpointService>,
174        config: WebSocketConfig,
175    ) -> Self {
176        Self {
177            bind_addr: bind_addr.into(),
178            service,
179            config,
180            shutdown_tx: None,
181            clients: Arc::new(RwLock::new(HashMap::new())),
182        }
183    }
184
185    /// Start the server and return the bound address
186    pub async fn start(&mut self) -> Result<SocketAddr> {
187        let listener = TcpListener::bind(&self.bind_addr).await
188            .map_err(|e| ReasoningError::Io(std::io::Error::new(
189                std::io::ErrorKind::AddrNotAvailable,
190                format!("Failed to bind: {}", e)
191            )))?;
192
193        let addr = listener.local_addr()
194            .map_err(|e| ReasoningError::Io(e))?;
195
196        let (shutdown_tx, mut shutdown_rx) = broadcast::channel(1);
197        self.shutdown_tx = Some(shutdown_tx.clone());
198
199        let service = Arc::clone(&self.service);
200        let clients = Arc::clone(&self.clients);
201        let config = self.config.clone();
202
203        tokio::spawn(async move {
204            loop {
205                tokio::select! {
206                    Ok((stream, peer_addr)) = listener.accept() => {
207                        let service = Arc::clone(&service);
208                        let clients = Arc::clone(&clients);
209                        let config = config.clone();
210
211                        tokio::spawn(async move {
212                            if let Err(e) = handle_connection(
213                                stream,
214                                peer_addr,
215                                service,
216                                clients,
217                                config,
218                            ).await {
219                                tracing::warn!("WebSocket connection error: {}", e);
220                            }
221                        });
222                    }
223                    _ = shutdown_rx.recv() => {
224                        tracing::info!("WebSocket server shutting down");
225                        break;
226                    }
227                }
228            }
229        });
230
231        tracing::info!("WebSocket server started on {}", addr);
232        Ok(addr)
233    }
234
235    /// Stop the server
236    pub async fn stop(&mut self) -> Result<()> {
237        if let Some(tx) = self.shutdown_tx.take() {
238            let _ = tx.send(());
239        }
240        
241        // Clear all clients
242        let mut clients = self.clients.write().await;
243        clients.clear();
244        
245        Ok(())
246    }
247
248    /// Get number of connected clients
249    pub async fn client_count(&self) -> usize {
250        self.clients.read().await.len()
251    }
252}
253
254/// Commands sent from message handler to event forwarding task
255type SubscribeCommand = (SessionId, tokio::sync::mpsc::UnboundedSender<WebSocketEvent>);
256
257async fn handle_connection(
258    stream: TcpStream,
259    peer_addr: SocketAddr,
260    service: Arc<CheckpointService>,
261    clients: Arc<RwLock<HashMap<String, mpsc::UnboundedSender<Message>>>>,
262    config: WebSocketConfig,
263) -> Result<()> {
264    let ws_stream = accept_async(stream).await
265        .map_err(|e| ReasoningError::Io(std::io::Error::new(
266            std::io::ErrorKind::ConnectionRefused,
267            format!("WebSocket handshake failed: {}", e)
268        )))?;
269
270    let client_id = uuid::Uuid::new_v4().to_string();
271    tracing::info!("New WebSocket connection: {} from {}", client_id, peer_addr);
272
273    let (mut ws_tx, mut ws_rx) = ws_stream.split();
274    let (tx, mut rx) = mpsc::unbounded_channel();
275
276    // Register client
277    {
278        let mut clients_guard = clients.write().await;
279        if clients_guard.len() >= config.max_connections {
280            let _ = ws_tx.send(Message::Text(
281                serde_json::to_string(&WebSocketResponse::error(
282                    "init".to_string(),
283                    "Server at capacity"
284                )).unwrap()
285            )).await;
286            return Ok(());
287        }
288        clients_guard.insert(client_id.clone(), tx);
289    }
290
291    let mut state = ClientState {
292        _id: client_id.clone(),
293        authenticated: !config.require_auth,
294        subscriptions: Vec::new(),
295    };
296
297    // Channel for coordinating subscriptions between message handler and event task
298    let (sub_tx, mut sub_rx) = mpsc::unbounded_channel::<SubscribeCommand>();
299
300    // Spawn task to forward messages from channel to WebSocket
301    let client_id_clone = client_id.clone();
302    let clients_clone = Arc::clone(&clients);
303    let forward_task = tokio::spawn(async move {
304        while let Some(msg) = rx.recv().await {
305            if ws_tx.send(msg).await.is_err() {
306                break;
307            }
308        }
309        // Remove client on disconnect
310        clients_clone.write().await.remove(&client_id_clone);
311    });
312
313    // Spawn task to forward service events to WebSocket client
314    let service_for_events = Arc::clone(&service);
315    let clients_for_events = Arc::clone(&clients);
316    let client_id_for_events = client_id.clone();
317    let event_forward_task = tokio::spawn(async move {
318        let mut event_receivers: HashMap<SessionId, mpsc::Receiver<CheckpointEvent>> = HashMap::new();
319        
320        loop {
321            tokio::select! {
322                // Handle new subscription requests
323                Some((session_id, notify_tx)) = sub_rx.recv() => {
324                    // Subscribe to service events for this session
325                    match service_for_events.subscribe(&session_id) {
326                        Ok(rx) => {
327                            event_receivers.insert(session_id, rx);
328                            // Notify that subscription is active
329                            let _ = notify_tx.send(WebSocketEvent {
330                                event_type: "subscribed".to_string(),
331                                data: serde_json::json!({
332                                    "session_id": session_id.to_string(),
333                                }),
334                            });
335                        }
336                        Err(e) => {
337                            let _ = notify_tx.send(WebSocketEvent {
338                                event_type: "subscribe_error".to_string(),
339                                data: serde_json::json!({
340                                    "session_id": session_id.to_string(),
341                                    "error": e.to_string(),
342                                }),
343                            });
344                        }
345                    }
346                }
347                
348                // Listen for events from all subscribed sessions
349                Some((_session_id, event)) = async {
350                    // Poll all receivers
351                    for (session_id, rx) in &mut event_receivers {
352                        if let Ok(event) = rx.try_recv() {
353                            return Some((*session_id, event));
354                        }
355                    }
356                    None
357                } => {
358                    let ws_event = WebSocketEvent::from_checkpoint_event(&event);
359                    let msg = Message::Text(serde_json::to_string(&ws_event).unwrap_or_default());
360                    
361                    // Send to the client's message channel
362                    if let Some(client_tx) = clients_for_events.read().await.get(&client_id_for_events) {
363                        let _ = client_tx.send(msg);
364                    }
365                }
366                
367                // Small sleep to prevent busy-waiting when no events
368                _ = tokio::time::sleep(tokio::time::Duration::from_millis(10)) => {}
369            }
370        }
371    });
372
373    // Handle incoming messages
374    while let Some(msg) = ws_rx.next().await {
375        match msg {
376            Ok(Message::Text(text)) => {
377                let response = handle_message(
378                    &text,
379                    &mut state,
380                    &service,
381                    &config,
382                    &sub_tx,
383                ).await;
384
385                let response_text = serde_json::to_string(&response)?;
386                let tx = clients.read().await.get(&client_id).cloned();
387                if let Some(tx) = tx {
388                    let _ = tx.send(Message::Text(response_text));
389                }
390            }
391            Ok(Message::Close(_)) => {
392                tracing::info!("Client {} disconnected", client_id);
393                break;
394            }
395            Ok(Message::Ping(data)) => {
396                let tx = clients.read().await.get(&client_id).cloned();
397                if let Some(tx) = tx {
398                    let _ = tx.send(Message::Pong(data));
399                }
400            }
401            Err(e) => {
402                tracing::warn!("WebSocket error from {}: {}", client_id, e);
403                break;
404            }
405            _ => {}
406        }
407    }
408
409    // Cleanup
410    event_forward_task.abort();
411    forward_task.abort();
412    clients.write().await.remove(&client_id);
413    tracing::info!("Client {} removed", client_id);
414
415    Ok(())
416}
417
418async fn handle_message(
419    text: &str,
420    state: &mut ClientState,
421    service: &Arc<CheckpointService>,
422    config: &WebSocketConfig,
423    sub_tx: &mpsc::UnboundedSender<SubscribeCommand>,
424) -> WebSocketResponse {
425    // Parse command
426    let cmd: WebSocketCommand = match serde_json::from_str(text) {
427        Ok(cmd) => cmd,
428        Err(e) => {
429            return WebSocketResponse::error(
430                "unknown".to_string(),
431                format!("Invalid JSON: {}", e)
432            );
433        }
434    };
435
436    // Check authentication
437    if config.require_auth && !state.authenticated && cmd.method != "authenticate" {
438        return WebSocketResponse::error(
439            cmd.id,
440            "Authentication required"
441        );
442    }
443
444    // Handle command
445    match cmd.method.as_str() {
446        "authenticate" => handle_authenticate(&cmd, state, config).await,
447        "create_session" => handle_create_session(&cmd, service).await,
448        "list_checkpoints" => handle_list_checkpoints(&cmd, service).await,
449        "checkpoint" => handle_checkpoint(&cmd, service).await,
450        "subscribe" => handle_subscribe(&cmd, state, sub_tx).await,
451        "metrics" => handle_metrics(&cmd, service).await,
452        _ => WebSocketResponse::error(
453            cmd.id,
454            format!("Unknown method: {}", cmd.method)
455        ),
456    }
457}
458
459async fn handle_authenticate(
460    cmd: &WebSocketCommand,
461    state: &mut ClientState,
462    config: &WebSocketConfig,
463) -> WebSocketResponse {
464    let token = cmd.params.get("token").and_then(|v| v.as_str());
465    
466    match (&config.auth_token, token) {
467        (Some(expected), Some(provided)) if expected == provided => {
468            state.authenticated = true;
469            WebSocketResponse::success(cmd.id.clone(), serde_json::json!({ "authenticated": true }))
470        }
471        _ => {
472            WebSocketResponse::error(cmd.id.clone(), "Invalid authentication token")
473        }
474    }
475}
476
477async fn handle_create_session(
478    cmd: &WebSocketCommand,
479    service: &Arc<CheckpointService>,
480) -> WebSocketResponse {
481    let name = cmd.params.get("name")
482        .and_then(|v| v.as_str())
483        .unwrap_or("unnamed");
484
485    match service.create_session(name) {
486        Ok(session_id) => {
487            WebSocketResponse::success(cmd.id.clone(), session_id.to_string())
488        }
489        Err(e) => {
490            WebSocketResponse::error(cmd.id.clone(), e.to_string())
491        }
492    }
493}
494
495async fn handle_list_checkpoints(
496    cmd: &WebSocketCommand,
497    service: &Arc<CheckpointService>,
498) -> WebSocketResponse {
499    let session_id_str = match cmd.params.get("session_id").and_then(|v| v.as_str()) {
500        Some(s) => s,
501        None => {
502            return WebSocketResponse::error(cmd.id.clone(), "Missing session_id parameter");
503        }
504    };
505
506    let session_id: SessionId = match uuid::Uuid::parse_str(session_id_str) {
507        Ok(uuid) => SessionId(uuid),
508        Err(_) => {
509            return WebSocketResponse::error(cmd.id.clone(), "Invalid session_id format");
510        }
511    };
512
513    match service.list_checkpoints(&session_id) {
514        Ok(checkpoints) => {
515            WebSocketResponse::success(cmd.id.clone(), checkpoints)
516        }
517        Err(e) => {
518            WebSocketResponse::error(cmd.id.clone(), e.to_string())
519        }
520    }
521}
522
523async fn handle_checkpoint(
524    cmd: &WebSocketCommand,
525    service: &Arc<CheckpointService>,
526) -> WebSocketResponse {
527    let session_id_str = match cmd.params.get("session_id").and_then(|v| v.as_str()) {
528        Some(s) => s,
529        None => {
530            return WebSocketResponse::error(cmd.id.clone(), "Missing session_id parameter");
531        }
532    };
533
534    let session_id: SessionId = match uuid::Uuid::parse_str(session_id_str) {
535        Ok(uuid) => SessionId(uuid),
536        Err(_) => {
537            return WebSocketResponse::error(cmd.id.clone(), "Invalid session_id format");
538        }
539    };
540
541    let message = cmd.params.get("message")
542        .and_then(|v| v.as_str())
543        .unwrap_or("Checkpoint");
544
545    match service.checkpoint(&session_id, message) {
546        Ok(checkpoint_id) => {
547            WebSocketResponse::success(cmd.id.clone(), checkpoint_id.to_string())
548        }
549        Err(e) => {
550            WebSocketResponse::error(cmd.id.clone(), e.to_string())
551        }
552    }
553}
554
555async fn handle_subscribe(
556    cmd: &WebSocketCommand,
557    state: &mut ClientState,
558    sub_tx: &mpsc::UnboundedSender<SubscribeCommand>,
559) -> WebSocketResponse {
560    let session_id_str = match cmd.params.get("session_id").and_then(|v| v.as_str()) {
561        Some(s) => s,
562        None => {
563            return WebSocketResponse::error(cmd.id.clone(), "Missing session_id parameter");
564        }
565    };
566
567    let session_id: SessionId = match uuid::Uuid::parse_str(session_id_str) {
568        Ok(uuid) => SessionId(uuid),
569        Err(_) => {
570            return WebSocketResponse::error(cmd.id.clone(), "Invalid session_id format");
571        }
572    };
573
574    state.subscriptions.push(session_id);
575    
576    // Channel to receive subscription confirmation from event task
577    let (notify_tx, mut notify_rx) = mpsc::unbounded_channel();
578    
579    // Send subscription request to event forwarding task
580    if let Err(e) = sub_tx.send((session_id, notify_tx)) {
581        return WebSocketResponse::error(
582            cmd.id.clone(),
583            format!("Failed to setup subscription: {}", e)
584        );
585    }
586    
587    // Wait for subscription confirmation (with timeout)
588    match tokio::time::timeout(
589        tokio::time::Duration::from_secs(5),
590        notify_rx.recv()
591    ).await {
592        Ok(Some(event)) if event.event_type == "subscribed" => {
593            WebSocketResponse::success(cmd.id.clone(), serde_json::json!({
594                "subscribed": true,
595                "session_id": session_id.to_string()
596            }))
597        }
598        Ok(Some(event)) if event.event_type == "subscribe_error" => {
599            WebSocketResponse::error(cmd.id.clone(), 
600                event.data.get("error")
601                    .and_then(|v| v.as_str())
602                    .unwrap_or("Subscription failed"))
603        }
604        _ => {
605            WebSocketResponse::error(cmd.id.clone(), "Subscription timeout")
606        }
607    }
608}
609
610async fn handle_metrics(
611    cmd: &WebSocketCommand,
612    service: &Arc<CheckpointService>,
613) -> WebSocketResponse {
614    match service.metrics() {
615        Ok(metrics) => {
616            WebSocketResponse::success(cmd.id.clone(), metrics)
617        }
618        Err(e) => {
619            WebSocketResponse::error(cmd.id.clone(), e.to_string())
620        }
621    }
622}
623
624#[cfg(test)]
625mod tests {
626    use super::*;
627
628    #[tokio::test]
629    async fn test_websocket_config_default() {
630        let config = WebSocketConfig::default();
631        assert!(!config.require_auth);
632        assert_eq!(config.max_connections, 100);
633    }
634
635    #[tokio::test]
636    async fn test_websocket_response_success() {
637        let response = WebSocketResponse::success("test-id".to_string(), "hello");
638        assert!(response.success);
639        assert_eq!(response.id, "test-id");
640        assert!(response.error.is_none());
641    }
642
643    #[tokio::test]
644    async fn test_websocket_response_error() {
645        let response = WebSocketResponse::error("test-id".to_string(), "something went wrong");
646        assert!(!response.success);
647        assert_eq!(response.id, "test-id");
648        assert_eq!(response.error.unwrap(), "something went wrong");
649    }
650}