use std::sync::Arc;
use axum::extract::{Query, State};
use axum::http::StatusCode;
use axum::response::IntoResponse;
use futures::{SinkExt, StreamExt};
use tokio::time::{Duration, interval};
use super::protocol::{WebCommand, WebEvent};
use super::server::AppHandle;
#[derive(serde::Deserialize)]
pub struct WsQuery {
token: String,
}
pub async fn ws_handler(
Query(query): Query<WsQuery>,
State(state): State<Arc<AppHandle>>,
ws: axum::extract::WebSocketUpgrade,
) -> impl IntoResponse {
let valid = state.session_store.lock().await.validate(&query.token).is_some();
if !valid {
return StatusCode::UNAUTHORIZED.into_response();
}
let session_id = uuid::Uuid::new_v4().to_string();
tracing::info!(session_id = %session_id, "web client connecting");
ws.on_upgrade(move |socket| async move {
handle_socket(socket, state, session_id).await;
})
.into_response()
}
async fn handle_socket(socket: axum::extract::ws::WebSocket, state: Arc<AppHandle>, session_id: String) {
let (mut ws_tx, mut ws_rx) = socket.split();
let mut broadcast_rx = state.broadcaster.subscribe();
let sync_init = build_sync_init_from_snapshot(&state);
tracing::info!(session_id = %session_id, "sending SyncInit");
if send_json(&mut ws_tx, &sync_init).await.is_err() {
tracing::warn!(session_id = %session_id, "failed to send SyncInit");
return;
}
tracing::info!(session_id = %session_id, "SyncInit sent, entering event loop");
let mut ping_interval = interval(Duration::from_secs(30));
ping_interval.tick().await;
loop {
tokio::select! {
event = broadcast_rx.recv() => {
match event {
Ok(web_event) => {
if is_targeted_to_other(&web_event, &session_id) {
continue;
}
if send_json(&mut ws_tx, &web_event).await.is_err() {
break;
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
tracing::debug!(session_id = %session_id, lagged = n, "web client lagged, skipping events");
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
}
}
msg = ws_rx.next() => {
match msg {
Some(Ok(axum::extract::ws::Message::Text(text))) => {
match serde_json::from_str::<WebCommand>(&text) {
Ok(cmd) => {
let _ = state.web_cmd_tx.send((cmd, session_id.clone()));
}
Err(e) => {
tracing::debug!(session_id = %session_id, error = %e, "invalid web command");
let err = WebEvent::Error { message: format!("invalid command: {e}") };
let _ = send_json(&mut ws_tx, &err).await;
}
}
}
Some(Ok(axum::extract::ws::Message::Close(_))) | None => break,
Some(Err(e)) => {
tracing::debug!(session_id = %session_id, error = %e, "ws recv error");
break;
}
_ => {} }
}
_ = ping_interval.tick() => {
if ws_tx.send(axum::extract::ws::Message::Ping(vec![].into())).await.is_err() {
break;
}
}
}
}
tracing::info!(session_id = %session_id, "web client disconnected");
}
fn build_sync_init_from_snapshot(state: &AppHandle) -> WebEvent {
if let Some(ref snapshot) = state.web_state_snapshot
&& let Ok(snap) = snapshot.read()
{
return WebEvent::SyncInit {
buffers: snap.buffers.clone(),
connections: snap.connections.clone(),
mention_count: snap.mention_count,
active_buffer_id: snap.active_buffer_id.clone(),
timestamp_format: snap.timestamp_format.clone(),
};
}
WebEvent::SyncInit {
buffers: Vec::new(),
connections: Vec::new(),
mention_count: 0,
active_buffer_id: None,
timestamp_format: crate::config::WebConfig::default().timestamp_format,
}
}
fn is_targeted_to_other(event: &WebEvent, session_id: &str) -> bool {
let target = match event {
WebEvent::Messages { session_id, .. }
| WebEvent::NickList { session_id, .. }
| WebEvent::MentionsList { session_id, .. } => session_id.as_deref(),
_ => None,
};
target.is_some_and(|t| t != session_id)
}
async fn send_json(
tx: &mut futures::stream::SplitSink<axum::extract::ws::WebSocket, axum::extract::ws::Message>,
event: &WebEvent,
) -> Result<(), ()> {
match serde_json::to_string(event) {
Ok(json) => tx
.send(axum::extract::ws::Message::Text(json.into()))
.await
.map_err(|_| ()),
Err(e) => {
tracing::warn!("failed to serialize WebEvent: {e}");
Err(())
}
}
}