guts_node/
realtime_api.rs1use axum::{
24 extract::{
25 ws::{Message, WebSocket, WebSocketUpgrade},
26 State,
27 },
28 response::IntoResponse,
29 routing::get,
30 Json, Router,
31};
32use futures_util::{SinkExt, StreamExt};
33use guts_realtime::{ClientCommand, EventHub, ServerMessage};
34use serde::Serialize;
35use std::sync::Arc;
36use tracing::{debug, error, info};
37
38use crate::api::AppState;
39
40pub fn realtime_routes() -> Router<AppState> {
42 Router::new()
43 .route("/ws", get(ws_handler))
44 .route("/api/realtime/stats", get(get_stats))
45}
46
47async fn ws_handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> impl IntoResponse {
49 ws.on_upgrade(move |socket| handle_socket(socket, state.realtime.clone()))
50}
51
52async fn handle_socket(socket: WebSocket, hub: Arc<EventHub>) {
54 let (client, mut receiver) = match hub.connect() {
56 Ok(c) => c,
57 Err(e) => {
58 error!("Failed to connect client: {}", e);
59 return;
60 }
61 };
62
63 let client_id = client.id.clone();
64 info!(client_id = %client_id, "WebSocket client connected");
65
66 let (mut ws_sender, mut ws_receiver) = socket.split();
68
69 let client_id_clone = client_id.clone();
71 let send_task = tokio::spawn(async move {
72 while let Some(msg) = receiver.recv().await {
73 if ws_sender.send(Message::Text(msg.into())).await.is_err() {
74 break;
75 }
76 }
77 debug!(client_id = %client_id_clone, "Send task ended");
78 });
79
80 while let Some(msg) = ws_receiver.next().await {
82 match msg {
83 Ok(Message::Text(text)) => {
84 let text_str: &str = &text;
85 match serde_json::from_str::<ClientCommand>(text_str) {
86 Ok(cmd) => {
87 let response = hub.handle_command(&client, cmd);
88 match response {
89 Ok(msg) => {
90 if let Ok(json) = serde_json::to_string(&msg) {
91 let _ = client.send(json);
92 }
93 }
94 Err(e) => {
95 let error_msg = ServerMessage::Error {
96 message: e.to_string(),
97 };
98 if let Ok(json) = serde_json::to_string(&error_msg) {
99 let _ = client.send(json);
100 }
101 }
102 }
103 }
104 Err(e) => {
105 debug!(client_id = %client_id, error = %e, "Invalid message format");
106 let error_msg = ServerMessage::Error {
107 message: format!("Invalid message format: {}", e),
108 };
109 if let Ok(json) = serde_json::to_string(&error_msg) {
110 let _ = client.send(json);
111 }
112 }
113 }
114 }
115 Ok(Message::Close(_)) => {
116 debug!(client_id = %client_id, "WebSocket close received");
117 break;
118 }
119 Ok(Message::Ping(data)) => {
120 debug!(client_id = %client_id, "Ping received, len={}", data.len());
122 }
123 Ok(Message::Pong(_)) => {
124 }
126 Ok(Message::Binary(_)) => {
127 debug!(client_id = %client_id, "Binary message ignored");
129 }
130 Err(e) => {
131 error!(client_id = %client_id, error = %e, "WebSocket error");
132 break;
133 }
134 }
135 }
136
137 send_task.abort();
139 hub.disconnect(&client_id);
140 info!(client_id = %client_id, "WebSocket client disconnected");
141}
142
143#[derive(Serialize)]
145struct StatsResponse {
146 current_connections: usize,
147 total_connections: u64,
148 total_subscriptions: u64,
149 total_events: u64,
150}
151
152async fn get_stats(State(state): State<AppState>) -> impl IntoResponse {
154 let stats = state.realtime.stats();
155 Json(StatsResponse {
156 current_connections: stats.current_connections,
157 total_connections: stats.total_connections,
158 total_subscriptions: stats.total_subscriptions,
159 total_events: stats.total_events,
160 })
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 #[test]
168 fn test_stats_serialization() {
169 let stats = StatsResponse {
170 current_connections: 10,
171 total_connections: 100,
172 total_subscriptions: 50,
173 total_events: 1000,
174 };
175
176 let json = serde_json::to_string(&stats).unwrap();
177 assert!(json.contains("\"current_connections\":10"));
178 assert!(json.contains("\"total_events\":1000"));
179 }
180}