Skip to main content

construct/gateway/
ws.rs

1//! WebSocket agent chat handler.
2//!
3//! Connect: `ws://host:port/ws/chat?session_id=ID&name=My+Session`
4//!
5//! Protocol:
6//! ```text
7//! Server -> Client: {"type":"session_start","session_id":"...","name":"...","resumed":true,"message_count":42}
8//! Client -> Server: {"type":"message","content":"Hello"}
9//! Server -> Client: {"type":"chunk","content":"Hi! "}
10//! Server -> Client: {"type":"tool_call","name":"shell","args":{...}}
11//! Server -> Client: {"type":"tool_result","name":"shell","output":"..."}
12//! Server -> Client: {"type":"done","full_response":"..."}
13//! ```
14//!
15//! Query params:
16//! - `session_id` — resume or create a session (default: new UUID)
17//! - `name` — optional human-readable label for the session
18//! - `token` — bearer auth token (alternative to Authorization header)
19
20use super::AppState;
21use axum::{
22    extract::{
23        Query, State, WebSocketUpgrade,
24        ws::{Message, WebSocket},
25    },
26    http::{HeaderMap, header},
27    response::IntoResponse,
28};
29use futures_util::{SinkExt, StreamExt};
30use serde::Deserialize;
31use std::sync::Arc;
32use tracing::debug;
33
34/// Optional connection parameters sent as the first WebSocket message.
35///
36/// If the first message after upgrade is `{"type":"connect",...}`, these
37/// parameters are extracted and an acknowledgement is sent back. Old clients
38/// that send `{"type":"message",...}` as the first frame still work — the
39/// message is processed normally (backward-compatible).
40#[derive(Debug, Deserialize)]
41struct ConnectParams {
42    #[serde(rename = "type")]
43    msg_type: String,
44    /// Client-chosen session ID for memory persistence
45    #[serde(default)]
46    session_id: Option<String>,
47    /// Device name for device registry tracking
48    #[serde(default)]
49    device_name: Option<String>,
50    /// Client capabilities
51    #[serde(default)]
52    capabilities: Vec<String>,
53}
54
55/// The sub-protocol we support for the chat WebSocket.
56const WS_PROTOCOL: &str = "construct.v1";
57
58/// Prefix used in `Sec-WebSocket-Protocol` to carry a bearer token.
59const BEARER_SUBPROTO_PREFIX: &str = "bearer.";
60
61#[derive(Deserialize)]
62pub struct WsQuery {
63    pub token: Option<String>,
64    pub session_id: Option<String>,
65    /// Optional human-readable name for the session.
66    pub name: Option<String>,
67}
68
69/// Extract a bearer token from WebSocket-compatible sources.
70///
71/// Precedence (first non-empty wins):
72/// 1. `Authorization: Bearer <token>` header
73/// 2. `Sec-WebSocket-Protocol: bearer.<token>` subprotocol
74/// 3. `?token=<token>` query parameter
75///
76/// Browsers cannot set custom headers on `new WebSocket(url)`, so the query
77/// parameter and subprotocol paths are required for browser-based clients.
78fn extract_ws_token<'a>(headers: &'a HeaderMap, query_token: Option<&'a str>) -> Option<&'a str> {
79    // 1. Authorization header
80    if let Some(t) = headers
81        .get(header::AUTHORIZATION)
82        .and_then(|v| v.to_str().ok())
83        .and_then(|auth| auth.strip_prefix("Bearer "))
84    {
85        if !t.is_empty() {
86            return Some(t);
87        }
88    }
89
90    // 2. Sec-WebSocket-Protocol: bearer.<token>
91    if let Some(t) = headers
92        .get("sec-websocket-protocol")
93        .and_then(|v| v.to_str().ok())
94        .and_then(|protos| {
95            protos
96                .split(',')
97                .map(|p| p.trim())
98                .find_map(|p| p.strip_prefix(BEARER_SUBPROTO_PREFIX))
99        })
100    {
101        if !t.is_empty() {
102            return Some(t);
103        }
104    }
105
106    // 3. ?token= query parameter
107    if let Some(t) = query_token {
108        if !t.is_empty() {
109            return Some(t);
110        }
111    }
112
113    None
114}
115
116/// GET /ws/chat — WebSocket upgrade for agent chat
117pub async fn handle_ws_chat(
118    State(state): State<AppState>,
119    Query(params): Query<WsQuery>,
120    headers: HeaderMap,
121    ws: WebSocketUpgrade,
122) -> impl IntoResponse {
123    // Auth: check header, subprotocol, then query param (precedence order)
124    if state.pairing.require_pairing() {
125        let token = extract_ws_token(&headers, params.token.as_deref()).unwrap_or("");
126        if !state.pairing.is_authenticated(token) {
127            return (
128                axum::http::StatusCode::UNAUTHORIZED,
129                "Unauthorized — provide Authorization header, Sec-WebSocket-Protocol bearer, or ?token= query param",
130            )
131                .into_response();
132        }
133    }
134
135    // Echo Sec-WebSocket-Protocol if the client requests our sub-protocol.
136    let ws = if headers
137        .get("sec-websocket-protocol")
138        .and_then(|v| v.to_str().ok())
139        .map_or(false, |protos| {
140            protos.split(',').any(|p| p.trim() == WS_PROTOCOL)
141        }) {
142        ws.protocols([WS_PROTOCOL])
143    } else {
144        ws
145    };
146
147    // Audit: log WebSocket chat connection
148    if let Some(ref logger) = state.audit_logger {
149        let _ = logger.log_security_event("dashboard", "WebSocket chat session connected");
150    }
151
152    let session_id = params.session_id;
153    let session_name = params.name;
154    ws.on_upgrade(move |socket| handle_socket(socket, state, session_id, session_name))
155        .into_response()
156}
157
158/// Gateway session key prefix to avoid collisions with channel sessions.
159const GW_SESSION_PREFIX: &str = "gw_";
160
161async fn handle_socket(
162    socket: WebSocket,
163    state: AppState,
164    session_id: Option<String>,
165    session_name: Option<String>,
166) {
167    let (mut sender, mut receiver) = socket.split();
168
169    // Resolve session ID: use provided or generate a new UUID
170    let session_id = session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
171    let session_key = format!("{GW_SESSION_PREFIX}{session_id}");
172
173    // Build a persistent Agent for this connection so history is maintained across turns.
174    let config = state.config.lock().clone();
175    let mut agent = match crate::agent::Agent::from_config(&config).await {
176        Ok(a) => a,
177        Err(e) => {
178            tracing::error!(error = %e, "Agent initialization failed");
179            let err = serde_json::json!({
180                "type": "error",
181                "message": format!("Failed to initialise agent: {e}"),
182                "code": "AGENT_INIT_FAILED"
183            });
184            let _ = sender.send(Message::Text(err.to_string().into())).await;
185            let _ = sender
186                .send(Message::Close(Some(axum::extract::ws::CloseFrame {
187                    code: 1011,
188                    reason: axum::extract::ws::Utf8Bytes::from_static(
189                        "Agent initialization failed",
190                    ),
191                })))
192                .await;
193            return;
194        }
195    };
196    agent.set_memory_session_id(Some(session_id.clone()));
197
198    // Hydrate agent from persisted session (if available)
199    let mut resumed = false;
200    let mut message_count: usize = 0;
201    let mut effective_name: Option<String> = None;
202    if let Some(ref backend) = state.session_backend {
203        let messages = backend.load(&session_key);
204        if !messages.is_empty() {
205            message_count = messages.len();
206            agent.seed_history(&messages);
207            resumed = true;
208        }
209        // Set session name if provided (non-empty) on connect
210        if let Some(ref name) = session_name {
211            if !name.is_empty() {
212                let _ = backend.set_session_name(&session_key, name);
213                effective_name = Some(name.clone());
214            }
215        }
216        // If no name was provided via query param, load the stored name
217        if effective_name.is_none() {
218            effective_name = backend.get_session_name(&session_key).unwrap_or(None);
219        }
220    }
221
222    // Send session_start message to client
223    let mut session_start = serde_json::json!({
224        "type": "session_start",
225        "session_id": session_id,
226        "resumed": resumed,
227        "message_count": message_count,
228    });
229    if let Some(ref name) = effective_name {
230        session_start["name"] = serde_json::Value::String(name.clone());
231    }
232    let _ = sender
233        .send(Message::Text(session_start.to_string().into()))
234        .await;
235
236    // ── Optional connect handshake ──────────────────────────────────
237    // The first message may be a `{"type":"connect",...}` frame carrying
238    // connection parameters.  If it is, we extract the params, send an
239    // ack, and proceed to the normal message loop.  If the first message
240    // is a regular `{"type":"message",...}` frame, we fall through and
241    // process it immediately (backward-compatible).
242    let mut first_msg_fallback: Option<String> = None;
243
244    // Wait up to 5 seconds for the first client frame.  Listen-only
245    // connections (e.g. WorkflowRunLive) may never send a message — the
246    // timeout lets them fall through to the broadcast relay loop.
247    match tokio::time::timeout(std::time::Duration::from_secs(5), receiver.next()).await {
248        Ok(Some(first)) => {
249            match first {
250                Ok(Message::Text(text)) => {
251                    if let Ok(cp) = serde_json::from_str::<ConnectParams>(&text) {
252                        if cp.msg_type == "connect" {
253                            debug!(
254                                session_id = ?cp.session_id,
255                                device_name = ?cp.device_name,
256                                capabilities = ?cp.capabilities,
257                                "WebSocket connect params received"
258                            );
259                            // Override session_id if provided in connect params
260                            if let Some(sid) = &cp.session_id {
261                                agent.set_memory_session_id(Some(sid.clone()));
262                            }
263                            let ack = serde_json::json!({
264                                "type": "connected",
265                                "message": "Connection established"
266                            });
267                            let _ = sender.send(Message::Text(ack.to_string().into())).await;
268                        } else {
269                            // Not a connect message — fall through to normal processing
270                            first_msg_fallback = Some(text.to_string());
271                        }
272                    } else {
273                        // Not parseable as ConnectParams — fall through
274                        first_msg_fallback = Some(text.to_string());
275                    }
276                }
277                Ok(Message::Close(_)) | Err(_) => return,
278                _ => {}
279            }
280        }
281        Ok(None) => return, // Stream ended
282        Err(_) => {
283            // Timeout — no initial message received within 5s.  Proceed to
284            // main loop so listen-only connections still receive broadcasts.
285            debug!(session_id = %session_id, "No initial message within 5s — entering listen-only mode");
286        }
287    }
288
289    // Subscribe to the broadcast channel early so we can relay operator channel
290    // events (agent.started, agent.completed, etc.) even during the first turn.
291    let mut broadcast_rx = state.event_tx.subscribe();
292
293    // Process the first message if it was not a connect frame
294    if let Some(ref text) = first_msg_fallback {
295        if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(text) {
296            if parsed["type"].as_str() == Some("message") {
297                let content = parsed["content"].as_str().unwrap_or("").to_string();
298                if !content.is_empty() {
299                    let page_ctx = parsed["page_context"].as_str();
300                    // Persist user message
301                    if let Some(ref backend) = state.session_backend {
302                        let user_msg = crate::providers::ChatMessage::user(&content);
303                        let _ = backend.append(&session_key, &user_msg);
304                    }
305                    process_chat_message(
306                        &state,
307                        &mut agent,
308                        &mut sender,
309                        &content,
310                        &session_key,
311                        page_ctx,
312                        &mut broadcast_rx,
313                    )
314                    .await;
315                }
316            } else {
317                let unknown_type = parsed["type"].as_str().unwrap_or("unknown");
318                let err = serde_json::json!({
319                    "type": "error",
320                    "message": format!(
321                        "Unsupported message type \"{unknown_type}\". Send {{\"type\":\"message\",\"content\":\"your text\"}}"
322                    )
323                });
324                let _ = sender.send(Message::Text(err.to_string().into())).await;
325            }
326        } else {
327            let err = serde_json::json!({
328                "type": "error",
329                "message": "Invalid JSON. Send {\"type\":\"message\",\"content\":\"your text\"}"
330            });
331            let _ = sender.send(Message::Text(err.to_string().into())).await;
332        }
333    }
334
335    loop {
336        tokio::select! {
337            // ── Branch 1: incoming WebSocket message from the client ──
338            ws_msg = receiver.next() => {
339                let msg = match ws_msg {
340                    Some(Ok(Message::Text(text))) => text,
341                    Some(Ok(Message::Close(_))) | Some(Err(_)) | None => break,
342                    _ => continue,
343                };
344
345                let parsed: serde_json::Value = match serde_json::from_str(&msg) {
346                    Ok(v) => v,
347                    Err(e) => {
348                        let err = serde_json::json!({
349                            "type": "error",
350                            "message": format!("Invalid JSON: {}", e),
351                            "code": "INVALID_JSON"
352                        });
353                        let _ = sender.send(Message::Text(err.to_string().into())).await;
354                        continue;
355                    }
356                };
357
358                let msg_type = parsed["type"].as_str().unwrap_or("");
359                if msg_type != "message" {
360                    let err = serde_json::json!({
361                        "type": "error",
362                        "message": format!(
363                            "Unsupported message type \"{msg_type}\". Send {{\"type\":\"message\",\"content\":\"your text\"}}"
364                        ),
365                        "code": "UNKNOWN_MESSAGE_TYPE"
366                    });
367                    let _ = sender.send(Message::Text(err.to_string().into())).await;
368                    continue;
369                }
370
371                let content = parsed["content"].as_str().unwrap_or("").to_string();
372                if content.is_empty() {
373                    let err = serde_json::json!({
374                        "type": "error",
375                        "message": "Message content cannot be empty",
376                        "code": "EMPTY_CONTENT"
377                    });
378                    let _ = sender.send(Message::Text(err.to_string().into())).await;
379                    continue;
380                }
381
382                // Acquire session lock to serialize concurrent turns
383                let _session_guard = match state.session_queue.acquire(&session_key).await {
384                    Ok(guard) => guard,
385                    Err(e) => {
386                        let err = serde_json::json!({
387                            "type": "error",
388                            "message": e.to_string(),
389                            "code": "SESSION_BUSY"
390                        });
391                        let _ = sender.send(Message::Text(err.to_string().into())).await;
392                        continue;
393                    }
394                };
395
396                let page_ctx = parsed["page_context"].as_str();
397
398                // Persist user message
399                if let Some(ref backend) = state.session_backend {
400                    let user_msg = crate::providers::ChatMessage::user(&content);
401                    let _ = backend.append(&session_key, &user_msg);
402                }
403
404                process_chat_message(&state, &mut agent, &mut sender, &content, &session_key, page_ctx, &mut broadcast_rx).await;
405            }
406
407            // ── Branch 2: broadcast channel event from operator ──
408            event = broadcast_rx.recv() => {
409                match event {
410                    Ok(ev) if ev["type"].as_str() == Some("channel_event") => {
411                        let relay = serde_json::json!({
412                            "type": "agent_event",
413                            "event": ev["payload"],
414                        });
415                        let _ = sender.send(Message::Text(relay.to_string().into())).await;
416                    }
417                    Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
418                    _ => {} // Skip non-channel events and lag errors
419                }
420            }
421        }
422    }
423}
424
425/// Build a context-aware system hint based on the dashboard page the user is viewing.
426///
427/// Returns `None` for unknown pages or the main chat — only Agent Pool and
428/// Agent Teams pages get specialised instructions.
429fn page_context_hint(page: &str) -> Option<&'static str> {
430    match page {
431        "agent_pool" => Some(concat!(
432            "[Page context: The user is on the **Agent Pool** page.\n",
433            "Available tools:\n",
434            "- `construct-operator__save_agent_template` — Create/update an agent\n",
435            "- `construct-operator__search_agent_pool` — Search agents by query\n",
436            "- `construct-operator__list_agent_templates` — List all agents (returns kref, name, role, etc.)\n\n",
437            "When creating agents, collect: name, role (coder/researcher/reviewer/specialist), ",
438            "expertise areas, preferred model (codex/claude), identity, soul, tone, and optionally system_hint.\n",
439            "Guide the user conversationally.\n\n",
440            "IMPORTANT behavioral rules:\n",
441            "- A tool returning empty content or no error means SUCCESS. Verify by calling list_agent_templates after.\n",
442            "- NEVER say a tool is broken or file a bug report. If something seems off, retry or verify.\n",
443            "- Do NOT ask the user to use the dashboard UI instead — YOU are the assistant, handle it.\n",
444            "- After creating/updating, confirm success by listing agents to show the result.]\n\n",
445        )),
446        "agent_teams" => Some(concat!(
447            "[Page context: The user is on the **Agent Teams** page.\n",
448            "Available tools:\n",
449            "- `construct-operator__create_team` — Create/update a team with members and edges\n",
450            "- `construct-operator__list_agent_templates` — List all agents (returns kref for member_krefs)\n",
451            "- `construct-operator__search_agent_pool` — Search agents by query\n",
452            "- `construct-operator__list_teams` — List existing teams\n",
453            "- `construct-operator__get_team` — Get team details with members and edges\n\n",
454            "When creating teams: collect a name, description, and select member agents.\n",
455            "Use the `kref` field from list_agent_templates for member_krefs — the system resolves names automatically.\n",
456            "Define edges (SUPPORTS, DEPENDS_ON, REPORTS_TO) between members to express the team structure.\n\n",
457            "IMPORTANT behavioral rules:\n",
458            "- A tool returning empty content or no error means SUCCESS. Verify by calling list_teams after.\n",
459            "- NEVER say a tool is broken or file a bug report. If something seems off, retry or verify.\n",
460            "- Do NOT ask the user to use the dashboard UI instead — YOU are the assistant, handle it.\n",
461            "- After creating a team, confirm success by calling list_teams or get_team to show the result.\n",
462            "- member_krefs accepts agent names, partial krefs, or full krefs — the resolver handles matching.]\n\n",
463        )),
464        "skills" => Some(concat!(
465            "[Page context: The user is on the **Skills Library** page.\n",
466            "Skills are reusable behavioral procedures stored in CognitiveMemory/Skills.\n",
467            "Available tools:\n",
468            "- `construct-operator__save_skill` — Create/update a skill (if available)\n",
469            "- `construct-operator__list_agent_templates` — List agents (skills may reference agents)\n",
470            "- `construct-operator__search_clawhub` — Search ClawHub public marketplace for community skills\n",
471            "- `construct-operator__browse_clawhub` — Browse trending skills on ClawHub\n",
472            "- `construct-operator__install_from_clawhub` — Install a skill from ClawHub by slug\n\n",
473            "A skill has: name, description, content (the procedure text), domain ",
474            "(Memory/Creative/Privacy/Graph/Behavioral/Other), and tags.\n",
475            "Guide the user through defining skills conversationally — help them articulate ",
476            "the procedure, choose the right domain, and write clear content.\n",
477            "When users want to find existing skills, search ClawHub first before creating from scratch.\n\n",
478            "IMPORTANT behavioral rules:\n",
479            "- A tool returning empty content or no error means SUCCESS. Verify after.\n",
480            "- NEVER say a tool is broken or file a bug report.\n",
481            "- Do NOT ask the user to use the dashboard UI instead — YOU are the assistant.]\n\n",
482        )),
483        "workflows" => Some(concat!(
484            "[Page context: The user is on the **Workflows** page.\n",
485            "Available tools: create_workflow, list_workflows, validate_workflow, run_workflow, ",
486            "get_workflow_status, cancel_workflow, resume_workflow, dry_run_workflow, ",
487            "recall_workflow_runs, get_workflow_run_detail, save_workflow_preset, list_workflow_presets ",
488            "(all prefixed with `construct-operator__`).\n\n",
489            "## Workflow schema (use this EXACTLY with create_workflow):\n",
490            "```yaml\n",
491            "workflow_def:\n",
492            "  name: my-workflow          # kebab-case identifier\n",
493            "  description: What it does\n",
494            "  tags: [tag1, tag2]         # optional\n",
495            "  inputs:                    # optional\n",
496            "    - name: task\n",
497            "      required: false\n",
498            "      default: default value\n",
499            "  steps:\n",
500            "    - id: research_step\n",
501            "      name: Research Phase\n",
502            "      action: research       # research | code | review | deploy | test | build | notify | approve | summarize | task | human_input\n",
503            "      description: Research the topic using ${inputs.task}\n",
504            "      agent_hints: [researcher]  # hints for operator: coder | researcher | reviewer\n",
505            "      depends_on: []\n",
506            "    - id: code_step\n",
507            "      name: Implementation\n",
508            "      action: code\n",
509            "      description: Implement based on ${research_step.output}\n",
510            "      agent_hints: [coder]\n",
511            "      depends_on: [research_step]\n",
512            "    - id: review_step\n",
513            "      name: Code Review\n",
514            "      action: review\n",
515            "      description: Review ${code_step.output}\n",
516            "      agent_hints: [reviewer]\n",
517            "      depends_on: [code_step]\n",
518            "    - id: feedback_step\n",
519            "      name: Get User Feedback\n",
520            "      action: human_input\n",
521            "      description: Please review the output and provide feedback\n",
522            "      channel: dashboard       # dashboard | slack | discord\n",
523            "      depends_on: [review_step]\n",
524            "```\n",
525            "The `action` field determines which agent type runs the step:\n",
526            "  research → researcher (claude), code → coder (codex), review → reviewer (claude),\n",
527            "  deploy/test/build → codex, notify/summarize → claude, task → generic claude,\n",
528            "  human_input → pauses workflow and sends a prompt to a channel (dashboard/slack/discord), waits for human response.\n",
529            "The `description` field is the agent's prompt — use ${step_id.output} and ${inputs.X} for interpolation.\n",
530            "`agent_hints` are optional suggestions (operator auto-selects if omitted).\n",
531            "For advanced use, add explicit `type` + config block (agent/shell/goto/output/human_approval).\n\n",
532            "Rules:\n",
533            "- create_workflow validates internally and returns {saved, path, valid, registered}. Trust it — do NOT call list_workflows or validate_workflow to verify.\n",
534            "- One tool call is enough for creation. Keep it simple.\n",
535            "- When the user says 'research agent', '3 agents', 'coder', etc., map to the right action.\n",
536            "- When running a workflow, always provide the cwd parameter.\n",
537            "- Do NOT ask the user to use the UI instead — handle it yourself.]\n\n",
538        )),
539        "canvas" => Some(concat!(
540            "[Page context: The user is on the **Live Canvas** page.\n",
541            "The canvas is YOUR primary output — render visual content IMMEDIATELY.\n\n",
542            "Available tools:\n",
543            "- `construct-operator__render_canvas` — Push content to the canvas (html, svg, markdown, text)\n",
544            "- `construct-operator__clear_canvas` — Clear a canvas\n\n",
545            "ALWAYS render to the canvas. The user opened this page to SEE visual output.\n",
546            "Use it for:\n",
547            "- Interactive HTML dashboards with charts, tables, and metrics\n",
548            "- SVG diagrams, flowcharts, architecture maps, or data visualizations\n",
549            "- Formatted reports, comparisons, or analyses\n",
550            "- Any content that benefits from visual presentation\n\n",
551            "CRITICAL rules:\n",
552            "- ALWAYS call render_canvas — do NOT just describe what you would render.\n",
553            "- For HTML: include ALL CSS inline. Use a dark theme (bg: #1a1a2e, text: #e2e8f0).\n",
554            "  Include modern styling with gradients, rounded corners, and clean typography.\n",
555            "- For SVG: provide complete <svg> with viewBox for responsive sizing.\n",
556            "- For charts: use inline CSS/HTML tables or SVG — no external JS libraries.\n",
557            "- Keep content self-contained — no external resources, CDNs, or imports.\n",
558            "- Default canvas_id is 'default'. You can use separate canvas_ids for multiple views.\n",
559            "- If the user asks a question, answer it AND render relevant visual content.\n",
560            "- Iterate: if the user gives feedback, re-render with improvements.]\n\n",
561        )),
562        _ => None,
563    }
564}
565
566/// Process a single chat message through the agent and send the response.
567///
568/// Uses [`Agent::turn_streamed`] so that intermediate text chunks, tool calls,
569/// and tool results are forwarded to the WebSocket client in real time.
570async fn process_chat_message(
571    state: &AppState,
572    agent: &mut crate::agent::Agent,
573    sender: &mut futures_util::stream::SplitSink<WebSocket, Message>,
574    content: &str,
575    session_key: &str,
576    page_context: Option<&str>,
577    broadcast_rx: &mut tokio::sync::broadcast::Receiver<serde_json::Value>,
578) {
579    use crate::agent::TurnEvent;
580
581    let provider_label = state
582        .config
583        .lock()
584        .default_provider
585        .clone()
586        .unwrap_or_else(|| "unknown".to_string());
587
588    // Broadcast agent_start event
589    let _ = state.event_tx.send(serde_json::json!({
590        "type": "agent_start",
591        "provider": provider_label,
592        "model": state.model,
593    }));
594
595    // Set session state to running
596    let turn_id = uuid::Uuid::new_v4().to_string();
597    if let Some(ref backend) = state.session_backend {
598        let _ = backend.set_session_state(session_key, "running", Some(&turn_id));
599    }
600
601    // Channel for streaming turn events from the agent.
602    let (event_tx, mut event_rx) = tokio::sync::mpsc::channel::<TurnEvent>(64);
603
604    // Run the streamed turn concurrently: the agent produces events
605    // while we forward them to the WebSocket below.  We cannot move
606    // `agent` into a spawned task (it is `&mut`), so we use a join
607    // instead — `turn_streamed` writes to the channel and we drain it
608    // from the other branch.
609    let content_owned = if let Some(hint) = page_context.and_then(page_context_hint) {
610        format!("{hint}{content}")
611    } else {
612        content.to_string()
613    };
614
615    // Scope the tool-loop cost tracker so token usage reported mid-stream
616    // (via StreamEvent::Usage) is recorded against the global CostTracker.
617    // Without this scope, record_tool_loop_cost_usage is a no-op.
618    let cost_tracking_context = state.cost_tracker.clone().map(|tracker| {
619        let prices = Arc::new(state.config.lock().cost.prices.clone());
620        crate::agent::cost::ToolLoopCostTrackingContext::new(tracker, prices)
621    });
622    let turn_fut = crate::agent::loop_::TOOL_LOOP_COST_TRACKING_CONTEXT
623        .scope(cost_tracking_context, async {
624            agent.turn_streamed(&content_owned, event_tx).await
625        });
626
627    // Drive both futures concurrently: the agent turn produces events
628    // and we relay them over WebSocket.  Also relay broadcast channel
629    // events (agent activity from the operator) so they reach the
630    // frontend in real-time even during long-running turns.
631    let forward_fut = async {
632        let mut turn_done = false;
633        loop {
634            if turn_done {
635                break;
636            }
637            tokio::select! {
638                event = event_rx.recv() => {
639                    match event {
640                        Some(event) => {
641                            let ws_msg = match event {
642                                TurnEvent::Chunk { delta } => {
643                                    serde_json::json!({ "type": "chunk", "content": delta })
644                                }
645                                TurnEvent::Thinking { delta } => {
646                                    serde_json::json!({ "type": "thinking", "content": delta })
647                                }
648                                TurnEvent::ToolCall { name, args } => {
649                                    serde_json::json!({ "type": "tool_call", "name": name, "args": args })
650                                }
651                                TurnEvent::ToolResult { name, output } => {
652                                    serde_json::json!({ "type": "tool_result", "name": name, "output": output })
653                                }
654                                TurnEvent::OperatorStatus { phase, detail } => {
655                                    serde_json::json!({ "type": "operator_status", "phase": phase, "detail": detail })
656                                }
657                            };
658                            let _ = sender.send(Message::Text(ws_msg.to_string().into())).await;
659                        }
660                        None => { turn_done = true; }
661                    }
662                }
663                bcast = broadcast_rx.recv() => {
664                    if let Ok(ev) = bcast {
665                        if ev["type"].as_str() == Some("channel_event") {
666                            let relay = serde_json::json!({
667                                "type": "agent_event",
668                                "event": ev["payload"],
669                            });
670                            let _ = sender.send(Message::Text(relay.to_string().into())).await;
671                        }
672                    }
673                }
674            }
675        }
676    };
677
678    let (result, ()) = tokio::join!(turn_fut, forward_fut);
679
680    match result {
681        Ok(response) => {
682            // Persist assistant response
683            if let Some(ref backend) = state.session_backend {
684                let assistant_msg = crate::providers::ChatMessage::assistant(&response);
685                let _ = backend.append(session_key, &assistant_msg);
686            }
687
688            // Send chunk_reset so the client clears any accumulated draft
689            // before the authoritative done message.
690            let reset = serde_json::json!({ "type": "chunk_reset" });
691            let _ = sender.send(Message::Text(reset.to_string().into())).await;
692
693            let done = serde_json::json!({
694                "type": "done",
695                "full_response": response,
696            });
697            let _ = sender.send(Message::Text(done.to_string().into())).await;
698
699            // Set session state to idle
700            if let Some(ref backend) = state.session_backend {
701                let _ = backend.set_session_state(session_key, "idle", None);
702            }
703
704            // Broadcast agent_end event
705            let _ = state.event_tx.send(serde_json::json!({
706                "type": "agent_end",
707                "provider": provider_label,
708                "model": state.model,
709            }));
710        }
711        Err(e) => {
712            // Set session state to error
713            if let Some(ref backend) = state.session_backend {
714                let _ = backend.set_session_state(session_key, "error", Some(&turn_id));
715            }
716
717            tracing::error!(error = %e, "Agent turn failed");
718            let sanitized = crate::providers::sanitize_api_error(&e.to_string());
719            let error_code = if sanitized.to_lowercase().contains("api key")
720                || sanitized.to_lowercase().contains("authentication")
721                || sanitized.to_lowercase().contains("unauthorized")
722            {
723                "AUTH_ERROR"
724            } else if sanitized.to_lowercase().contains("provider")
725                || sanitized.to_lowercase().contains("model")
726            {
727                "PROVIDER_ERROR"
728            } else {
729                "AGENT_ERROR"
730            };
731            let err = serde_json::json!({
732                "type": "error",
733                "message": sanitized,
734                "code": error_code,
735            });
736            let _ = sender.send(Message::Text(err.to_string().into())).await;
737
738            // Broadcast error event
739            let _ = state.event_tx.send(serde_json::json!({
740                "type": "error",
741                "component": "ws_chat",
742                "message": sanitized,
743            }));
744        }
745    }
746}
747
748#[cfg(test)]
749mod tests {
750    use super::*;
751    use axum::http::HeaderMap;
752
753    #[test]
754    fn extract_ws_token_from_authorization_header() {
755        let mut headers = HeaderMap::new();
756        headers.insert("authorization", "Bearer zc_test123".parse().unwrap());
757        assert_eq!(extract_ws_token(&headers, None), Some("zc_test123"));
758    }
759
760    #[test]
761    fn extract_ws_token_from_subprotocol() {
762        let mut headers = HeaderMap::new();
763        headers.insert(
764            "sec-websocket-protocol",
765            "construct.v1, bearer.zc_sub456".parse().unwrap(),
766        );
767        assert_eq!(extract_ws_token(&headers, None), Some("zc_sub456"));
768    }
769
770    #[test]
771    fn extract_ws_token_from_query_param() {
772        let headers = HeaderMap::new();
773        assert_eq!(
774            extract_ws_token(&headers, Some("zc_query789")),
775            Some("zc_query789")
776        );
777    }
778
779    #[test]
780    fn extract_ws_token_precedence_header_over_subprotocol() {
781        let mut headers = HeaderMap::new();
782        headers.insert("authorization", "Bearer zc_header".parse().unwrap());
783        headers.insert("sec-websocket-protocol", "bearer.zc_sub".parse().unwrap());
784        assert_eq!(
785            extract_ws_token(&headers, Some("zc_query")),
786            Some("zc_header")
787        );
788    }
789
790    #[test]
791    fn extract_ws_token_precedence_subprotocol_over_query() {
792        let mut headers = HeaderMap::new();
793        headers.insert("sec-websocket-protocol", "bearer.zc_sub".parse().unwrap());
794        assert_eq!(extract_ws_token(&headers, Some("zc_query")), Some("zc_sub"));
795    }
796
797    #[test]
798    fn extract_ws_token_returns_none_when_empty() {
799        let headers = HeaderMap::new();
800        assert_eq!(extract_ws_token(&headers, None), None);
801    }
802
803    #[test]
804    fn extract_ws_token_skips_empty_header_value() {
805        let mut headers = HeaderMap::new();
806        headers.insert("authorization", "Bearer ".parse().unwrap());
807        assert_eq!(
808            extract_ws_token(&headers, Some("zc_fallback")),
809            Some("zc_fallback")
810        );
811    }
812
813    #[test]
814    fn extract_ws_token_skips_empty_query_param() {
815        let headers = HeaderMap::new();
816        assert_eq!(extract_ws_token(&headers, Some("")), None);
817    }
818
819    #[test]
820    fn extract_ws_token_subprotocol_with_multiple_entries() {
821        let mut headers = HeaderMap::new();
822        headers.insert(
823            "sec-websocket-protocol",
824            "construct.v1, bearer.zc_tok, other".parse().unwrap(),
825        );
826        assert_eq!(extract_ws_token(&headers, None), Some("zc_tok"));
827    }
828}