use std::sync::Arc;
use axum::extract::Query;
use axum::extract::{
ws::{Message, WebSocket},
State, WebSocketUpgrade,
};
use axum::response::IntoResponse;
use axum::Json;
use futures_util::{SinkExt, StreamExt as FuturesStreamExt};
use serde::{Deserialize, Serialize};
use oxios_gateway::message::IncomingMessage;
use crate::error::AppError;
use crate::server::AppState;
#[derive(Debug, Deserialize)]
pub(crate) struct ChatRequest {
content: String,
#[serde(default = "default_user")]
user_id: String,
#[serde(default)]
session_id: String,
}
pub(crate) fn default_user() -> String {
"default".into()
}
#[derive(Debug, Serialize)]
pub(crate) struct ChatResponse {
id: String,
echo: String,
reply: String,
#[serde(skip_serializing_if = "Option::is_none")]
session_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
phase: Option<String>,
}
pub(crate) async fn handle_chat(
state: State<Arc<AppState>>,
Json(body): Json<ChatRequest>,
) -> Result<Json<ChatResponse>, AppError> {
const MAX_CHAT_LENGTH: usize = 64 * 1024;
if body.content.len() > MAX_CHAT_LENGTH {
return Err(AppError::PayloadTooLarge {
size: body.content.len(),
limit: MAX_CHAT_LENGTH,
});
}
tracing::info!(content = %body.content, user = %body.user_id, "Chat message received");
let mut msg = IncomingMessage::new("web", &body.user_id, &body.content);
if !body.session_id.is_empty() {
msg.metadata
.insert("session_id".to_owned(), body.session_id.clone());
}
let msg_id = msg.id.to_string();
let content_echo = body.content.clone();
match state.channel.send_and_wait(msg).await {
Ok(response) => {
tracing::info!(reply_len = response.content.len(), "Chat response received");
{
let session_id_for_save = response
.metadata
.get("session_id")
.cloned()
.unwrap_or_else(|| msg_id.clone());
let session_id = oxios_kernel::state_store::SessionId(session_id_for_save.clone());
match state.kernel.state.load_session(&session_id).await {
Ok(Some(mut session)) => {
session.add_user_message(&content_echo);
session.add_agent_response(oxios_kernel::state_store::AgentResponse {
content: response.content.clone(),
session_id: Some(session_id.0),
seed_id: response.metadata.get("seed_id").cloned(),
phase_reached: response.metadata.get("phase").cloned(),
evaluation_passed: response
.metadata
.get("evaluation_passed")
.and_then(|v| v.parse().ok()),
timestamp: chrono::Utc::now(),
});
if let Err(e) = state.kernel.state.save_session(&session).await {
tracing::warn!(error = %e, "Failed to persist session");
}
}
Ok(None) => {
let mut session =
oxios_kernel::state_store::Session::new(body.user_id.clone());
session.id =
oxios_kernel::state_store::SessionId(session_id_for_save.clone());
session.add_user_message(&content_echo);
session.add_agent_response(oxios_kernel::state_store::AgentResponse {
content: response.content.clone(),
session_id: Some(session_id.0),
seed_id: response.metadata.get("seed_id").cloned(),
phase_reached: response.metadata.get("phase").cloned(),
evaluation_passed: response
.metadata
.get("evaluation_passed")
.and_then(|v| v.parse().ok()),
timestamp: chrono::Utc::now(),
});
if let Err(e) = state.kernel.state.save_session(&session).await {
tracing::warn!(error = %e, "Failed to create session");
}
}
Err(e) => tracing::warn!(error = %e, "Failed to load/create session"),
}
}
Ok(Json(ChatResponse {
id: msg_id,
echo: content_echo,
reply: response.content,
session_id: response.metadata.get("session_id").cloned(),
phase: response.metadata.get("phase").cloned(),
}))
}
Err(e) => {
tracing::error!(error = %e, "Failed to get response from gateway");
Err(AppError::Internal("gateway response failed".into()))
}
}
}
#[derive(Debug, serde::Deserialize)]
pub(crate) struct WsParams {
token: Option<String>,
}
pub(crate) async fn handle_chat_stream(
ws: WebSocketUpgrade,
state: State<Arc<AppState>>,
Query(params): Query<WsParams>,
) -> impl axum::response::IntoResponse {
if state.config.read().security.auth_enabled {
let token = params.token.as_deref().unwrap_or("");
if !state.kernel.security.validate_token(token) {
return axum::http::StatusCode::UNAUTHORIZED.into_response();
}
}
ws.on_upgrade(move |socket| handle_chat_websocket(socket, state.0))
}
pub(crate) async fn handle_chat_websocket(socket: WebSocket, state: Arc<AppState>) {
let (mut ws_tx, mut ws_rx) = socket.split();
let mut outgoing_rx = state.kernel.infra.subscribe();
let recv_task = tokio::spawn(async move {
while let Ok(msg) = outgoing_rx.recv().await {
let json = match serde_json::to_string(&msg) {
Ok(j) => j,
Err(e) => {
tracing::error!(error = %e, "Failed to serialize outgoing message");
continue;
}
};
if ws_tx.send(Message::Text(json.into())).await.is_err() {
break;
}
}
});
let send_tx = state.channel.incoming_tx.clone();
let send_task = tokio::spawn(async move {
while let Some(Ok(msg)) = FuturesStreamExt::next(&mut ws_rx).await {
match msg {
Message::Text(text) => {
let incoming = oxios_gateway::message::IncomingMessage::new(
"web",
"session", text.to_string(),
);
if send_tx.send(incoming).await.is_err() {
break;
}
}
Message::Close(_) => break,
_ => {}
}
}
});
tokio::select! {
_ = recv_task => {},
_ = send_task => {},
}
}