1use super::AppState;
21use axum::{
22 extract::{
23 Query, State, WebSocketUpgrade,
24 ws::{Message, WebSocket},
25 },
26 http::{HeaderMap, header},
27 response::IntoResponse,
28};
29use futures_util::{SinkExt, StreamExt};
30use serde::Deserialize;
31use std::sync::Arc;
32use tracing::debug;
33
34#[derive(Debug, Deserialize)]
41struct ConnectParams {
42 #[serde(rename = "type")]
43 msg_type: String,
44 #[serde(default)]
46 session_id: Option<String>,
47 #[serde(default)]
49 device_name: Option<String>,
50 #[serde(default)]
52 capabilities: Vec<String>,
53}
54
55const WS_PROTOCOL: &str = "construct.v1";
57
58const BEARER_SUBPROTO_PREFIX: &str = "bearer.";
60
61#[derive(Deserialize)]
62pub struct WsQuery {
63 pub token: Option<String>,
64 pub session_id: Option<String>,
65 pub name: Option<String>,
67}
68
69fn extract_ws_token<'a>(headers: &'a HeaderMap, query_token: Option<&'a str>) -> Option<&'a str> {
79 if let Some(t) = headers
81 .get(header::AUTHORIZATION)
82 .and_then(|v| v.to_str().ok())
83 .and_then(|auth| auth.strip_prefix("Bearer "))
84 {
85 if !t.is_empty() {
86 return Some(t);
87 }
88 }
89
90 if let Some(t) = headers
92 .get("sec-websocket-protocol")
93 .and_then(|v| v.to_str().ok())
94 .and_then(|protos| {
95 protos
96 .split(',')
97 .map(|p| p.trim())
98 .find_map(|p| p.strip_prefix(BEARER_SUBPROTO_PREFIX))
99 })
100 {
101 if !t.is_empty() {
102 return Some(t);
103 }
104 }
105
106 if let Some(t) = query_token {
108 if !t.is_empty() {
109 return Some(t);
110 }
111 }
112
113 None
114}
115
116pub async fn handle_ws_chat(
118 State(state): State<AppState>,
119 Query(params): Query<WsQuery>,
120 headers: HeaderMap,
121 ws: WebSocketUpgrade,
122) -> impl IntoResponse {
123 if state.pairing.require_pairing() {
125 let token = extract_ws_token(&headers, params.token.as_deref()).unwrap_or("");
126 if !state.pairing.is_authenticated(token) {
127 return (
128 axum::http::StatusCode::UNAUTHORIZED,
129 "Unauthorized — provide Authorization header, Sec-WebSocket-Protocol bearer, or ?token= query param",
130 )
131 .into_response();
132 }
133 }
134
135 let ws = if headers
137 .get("sec-websocket-protocol")
138 .and_then(|v| v.to_str().ok())
139 .map_or(false, |protos| {
140 protos.split(',').any(|p| p.trim() == WS_PROTOCOL)
141 }) {
142 ws.protocols([WS_PROTOCOL])
143 } else {
144 ws
145 };
146
147 if let Some(ref logger) = state.audit_logger {
149 let _ = logger.log_security_event("dashboard", "WebSocket chat session connected");
150 }
151
152 let session_id = params.session_id;
153 let session_name = params.name;
154 ws.on_upgrade(move |socket| handle_socket(socket, state, session_id, session_name))
155 .into_response()
156}
157
158const GW_SESSION_PREFIX: &str = "gw_";
160
161async fn handle_socket(
162 socket: WebSocket,
163 state: AppState,
164 session_id: Option<String>,
165 session_name: Option<String>,
166) {
167 let (mut sender, mut receiver) = socket.split();
168
169 let session_id = session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
171 let session_key = format!("{GW_SESSION_PREFIX}{session_id}");
172
173 let config = state.config.lock().clone();
175 let mut agent = match crate::agent::Agent::from_config(&config).await {
176 Ok(a) => a,
177 Err(e) => {
178 tracing::error!(error = %e, "Agent initialization failed");
179 let err = serde_json::json!({
180 "type": "error",
181 "message": format!("Failed to initialise agent: {e}"),
182 "code": "AGENT_INIT_FAILED"
183 });
184 let _ = sender.send(Message::Text(err.to_string().into())).await;
185 let _ = sender
186 .send(Message::Close(Some(axum::extract::ws::CloseFrame {
187 code: 1011,
188 reason: axum::extract::ws::Utf8Bytes::from_static(
189 "Agent initialization failed",
190 ),
191 })))
192 .await;
193 return;
194 }
195 };
196 agent.set_memory_session_id(Some(session_id.clone()));
197
198 let mut resumed = false;
200 let mut message_count: usize = 0;
201 let mut effective_name: Option<String> = None;
202 if let Some(ref backend) = state.session_backend {
203 let messages = backend.load(&session_key);
204 if !messages.is_empty() {
205 message_count = messages.len();
206 agent.seed_history(&messages);
207 resumed = true;
208 }
209 if let Some(ref name) = session_name {
211 if !name.is_empty() {
212 let _ = backend.set_session_name(&session_key, name);
213 effective_name = Some(name.clone());
214 }
215 }
216 if effective_name.is_none() {
218 effective_name = backend.get_session_name(&session_key).unwrap_or(None);
219 }
220 }
221
222 let mut session_start = serde_json::json!({
224 "type": "session_start",
225 "session_id": session_id,
226 "resumed": resumed,
227 "message_count": message_count,
228 });
229 if let Some(ref name) = effective_name {
230 session_start["name"] = serde_json::Value::String(name.clone());
231 }
232 let _ = sender
233 .send(Message::Text(session_start.to_string().into()))
234 .await;
235
236 let mut first_msg_fallback: Option<String> = None;
243
244 match tokio::time::timeout(std::time::Duration::from_secs(5), receiver.next()).await {
248 Ok(Some(first)) => {
249 match first {
250 Ok(Message::Text(text)) => {
251 if let Ok(cp) = serde_json::from_str::<ConnectParams>(&text) {
252 if cp.msg_type == "connect" {
253 debug!(
254 session_id = ?cp.session_id,
255 device_name = ?cp.device_name,
256 capabilities = ?cp.capabilities,
257 "WebSocket connect params received"
258 );
259 if let Some(sid) = &cp.session_id {
261 agent.set_memory_session_id(Some(sid.clone()));
262 }
263 let ack = serde_json::json!({
264 "type": "connected",
265 "message": "Connection established"
266 });
267 let _ = sender.send(Message::Text(ack.to_string().into())).await;
268 } else {
269 first_msg_fallback = Some(text.to_string());
271 }
272 } else {
273 first_msg_fallback = Some(text.to_string());
275 }
276 }
277 Ok(Message::Close(_)) | Err(_) => return,
278 _ => {}
279 }
280 }
281 Ok(None) => return, Err(_) => {
283 debug!(session_id = %session_id, "No initial message within 5s — entering listen-only mode");
286 }
287 }
288
289 let mut broadcast_rx = state.event_tx.subscribe();
292
293 if let Some(ref text) = first_msg_fallback {
295 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(text) {
296 if parsed["type"].as_str() == Some("message") {
297 let content = parsed["content"].as_str().unwrap_or("").to_string();
298 if !content.is_empty() {
299 let page_ctx = parsed["page_context"].as_str();
300 if let Some(ref backend) = state.session_backend {
302 let user_msg = crate::providers::ChatMessage::user(&content);
303 let _ = backend.append(&session_key, &user_msg);
304 }
305 process_chat_message(
306 &state,
307 &mut agent,
308 &mut sender,
309 &content,
310 &session_key,
311 page_ctx,
312 &mut broadcast_rx,
313 )
314 .await;
315 }
316 } else {
317 let unknown_type = parsed["type"].as_str().unwrap_or("unknown");
318 let err = serde_json::json!({
319 "type": "error",
320 "message": format!(
321 "Unsupported message type \"{unknown_type}\". Send {{\"type\":\"message\",\"content\":\"your text\"}}"
322 )
323 });
324 let _ = sender.send(Message::Text(err.to_string().into())).await;
325 }
326 } else {
327 let err = serde_json::json!({
328 "type": "error",
329 "message": "Invalid JSON. Send {\"type\":\"message\",\"content\":\"your text\"}"
330 });
331 let _ = sender.send(Message::Text(err.to_string().into())).await;
332 }
333 }
334
335 loop {
336 tokio::select! {
337 ws_msg = receiver.next() => {
339 let msg = match ws_msg {
340 Some(Ok(Message::Text(text))) => text,
341 Some(Ok(Message::Close(_))) | Some(Err(_)) | None => break,
342 _ => continue,
343 };
344
345 let parsed: serde_json::Value = match serde_json::from_str(&msg) {
346 Ok(v) => v,
347 Err(e) => {
348 let err = serde_json::json!({
349 "type": "error",
350 "message": format!("Invalid JSON: {}", e),
351 "code": "INVALID_JSON"
352 });
353 let _ = sender.send(Message::Text(err.to_string().into())).await;
354 continue;
355 }
356 };
357
358 let msg_type = parsed["type"].as_str().unwrap_or("");
359 if msg_type != "message" {
360 let err = serde_json::json!({
361 "type": "error",
362 "message": format!(
363 "Unsupported message type \"{msg_type}\". Send {{\"type\":\"message\",\"content\":\"your text\"}}"
364 ),
365 "code": "UNKNOWN_MESSAGE_TYPE"
366 });
367 let _ = sender.send(Message::Text(err.to_string().into())).await;
368 continue;
369 }
370
371 let content = parsed["content"].as_str().unwrap_or("").to_string();
372 if content.is_empty() {
373 let err = serde_json::json!({
374 "type": "error",
375 "message": "Message content cannot be empty",
376 "code": "EMPTY_CONTENT"
377 });
378 let _ = sender.send(Message::Text(err.to_string().into())).await;
379 continue;
380 }
381
382 let _session_guard = match state.session_queue.acquire(&session_key).await {
384 Ok(guard) => guard,
385 Err(e) => {
386 let err = serde_json::json!({
387 "type": "error",
388 "message": e.to_string(),
389 "code": "SESSION_BUSY"
390 });
391 let _ = sender.send(Message::Text(err.to_string().into())).await;
392 continue;
393 }
394 };
395
396 let page_ctx = parsed["page_context"].as_str();
397
398 if let Some(ref backend) = state.session_backend {
400 let user_msg = crate::providers::ChatMessage::user(&content);
401 let _ = backend.append(&session_key, &user_msg);
402 }
403
404 process_chat_message(&state, &mut agent, &mut sender, &content, &session_key, page_ctx, &mut broadcast_rx).await;
405 }
406
407 event = broadcast_rx.recv() => {
409 match event {
410 Ok(ev) if ev["type"].as_str() == Some("channel_event") => {
411 let relay = serde_json::json!({
412 "type": "agent_event",
413 "event": ev["payload"],
414 });
415 let _ = sender.send(Message::Text(relay.to_string().into())).await;
416 }
417 Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
418 _ => {} }
420 }
421 }
422 }
423}
424
425fn page_context_hint(page: &str) -> Option<&'static str> {
430 match page {
431 "agent_pool" => Some(concat!(
432 "[Page context: The user is on the **Agent Pool** page.\n",
433 "Available tools:\n",
434 "- `construct-operator__save_agent_template` — Create/update an agent\n",
435 "- `construct-operator__search_agent_pool` — Search agents by query\n",
436 "- `construct-operator__list_agent_templates` — List all agents (returns kref, name, role, etc.)\n\n",
437 "When creating agents, collect: name, role (coder/researcher/reviewer/specialist), ",
438 "expertise areas, preferred model (codex/claude), identity, soul, tone, and optionally system_hint.\n",
439 "Guide the user conversationally.\n\n",
440 "IMPORTANT behavioral rules:\n",
441 "- A tool returning empty content or no error means SUCCESS. Verify by calling list_agent_templates after.\n",
442 "- NEVER say a tool is broken or file a bug report. If something seems off, retry or verify.\n",
443 "- Do NOT ask the user to use the dashboard UI instead — YOU are the assistant, handle it.\n",
444 "- After creating/updating, confirm success by listing agents to show the result.]\n\n",
445 )),
446 "agent_teams" => Some(concat!(
447 "[Page context: The user is on the **Agent Teams** page.\n",
448 "Available tools:\n",
449 "- `construct-operator__create_team` — Create/update a team with members and edges\n",
450 "- `construct-operator__list_agent_templates` — List all agents (returns kref for member_krefs)\n",
451 "- `construct-operator__search_agent_pool` — Search agents by query\n",
452 "- `construct-operator__list_teams` — List existing teams\n",
453 "- `construct-operator__get_team` — Get team details with members and edges\n\n",
454 "When creating teams: collect a name, description, and select member agents.\n",
455 "Use the `kref` field from list_agent_templates for member_krefs — the system resolves names automatically.\n",
456 "Define edges (SUPPORTS, DEPENDS_ON, REPORTS_TO) between members to express the team structure.\n\n",
457 "IMPORTANT behavioral rules:\n",
458 "- A tool returning empty content or no error means SUCCESS. Verify by calling list_teams after.\n",
459 "- NEVER say a tool is broken or file a bug report. If something seems off, retry or verify.\n",
460 "- Do NOT ask the user to use the dashboard UI instead — YOU are the assistant, handle it.\n",
461 "- After creating a team, confirm success by calling list_teams or get_team to show the result.\n",
462 "- member_krefs accepts agent names, partial krefs, or full krefs — the resolver handles matching.]\n\n",
463 )),
464 "skills" => Some(concat!(
465 "[Page context: The user is on the **Skills Library** page.\n",
466 "Skills are reusable behavioral procedures stored in CognitiveMemory/Skills.\n",
467 "Available tools:\n",
468 "- `construct-operator__save_skill` — Create/update a skill (if available)\n",
469 "- `construct-operator__list_agent_templates` — List agents (skills may reference agents)\n",
470 "- `construct-operator__search_clawhub` — Search ClawHub public marketplace for community skills\n",
471 "- `construct-operator__browse_clawhub` — Browse trending skills on ClawHub\n",
472 "- `construct-operator__install_from_clawhub` — Install a skill from ClawHub by slug\n\n",
473 "A skill has: name, description, content (the procedure text), domain ",
474 "(Memory/Creative/Privacy/Graph/Behavioral/Other), and tags.\n",
475 "Guide the user through defining skills conversationally — help them articulate ",
476 "the procedure, choose the right domain, and write clear content.\n",
477 "When users want to find existing skills, search ClawHub first before creating from scratch.\n\n",
478 "IMPORTANT behavioral rules:\n",
479 "- A tool returning empty content or no error means SUCCESS. Verify after.\n",
480 "- NEVER say a tool is broken or file a bug report.\n",
481 "- Do NOT ask the user to use the dashboard UI instead — YOU are the assistant.]\n\n",
482 )),
483 "workflows" => Some(concat!(
484 "[Page context: The user is on the **Workflows** page.\n",
485 "Available tools: create_workflow, list_workflows, validate_workflow, run_workflow, ",
486 "get_workflow_status, cancel_workflow, resume_workflow, dry_run_workflow, ",
487 "recall_workflow_runs, get_workflow_run_detail, save_workflow_preset, list_workflow_presets ",
488 "(all prefixed with `construct-operator__`).\n\n",
489 "## Workflow schema (use this EXACTLY with create_workflow):\n",
490 "```yaml\n",
491 "workflow_def:\n",
492 " name: my-workflow # kebab-case identifier\n",
493 " description: What it does\n",
494 " tags: [tag1, tag2] # optional\n",
495 " inputs: # optional\n",
496 " - name: task\n",
497 " required: false\n",
498 " default: default value\n",
499 " steps:\n",
500 " - id: research_step\n",
501 " name: Research Phase\n",
502 " action: research # research | code | review | deploy | test | build | notify | approve | summarize | task | human_input\n",
503 " description: Research the topic using ${inputs.task}\n",
504 " agent_hints: [researcher] # hints for operator: coder | researcher | reviewer\n",
505 " depends_on: []\n",
506 " - id: code_step\n",
507 " name: Implementation\n",
508 " action: code\n",
509 " description: Implement based on ${research_step.output}\n",
510 " agent_hints: [coder]\n",
511 " depends_on: [research_step]\n",
512 " - id: review_step\n",
513 " name: Code Review\n",
514 " action: review\n",
515 " description: Review ${code_step.output}\n",
516 " agent_hints: [reviewer]\n",
517 " depends_on: [code_step]\n",
518 " - id: feedback_step\n",
519 " name: Get User Feedback\n",
520 " action: human_input\n",
521 " description: Please review the output and provide feedback\n",
522 " channel: dashboard # dashboard | slack | discord\n",
523 " depends_on: [review_step]\n",
524 "```\n",
525 "The `action` field determines which agent type runs the step:\n",
526 " research → researcher (claude), code → coder (codex), review → reviewer (claude),\n",
527 " deploy/test/build → codex, notify/summarize → claude, task → generic claude,\n",
528 " human_input → pauses workflow and sends a prompt to a channel (dashboard/slack/discord), waits for human response.\n",
529 "The `description` field is the agent's prompt — use ${step_id.output} and ${inputs.X} for interpolation.\n",
530 "`agent_hints` are optional suggestions (operator auto-selects if omitted).\n",
531 "For advanced use, add explicit `type` + config block (agent/shell/goto/output/human_approval).\n\n",
532 "Rules:\n",
533 "- create_workflow validates internally and returns {saved, path, valid, registered}. Trust it — do NOT call list_workflows or validate_workflow to verify.\n",
534 "- One tool call is enough for creation. Keep it simple.\n",
535 "- When the user says 'research agent', '3 agents', 'coder', etc., map to the right action.\n",
536 "- When running a workflow, always provide the cwd parameter.\n",
537 "- Do NOT ask the user to use the UI instead — handle it yourself.]\n\n",
538 )),
539 "canvas" => Some(concat!(
540 "[Page context: The user is on the **Live Canvas** page.\n",
541 "The canvas is YOUR primary output — render visual content IMMEDIATELY.\n\n",
542 "Available tools:\n",
543 "- `construct-operator__render_canvas` — Push content to the canvas (html, svg, markdown, text)\n",
544 "- `construct-operator__clear_canvas` — Clear a canvas\n\n",
545 "ALWAYS render to the canvas. The user opened this page to SEE visual output.\n",
546 "Use it for:\n",
547 "- Interactive HTML dashboards with charts, tables, and metrics\n",
548 "- SVG diagrams, flowcharts, architecture maps, or data visualizations\n",
549 "- Formatted reports, comparisons, or analyses\n",
550 "- Any content that benefits from visual presentation\n\n",
551 "CRITICAL rules:\n",
552 "- ALWAYS call render_canvas — do NOT just describe what you would render.\n",
553 "- For HTML: include ALL CSS inline. Use a dark theme (bg: #1a1a2e, text: #e2e8f0).\n",
554 " Include modern styling with gradients, rounded corners, and clean typography.\n",
555 "- For SVG: provide complete <svg> with viewBox for responsive sizing.\n",
556 "- For charts: use inline CSS/HTML tables or SVG — no external JS libraries.\n",
557 "- Keep content self-contained — no external resources, CDNs, or imports.\n",
558 "- Default canvas_id is 'default'. You can use separate canvas_ids for multiple views.\n",
559 "- If the user asks a question, answer it AND render relevant visual content.\n",
560 "- Iterate: if the user gives feedback, re-render with improvements.]\n\n",
561 )),
562 _ => None,
563 }
564}
565
566async fn process_chat_message(
571 state: &AppState,
572 agent: &mut crate::agent::Agent,
573 sender: &mut futures_util::stream::SplitSink<WebSocket, Message>,
574 content: &str,
575 session_key: &str,
576 page_context: Option<&str>,
577 broadcast_rx: &mut tokio::sync::broadcast::Receiver<serde_json::Value>,
578) {
579 use crate::agent::TurnEvent;
580
581 let provider_label = state
582 .config
583 .lock()
584 .default_provider
585 .clone()
586 .unwrap_or_else(|| "unknown".to_string());
587
588 let _ = state.event_tx.send(serde_json::json!({
590 "type": "agent_start",
591 "provider": provider_label,
592 "model": state.model,
593 }));
594
595 let turn_id = uuid::Uuid::new_v4().to_string();
597 if let Some(ref backend) = state.session_backend {
598 let _ = backend.set_session_state(session_key, "running", Some(&turn_id));
599 }
600
601 let (event_tx, mut event_rx) = tokio::sync::mpsc::channel::<TurnEvent>(64);
603
604 let content_owned = if let Some(hint) = page_context.and_then(page_context_hint) {
610 format!("{hint}{content}")
611 } else {
612 content.to_string()
613 };
614
615 let cost_tracking_context = state.cost_tracker.clone().map(|tracker| {
619 let prices = Arc::new(state.config.lock().cost.prices.clone());
620 crate::agent::cost::ToolLoopCostTrackingContext::new(tracker, prices)
621 });
622 let turn_fut = crate::agent::loop_::TOOL_LOOP_COST_TRACKING_CONTEXT
623 .scope(cost_tracking_context, async {
624 agent.turn_streamed(&content_owned, event_tx).await
625 });
626
627 let forward_fut = async {
632 let mut turn_done = false;
633 loop {
634 if turn_done {
635 break;
636 }
637 tokio::select! {
638 event = event_rx.recv() => {
639 match event {
640 Some(event) => {
641 let ws_msg = match event {
642 TurnEvent::Chunk { delta } => {
643 serde_json::json!({ "type": "chunk", "content": delta })
644 }
645 TurnEvent::Thinking { delta } => {
646 serde_json::json!({ "type": "thinking", "content": delta })
647 }
648 TurnEvent::ToolCall { name, args } => {
649 serde_json::json!({ "type": "tool_call", "name": name, "args": args })
650 }
651 TurnEvent::ToolResult { name, output } => {
652 serde_json::json!({ "type": "tool_result", "name": name, "output": output })
653 }
654 TurnEvent::OperatorStatus { phase, detail } => {
655 serde_json::json!({ "type": "operator_status", "phase": phase, "detail": detail })
656 }
657 };
658 let _ = sender.send(Message::Text(ws_msg.to_string().into())).await;
659 }
660 None => { turn_done = true; }
661 }
662 }
663 bcast = broadcast_rx.recv() => {
664 if let Ok(ev) = bcast {
665 if ev["type"].as_str() == Some("channel_event") {
666 let relay = serde_json::json!({
667 "type": "agent_event",
668 "event": ev["payload"],
669 });
670 let _ = sender.send(Message::Text(relay.to_string().into())).await;
671 }
672 }
673 }
674 }
675 }
676 };
677
678 let (result, ()) = tokio::join!(turn_fut, forward_fut);
679
680 match result {
681 Ok(response) => {
682 if let Some(ref backend) = state.session_backend {
684 let assistant_msg = crate::providers::ChatMessage::assistant(&response);
685 let _ = backend.append(session_key, &assistant_msg);
686 }
687
688 let reset = serde_json::json!({ "type": "chunk_reset" });
691 let _ = sender.send(Message::Text(reset.to_string().into())).await;
692
693 let done = serde_json::json!({
694 "type": "done",
695 "full_response": response,
696 });
697 let _ = sender.send(Message::Text(done.to_string().into())).await;
698
699 if let Some(ref backend) = state.session_backend {
701 let _ = backend.set_session_state(session_key, "idle", None);
702 }
703
704 let _ = state.event_tx.send(serde_json::json!({
706 "type": "agent_end",
707 "provider": provider_label,
708 "model": state.model,
709 }));
710 }
711 Err(e) => {
712 if let Some(ref backend) = state.session_backend {
714 let _ = backend.set_session_state(session_key, "error", Some(&turn_id));
715 }
716
717 tracing::error!(error = %e, "Agent turn failed");
718 let sanitized = crate::providers::sanitize_api_error(&e.to_string());
719 let error_code = if sanitized.to_lowercase().contains("api key")
720 || sanitized.to_lowercase().contains("authentication")
721 || sanitized.to_lowercase().contains("unauthorized")
722 {
723 "AUTH_ERROR"
724 } else if sanitized.to_lowercase().contains("provider")
725 || sanitized.to_lowercase().contains("model")
726 {
727 "PROVIDER_ERROR"
728 } else {
729 "AGENT_ERROR"
730 };
731 let err = serde_json::json!({
732 "type": "error",
733 "message": sanitized,
734 "code": error_code,
735 });
736 let _ = sender.send(Message::Text(err.to_string().into())).await;
737
738 let _ = state.event_tx.send(serde_json::json!({
740 "type": "error",
741 "component": "ws_chat",
742 "message": sanitized,
743 }));
744 }
745 }
746}
747
748#[cfg(test)]
749mod tests {
750 use super::*;
751 use axum::http::HeaderMap;
752
753 #[test]
754 fn extract_ws_token_from_authorization_header() {
755 let mut headers = HeaderMap::new();
756 headers.insert("authorization", "Bearer zc_test123".parse().unwrap());
757 assert_eq!(extract_ws_token(&headers, None), Some("zc_test123"));
758 }
759
760 #[test]
761 fn extract_ws_token_from_subprotocol() {
762 let mut headers = HeaderMap::new();
763 headers.insert(
764 "sec-websocket-protocol",
765 "construct.v1, bearer.zc_sub456".parse().unwrap(),
766 );
767 assert_eq!(extract_ws_token(&headers, None), Some("zc_sub456"));
768 }
769
770 #[test]
771 fn extract_ws_token_from_query_param() {
772 let headers = HeaderMap::new();
773 assert_eq!(
774 extract_ws_token(&headers, Some("zc_query789")),
775 Some("zc_query789")
776 );
777 }
778
779 #[test]
780 fn extract_ws_token_precedence_header_over_subprotocol() {
781 let mut headers = HeaderMap::new();
782 headers.insert("authorization", "Bearer zc_header".parse().unwrap());
783 headers.insert("sec-websocket-protocol", "bearer.zc_sub".parse().unwrap());
784 assert_eq!(
785 extract_ws_token(&headers, Some("zc_query")),
786 Some("zc_header")
787 );
788 }
789
790 #[test]
791 fn extract_ws_token_precedence_subprotocol_over_query() {
792 let mut headers = HeaderMap::new();
793 headers.insert("sec-websocket-protocol", "bearer.zc_sub".parse().unwrap());
794 assert_eq!(extract_ws_token(&headers, Some("zc_query")), Some("zc_sub"));
795 }
796
797 #[test]
798 fn extract_ws_token_returns_none_when_empty() {
799 let headers = HeaderMap::new();
800 assert_eq!(extract_ws_token(&headers, None), None);
801 }
802
803 #[test]
804 fn extract_ws_token_skips_empty_header_value() {
805 let mut headers = HeaderMap::new();
806 headers.insert("authorization", "Bearer ".parse().unwrap());
807 assert_eq!(
808 extract_ws_token(&headers, Some("zc_fallback")),
809 Some("zc_fallback")
810 );
811 }
812
813 #[test]
814 fn extract_ws_token_skips_empty_query_param() {
815 let headers = HeaderMap::new();
816 assert_eq!(extract_ws_token(&headers, Some("")), None);
817 }
818
819 #[test]
820 fn extract_ws_token_subprotocol_with_multiple_entries() {
821 let mut headers = HeaderMap::new();
822 headers.insert(
823 "sec-websocket-protocol",
824 "construct.v1, bearer.zc_tok, other".parse().unwrap(),
825 );
826 assert_eq!(extract_ws_token(&headers, None), Some("zc_tok"));
827 }
828}