1use std::collections::HashMap;
6use std::net::SocketAddr;
7use std::sync::Arc;
8
9use futures_util::{SinkExt, StreamExt};
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::{broadcast, mpsc, RwLock};
12use tokio_tungstenite::{accept_async, tungstenite::Message};
13
14use crate::errors::{ReasoningError, Result};
15use crate::service::CheckpointEvent;
16use crate::service::CheckpointService;
17use crate::SessionId;
18
19#[derive(Clone, Debug)]
21pub struct WebSocketConfig {
22 pub require_auth: bool,
23 pub auth_token: Option<String>,
24 pub max_connections: usize,
25}
26
27impl Default for WebSocketConfig {
28 fn default() -> Self {
29 Self {
30 require_auth: false,
31 auth_token: None,
32 max_connections: 100,
33 }
34 }
35}
36
37#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
39pub struct WebSocketCommand {
40 pub id: String,
41 pub method: String,
42 pub params: serde_json::Value,
43}
44
45#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
47pub struct WebSocketResponse {
48 pub id: String,
49 pub success: bool,
50 #[serde(skip_serializing_if = "Option::is_none")]
51 pub result: Option<serde_json::Value>,
52 #[serde(skip_serializing_if = "Option::is_none")]
53 pub error: Option<String>,
54}
55
56impl WebSocketResponse {
57 pub fn success(id: String, result: impl serde::Serialize) -> Self {
58 Self {
59 id,
60 success: true,
61 result: serde_json::to_value(result).ok(),
62 error: None,
63 }
64 }
65
66 pub fn error(id: String, message: impl Into<String>) -> Self {
67 Self {
68 id,
69 success: false,
70 result: None,
71 error: Some(message.into()),
72 }
73 }
74}
75
76#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
78pub struct WebSocketEvent {
79 pub event_type: String,
80 pub data: serde_json::Value,
81}
82
83impl WebSocketEvent {
84 pub fn checkpoint_created(checkpoint_id: impl ToString, session_id: impl ToString) -> Self {
85 Self {
86 event_type: "checkpoint_created".to_string(),
87 data: serde_json::json!({
88 "checkpoint_id": checkpoint_id.to_string(),
89 "session_id": session_id.to_string(),
90 "timestamp": chrono::Utc::now().to_rfc3339(),
91 }),
92 }
93 }
94
95 pub fn checkpoint_restored(checkpoint_id: impl ToString, session_id: impl ToString) -> Self {
96 Self {
97 event_type: "checkpoint_restored".to_string(),
98 data: serde_json::json!({
99 "checkpoint_id": checkpoint_id.to_string(),
100 "session_id": session_id.to_string(),
101 "timestamp": chrono::Utc::now().to_rfc3339(),
102 }),
103 }
104 }
105
106 pub fn checkpoint_deleted(checkpoint_id: impl ToString, session_id: impl ToString) -> Self {
107 Self {
108 event_type: "checkpoint_deleted".to_string(),
109 data: serde_json::json!({
110 "checkpoint_id": checkpoint_id.to_string(),
111 "session_id": session_id.to_string(),
112 "timestamp": chrono::Utc::now().to_rfc3339(),
113 }),
114 }
115 }
116
117 pub fn checkpoints_compacted(session_id: impl ToString, remaining: usize) -> Self {
118 Self {
119 event_type: "checkpoints_compacted".to_string(),
120 data: serde_json::json!({
121 "session_id": session_id.to_string(),
122 "remaining": remaining,
123 "timestamp": chrono::Utc::now().to_rfc3339(),
124 }),
125 }
126 }
127
128 pub fn from_checkpoint_event(event: &CheckpointEvent) -> Self {
130 match event {
131 CheckpointEvent::Created { checkpoint_id, session_id, .. } => {
132 Self::checkpoint_created(checkpoint_id.to_string(), session_id.to_string())
133 }
134 CheckpointEvent::Restored { checkpoint_id, session_id } => {
135 Self::checkpoint_restored(checkpoint_id.to_string(), session_id.to_string())
136 }
137 CheckpointEvent::Deleted { checkpoint_id, session_id } => {
138 Self::checkpoint_deleted(checkpoint_id.to_string(), session_id.to_string())
139 }
140 CheckpointEvent::Compacted { session_id, remaining } => {
141 Self::checkpoints_compacted(session_id.to_string(), *remaining)
142 }
143 }
144 }
145}
146
147#[derive(Debug, Clone)]
149struct ClientState {
150 _id: String,
151 authenticated: bool,
152 subscriptions: Vec<SessionId>,
153}
154
155pub struct CheckpointWebSocketServer {
157 bind_addr: String,
158 service: Arc<CheckpointService>,
159 config: WebSocketConfig,
160 shutdown_tx: Option<broadcast::Sender<()>>,
161 clients: Arc<RwLock<HashMap<String, mpsc::UnboundedSender<Message>>>>,
162}
163
164impl CheckpointWebSocketServer {
165 pub fn new(bind_addr: impl Into<String>, service: Arc<CheckpointService>) -> Self {
167 Self::with_config(bind_addr, service, WebSocketConfig::default())
168 }
169
170 pub fn with_config(
172 bind_addr: impl Into<String>,
173 service: Arc<CheckpointService>,
174 config: WebSocketConfig,
175 ) -> Self {
176 Self {
177 bind_addr: bind_addr.into(),
178 service,
179 config,
180 shutdown_tx: None,
181 clients: Arc::new(RwLock::new(HashMap::new())),
182 }
183 }
184
185 pub async fn start(&mut self) -> Result<SocketAddr> {
187 let listener = TcpListener::bind(&self.bind_addr).await
188 .map_err(|e| ReasoningError::Io(std::io::Error::new(
189 std::io::ErrorKind::AddrNotAvailable,
190 format!("Failed to bind: {}", e)
191 )))?;
192
193 let addr = listener.local_addr()
194 .map_err(|e| ReasoningError::Io(e))?;
195
196 let (shutdown_tx, mut shutdown_rx) = broadcast::channel(1);
197 self.shutdown_tx = Some(shutdown_tx.clone());
198
199 let service = Arc::clone(&self.service);
200 let clients = Arc::clone(&self.clients);
201 let config = self.config.clone();
202
203 tokio::spawn(async move {
204 loop {
205 tokio::select! {
206 Ok((stream, peer_addr)) = listener.accept() => {
207 let service = Arc::clone(&service);
208 let clients = Arc::clone(&clients);
209 let config = config.clone();
210
211 tokio::spawn(async move {
212 if let Err(e) = handle_connection(
213 stream,
214 peer_addr,
215 service,
216 clients,
217 config,
218 ).await {
219 tracing::warn!("WebSocket connection error: {}", e);
220 }
221 });
222 }
223 _ = shutdown_rx.recv() => {
224 tracing::info!("WebSocket server shutting down");
225 break;
226 }
227 }
228 }
229 });
230
231 tracing::info!("WebSocket server started on {}", addr);
232 Ok(addr)
233 }
234
235 pub async fn stop(&mut self) -> Result<()> {
237 if let Some(tx) = self.shutdown_tx.take() {
238 let _ = tx.send(());
239 }
240
241 let mut clients = self.clients.write().await;
243 clients.clear();
244
245 Ok(())
246 }
247
248 pub async fn client_count(&self) -> usize {
250 self.clients.read().await.len()
251 }
252}
253
254type SubscribeCommand = (SessionId, tokio::sync::mpsc::UnboundedSender<WebSocketEvent>);
256
257async fn handle_connection(
258 stream: TcpStream,
259 peer_addr: SocketAddr,
260 service: Arc<CheckpointService>,
261 clients: Arc<RwLock<HashMap<String, mpsc::UnboundedSender<Message>>>>,
262 config: WebSocketConfig,
263) -> Result<()> {
264 let ws_stream = accept_async(stream).await
265 .map_err(|e| ReasoningError::Io(std::io::Error::new(
266 std::io::ErrorKind::ConnectionRefused,
267 format!("WebSocket handshake failed: {}", e)
268 )))?;
269
270 let client_id = uuid::Uuid::new_v4().to_string();
271 tracing::info!("New WebSocket connection: {} from {}", client_id, peer_addr);
272
273 let (mut ws_tx, mut ws_rx) = ws_stream.split();
274 let (tx, mut rx) = mpsc::unbounded_channel();
275
276 {
278 let mut clients_guard = clients.write().await;
279 if clients_guard.len() >= config.max_connections {
280 let _ = ws_tx.send(Message::Text(
281 serde_json::to_string(&WebSocketResponse::error(
282 "init".to_string(),
283 "Server at capacity"
284 )).unwrap()
285 )).await;
286 return Ok(());
287 }
288 clients_guard.insert(client_id.clone(), tx);
289 }
290
291 let mut state = ClientState {
292 _id: client_id.clone(),
293 authenticated: !config.require_auth,
294 subscriptions: Vec::new(),
295 };
296
297 let (sub_tx, mut sub_rx) = mpsc::unbounded_channel::<SubscribeCommand>();
299
300 let client_id_clone = client_id.clone();
302 let clients_clone = Arc::clone(&clients);
303 let forward_task = tokio::spawn(async move {
304 while let Some(msg) = rx.recv().await {
305 if ws_tx.send(msg).await.is_err() {
306 break;
307 }
308 }
309 clients_clone.write().await.remove(&client_id_clone);
311 });
312
313 let service_for_events = Arc::clone(&service);
315 let clients_for_events = Arc::clone(&clients);
316 let client_id_for_events = client_id.clone();
317 let event_forward_task = tokio::spawn(async move {
318 let mut event_receivers: HashMap<SessionId, mpsc::Receiver<CheckpointEvent>> = HashMap::new();
319
320 loop {
321 tokio::select! {
322 Some((session_id, notify_tx)) = sub_rx.recv() => {
324 match service_for_events.subscribe(&session_id) {
326 Ok(rx) => {
327 event_receivers.insert(session_id, rx);
328 let _ = notify_tx.send(WebSocketEvent {
330 event_type: "subscribed".to_string(),
331 data: serde_json::json!({
332 "session_id": session_id.to_string(),
333 }),
334 });
335 }
336 Err(e) => {
337 let _ = notify_tx.send(WebSocketEvent {
338 event_type: "subscribe_error".to_string(),
339 data: serde_json::json!({
340 "session_id": session_id.to_string(),
341 "error": e.to_string(),
342 }),
343 });
344 }
345 }
346 }
347
348 Some((_session_id, event)) = async {
350 for (session_id, rx) in &mut event_receivers {
352 if let Ok(event) = rx.try_recv() {
353 return Some((*session_id, event));
354 }
355 }
356 None
357 } => {
358 let ws_event = WebSocketEvent::from_checkpoint_event(&event);
359 let msg = Message::Text(serde_json::to_string(&ws_event).unwrap_or_default());
360
361 if let Some(client_tx) = clients_for_events.read().await.get(&client_id_for_events) {
363 let _ = client_tx.send(msg);
364 }
365 }
366
367 _ = tokio::time::sleep(tokio::time::Duration::from_millis(10)) => {}
369 }
370 }
371 });
372
373 while let Some(msg) = ws_rx.next().await {
375 match msg {
376 Ok(Message::Text(text)) => {
377 let response = handle_message(
378 &text,
379 &mut state,
380 &service,
381 &config,
382 &sub_tx,
383 ).await;
384
385 let response_text = serde_json::to_string(&response)?;
386 let tx = clients.read().await.get(&client_id).cloned();
387 if let Some(tx) = tx {
388 let _ = tx.send(Message::Text(response_text));
389 }
390 }
391 Ok(Message::Close(_)) => {
392 tracing::info!("Client {} disconnected", client_id);
393 break;
394 }
395 Ok(Message::Ping(data)) => {
396 let tx = clients.read().await.get(&client_id).cloned();
397 if let Some(tx) = tx {
398 let _ = tx.send(Message::Pong(data));
399 }
400 }
401 Err(e) => {
402 tracing::warn!("WebSocket error from {}: {}", client_id, e);
403 break;
404 }
405 _ => {}
406 }
407 }
408
409 event_forward_task.abort();
411 forward_task.abort();
412 clients.write().await.remove(&client_id);
413 tracing::info!("Client {} removed", client_id);
414
415 Ok(())
416}
417
418async fn handle_message(
419 text: &str,
420 state: &mut ClientState,
421 service: &Arc<CheckpointService>,
422 config: &WebSocketConfig,
423 sub_tx: &mpsc::UnboundedSender<SubscribeCommand>,
424) -> WebSocketResponse {
425 let cmd: WebSocketCommand = match serde_json::from_str(text) {
427 Ok(cmd) => cmd,
428 Err(e) => {
429 return WebSocketResponse::error(
430 "unknown".to_string(),
431 format!("Invalid JSON: {}", e)
432 );
433 }
434 };
435
436 if config.require_auth && !state.authenticated && cmd.method != "authenticate" {
438 return WebSocketResponse::error(
439 cmd.id,
440 "Authentication required"
441 );
442 }
443
444 match cmd.method.as_str() {
446 "authenticate" => handle_authenticate(&cmd, state, config).await,
447 "create_session" => handle_create_session(&cmd, service).await,
448 "list_checkpoints" => handle_list_checkpoints(&cmd, service).await,
449 "checkpoint" => handle_checkpoint(&cmd, service).await,
450 "subscribe" => handle_subscribe(&cmd, state, sub_tx).await,
451 "metrics" => handle_metrics(&cmd, service).await,
452 _ => WebSocketResponse::error(
453 cmd.id,
454 format!("Unknown method: {}", cmd.method)
455 ),
456 }
457}
458
459async fn handle_authenticate(
460 cmd: &WebSocketCommand,
461 state: &mut ClientState,
462 config: &WebSocketConfig,
463) -> WebSocketResponse {
464 let token = cmd.params.get("token").and_then(|v| v.as_str());
465
466 match (&config.auth_token, token) {
467 (Some(expected), Some(provided)) if expected == provided => {
468 state.authenticated = true;
469 WebSocketResponse::success(cmd.id.clone(), serde_json::json!({ "authenticated": true }))
470 }
471 _ => {
472 WebSocketResponse::error(cmd.id.clone(), "Invalid authentication token")
473 }
474 }
475}
476
477async fn handle_create_session(
478 cmd: &WebSocketCommand,
479 service: &Arc<CheckpointService>,
480) -> WebSocketResponse {
481 let name = cmd.params.get("name")
482 .and_then(|v| v.as_str())
483 .unwrap_or("unnamed");
484
485 match service.create_session(name) {
486 Ok(session_id) => {
487 WebSocketResponse::success(cmd.id.clone(), session_id.to_string())
488 }
489 Err(e) => {
490 WebSocketResponse::error(cmd.id.clone(), e.to_string())
491 }
492 }
493}
494
495async fn handle_list_checkpoints(
496 cmd: &WebSocketCommand,
497 service: &Arc<CheckpointService>,
498) -> WebSocketResponse {
499 let session_id_str = match cmd.params.get("session_id").and_then(|v| v.as_str()) {
500 Some(s) => s,
501 None => {
502 return WebSocketResponse::error(cmd.id.clone(), "Missing session_id parameter");
503 }
504 };
505
506 let session_id: SessionId = match uuid::Uuid::parse_str(session_id_str) {
507 Ok(uuid) => SessionId(uuid),
508 Err(_) => {
509 return WebSocketResponse::error(cmd.id.clone(), "Invalid session_id format");
510 }
511 };
512
513 match service.list_checkpoints(&session_id) {
514 Ok(checkpoints) => {
515 WebSocketResponse::success(cmd.id.clone(), checkpoints)
516 }
517 Err(e) => {
518 WebSocketResponse::error(cmd.id.clone(), e.to_string())
519 }
520 }
521}
522
523async fn handle_checkpoint(
524 cmd: &WebSocketCommand,
525 service: &Arc<CheckpointService>,
526) -> WebSocketResponse {
527 let session_id_str = match cmd.params.get("session_id").and_then(|v| v.as_str()) {
528 Some(s) => s,
529 None => {
530 return WebSocketResponse::error(cmd.id.clone(), "Missing session_id parameter");
531 }
532 };
533
534 let session_id: SessionId = match uuid::Uuid::parse_str(session_id_str) {
535 Ok(uuid) => SessionId(uuid),
536 Err(_) => {
537 return WebSocketResponse::error(cmd.id.clone(), "Invalid session_id format");
538 }
539 };
540
541 let message = cmd.params.get("message")
542 .and_then(|v| v.as_str())
543 .unwrap_or("Checkpoint");
544
545 match service.checkpoint(&session_id, message) {
546 Ok(checkpoint_id) => {
547 WebSocketResponse::success(cmd.id.clone(), checkpoint_id.to_string())
548 }
549 Err(e) => {
550 WebSocketResponse::error(cmd.id.clone(), e.to_string())
551 }
552 }
553}
554
555async fn handle_subscribe(
556 cmd: &WebSocketCommand,
557 state: &mut ClientState,
558 sub_tx: &mpsc::UnboundedSender<SubscribeCommand>,
559) -> WebSocketResponse {
560 let session_id_str = match cmd.params.get("session_id").and_then(|v| v.as_str()) {
561 Some(s) => s,
562 None => {
563 return WebSocketResponse::error(cmd.id.clone(), "Missing session_id parameter");
564 }
565 };
566
567 let session_id: SessionId = match uuid::Uuid::parse_str(session_id_str) {
568 Ok(uuid) => SessionId(uuid),
569 Err(_) => {
570 return WebSocketResponse::error(cmd.id.clone(), "Invalid session_id format");
571 }
572 };
573
574 state.subscriptions.push(session_id);
575
576 let (notify_tx, mut notify_rx) = mpsc::unbounded_channel();
578
579 if let Err(e) = sub_tx.send((session_id, notify_tx)) {
581 return WebSocketResponse::error(
582 cmd.id.clone(),
583 format!("Failed to setup subscription: {}", e)
584 );
585 }
586
587 match tokio::time::timeout(
589 tokio::time::Duration::from_secs(5),
590 notify_rx.recv()
591 ).await {
592 Ok(Some(event)) if event.event_type == "subscribed" => {
593 WebSocketResponse::success(cmd.id.clone(), serde_json::json!({
594 "subscribed": true,
595 "session_id": session_id.to_string()
596 }))
597 }
598 Ok(Some(event)) if event.event_type == "subscribe_error" => {
599 WebSocketResponse::error(cmd.id.clone(),
600 event.data.get("error")
601 .and_then(|v| v.as_str())
602 .unwrap_or("Subscription failed"))
603 }
604 _ => {
605 WebSocketResponse::error(cmd.id.clone(), "Subscription timeout")
606 }
607 }
608}
609
610async fn handle_metrics(
611 cmd: &WebSocketCommand,
612 service: &Arc<CheckpointService>,
613) -> WebSocketResponse {
614 match service.metrics() {
615 Ok(metrics) => {
616 WebSocketResponse::success(cmd.id.clone(), metrics)
617 }
618 Err(e) => {
619 WebSocketResponse::error(cmd.id.clone(), e.to_string())
620 }
621 }
622}
623
624#[cfg(test)]
625mod tests {
626 use super::*;
627
628 #[tokio::test]
629 async fn test_websocket_config_default() {
630 let config = WebSocketConfig::default();
631 assert!(!config.require_auth);
632 assert_eq!(config.max_connections, 100);
633 }
634
635 #[tokio::test]
636 async fn test_websocket_response_success() {
637 let response = WebSocketResponse::success("test-id".to_string(), "hello");
638 assert!(response.success);
639 assert_eq!(response.id, "test-id");
640 assert!(response.error.is_none());
641 }
642
643 #[tokio::test]
644 async fn test_websocket_response_error() {
645 let response = WebSocketResponse::error("test-id".to_string(), "something went wrong");
646 assert!(!response.success);
647 assert_eq!(response.id, "test-id");
648 assert_eq!(response.error.unwrap(), "something went wrong");
649 }
650}