use std::collections::HashMap;
use std::future::Future;
use std::sync::Arc;
use anyhow::{anyhow, Context, Result};
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
Query, State,
},
response::Response,
};
use futures_util::{SinkExt, StreamExt};
use serde_json::{json, Value};
use tokio::sync::mpsc;
use crate::agent::{spawn_agent, Agent};
use crate::config::Config;
use crate::session::{extract_session_info, short_reason, try_load_session};
const PROTOCOL_VERSION: u32 = 1;
async fn start_new_session(agent: &Agent, cwd: &str) -> Result<(String, Option<Value>)> {
let result = agent
.request("session/new", json!({ "cwd": cwd, "mcpServers": [] }))
.await
.context("Failed to start new session")?;
let sid = result
.get("sessionId")
.and_then(Value::as_str)
.ok_or_else(|| anyhow!("Session creation returned no session id"))?
.to_string();
Ok((sid, extract_session_info(&result)))
}
fn spawn_with_error_report(
to_ws: mpsc::UnboundedSender<Message>,
error_prefix: &'static str,
fut: impl Future<Output = Result<()>> + Send + 'static,
) {
tokio::spawn(async move {
if let Err(e) = fut.await {
let _ = to_ws.send(text_msg(json!({
"type": "error",
"message": format!("{error_prefix}: {e}")
})));
}
});
}
pub(crate) async fn ws_upgrade(
ws: WebSocketUpgrade,
Query(params): Query<HashMap<String, String>>,
State(cfg): State<Arc<Config>>,
) -> Response {
let resume = params.get("session").cloned();
let cwd_override = params
.get("cwd")
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty());
ws.on_upgrade(move |socket| async move {
if let Err(e) = handle_ws(socket, cfg, resume, cwd_override).await {
eprintln!("WebSocket session ended: {e:?}");
}
})
}
fn text_msg(value: Value) -> Message {
Message::Text(value.to_string())
}
async fn handle_ws(
ws: WebSocket,
cfg: Arc<Config>,
resume_session_id: Option<String>,
cwd_override: Option<String>,
) -> Result<()> {
let (mut sink, mut stream) = ws.split();
let (to_ws_tx, mut to_ws_rx) = mpsc::unbounded_channel::<Message>();
let writer = tokio::spawn(async move {
while let Some(msg) = to_ws_rx.recv().await {
if sink.send(msg).await.is_err() {
break;
}
}
});
let (agent, mut updates_rx) = match spawn_agent(&cfg).await {
Ok((a, rx)) => (Arc::new(a), rx),
Err(e) => {
let _ = to_ws_tx.send(text_msg(
json!({ "type": "error", "message": format!("{e}") }),
));
drop(to_ws_tx);
let _ = writer.await;
return Ok(());
}
};
let initialize_result = agent
.request(
"initialize",
json!({
"protocolVersion": PROTOCOL_VERSION,
"clientCapabilities": {
"fs": { "readTextFile": false, "writeTextFile": false }
}
}),
)
.await
.context("Failed to initialize agent")?;
let prompt_capabilities = initialize_result
.get("agentCapabilities")
.and_then(|c| c.get("promptCapabilities"))
.cloned()
.unwrap_or_else(|| json!({}));
let cwd_str = match cwd_override {
Some(c) => c,
None => std::env::current_dir()?.to_string_lossy().to_string(),
};
let (session_id, resumed, session_info) = match resume_session_id {
Some(sid) => match try_load_session(&agent, &sid, &cwd_str).await {
Ok(value) => (sid, true, extract_session_info(&value)),
Err(err_str) => {
eprintln!("Session load failed: {err_str}. Falling back to a new session.");
let _ = to_ws_tx.send(text_msg(json!({
"type": "append",
"role": "sys",
"text": format!(
"\n[{} — Starting a new one.]\n",
short_reason(&err_str)
)
})));
let (sid, info) = start_new_session(&agent, &cwd_str).await?;
(sid, false, info)
}
},
None => {
let (sid, info) = start_new_session(&agent, &cwd_str).await?;
(sid, false, info)
}
};
let _ = to_ws_tx.send(text_msg(json!({
"type": "ready",
"sessionId": session_id,
"resumed": resumed,
"cwd": cwd_str,
"promptCapabilities": prompt_capabilities,
"buildId": env!("MEZAME_BUILD_ID")
})));
if let Some(info) = session_info {
let _ = to_ws_tx.send(text_msg(json!({
"type": "session_info",
"info": info
})));
}
let mut suppress_session_updates = resumed;
loop {
tokio::select! {
ws_msg = stream.next() => {
let text = match ws_msg {
None => break, Some(Err(_)) => break, Some(Ok(Message::Close(_))) => break, Some(Ok(Message::Text(t))) => t,
Some(Ok(_)) => continue, };
let v: Value = match serde_json::from_str(&text) {
Ok(v) => v,
Err(_) => continue
};
match v.get("type").and_then(Value::as_str) {
Some("prompt") => {
let prompt_blocks: Vec<Value> = if let Some(blocks) = v.get("blocks").and_then(Value::as_array) {
blocks.clone()
} else if let Some(user_text) = v.get("text").and_then(Value::as_str) {
vec![json!({ "type": "text", "text": user_text })]
} else {
continue;
};
if prompt_blocks.is_empty() {
continue;
}
suppress_session_updates = false;
let agent = agent.clone();
let to_ws = to_ws_tx.clone();
let sid = session_id.clone();
tokio::spawn(async move {
let res = agent
.request(
"session/prompt",
json!({
"sessionId": sid,
"prompt": prompt_blocks
})
)
.await;
if let Err(e) = res {
let _ = to_ws.send(text_msg(json!({ "type": "error", "message": format!("{e}") })));
}
let _ = to_ws.send(text_msg(json!({ "type": "prompt_done" })));
});
}
Some("permission_response") => {
let Some(id) = v.get("id").cloned() else {
continue;
};
let option_id = v
.get("optionId")
.and_then(Value::as_str)
.unwrap_or("")
.to_string();
let agent = agent.clone();
spawn_with_error_report(
to_ws_tx.clone(),
"Permission reply failed",
async move {
agent
.respond(
id,
json!({
"outcome": {
"outcome": "selected",
"optionId": option_id
}
}),
)
.await
},
);
}
Some("cancel") => {
let agent = agent.clone();
let sid = session_id.clone();
tokio::spawn(async move {
let _ = agent
.notify(
"session/cancel",
json!({ "sessionId": sid })
)
.await;
});
}
Some("set_mode") => {
let Some(mode_id) = v.get("modeId").and_then(Value::as_str) else {
continue;
};
let mode_id = mode_id.to_string();
let agent = agent.clone();
let sid = session_id.clone();
spawn_with_error_report(
to_ws_tx.clone(),
"Failed to change agent mode",
async move {
agent
.request(
"session/set_mode",
json!({ "sessionId": sid, "modeId": mode_id }),
)
.await?;
Ok(())
},
);
}
Some("set_model") => {
let Some(model_id) = v.get("modelId").and_then(Value::as_str) else {
continue;
};
let model_id = model_id.to_string();
let agent = agent.clone();
let sid = session_id.clone();
spawn_with_error_report(
to_ws_tx.clone(),
"Failed to change model",
async move {
agent
.request(
"session/set_model",
json!({ "sessionId": sid, "modelId": model_id }),
)
.await?;
Ok(())
},
);
}
_ => continue
}
}
agent_msg = updates_rx.recv() => {
match agent_msg {
Some(msg) => handle_agent_message(&to_ws_tx, msg, suppress_session_updates).await,
None => break, }
}
else => break
}
}
agent.shutdown(Some(&session_id)).await;
drop(to_ws_tx);
let _ = writer.await;
Ok(())
}
async fn handle_agent_message(
tx: &mpsc::UnboundedSender<Message>,
msg: Value,
suppress_session_updates: bool,
) {
let method = msg.get("method").and_then(Value::as_str).unwrap_or("");
match method {
"_kiro.dev/commands/available" => {
if let Some(params) = msg.get("params") {
let commands = params
.get("commands")
.cloned()
.unwrap_or(Value::Array(vec![]));
let prompts = params
.get("prompts")
.cloned()
.unwrap_or(Value::Array(vec![]));
let _ = tx.send(text_msg(json!({
"type": "commands",
"commands": commands,
"prompts": prompts
})));
}
}
"session/update" => {
if suppress_session_updates {
return;
}
let update = msg
.get("params")
.and_then(|p| p.get("update"))
.cloned()
.unwrap_or(Value::Null);
let kind = update
.get("sessionUpdate")
.and_then(Value::as_str)
.unwrap_or("");
match kind {
"agent_message_chunk" => {
if let Some(text) = update
.get("content")
.and_then(|c| c.get("text"))
.and_then(Value::as_str)
{
let _ = tx.send(text_msg(
json!({ "type": "append", "role": "agent", "text": text }),
));
}
}
"user_message_chunk" => {
if let Some(text) = update
.get("content")
.and_then(|c| c.get("text"))
.and_then(Value::as_str)
{
let _ = tx.send(text_msg(json!({
"type": "append",
"role": "user",
"text": format!("> {text}\n")
})));
}
}
"agent_thought_chunk" => {
if let Some(text) = update
.get("content")
.and_then(|c| c.get("text"))
.and_then(Value::as_str)
{
let _ = tx.send(text_msg(json!({ "type": "append", "role": "sys", "text": format!("(thinking) {text}") })));
}
}
"tool_call" | "tool_call_update" => {
let tool_call_id = update.get("toolCallId").cloned().unwrap_or(Value::Null);
if tool_call_id.is_null() {
let title = update
.get("title")
.and_then(Value::as_str)
.unwrap_or("tool");
let status = update.get("status").and_then(Value::as_str).unwrap_or("");
let line = if status.is_empty() {
format!("\n[{title}]\n")
} else {
format!("\n[{title}: {status}]\n")
};
let _ = tx.send(text_msg(
json!({ "type": "append", "role": "sys", "text": line }),
));
return;
}
let _ = tx.send(text_msg(json!({
"type": "tool_call",
"toolCallId": tool_call_id,
"title": update.get("title").cloned().unwrap_or(Value::Null),
"status": update.get("status").cloned().unwrap_or(Value::Null),
"kind": update.get("kind").cloned().unwrap_or(Value::Null),
"rawInput": update.get("rawInput").cloned().unwrap_or(Value::Null),
"content": update.get("content").cloned().unwrap_or(Value::Null),
"locations": update.get("locations").cloned().unwrap_or(Value::Null)
})));
}
_ => {}
}
}
"session/request_permission" => {
if let Some(params) = msg.get("params") {
let id = msg.get("id").cloned().unwrap_or(Value::Null);
let title = params
.get("toolCall")
.and_then(|tc| tc.get("title").or_else(|| tc.get("name")))
.and_then(Value::as_str)
.unwrap_or("tool")
.to_string();
let options = params
.get("options")
.cloned()
.unwrap_or(Value::Array(vec![]));
let _ = tx.send(text_msg(json!({
"type": "permission_request",
"id": id,
"title": title,
"options": options
})));
}
}
_ => {
}
}
}