#[cfg(feature = "http-api")]
use std::sync::Arc;
#[cfg(feature = "http-api")]
use axum::{
extract::{
ws::{Message, WebSocket},
Query, State, WebSocketUpgrade,
},
http::StatusCode,
response::IntoResponse,
};
#[cfg(feature = "http-api")]
use serde::Deserialize;
#[cfg(feature = "http-api")]
use subtle::ConstantTimeEq;
#[cfg(feature = "http-api")]
use tokio::sync::mpsc;
#[cfg(feature = "http-api")]
use super::coordinator::{CoordinatorSession, CoordinatorState};
#[cfg(feature = "http-api")]
use super::ws_types::{ClientMessage, ServerMessage};
#[cfg(feature = "http-api")]
#[derive(Debug, Deserialize)]
pub struct WsChatParams {
token: Option<String>,
}
#[cfg(feature = "http-api")]
fn validate_token(token: &str, key_store: Option<&Arc<super::api_keys::ApiKeyStore>>) -> bool {
if let Some(store) = key_store {
if store.has_records() {
return store.validate_key(token).is_some();
}
}
match std::env::var("SYMBIONT_API_TOKEN") {
Ok(expected) => bool::from(token.as_bytes().ct_eq(expected.as_bytes())),
Err(_) => false,
}
}
#[cfg(feature = "http-api")]
pub async fn ws_chat_handler(
ws: WebSocketUpgrade,
State(coordinator_state): State<Arc<CoordinatorState>>,
Query(params): Query<WsChatParams>,
key_store: Option<axum::Extension<Arc<super::api_keys::ApiKeyStore>>>,
) -> Result<impl IntoResponse, StatusCode> {
let token = params.token.as_deref().ok_or(StatusCode::UNAUTHORIZED)?;
let store_ref = key_store.as_ref().map(|ext| &ext.0);
if !validate_token(token, store_ref) {
return Err(StatusCode::UNAUTHORIZED);
}
Ok(ws.on_upgrade(move |socket| handle_socket(socket, coordinator_state)))
}
#[cfg(feature = "http-api")]
async fn handle_socket(socket: WebSocket, state: Arc<CoordinatorState>) {
let (mut ws_writer, mut ws_reader) = socket.split();
let (out_tx, mut out_rx) = mpsc::channel::<ServerMessage>(64);
let mut session = CoordinatorSession::new(state, out_tx.clone());
use axum::extract::ws::Message as WsMessage;
use futures::SinkExt;
let writer_handle = tokio::spawn(async move {
while let Some(msg) = out_rx.recv().await {
match serde_json::to_string(&msg) {
Ok(json) => {
if ws_writer.send(WsMessage::Text(json)).await.is_err() {
break;
}
}
Err(e) => {
tracing::error!("Failed to serialize ServerMessage: {}", e);
}
}
}
});
let heartbeat_tx = out_tx.clone();
let heartbeat_handle = tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
loop {
interval.tick().await;
if heartbeat_tx.send(ServerMessage::Pong).await.is_err() {
break;
}
}
});
use futures::StreamExt;
while let Some(msg) = ws_reader.next().await {
let msg = match msg {
Ok(m) => m,
Err(e) => {
tracing::debug!("WebSocket read error: {}", e);
break;
}
};
match msg {
Message::Text(text) => match serde_json::from_str::<ClientMessage>(&text) {
Ok(ClientMessage::ChatSend { content, .. }) => {
session.handle_chat(content).await;
}
Ok(ClientMessage::Ping) => {
let _ = out_tx.send(ServerMessage::Pong).await;
}
Err(e) => {
let _ = out_tx
.send(ServerMessage::Error {
request_id: None,
code: "PARSE_ERROR".into(),
message: format!("Invalid message: {}", e),
})
.await;
}
},
Message::Close(_) => break,
_ => {} }
}
heartbeat_handle.abort();
drop(out_tx);
let _ = writer_handle.await;
tracing::info!("WebSocket connection closed");
}