use super::AppState;
use axum::{
extract::{
Query, State, WebSocketUpgrade,
ws::{Message, WebSocket},
},
http::{HeaderMap, header},
response::IntoResponse,
};
use futures_util::{SinkExt, StreamExt};
use serde::Deserialize;
use tracing::debug;
#[derive(Debug, Deserialize)]
struct ConnectParams {
#[serde(rename = "type")]
msg_type: String,
#[serde(default)]
session_id: Option<String>,
#[serde(default)]
device_name: Option<String>,
#[serde(default)]
capabilities: Vec<String>,
}
const WS_PROTOCOL: &str = "zeroclaw.v1";
const BEARER_SUBPROTO_PREFIX: &str = "bearer.";
#[derive(Deserialize)]
pub struct WsQuery {
pub token: Option<String>,
pub session_id: Option<String>,
pub name: Option<String>,
}
fn extract_ws_token<'a>(headers: &'a HeaderMap, query_token: Option<&'a str>) -> Option<&'a str> {
if let Some(t) = headers
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|auth| auth.strip_prefix("Bearer "))
{
if !t.is_empty() {
return Some(t);
}
}
if let Some(t) = headers
.get("sec-websocket-protocol")
.and_then(|v| v.to_str().ok())
.and_then(|protos| {
protos
.split(',')
.map(|p| p.trim())
.find_map(|p| p.strip_prefix(BEARER_SUBPROTO_PREFIX))
})
{
if !t.is_empty() {
return Some(t);
}
}
if let Some(t) = query_token {
if !t.is_empty() {
return Some(t);
}
}
None
}
pub async fn handle_ws_chat(
State(state): State<AppState>,
Query(params): Query<WsQuery>,
headers: HeaderMap,
ws: WebSocketUpgrade,
) -> impl IntoResponse {
if state.pairing.require_pairing() {
let token = extract_ws_token(&headers, params.token.as_deref()).unwrap_or("");
if !state.pairing.is_authenticated(token) {
return (
axum::http::StatusCode::UNAUTHORIZED,
"Unauthorized — provide Authorization header, Sec-WebSocket-Protocol bearer, or ?token= query param",
)
.into_response();
}
}
let ws = if headers
.get("sec-websocket-protocol")
.and_then(|v| v.to_str().ok())
.map_or(false, |protos| {
protos.split(',').any(|p| p.trim() == WS_PROTOCOL)
}) {
ws.protocols([WS_PROTOCOL])
} else {
ws
};
let session_id = params.session_id;
let session_name = params.name;
ws.on_upgrade(move |socket| handle_socket(socket, state, session_id, session_name))
.into_response()
}
const GW_SESSION_PREFIX: &str = "gw_";
async fn handle_socket(
socket: WebSocket,
state: AppState,
session_id: Option<String>,
session_name: Option<String>,
) {
let (mut sender, mut receiver) = socket.split();
let session_id = session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let session_key = format!("{GW_SESSION_PREFIX}{session_id}");
let config = state.config.lock().clone();
let mut agent = match crate::agent::Agent::from_config(&config).await {
Ok(a) => a,
Err(e) => {
tracing::error!(error = %e, "Agent initialization failed");
let err = serde_json::json!({
"type": "error",
"message": format!("Failed to initialise agent: {e}"),
"code": "AGENT_INIT_FAILED"
});
let _ = sender.send(Message::Text(err.to_string().into())).await;
let _ = sender
.send(Message::Close(Some(axum::extract::ws::CloseFrame {
code: 1011,
reason: axum::extract::ws::Utf8Bytes::from_static(
"Agent initialization failed",
),
})))
.await;
return;
}
};
agent.set_memory_session_id(Some(session_id.clone()));
let mut resumed = false;
let mut message_count: usize = 0;
let mut effective_name: Option<String> = None;
if let Some(ref backend) = state.session_backend {
let messages = backend.load(&session_key);
if !messages.is_empty() {
message_count = messages.len();
agent.seed_history(&messages);
resumed = true;
}
if let Some(ref name) = session_name {
if !name.is_empty() {
let _ = backend.set_session_name(&session_key, name);
effective_name = Some(name.clone());
}
}
if effective_name.is_none() {
effective_name = backend.get_session_name(&session_key).unwrap_or(None);
}
}
let mut session_start = serde_json::json!({
"type": "session_start",
"session_id": session_id,
"resumed": resumed,
"message_count": message_count,
});
if let Some(ref name) = effective_name {
session_start["name"] = serde_json::Value::String(name.clone());
}
let _ = sender
.send(Message::Text(session_start.to_string().into()))
.await;
let mut first_msg_fallback: Option<String> = None;
if let Some(first) = receiver.next().await {
match first {
Ok(Message::Text(text)) => {
if let Ok(cp) = serde_json::from_str::<ConnectParams>(&text) {
if cp.msg_type == "connect" {
debug!(
session_id = ?cp.session_id,
device_name = ?cp.device_name,
capabilities = ?cp.capabilities,
"WebSocket connect params received"
);
if let Some(sid) = &cp.session_id {
agent.set_memory_session_id(Some(sid.clone()));
}
let ack = serde_json::json!({
"type": "connected",
"message": "Connection established"
});
let _ = sender.send(Message::Text(ack.to_string().into())).await;
} else {
first_msg_fallback = Some(text.to_string());
}
} else {
first_msg_fallback = Some(text.to_string());
}
}
Ok(Message::Close(_)) | Err(_) => return,
_ => {}
}
}
if let Some(ref text) = first_msg_fallback {
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(text) {
if parsed["type"].as_str() == Some("message") {
let content = parsed["content"].as_str().unwrap_or("").to_string();
if !content.is_empty() {
if let Some(ref backend) = state.session_backend {
let user_msg = crate::providers::ChatMessage::user(&content);
let _ = backend.append(&session_key, &user_msg);
}
process_chat_message(&state, &mut agent, &mut sender, &content, &session_key)
.await;
}
} else {
let unknown_type = parsed["type"].as_str().unwrap_or("unknown");
let err = serde_json::json!({
"type": "error",
"message": format!(
"Unsupported message type \"{unknown_type}\". Send {{\"type\":\"message\",\"content\":\"your text\"}}"
)
});
let _ = sender.send(Message::Text(err.to_string().into())).await;
}
} else {
let err = serde_json::json!({
"type": "error",
"message": "Invalid JSON. Send {\"type\":\"message\",\"content\":\"your text\"}"
});
let _ = sender.send(Message::Text(err.to_string().into())).await;
}
}
let mut broadcast_rx = state.event_tx.subscribe();
loop {
tokio::select! {
client_msg = receiver.next() => {
let Some(msg) = client_msg else { break };
let msg = match msg {
Ok(Message::Text(text)) => text,
Ok(Message::Close(_)) | Err(_) => break,
_ => continue,
};
let parsed: serde_json::Value = match serde_json::from_str(&msg) {
Ok(v) => v,
Err(e) => {
let err = serde_json::json!({
"type": "error",
"message": format!("Invalid JSON: {}", e),
"code": "INVALID_JSON"
});
let _ = sender.send(Message::Text(err.to_string().into())).await;
continue;
}
};
let msg_type = parsed["type"].as_str().unwrap_or("");
if msg_type != "message" {
let err = serde_json::json!({
"type": "error",
"message": format!(
"Unsupported message type \"{msg_type}\". Send {{\"type\":\"message\",\"content\":\"your text\"}}"
),
"code": "UNKNOWN_MESSAGE_TYPE"
});
let _ = sender.send(Message::Text(err.to_string().into())).await;
continue;
}
let content = parsed["content"].as_str().unwrap_or("").to_string();
if content.is_empty() {
let err = serde_json::json!({
"type": "error",
"message": "Message content cannot be empty",
"code": "EMPTY_CONTENT"
});
let _ = sender.send(Message::Text(err.to_string().into())).await;
continue;
}
let _session_guard = match state.session_queue.acquire(&session_key).await {
Ok(guard) => guard,
Err(e) => {
let err = serde_json::json!({
"type": "error",
"message": e.to_string(),
"code": "SESSION_BUSY"
});
let _ = sender.send(Message::Text(err.to_string().into())).await;
continue;
}
};
if let Some(ref backend) = state.session_backend {
let user_msg = crate::providers::ChatMessage::user(&content);
let _ = backend.append(&session_key, &user_msg);
}
process_chat_message(&state, &mut agent, &mut sender, &content, &session_key).await;
}
event = broadcast_rx.recv() => {
if let Ok(event) = event {
let _ = sender.send(Message::Text(event.to_string().into())).await;
}
}
}
}
}
async fn process_chat_message(
state: &AppState,
agent: &mut crate::agent::Agent,
sender: &mut futures_util::stream::SplitSink<WebSocket, Message>,
content: &str,
session_key: &str,
) {
use crate::agent::TurnEvent;
let provider_label = state
.config
.lock()
.default_provider
.clone()
.unwrap_or_else(|| "unknown".to_string());
let _ = state.event_tx.send(serde_json::json!({
"type": "agent_start",
"provider": provider_label,
"model": state.model,
}));
let turn_id = uuid::Uuid::new_v4().to_string();
if let Some(ref backend) = state.session_backend {
let _ = backend.set_session_state(session_key, "running", Some(&turn_id));
}
let (event_tx, mut event_rx) = tokio::sync::mpsc::channel::<TurnEvent>(64);
let content_owned = content.to_string();
let turn_fut = async { agent.turn_streamed(&content_owned, event_tx).await };
let forward_fut = async {
while let Some(event) = event_rx.recv().await {
let ws_msg = match event {
TurnEvent::Chunk { delta } => {
serde_json::json!({ "type": "chunk", "content": delta })
}
TurnEvent::Thinking { delta } => {
serde_json::json!({ "type": "thinking", "content": delta })
}
TurnEvent::ToolCall { name, args } => {
serde_json::json!({ "type": "tool_call", "name": name, "args": args })
}
TurnEvent::ToolResult { name, output } => {
serde_json::json!({ "type": "tool_result", "name": name, "output": output })
}
};
let _ = sender.send(Message::Text(ws_msg.to_string().into())).await;
}
};
let (result, ()) = tokio::join!(turn_fut, forward_fut);
match result {
Ok(response) => {
if let Some(ref backend) = state.session_backend {
let assistant_msg = crate::providers::ChatMessage::assistant(&response);
let _ = backend.append(session_key, &assistant_msg);
}
if state.auto_save {
let mem = state.mem.clone();
let provider = state.provider.clone();
let model = state.model.clone();
let user_msg = content.to_string();
let assistant_resp = response.clone();
tokio::spawn(async move {
if let Err(e) = crate::memory::consolidation::consolidate_turn(
provider.as_ref(),
&model,
mem.as_ref(),
&user_msg,
&assistant_resp,
)
.await
{
tracing::debug!("WS memory consolidation skipped: {e}");
}
});
}
let reset = serde_json::json!({ "type": "chunk_reset" });
let _ = sender.send(Message::Text(reset.to_string().into())).await;
let done = serde_json::json!({
"type": "done",
"full_response": response,
});
let _ = sender.send(Message::Text(done.to_string().into())).await;
if let Some(ref backend) = state.session_backend {
let _ = backend.set_session_state(session_key, "idle", None);
}
let _ = state.event_tx.send(serde_json::json!({
"type": "agent_end",
"provider": provider_label,
"model": state.model,
}));
}
Err(e) => {
if let Some(ref backend) = state.session_backend {
let _ = backend.set_session_state(session_key, "error", Some(&turn_id));
}
tracing::error!(error = %e, "Agent turn failed");
let sanitized = crate::providers::sanitize_api_error(&e.to_string());
let error_code = if sanitized.to_lowercase().contains("api key")
|| sanitized.to_lowercase().contains("authentication")
|| sanitized.to_lowercase().contains("unauthorized")
{
"AUTH_ERROR"
} else if sanitized.to_lowercase().contains("provider")
|| sanitized.to_lowercase().contains("model")
{
"PROVIDER_ERROR"
} else {
"AGENT_ERROR"
};
let err = serde_json::json!({
"type": "error",
"message": sanitized,
"code": error_code,
});
let _ = sender.send(Message::Text(err.to_string().into())).await;
let _ = state.event_tx.send(serde_json::json!({
"type": "error",
"component": "ws_chat",
"message": sanitized,
}));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::HeaderMap;
#[test]
fn extract_ws_token_from_authorization_header() {
let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer zc_test123".parse().unwrap());
assert_eq!(extract_ws_token(&headers, None), Some("zc_test123"));
}
#[test]
fn extract_ws_token_from_subprotocol() {
let mut headers = HeaderMap::new();
headers.insert(
"sec-websocket-protocol",
"zeroclaw.v1, bearer.zc_sub456".parse().unwrap(),
);
assert_eq!(extract_ws_token(&headers, None), Some("zc_sub456"));
}
#[test]
fn extract_ws_token_from_query_param() {
let headers = HeaderMap::new();
assert_eq!(
extract_ws_token(&headers, Some("zc_query789")),
Some("zc_query789")
);
}
#[test]
fn extract_ws_token_precedence_header_over_subprotocol() {
let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer zc_header".parse().unwrap());
headers.insert("sec-websocket-protocol", "bearer.zc_sub".parse().unwrap());
assert_eq!(
extract_ws_token(&headers, Some("zc_query")),
Some("zc_header")
);
}
#[test]
fn extract_ws_token_precedence_subprotocol_over_query() {
let mut headers = HeaderMap::new();
headers.insert("sec-websocket-protocol", "bearer.zc_sub".parse().unwrap());
assert_eq!(extract_ws_token(&headers, Some("zc_query")), Some("zc_sub"));
}
#[test]
fn extract_ws_token_returns_none_when_empty() {
let headers = HeaderMap::new();
assert_eq!(extract_ws_token(&headers, None), None);
}
#[test]
fn extract_ws_token_skips_empty_header_value() {
let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer ".parse().unwrap());
assert_eq!(
extract_ws_token(&headers, Some("zc_fallback")),
Some("zc_fallback")
);
}
#[test]
fn extract_ws_token_skips_empty_query_param() {
let headers = HeaderMap::new();
assert_eq!(extract_ws_token(&headers, Some("")), None);
}
#[test]
fn extract_ws_token_subprotocol_with_multiple_entries() {
let mut headers = HeaderMap::new();
headers.insert(
"sec-websocket-protocol",
"zeroclaw.v1, bearer.zc_tok, other".parse().unwrap(),
);
assert_eq!(extract_ws_token(&headers, None), Some("zc_tok"));
}
}