use std::collections::HashMap;
use std::sync::Arc;
use anyhow::{anyhow, Context, Result};
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
Query, State
},
response::Response
};
use futures_util::{SinkExt, StreamExt};
use serde_json::{json, Value};
use tokio::sync::mpsc;
use crate::agent::spawn_agent;
use crate::config::Config;
use crate::session::{extract_session_info, short_reason, try_load_session};
const PROTOCOL_VERSION: u32 = 1;
pub(crate) async fn ws_upgrade(
ws: WebSocketUpgrade,
Query(params): Query<HashMap<String, String>>,
State(cfg): State<Arc<Config>>
) -> Response {
let resume = params.get("session").cloned();
let cwd_override = params
.get("cwd")
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty());
ws.on_upgrade(move |socket| async move {
if let Err(e) = handle_ws(socket, cfg, resume, cwd_override).await {
eprintln!("WebSocket session ended: {e:?}");
}
})
}
fn text_msg(value: Value) -> Message {
Message::Text(value.to_string())
}
async fn handle_ws(
ws: WebSocket,
cfg: Arc<Config>,
resume_session_id: Option<String>,
cwd_override: Option<String>
) -> Result<()> {
let (mut sink, mut stream) = ws.split();
let (to_ws_tx, mut to_ws_rx) = mpsc::unbounded_channel::<Message>();
let writer = tokio::spawn(async move {
while let Some(msg) = to_ws_rx.recv().await {
if sink.send(msg).await.is_err() {
break;
}
}
});
let agent = match spawn_agent(&cfg).await {
Ok(a) => Arc::new(a),
Err(e) => {
let _ = to_ws_tx.send(text_msg(json!({ "type": "error", "message": format!("{e}") })));
drop(to_ws_tx);
let _ = writer.await;
return Ok(());
}
};
let initialize_result = agent
.request(
"initialize",
json!({
"protocolVersion": PROTOCOL_VERSION,
"clientCapabilities": {
"fs": { "readTextFile": false, "writeTextFile": false }
}
})
)
.await
.context("Failed to initialize agent")?;
let prompt_capabilities = initialize_result
.get("agentCapabilities")
.and_then(|c| c.get("promptCapabilities"))
.cloned()
.unwrap_or_else(|| json!({}));
let cwd_str = match cwd_override {
Some(c) => c,
None => std::env::current_dir()?.to_string_lossy().to_string()
};
let (session_id, resumed, session_info) = match resume_session_id {
Some(sid) => match try_load_session(&agent, &sid, &cwd_str).await {
Ok(value) => (sid, true, extract_session_info(&value)),
Err(err_str) => {
eprintln!("Session load failed: {err_str}. Falling back to a new session.");
let _ = to_ws_tx.send(text_msg(json!({
"type": "append",
"role": "sys",
"text": format!(
"\n[{} — Starting a new one.]\n",
short_reason(&err_str)
)
})));
let new_session = agent
.request(
"session/new",
json!({ "cwd": cwd_str, "mcpServers": [] })
)
.await
.context("Failed to start fallback session")?;
let sid = new_session
.get("sessionId")
.and_then(Value::as_str)
.ok_or_else(|| anyhow!("Session creation returned no session id"))?
.to_string();
(sid, false, extract_session_info(&new_session))
}
},
None => {
let new_session = agent
.request(
"session/new",
json!({ "cwd": cwd_str, "mcpServers": [] })
)
.await
.context("Failed to start new session")?;
let sid = new_session
.get("sessionId")
.and_then(Value::as_str)
.ok_or_else(|| anyhow!("Session creation returned no session id"))?
.to_string();
(sid, false, extract_session_info(&new_session))
}
};
let _ = to_ws_tx.send(text_msg(json!({
"type": "ready",
"sessionId": session_id,
"resumed": resumed,
"cwd": cwd_str,
"promptCapabilities": prompt_capabilities,
"buildId": env!("MEZAME_BUILD_ID")
})));
if let Some(info) = session_info {
let _ = to_ws_tx.send(text_msg(json!({
"type": "session_info",
"info": info
})));
}
let mut suppress_session_updates = resumed;
let mut updates_rx = agent.take_updates();
loop {
tokio::select! {
Some(Ok(msg)) = stream.next() => {
let text = match msg {
Message::Text(t) => t,
Message::Close(_) => break,
_ => continue
};
let v: Value = match serde_json::from_str(&text) {
Ok(v) => v,
Err(_) => continue
};
match v.get("type").and_then(Value::as_str) {
Some("prompt") => {
let prompt_blocks: Vec<Value> = if let Some(blocks) = v.get("blocks").and_then(Value::as_array) {
blocks.clone()
} else if let Some(user_text) = v.get("text").and_then(Value::as_str) {
vec![json!({ "type": "text", "text": user_text })]
} else {
continue;
};
if prompt_blocks.is_empty() {
continue;
}
suppress_session_updates = false;
let agent = agent.clone();
let to_ws = to_ws_tx.clone();
let sid = session_id.clone();
tokio::spawn(async move {
let res = agent
.request(
"session/prompt",
json!({
"sessionId": sid,
"prompt": prompt_blocks
})
)
.await;
if let Err(e) = res {
let _ = to_ws.send(text_msg(json!({ "type": "error", "message": format!("{e}") })));
}
let _ = to_ws.send(text_msg(json!({ "type": "prompt_done" })));
});
}
Some("permission_response") => {
let Some(id) = v.get("id").cloned() else {
continue;
};
let option_id = v
.get("optionId")
.and_then(Value::as_str)
.unwrap_or("")
.to_string();
let agent = agent.clone();
let to_ws = to_ws_tx.clone();
tokio::spawn(async move {
if let Err(e) = agent
.respond(
id,
json!({ "outcome": { "outcome": "selected", "optionId": option_id } })
)
.await
{
let _ = to_ws.send(text_msg(json!({
"type": "error",
"message": format!("Permission reply failed: {e}")
})));
}
});
}
Some("cancel") => {
let agent = agent.clone();
let sid = session_id.clone();
tokio::spawn(async move {
let _ = agent
.notify(
"session/cancel",
json!({ "sessionId": sid })
)
.await;
});
}
Some("set_mode") => {
let Some(mode_id) = v.get("modeId").and_then(Value::as_str) else {
continue;
};
let mode_id = mode_id.to_string();
let agent = agent.clone();
let sid = session_id.clone();
let to_ws = to_ws_tx.clone();
tokio::spawn(async move {
if let Err(e) = agent
.request(
"session/set_mode",
json!({ "sessionId": sid, "modeId": mode_id })
)
.await
{
let _ = to_ws.send(text_msg(json!({
"type": "error",
"message": format!("Failed to change agent mode: {e}")
})));
}
});
}
Some("set_model") => {
let Some(model_id) = v.get("modelId").and_then(Value::as_str) else {
continue;
};
let model_id = model_id.to_string();
let agent = agent.clone();
let sid = session_id.clone();
let to_ws = to_ws_tx.clone();
tokio::spawn(async move {
if let Err(e) = agent
.request(
"session/set_model",
json!({ "sessionId": sid, "modelId": model_id })
)
.await
{
let _ = to_ws.send(text_msg(json!({
"type": "error",
"message": format!("Failed to change model: {e}")
})));
}
});
}
_ => continue
}
}
Some(agent_msg) = updates_rx.recv() => {
handle_agent_message(&to_ws_tx, agent_msg, suppress_session_updates).await;
}
else => break
}
}
agent.shutdown(Some(&session_id)).await;
drop(to_ws_tx);
let _ = writer.await;
Ok(())
}
async fn handle_agent_message(
tx: &mpsc::UnboundedSender<Message>,
msg: Value,
suppress_session_updates: bool
) {
let method = msg.get("method").and_then(Value::as_str).unwrap_or("");
match method {
"_kiro.dev/commands/available" => {
if let Some(params) = msg.get("params") {
let commands = params.get("commands").cloned().unwrap_or(Value::Array(vec![]));
let prompts = params.get("prompts").cloned().unwrap_or(Value::Array(vec![]));
let _ = tx.send(text_msg(json!({
"type": "commands",
"commands": commands,
"prompts": prompts
})));
}
}
"session/update" => {
if suppress_session_updates {
return;
}
let update = msg.get("params").and_then(|p| p.get("update")).cloned().unwrap_or(Value::Null);
let kind = update.get("sessionUpdate").and_then(Value::as_str).unwrap_or("");
match kind {
"agent_message_chunk" => {
if let Some(text) = update.get("content").and_then(|c| c.get("text")).and_then(Value::as_str) {
let _ = tx.send(text_msg(json!({ "type": "append", "role": "agent", "text": text })));
}
}
"user_message_chunk" => {
if let Some(text) = update.get("content").and_then(|c| c.get("text")).and_then(Value::as_str) {
let _ = tx.send(text_msg(json!({
"type": "append",
"role": "user",
"text": format!("> {text}\n")
})));
}
}
"agent_thought_chunk" => {
if let Some(text) = update.get("content").and_then(|c| c.get("text")).and_then(Value::as_str) {
let _ = tx.send(text_msg(json!({ "type": "append", "role": "sys", "text": format!("(thinking) {text}") })));
}
}
"tool_call" | "tool_call_update" => {
let tool_call_id = update
.get("toolCallId")
.cloned()
.unwrap_or(Value::Null);
if tool_call_id.is_null() {
let title = update.get("title").and_then(Value::as_str).unwrap_or("tool");
let status = update.get("status").and_then(Value::as_str).unwrap_or("");
let line = if status.is_empty() {
format!("\n[{title}]\n")
} else {
format!("\n[{title}: {status}]\n")
};
let _ = tx.send(text_msg(json!({ "type": "append", "role": "sys", "text": line })));
return;
}
let _ = tx.send(text_msg(json!({
"type": "tool_call",
"toolCallId": tool_call_id,
"title": update.get("title").cloned().unwrap_or(Value::Null),
"status": update.get("status").cloned().unwrap_or(Value::Null),
"kind": update.get("kind").cloned().unwrap_or(Value::Null),
"rawInput": update.get("rawInput").cloned().unwrap_or(Value::Null),
"content": update.get("content").cloned().unwrap_or(Value::Null),
"locations": update.get("locations").cloned().unwrap_or(Value::Null)
})));
}
_ => {}
}
}
"session/request_permission" => {
if let Some(params) = msg.get("params") {
let id = msg.get("id").cloned().unwrap_or(Value::Null);
let title = params
.get("toolCall")
.and_then(|tc| tc.get("title").or_else(|| tc.get("name")))
.and_then(Value::as_str)
.unwrap_or("tool")
.to_string();
let options = params.get("options").cloned().unwrap_or(Value::Array(vec![]));
let _ = tx.send(text_msg(json!({
"type": "permission_request",
"id": id,
"title": title,
"options": options
})));
}
}
_ => {
}
}
}