1use crate::state::AppState;
4use axum::{
5 extract::{
6 ws::{Message, WebSocket},
7 State, WebSocketUpgrade,
8 },
9 response::IntoResponse,
10 routing::get,
11 Router,
12};
13use futures::{sink::SinkExt, stream::StreamExt};
14use serde::{Deserialize, Serialize};
15use std::sync::Arc;
16use tokio::time::{interval, Duration};
17
18pub fn routes(state: Arc<AppState>) -> Router {
19 Router::new()
20 .route("/", get(websocket_handler))
21 .with_state(state)
22}
23
24async fn websocket_handler(
25 ws: WebSocketUpgrade,
26 State(state): State<Arc<AppState>>,
27) -> impl IntoResponse {
28 ws.on_upgrade(move |socket| handle_socket(socket, state))
29}
30
31async fn handle_socket(socket: WebSocket, state: Arc<AppState>) {
32 let (mut sender, mut receiver) = socket.split();
33 let user_id = "user_001".to_string(); if !state.can_create_ws_connection(&user_id).await {
37 let _ = sender
38 .send(Message::Text(
39 serde_json::to_string(&WsMessage::error(
40 "Maximum WebSocket connections reached",
41 ))
42 .unwrap(),
43 ))
44 .await;
45 return;
46 }
47
48 state.increment_ws_connection(user_id.clone()).await;
49
50 let welcome = WsMessage::connected("Welcome to AVL Console");
52 if sender
53 .send(Message::Text(serde_json::to_string(&welcome).unwrap()))
54 .await
55 .is_err()
56 {
57 state.decrement_ws_connection(&user_id).await;
58 return;
59 }
60
61 let mut ping_interval = interval(Duration::from_secs(state.config.ws_ping_interval));
63 let (ping_tx, mut ping_rx) = tokio::sync::mpsc::channel(10);
64
65 tokio::spawn(async move {
66 loop {
67 ping_interval.tick().await;
68 if ping_tx.send(()).await.is_err() {
69 break;
70 }
71 }
72 });
73
74 loop {
76 tokio::select! {
77 msg = receiver.next() => {
78 match msg {
79 Some(Ok(Message::Text(text))) => {
80 if let Ok(ws_msg) = serde_json::from_str::<WsMessage>(&text) {
81 handle_message(ws_msg, &mut sender, &state).await;
82 }
83 }
84 Some(Ok(Message::Close(_))) | None => {
85 break;
86 }
87 _ => {}
88 }
89 }
90 _ = ping_rx.recv() => {
91 if sender.send(Message::Ping(vec![])).await.is_err() {
92 break;
93 }
94 }
95 }
96 }
97
98 state.decrement_ws_connection(&user_id).await;
99}
100
101async fn handle_message(
102 msg: WsMessage,
103 sender: &mut futures::stream::SplitSink<WebSocket, Message>,
104 _state: &Arc<AppState>,
105) {
106 match msg.msg_type.as_str() {
107 "subscribe" => {
108 let response = WsMessage::subscribed(&msg.payload.unwrap_or_default());
109 let _ = sender
110 .send(Message::Text(serde_json::to_string(&response).unwrap()))
111 .await;
112 }
113 "ping" => {
114 let response = WsMessage::pong();
115 let _ = sender
116 .send(Message::Text(serde_json::to_string(&response).unwrap()))
117 .await;
118 }
119 _ => {}
120 }
121}
122
123#[derive(Debug, Serialize, Deserialize)]
124struct WsMessage {
125 #[serde(rename = "type")]
126 msg_type: String,
127 payload: Option<String>,
128}
129
130impl WsMessage {
131 fn connected(msg: &str) -> Self {
132 Self {
133 msg_type: "connected".to_string(),
134 payload: Some(msg.to_string()),
135 }
136 }
137
138 fn error(msg: &str) -> Self {
139 Self {
140 msg_type: "error".to_string(),
141 payload: Some(msg.to_string()),
142 }
143 }
144
145 fn subscribed(topic: &str) -> Self {
146 Self {
147 msg_type: "subscribed".to_string(),
148 payload: Some(topic.to_string()),
149 }
150 }
151
152 fn pong() -> Self {
153 Self {
154 msg_type: "pong".to_string(),
155 payload: None,
156 }
157 }
158}