guts_node/
realtime_api.rs

1//! Real-time WebSocket API for live updates.
2//!
3//! This module provides WebSocket endpoints for real-time communication:
4//!
5//! - `/ws` - Main WebSocket endpoint for real-time updates
6//! - `/api/realtime/stats` - Statistics about real-time connections
7//!
8//! ## WebSocket Protocol
9//!
10//! Clients can subscribe to channels to receive real-time events:
11//!
12//! ```json
13//! // Subscribe to a repository
14//! {"type": "subscribe", "channel": "repo:owner/name"}
15//!
16//! // Unsubscribe from a channel
17//! {"type": "unsubscribe", "channel": "repo:owner/name"}
18//!
19//! // Ping for keepalive
20//! {"type": "ping"}
21//! ```
22
23use axum::{
24    extract::{
25        ws::{Message, WebSocket, WebSocketUpgrade},
26        State,
27    },
28    response::IntoResponse,
29    routing::get,
30    Json, Router,
31};
32use futures_util::{SinkExt, StreamExt};
33use guts_realtime::{ClientCommand, EventHub, ServerMessage};
34use serde::Serialize;
35use std::sync::Arc;
36use tracing::{debug, error, info};
37
38use crate::api::AppState;
39
40/// Create the real-time API routes.
41pub fn realtime_routes() -> Router<AppState> {
42    Router::new()
43        .route("/ws", get(ws_handler))
44        .route("/api/realtime/stats", get(get_stats))
45}
46
47/// WebSocket upgrade handler.
48async fn ws_handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> impl IntoResponse {
49    ws.on_upgrade(move |socket| handle_socket(socket, state.realtime.clone()))
50}
51
52/// Handle a WebSocket connection.
53async fn handle_socket(socket: WebSocket, hub: Arc<EventHub>) {
54    // Connect the client to the hub
55    let (client, mut receiver) = match hub.connect() {
56        Ok(c) => c,
57        Err(e) => {
58            error!("Failed to connect client: {}", e);
59            return;
60        }
61    };
62
63    let client_id = client.id.clone();
64    info!(client_id = %client_id, "WebSocket client connected");
65
66    // Split the WebSocket into sender and receiver
67    let (mut ws_sender, mut ws_receiver) = socket.split();
68
69    // Spawn a task to forward messages from the hub to the WebSocket
70    let client_id_clone = client_id.clone();
71    let send_task = tokio::spawn(async move {
72        while let Some(msg) = receiver.recv().await {
73            if ws_sender.send(Message::Text(msg.into())).await.is_err() {
74                break;
75            }
76        }
77        debug!(client_id = %client_id_clone, "Send task ended");
78    });
79
80    // Handle incoming messages from the WebSocket
81    while let Some(msg) = ws_receiver.next().await {
82        match msg {
83            Ok(Message::Text(text)) => {
84                let text_str: &str = &text;
85                match serde_json::from_str::<ClientCommand>(text_str) {
86                    Ok(cmd) => {
87                        let response = hub.handle_command(&client, cmd);
88                        match response {
89                            Ok(msg) => {
90                                if let Ok(json) = serde_json::to_string(&msg) {
91                                    let _ = client.send(json);
92                                }
93                            }
94                            Err(e) => {
95                                let error_msg = ServerMessage::Error {
96                                    message: e.to_string(),
97                                };
98                                if let Ok(json) = serde_json::to_string(&error_msg) {
99                                    let _ = client.send(json);
100                                }
101                            }
102                        }
103                    }
104                    Err(e) => {
105                        debug!(client_id = %client_id, error = %e, "Invalid message format");
106                        let error_msg = ServerMessage::Error {
107                            message: format!("Invalid message format: {}", e),
108                        };
109                        if let Ok(json) = serde_json::to_string(&error_msg) {
110                            let _ = client.send(json);
111                        }
112                    }
113                }
114            }
115            Ok(Message::Close(_)) => {
116                debug!(client_id = %client_id, "WebSocket close received");
117                break;
118            }
119            Ok(Message::Ping(data)) => {
120                // Axum handles pong automatically, but log it
121                debug!(client_id = %client_id, "Ping received, len={}", data.len());
122            }
123            Ok(Message::Pong(_)) => {
124                // Ignore pong
125            }
126            Ok(Message::Binary(_)) => {
127                // We don't support binary messages
128                debug!(client_id = %client_id, "Binary message ignored");
129            }
130            Err(e) => {
131                error!(client_id = %client_id, error = %e, "WebSocket error");
132                break;
133            }
134        }
135    }
136
137    // Clean up
138    send_task.abort();
139    hub.disconnect(&client_id);
140    info!(client_id = %client_id, "WebSocket client disconnected");
141}
142
143/// Statistics response.
144#[derive(Serialize)]
145struct StatsResponse {
146    current_connections: usize,
147    total_connections: u64,
148    total_subscriptions: u64,
149    total_events: u64,
150}
151
152/// Get real-time connection statistics.
153async fn get_stats(State(state): State<AppState>) -> impl IntoResponse {
154    let stats = state.realtime.stats();
155    Json(StatsResponse {
156        current_connections: stats.current_connections,
157        total_connections: stats.total_connections,
158        total_subscriptions: stats.total_subscriptions,
159        total_events: stats.total_events,
160    })
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[test]
168    fn test_stats_serialization() {
169        let stats = StatsResponse {
170            current_connections: 10,
171            total_connections: 100,
172            total_subscriptions: 50,
173            total_events: 1000,
174        };
175
176        let json = serde_json::to_string(&stats).unwrap();
177        assert!(json.contains("\"current_connections\":10"));
178        assert!(json.contains("\"total_events\":1000"));
179    }
180}