use crate::api::events::EventBus;
use crate::api::server::AppState;
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::extract::{Query, State};
use std::collections::HashMap;
use std::sync::Arc;
pub async fn ws_events(
ws: WebSocketUpgrade,
State(state): State<Arc<AppState>>,
Query(params): Query<HashMap<String, String>>,
) -> axum::response::Response {
let is_authenticated = params
.get("auth")
.map(|token| {
token == &state.api_token
|| crate::api::auth::validate_jwt(token, &state.jwt_secret).is_ok()
})
.unwrap_or(false);
if !is_authenticated {
return axum::response::Response::builder()
.status(axum::http::StatusCode::UNAUTHORIZED)
.body(axum::body::Body::from("Missing or invalid auth token"))
.expect("response build is infallible");
}
let permit = match state.ws_semaphore.clone().try_acquire_owned() {
Ok(permit) => permit,
Err(_) => {
return axum::response::Response::builder()
.status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
.body(axum::body::Body::from("Too many WebSocket connections"))
.expect("response build is infallible")
}
};
let event_bus = state.event_bus.clone();
ws.on_upgrade(move |socket| handle_ws(socket, event_bus, permit))
}
async fn handle_ws(
mut socket: WebSocket,
event_bus: EventBus,
_permit: tokio::sync::OwnedSemaphorePermit,
) {
let mut rx = event_bus.subscribe();
loop {
tokio::select! {
event = rx.recv() => {
match event {
Ok(e) => {
let json = match serde_json::to_string(&e) {
Ok(j) => j,
Err(_) => continue,
};
if socket.send(Message::Text(json.into())).await.is_err() {
break; }
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
Err(_) => break,
}
}
msg = socket.recv() => {
match msg {
Some(Ok(Message::Close(_))) | None => break,
_ => {} }
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api::server::AppState;
#[test]
fn test_ws_handler_compiles() {
use axum::extract::Query;
use std::collections::HashMap;
let _: fn(WebSocketUpgrade, State<Arc<AppState>>, Query<HashMap<String, String>>) -> _ =
|ws, state, query| ws_events(ws, state, query);
}
#[test]
fn test_ws_semaphore_exhaustion_reduces_permits() {
let sem = Arc::new(tokio::sync::Semaphore::new(AppState::MAX_WS_CONNECTIONS));
let mut permits = Vec::new();
for _ in 0..AppState::MAX_WS_CONNECTIONS {
permits.push(sem.clone().try_acquire_owned().expect("permit available"));
}
assert_eq!(sem.available_permits(), 0);
assert!(sem.clone().try_acquire_owned().is_err());
drop(permits.pop());
assert_eq!(sem.available_permits(), 1);
}
}