use std::sync::Arc;
use axum::{
Router,
extract::{
State, WebSocketUpgrade,
ws::{Message, WebSocket},
},
response::IntoResponse,
routing::get,
};
use futures::{SinkExt, StreamExt};
use tokio::sync::broadcast::error::RecvError;
use tracing::{debug, info, warn};
use crate::state::{AppState, ReadingEvent};
pub fn router() -> Router<Arc<AppState>> {
Router::new().route("/api/ws", get(ws_handler))
}
async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<AppState>>) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_socket(socket, state))
}
async fn handle_socket(socket: WebSocket, state: Arc<AppState>) {
use std::sync::atomic::Ordering;
let (mut sender, mut receiver) = socket.split();
let mut rx = state.readings_tx.subscribe();
let ws_dropped = Arc::clone(&state);
info!("WebSocket client connected");
let snapshot: Vec<String> = match state
.with_store_read(|store| {
let mut events = Vec::new();
for (device, reading) in store.list_latest_readings()? {
let event = ReadingEvent {
device_id: device.id.clone(),
reading,
};
if let Ok(json) = serde_json::to_string(&event) {
events.push(json);
}
}
Ok(events)
})
.await
{
Ok(snapshot) => snapshot,
Err(e) => {
warn!("Failed to load initial WebSocket snapshot: {}", e);
let payload = serde_json::json!({
"type": "error",
"error": format!("Failed to load initial snapshot: {}", e),
})
.to_string();
let _ = sender.send(Message::Text(payload.into())).await;
let _ = sender.send(Message::Close(None)).await;
return;
}
};
for json in snapshot {
if sender.send(Message::Text(json.into())).await.is_err() {
info!("WebSocket client disconnected during initial snapshot");
return;
}
}
debug!("Sent initial snapshot to WebSocket client");
let mut send_task = tokio::spawn(async move {
loop {
match rx.recv().await {
Ok(event) => {
let json = match serde_json::to_string(&event) {
Ok(j) => j,
Err(e) => {
warn!("Failed to serialize event: {}", e);
continue;
}
};
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
}
Err(RecvError::Lagged(n)) => {
ws_dropped
.ws_messages_dropped
.fetch_add(n, Ordering::Relaxed);
warn!("WebSocket client lagged, skipped {n} messages");
continue;
}
Err(RecvError::Closed) => break,
}
}
});
let mut recv_task = tokio::spawn(async move {
while let Some(result) = receiver.next().await {
match result {
Ok(Message::Close(_)) => break,
Ok(Message::Ping(data)) => {
debug!("Received ping");
let _ = data;
}
Ok(_) => {
}
Err(e) => {
warn!("WebSocket receive error: {}", e);
break;
}
}
}
});
tokio::select! {
_ = &mut send_task => {
recv_task.abort();
},
_ = &mut recv_task => {
send_task.abort();
},
}
info!("WebSocket client disconnected");
}