avl_console/
websocket.rs

1//! WebSocket support for real-time updates
2
3use crate::state::AppState;
4use axum::{
5    extract::{
6        ws::{Message, WebSocket},
7        State, WebSocketUpgrade,
8    },
9    response::IntoResponse,
10    routing::get,
11    Router,
12};
13use futures::{sink::SinkExt, stream::StreamExt};
14use serde::{Deserialize, Serialize};
15use std::sync::Arc;
16use tokio::time::{interval, Duration};
17
18pub fn routes(state: Arc<AppState>) -> Router {
19    Router::new()
20        .route("/", get(websocket_handler))
21        .with_state(state)
22}
23
24async fn websocket_handler(
25    ws: WebSocketUpgrade,
26    State(state): State<Arc<AppState>>,
27) -> impl IntoResponse {
28    ws.on_upgrade(move |socket| handle_socket(socket, state))
29}
30
31async fn handle_socket(socket: WebSocket, state: Arc<AppState>) {
32    let (mut sender, mut receiver) = socket.split();
33    let user_id = "user_001".to_string(); // TODO: Extract from auth
34
35    // Check connection limit
36    if !state.can_create_ws_connection(&user_id).await {
37        let _ = sender
38            .send(Message::Text(
39                serde_json::to_string(&WsMessage::error(
40                    "Maximum WebSocket connections reached",
41                ))
42                .unwrap(),
43            ))
44            .await;
45        return;
46    }
47
48    state.increment_ws_connection(user_id.clone()).await;
49
50    // Send welcome message
51    let welcome = WsMessage::connected("Welcome to AVL Console");
52    if sender
53        .send(Message::Text(serde_json::to_string(&welcome).unwrap()))
54        .await
55        .is_err()
56    {
57        state.decrement_ws_connection(&user_id).await;
58        return;
59    }
60
61    // Spawn ping task
62    let mut ping_interval = interval(Duration::from_secs(state.config.ws_ping_interval));
63    let (ping_tx, mut ping_rx) = tokio::sync::mpsc::channel(10);
64
65    tokio::spawn(async move {
66        loop {
67            ping_interval.tick().await;
68            if ping_tx.send(()).await.is_err() {
69                break;
70            }
71        }
72    });
73
74    // Handle messages
75    loop {
76        tokio::select! {
77            msg = receiver.next() => {
78                match msg {
79                    Some(Ok(Message::Text(text))) => {
80                        if let Ok(ws_msg) = serde_json::from_str::<WsMessage>(&text) {
81                            handle_message(ws_msg, &mut sender, &state).await;
82                        }
83                    }
84                    Some(Ok(Message::Close(_))) | None => {
85                        break;
86                    }
87                    _ => {}
88                }
89            }
90            _ = ping_rx.recv() => {
91                if sender.send(Message::Ping(vec![])).await.is_err() {
92                    break;
93                }
94            }
95        }
96    }
97
98    state.decrement_ws_connection(&user_id).await;
99}
100
101async fn handle_message(
102    msg: WsMessage,
103    sender: &mut futures::stream::SplitSink<WebSocket, Message>,
104    _state: &Arc<AppState>,
105) {
106    match msg.msg_type.as_str() {
107        "subscribe" => {
108            let response = WsMessage::subscribed(&msg.payload.unwrap_or_default());
109            let _ = sender
110                .send(Message::Text(serde_json::to_string(&response).unwrap()))
111                .await;
112        }
113        "ping" => {
114            let response = WsMessage::pong();
115            let _ = sender
116                .send(Message::Text(serde_json::to_string(&response).unwrap()))
117                .await;
118        }
119        _ => {}
120    }
121}
122
123#[derive(Debug, Serialize, Deserialize)]
124struct WsMessage {
125    #[serde(rename = "type")]
126    msg_type: String,
127    payload: Option<String>,
128}
129
130impl WsMessage {
131    fn connected(msg: &str) -> Self {
132        Self {
133            msg_type: "connected".to_string(),
134            payload: Some(msg.to_string()),
135        }
136    }
137
138    fn error(msg: &str) -> Self {
139        Self {
140            msg_type: "error".to_string(),
141            payload: Some(msg.to_string()),
142        }
143    }
144
145    fn subscribed(topic: &str) -> Self {
146        Self {
147            msg_type: "subscribed".to_string(),
148            payload: Some(topic.to_string()),
149        }
150    }
151
152    fn pong() -> Self {
153        Self {
154            msg_type: "pong".to_string(),
155            payload: None,
156        }
157    }
158}