use super::AppState;
use axum::{
extract::{
Query, State, WebSocketUpgrade,
ws::{Message, WebSocket},
},
http::{HeaderMap, header},
response::IntoResponse,
};
use futures_util::{SinkExt, StreamExt};
use serde::Deserialize;
use std::sync::Arc;
use tracing::debug;
#[derive(Debug, Deserialize)]
struct ConnectParams {
#[serde(rename = "type")]
msg_type: String,
#[serde(default)]
session_id: Option<String>,
#[serde(default)]
device_name: Option<String>,
#[serde(default)]
capabilities: Vec<String>,
}
const WS_PROTOCOL: &str = "construct.v1";
const BEARER_SUBPROTO_PREFIX: &str = "bearer.";
#[derive(Deserialize)]
pub struct WsQuery {
pub token: Option<String>,
pub session_id: Option<String>,
pub name: Option<String>,
}
fn extract_ws_token<'a>(headers: &'a HeaderMap, query_token: Option<&'a str>) -> Option<&'a str> {
if let Some(t) = headers
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|auth| auth.strip_prefix("Bearer "))
{
if !t.is_empty() {
return Some(t);
}
}
if let Some(t) = headers
.get("sec-websocket-protocol")
.and_then(|v| v.to_str().ok())
.and_then(|protos| {
protos
.split(',')
.map(|p| p.trim())
.find_map(|p| p.strip_prefix(BEARER_SUBPROTO_PREFIX))
})
{
if !t.is_empty() {
return Some(t);
}
}
if let Some(t) = query_token {
if !t.is_empty() {
return Some(t);
}
}
None
}
pub async fn handle_ws_chat(
State(state): State<AppState>,
Query(params): Query<WsQuery>,
headers: HeaderMap,
ws: WebSocketUpgrade,
) -> impl IntoResponse {
if state.pairing.require_pairing() {
let token = extract_ws_token(&headers, params.token.as_deref()).unwrap_or("");
if !state.pairing.is_authenticated(token) {
return (
axum::http::StatusCode::UNAUTHORIZED,
"Unauthorized — provide Authorization header, Sec-WebSocket-Protocol bearer, or ?token= query param",
)
.into_response();
}
}
let ws = if headers
.get("sec-websocket-protocol")
.and_then(|v| v.to_str().ok())
.map_or(false, |protos| {
protos.split(',').any(|p| p.trim() == WS_PROTOCOL)
}) {
ws.protocols([WS_PROTOCOL])
} else {
ws
};
if let Some(ref logger) = state.audit_logger {
let _ = logger.log_security_event("dashboard", "WebSocket chat session connected");
}
let session_id = params.session_id;
let session_name = params.name;
ws.on_upgrade(move |socket| handle_socket(socket, state, session_id, session_name))
.into_response()
}
const GW_SESSION_PREFIX: &str = "gw_";
async fn handle_socket(
socket: WebSocket,
state: AppState,
session_id: Option<String>,
session_name: Option<String>,
) {
let (mut sender, mut receiver) = socket.split();
let session_id = session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let session_key = format!("{GW_SESSION_PREFIX}{session_id}");
let config = state.config.lock().clone();
let mut agent = match crate::agent::Agent::from_config(&config).await {
Ok(a) => a,
Err(e) => {
tracing::error!(error = %e, "Agent initialization failed");
let err = serde_json::json!({
"type": "error",
"message": format!("Failed to initialise agent: {e}"),
"code": "AGENT_INIT_FAILED"
});
let _ = sender.send(Message::Text(err.to_string().into())).await;
let _ = sender
.send(Message::Close(Some(axum::extract::ws::CloseFrame {
code: 1011,
reason: axum::extract::ws::Utf8Bytes::from_static(
"Agent initialization failed",
),
})))
.await;
return;
}
};
agent.set_memory_session_id(Some(session_id.clone()));
let mut resumed = false;
let mut message_count: usize = 0;
let mut effective_name: Option<String> = None;
if let Some(ref backend) = state.session_backend {
let messages = backend.load(&session_key);
if !messages.is_empty() {
message_count = messages.len();
agent.seed_history(&messages);
resumed = true;
}
if let Some(ref name) = session_name {
if !name.is_empty() {
let _ = backend.set_session_name(&session_key, name);
effective_name = Some(name.clone());
}
}
if effective_name.is_none() {
effective_name = backend.get_session_name(&session_key).unwrap_or(None);
}
}
let mut session_start = serde_json::json!({
"type": "session_start",
"session_id": session_id,
"resumed": resumed,
"message_count": message_count,
});
if let Some(ref name) = effective_name {
session_start["name"] = serde_json::Value::String(name.clone());
}
let _ = sender
.send(Message::Text(session_start.to_string().into()))
.await;
let mut first_msg_fallback: Option<String> = None;
match tokio::time::timeout(std::time::Duration::from_secs(5), receiver.next()).await {
Ok(Some(first)) => {
match first {
Ok(Message::Text(text)) => {
if let Ok(cp) = serde_json::from_str::<ConnectParams>(&text) {
if cp.msg_type == "connect" {
debug!(
session_id = ?cp.session_id,
device_name = ?cp.device_name,
capabilities = ?cp.capabilities,
"WebSocket connect params received"
);
if let Some(sid) = &cp.session_id {
agent.set_memory_session_id(Some(sid.clone()));
}
let ack = serde_json::json!({
"type": "connected",
"message": "Connection established"
});
let _ = sender.send(Message::Text(ack.to_string().into())).await;
} else {
first_msg_fallback = Some(text.to_string());
}
} else {
first_msg_fallback = Some(text.to_string());
}
}
Ok(Message::Close(_)) | Err(_) => return,
_ => {}
}
}
Ok(None) => return, Err(_) => {
debug!(session_id = %session_id, "No initial message within 5s — entering listen-only mode");
}
}
let mut broadcast_rx = state.event_tx.subscribe();
if let Some(ref text) = first_msg_fallback {
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(text) {
if parsed["type"].as_str() == Some("message") {
let content = parsed["content"].as_str().unwrap_or("").to_string();
if !content.is_empty() {
let page_ctx = parsed["page_context"].as_str();
if let Some(ref backend) = state.session_backend {
let user_msg = crate::providers::ChatMessage::user(&content);
let _ = backend.append(&session_key, &user_msg);
}
process_chat_message(
&state,
&mut agent,
&mut sender,
&content,
&session_key,
page_ctx,
&mut broadcast_rx,
)
.await;
}
} else {
let unknown_type = parsed["type"].as_str().unwrap_or("unknown");
let err = serde_json::json!({
"type": "error",
"message": format!(
"Unsupported message type \"{unknown_type}\". Send {{\"type\":\"message\",\"content\":\"your text\"}}"
)
});
let _ = sender.send(Message::Text(err.to_string().into())).await;
}
} else {
let err = serde_json::json!({
"type": "error",
"message": "Invalid JSON. Send {\"type\":\"message\",\"content\":\"your text\"}"
});
let _ = sender.send(Message::Text(err.to_string().into())).await;
}
}
loop {
tokio::select! {
ws_msg = receiver.next() => {
let msg = match ws_msg {
Some(Ok(Message::Text(text))) => text,
Some(Ok(Message::Close(_))) | Some(Err(_)) | None => break,
_ => continue,
};
let parsed: serde_json::Value = match serde_json::from_str(&msg) {
Ok(v) => v,
Err(e) => {
let err = serde_json::json!({
"type": "error",
"message": format!("Invalid JSON: {}", e),
"code": "INVALID_JSON"
});
let _ = sender.send(Message::Text(err.to_string().into())).await;
continue;
}
};
let msg_type = parsed["type"].as_str().unwrap_or("");
if msg_type != "message" {
let err = serde_json::json!({
"type": "error",
"message": format!(
"Unsupported message type \"{msg_type}\". Send {{\"type\":\"message\",\"content\":\"your text\"}}"
),
"code": "UNKNOWN_MESSAGE_TYPE"
});
let _ = sender.send(Message::Text(err.to_string().into())).await;
continue;
}
let content = parsed["content"].as_str().unwrap_or("").to_string();
if content.is_empty() {
let err = serde_json::json!({
"type": "error",
"message": "Message content cannot be empty",
"code": "EMPTY_CONTENT"
});
let _ = sender.send(Message::Text(err.to_string().into())).await;
continue;
}
let _session_guard = match state.session_queue.acquire(&session_key).await {
Ok(guard) => guard,
Err(e) => {
let err = serde_json::json!({
"type": "error",
"message": e.to_string(),
"code": "SESSION_BUSY"
});
let _ = sender.send(Message::Text(err.to_string().into())).await;
continue;
}
};
let page_ctx = parsed["page_context"].as_str();
if let Some(ref backend) = state.session_backend {
let user_msg = crate::providers::ChatMessage::user(&content);
let _ = backend.append(&session_key, &user_msg);
}
process_chat_message(&state, &mut agent, &mut sender, &content, &session_key, page_ctx, &mut broadcast_rx).await;
}
event = broadcast_rx.recv() => {
match event {
Ok(ev) if ev["type"].as_str() == Some("channel_event") => {
let relay = serde_json::json!({
"type": "agent_event",
"event": ev["payload"],
});
let _ = sender.send(Message::Text(relay.to_string().into())).await;
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
_ => {} }
}
}
}
}
fn page_context_hint(page: &str) -> Option<&'static str> {
match page {
"agent_pool" => Some(concat!(
"[Page context: The user is on the **Agent Pool** page.\n",
"Available tools:\n",
"- `construct-operator__save_agent_template` — Create/update an agent\n",
"- `construct-operator__search_agent_pool` — Search agents by query\n",
"- `construct-operator__list_agent_templates` — List all agents (returns kref, name, role, etc.)\n\n",
"When creating agents, collect: name, role (coder/researcher/reviewer/specialist), ",
"expertise areas, preferred model (codex/claude), identity, soul, tone, and optionally system_hint.\n",
"Guide the user conversationally.\n\n",
"IMPORTANT behavioral rules:\n",
"- A tool returning empty content or no error means SUCCESS. Verify by calling list_agent_templates after.\n",
"- NEVER say a tool is broken or file a bug report. If something seems off, retry or verify.\n",
"- Do NOT ask the user to use the dashboard UI instead — YOU are the assistant, handle it.\n",
"- After creating/updating, confirm success by listing agents to show the result.]\n\n",
)),
"agent_teams" => Some(concat!(
"[Page context: The user is on the **Agent Teams** page.\n",
"Available tools:\n",
"- `construct-operator__create_team` — Create/update a team with members and edges\n",
"- `construct-operator__list_agent_templates` — List all agents (returns kref for member_krefs)\n",
"- `construct-operator__search_agent_pool` — Search agents by query\n",
"- `construct-operator__list_teams` — List existing teams\n",
"- `construct-operator__get_team` — Get team details with members and edges\n\n",
"When creating teams: collect a name, description, and select member agents.\n",
"Use the `kref` field from list_agent_templates for member_krefs — the system resolves names automatically.\n",
"Define edges (SUPPORTS, DEPENDS_ON, REPORTS_TO) between members to express the team structure.\n\n",
"IMPORTANT behavioral rules:\n",
"- A tool returning empty content or no error means SUCCESS. Verify by calling list_teams after.\n",
"- NEVER say a tool is broken or file a bug report. If something seems off, retry or verify.\n",
"- Do NOT ask the user to use the dashboard UI instead — YOU are the assistant, handle it.\n",
"- After creating a team, confirm success by calling list_teams or get_team to show the result.\n",
"- member_krefs accepts agent names, partial krefs, or full krefs — the resolver handles matching.]\n\n",
)),
"skills" => Some(concat!(
"[Page context: The user is on the **Skills Library** page.\n",
"Skills are reusable behavioral procedures stored in CognitiveMemory/Skills.\n",
"Available tools:\n",
"- `construct-operator__save_skill` — Create/update a skill (if available)\n",
"- `construct-operator__list_agent_templates` — List agents (skills may reference agents)\n",
"- `construct-operator__search_clawhub` — Search ClawHub public marketplace for community skills\n",
"- `construct-operator__browse_clawhub` — Browse trending skills on ClawHub\n",
"- `construct-operator__install_from_clawhub` — Install a skill from ClawHub by slug\n\n",
"A skill has: name, description, content (the procedure text), domain ",
"(Memory/Creative/Privacy/Graph/Behavioral/Other), and tags.\n",
"Guide the user through defining skills conversationally — help them articulate ",
"the procedure, choose the right domain, and write clear content.\n",
"When users want to find existing skills, search ClawHub first before creating from scratch.\n\n",
"IMPORTANT behavioral rules:\n",
"- A tool returning empty content or no error means SUCCESS. Verify after.\n",
"- NEVER say a tool is broken or file a bug report.\n",
"- Do NOT ask the user to use the dashboard UI instead — YOU are the assistant.]\n\n",
)),
"workflows" => Some(concat!(
"[Page context: The user is on the **Workflows** page.\n",
"Available tools: create_workflow, list_workflows, validate_workflow, run_workflow, ",
"get_workflow_status, cancel_workflow, resume_workflow, dry_run_workflow, ",
"recall_workflow_runs, get_workflow_run_detail, save_workflow_preset, list_workflow_presets ",
"(all prefixed with `construct-operator__`).\n\n",
"## Workflow schema (use this EXACTLY with create_workflow):\n",
"```yaml\n",
"workflow_def:\n",
" name: my-workflow # kebab-case identifier\n",
" description: What it does\n",
" tags: [tag1, tag2] # optional\n",
" inputs: # optional\n",
" - name: task\n",
" required: false\n",
" default: default value\n",
" steps:\n",
" - id: research_step\n",
" name: Research Phase\n",
" action: research # research | code | review | deploy | test | build | notify | approve | summarize | task | human_input\n",
" description: Research the topic using ${inputs.task}\n",
" agent_hints: [researcher] # hints for operator: coder | researcher | reviewer\n",
" depends_on: []\n",
" - id: code_step\n",
" name: Implementation\n",
" action: code\n",
" description: Implement based on ${research_step.output}\n",
" agent_hints: [coder]\n",
" depends_on: [research_step]\n",
" - id: review_step\n",
" name: Code Review\n",
" action: review\n",
" description: Review ${code_step.output}\n",
" agent_hints: [reviewer]\n",
" depends_on: [code_step]\n",
" - id: feedback_step\n",
" name: Get User Feedback\n",
" action: human_input\n",
" description: Please review the output and provide feedback\n",
" channel: dashboard # dashboard | slack | discord\n",
" depends_on: [review_step]\n",
"```\n",
"The `action` field determines which agent type runs the step:\n",
" research → researcher (claude), code → coder (codex), review → reviewer (claude),\n",
" deploy/test/build → codex, notify/summarize → claude, task → generic claude,\n",
" human_input → pauses workflow and sends a prompt to a channel (dashboard/slack/discord), waits for human response.\n",
"The `description` field is the agent's prompt — use ${step_id.output} and ${inputs.X} for interpolation.\n",
"`agent_hints` are optional suggestions (operator auto-selects if omitted).\n",
"For advanced use, add explicit `type` + config block (agent/shell/goto/output/human_approval).\n\n",
"Rules:\n",
"- create_workflow validates internally and returns {saved, path, valid, registered}. Trust it — do NOT call list_workflows or validate_workflow to verify.\n",
"- One tool call is enough for creation. Keep it simple.\n",
"- When the user says 'research agent', '3 agents', 'coder', etc., map to the right action.\n",
"- When running a workflow, always provide the cwd parameter.\n",
"- Do NOT ask the user to use the UI instead — handle it yourself.]\n\n",
)),
"canvas" => Some(concat!(
"[Page context: The user is on the **Live Canvas** page.\n",
"The canvas is YOUR primary output — render visual content IMMEDIATELY.\n\n",
"Available tools:\n",
"- `construct-operator__render_canvas` — Push content to the canvas (html, svg, markdown, text)\n",
"- `construct-operator__clear_canvas` — Clear a canvas\n\n",
"ALWAYS render to the canvas. The user opened this page to SEE visual output.\n",
"Use it for:\n",
"- Interactive HTML dashboards with charts, tables, and metrics\n",
"- SVG diagrams, flowcharts, architecture maps, or data visualizations\n",
"- Formatted reports, comparisons, or analyses\n",
"- Any content that benefits from visual presentation\n\n",
"CRITICAL rules:\n",
"- ALWAYS call render_canvas — do NOT just describe what you would render.\n",
"- For HTML: include ALL CSS inline. Use a dark theme (bg: #1a1a2e, text: #e2e8f0).\n",
" Include modern styling with gradients, rounded corners, and clean typography.\n",
"- For SVG: provide complete <svg> with viewBox for responsive sizing.\n",
"- For charts: use inline CSS/HTML tables or SVG — no external JS libraries.\n",
"- Keep content self-contained — no external resources, CDNs, or imports.\n",
"- Default canvas_id is 'default'. You can use separate canvas_ids for multiple views.\n",
"- If the user asks a question, answer it AND render relevant visual content.\n",
"- Iterate: if the user gives feedback, re-render with improvements.]\n\n",
)),
_ => None,
}
}
async fn process_chat_message(
state: &AppState,
agent: &mut crate::agent::Agent,
sender: &mut futures_util::stream::SplitSink<WebSocket, Message>,
content: &str,
session_key: &str,
page_context: Option<&str>,
broadcast_rx: &mut tokio::sync::broadcast::Receiver<serde_json::Value>,
) {
use crate::agent::TurnEvent;
let provider_label = state
.config
.lock()
.default_provider
.clone()
.unwrap_or_else(|| "unknown".to_string());
let _ = state.event_tx.send(serde_json::json!({
"type": "agent_start",
"provider": provider_label,
"model": state.model,
}));
let turn_id = uuid::Uuid::new_v4().to_string();
if let Some(ref backend) = state.session_backend {
let _ = backend.set_session_state(session_key, "running", Some(&turn_id));
}
let (event_tx, mut event_rx) = tokio::sync::mpsc::channel::<TurnEvent>(64);
let content_owned = if let Some(hint) = page_context.and_then(page_context_hint) {
format!("{hint}{content}")
} else {
content.to_string()
};
let cost_tracking_context = state.cost_tracker.clone().map(|tracker| {
let prices = Arc::new(state.config.lock().cost.prices.clone());
crate::agent::cost::ToolLoopCostTrackingContext::new(tracker, prices)
});
let turn_fut = crate::agent::loop_::TOOL_LOOP_COST_TRACKING_CONTEXT
.scope(cost_tracking_context, async {
agent.turn_streamed(&content_owned, event_tx).await
});
let forward_fut = async {
let mut turn_done = false;
loop {
if turn_done {
break;
}
tokio::select! {
event = event_rx.recv() => {
match event {
Some(event) => {
let ws_msg = match event {
TurnEvent::Chunk { delta } => {
serde_json::json!({ "type": "chunk", "content": delta })
}
TurnEvent::Thinking { delta } => {
serde_json::json!({ "type": "thinking", "content": delta })
}
TurnEvent::ToolCall { name, args } => {
serde_json::json!({ "type": "tool_call", "name": name, "args": args })
}
TurnEvent::ToolResult { name, output } => {
serde_json::json!({ "type": "tool_result", "name": name, "output": output })
}
TurnEvent::OperatorStatus { phase, detail } => {
serde_json::json!({ "type": "operator_status", "phase": phase, "detail": detail })
}
};
let _ = sender.send(Message::Text(ws_msg.to_string().into())).await;
}
None => { turn_done = true; }
}
}
bcast = broadcast_rx.recv() => {
if let Ok(ev) = bcast {
if ev["type"].as_str() == Some("channel_event") {
let relay = serde_json::json!({
"type": "agent_event",
"event": ev["payload"],
});
let _ = sender.send(Message::Text(relay.to_string().into())).await;
}
}
}
}
}
};
let (result, ()) = tokio::join!(turn_fut, forward_fut);
match result {
Ok(response) => {
if let Some(ref backend) = state.session_backend {
let assistant_msg = crate::providers::ChatMessage::assistant(&response);
let _ = backend.append(session_key, &assistant_msg);
}
let reset = serde_json::json!({ "type": "chunk_reset" });
let _ = sender.send(Message::Text(reset.to_string().into())).await;
let done = serde_json::json!({
"type": "done",
"full_response": response,
});
let _ = sender.send(Message::Text(done.to_string().into())).await;
if let Some(ref backend) = state.session_backend {
let _ = backend.set_session_state(session_key, "idle", None);
}
let _ = state.event_tx.send(serde_json::json!({
"type": "agent_end",
"provider": provider_label,
"model": state.model,
}));
}
Err(e) => {
if let Some(ref backend) = state.session_backend {
let _ = backend.set_session_state(session_key, "error", Some(&turn_id));
}
tracing::error!(error = %e, "Agent turn failed");
let sanitized = crate::providers::sanitize_api_error(&e.to_string());
let error_code = if sanitized.to_lowercase().contains("api key")
|| sanitized.to_lowercase().contains("authentication")
|| sanitized.to_lowercase().contains("unauthorized")
{
"AUTH_ERROR"
} else if sanitized.to_lowercase().contains("provider")
|| sanitized.to_lowercase().contains("model")
{
"PROVIDER_ERROR"
} else {
"AGENT_ERROR"
};
let err = serde_json::json!({
"type": "error",
"message": sanitized,
"code": error_code,
});
let _ = sender.send(Message::Text(err.to_string().into())).await;
let _ = state.event_tx.send(serde_json::json!({
"type": "error",
"component": "ws_chat",
"message": sanitized,
}));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::HeaderMap;
#[test]
fn extract_ws_token_from_authorization_header() {
let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer zc_test123".parse().unwrap());
assert_eq!(extract_ws_token(&headers, None), Some("zc_test123"));
}
#[test]
fn extract_ws_token_from_subprotocol() {
let mut headers = HeaderMap::new();
headers.insert(
"sec-websocket-protocol",
"construct.v1, bearer.zc_sub456".parse().unwrap(),
);
assert_eq!(extract_ws_token(&headers, None), Some("zc_sub456"));
}
#[test]
fn extract_ws_token_from_query_param() {
let headers = HeaderMap::new();
assert_eq!(
extract_ws_token(&headers, Some("zc_query789")),
Some("zc_query789")
);
}
#[test]
fn extract_ws_token_precedence_header_over_subprotocol() {
let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer zc_header".parse().unwrap());
headers.insert("sec-websocket-protocol", "bearer.zc_sub".parse().unwrap());
assert_eq!(
extract_ws_token(&headers, Some("zc_query")),
Some("zc_header")
);
}
#[test]
fn extract_ws_token_precedence_subprotocol_over_query() {
let mut headers = HeaderMap::new();
headers.insert("sec-websocket-protocol", "bearer.zc_sub".parse().unwrap());
assert_eq!(extract_ws_token(&headers, Some("zc_query")), Some("zc_sub"));
}
#[test]
fn extract_ws_token_returns_none_when_empty() {
let headers = HeaderMap::new();
assert_eq!(extract_ws_token(&headers, None), None);
}
#[test]
fn extract_ws_token_skips_empty_header_value() {
let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer ".parse().unwrap());
assert_eq!(
extract_ws_token(&headers, Some("zc_fallback")),
Some("zc_fallback")
);
}
#[test]
fn extract_ws_token_skips_empty_query_param() {
let headers = HeaderMap::new();
assert_eq!(extract_ws_token(&headers, Some("")), None);
}
#[test]
fn extract_ws_token_subprotocol_with_multiple_entries() {
let mut headers = HeaderMap::new();
headers.insert(
"sec-websocket-protocol",
"construct.v1, bearer.zc_tok, other".parse().unwrap(),
);
assert_eq!(extract_ws_token(&headers, None), Some("zc_tok"));
}
}