use std::collections::HashMap;
use std::convert::Infallible;
use std::sync::Arc;
use axum::Router;
use axum::extract::State;
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::http::StatusCode;
use axum::response::Json;
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::routing::{get, post};
use futures::SinkExt;
use futures::stream::StreamExt;
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex as TokioMutex;
use tokio_stream::wrappers::BroadcastStream;
use uuid::Uuid;
use agent_code_lib::query::{QueryEngine, StreamSink};
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type")]
pub enum SseEvent {
#[serde(rename = "text_delta")]
TextDelta { text: String },
#[serde(rename = "tool_start")]
ToolStart { name: String },
#[serde(rename = "tool_result")]
ToolResult { name: String, is_error: bool },
#[serde(rename = "thinking")]
Thinking { text: String },
#[serde(rename = "turn_complete")]
TurnComplete { turn: usize },
#[serde(rename = "usage")]
Usage {
input_tokens: u64,
output_tokens: u64,
},
#[serde(rename = "error")]
Error { message: String },
#[serde(rename = "compact")]
Compact { freed_tokens: u64 },
#[serde(rename = "warning")]
Warning { message: String },
#[serde(rename = "done")]
Done {
response: String,
turn_count: usize,
tools_used: Vec<String>,
cost_usd: f64,
},
}
pub struct ServerState {
pub engine: tokio::sync::Mutex<QueryEngine>,
pub event_tx: tokio::sync::RwLock<Option<tokio::sync::broadcast::Sender<SseEvent>>>,
pub auth_token: String,
pub permission_requests: TokioMutex<HashMap<String, tokio::sync::oneshot::Sender<String>>>,
}
#[derive(Debug, Deserialize)]
pub struct MessageRequest {
pub content: String,
}
#[derive(Debug, Serialize)]
pub struct MessageResponse {
pub response: String,
pub turn_count: usize,
pub tools_used: Vec<String>,
pub cost_usd: f64,
}
#[derive(Debug, Serialize)]
pub struct StatusResponse {
pub session_id: String,
pub model: String,
pub cwd: String,
pub turn_count: usize,
pub message_count: usize,
pub cost_usd: f64,
pub plan_mode: bool,
pub version: String,
pub streaming: bool,
}
#[derive(Debug, Serialize)]
pub struct MessagesResponse {
pub messages: Vec<MessageEntry>,
}
#[derive(Debug, Serialize)]
pub struct MessageEntry {
pub role: String,
pub content: String,
pub tool_calls: usize,
}
struct SseBroadcastSink {
tx: tokio::sync::broadcast::Sender<SseEvent>,
text: std::sync::Mutex<String>,
tools: std::sync::Mutex<Vec<String>>,
}
impl SseBroadcastSink {
fn new() -> (Self, tokio::sync::broadcast::Receiver<SseEvent>) {
let (tx, rx) = tokio::sync::broadcast::channel(256);
(
Self {
tx,
text: std::sync::Mutex::new(String::new()),
tools: std::sync::Mutex::new(Vec::new()),
},
rx,
)
}
fn send(&self, event: SseEvent) {
let _ = self.tx.send(event);
}
}
impl StreamSink for SseBroadcastSink {
fn on_text(&self, text: &str) {
if let Ok(mut t) = self.text.lock() {
t.push_str(text);
}
self.send(SseEvent::TextDelta {
text: text.to_string(),
});
}
fn on_tool_start(&self, name: &str, _input: &serde_json::Value) {
if let Ok(mut tools) = self.tools.lock()
&& !tools.contains(&name.to_string())
{
tools.push(name.to_string());
}
self.send(SseEvent::ToolStart {
name: name.to_string(),
});
}
fn on_tool_result(&self, name: &str, result: &agent_code_lib::tools::ToolResult) {
self.send(SseEvent::ToolResult {
name: name.to_string(),
is_error: result.is_error,
});
}
fn on_thinking(&self, text: &str) {
self.send(SseEvent::Thinking {
text: text.to_string(),
});
}
fn on_turn_complete(&self, turn: usize) {
self.send(SseEvent::TurnComplete { turn });
}
fn on_error(&self, error: &str) {
if let Ok(mut t) = self.text.lock() {
t.push_str(&format!("\n[Error: {error}]"));
}
self.send(SseEvent::Error {
message: error.to_string(),
});
}
fn on_usage(&self, usage: &agent_code_lib::llm::message::Usage) {
self.send(SseEvent::Usage {
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
});
}
fn on_compact(&self, freed_tokens: u64) {
self.send(SseEvent::Compact { freed_tokens });
}
fn on_warning(&self, msg: &str) {
self.send(SseEvent::Warning {
message: msg.to_string(),
});
}
}
pub async fn run_server(engine: QueryEngine, port: u16) -> anyhow::Result<()> {
let auth_token = Uuid::new_v4().to_string();
let state = Arc::new(ServerState {
engine: tokio::sync::Mutex::new(engine),
event_tx: tokio::sync::RwLock::new(None),
auth_token: auth_token.clone(),
permission_requests: TokioMutex::new(HashMap::new()),
});
let cwd = std::env::current_dir()
.map(|p| p.display().to_string())
.unwrap_or_default();
let app = Router::new()
.route("/message", post(handle_message))
.route("/events", get(handle_events))
.route("/status", get(handle_status))
.route("/messages", get(handle_messages))
.route("/health", get(handle_health))
.route("/ws", get(handle_ws))
.route("/permission", post(handle_permission))
.with_state(state);
let addr = format!("127.0.0.1:{port}");
let listener = tokio::net::TcpListener::bind(&addr).await?;
let actual_port = listener.local_addr()?.port();
let lock_file = agent_code_lib::services::bridge::write_lock_file(actual_port, &cwd).ok();
eprintln!("agent-code server listening on http://127.0.0.1:{actual_port}");
eprintln!("POST /message — send a prompt");
eprintln!("GET /events — SSE event stream");
eprintln!("GET /ws — WebSocket (JSON-RPC)");
eprintln!("GET /status — session status");
eprintln!("GET /messages — conversation history");
eprintln!("GET /health — health check");
eprintln!("POST /permission — respond to permission request");
eprintln!();
eprintln!("Auth token: {auth_token}");
eprintln!("Press Ctrl+C to stop.");
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await?;
if let Some(ref lf) = lock_file {
agent_code_lib::services::bridge::remove_lock_file(lf);
}
eprintln!("\nServer stopped.");
Ok(())
}
async fn handle_message(
State(state): State<Arc<ServerState>>,
Json(req): Json<MessageRequest>,
) -> Result<Json<MessageResponse>, (StatusCode, String)> {
let (sink, _rx) = SseBroadcastSink::new();
let sink = Arc::new(sink);
{
let mut event_tx = state.event_tx.write().await;
*event_tx = Some(sink.tx.clone());
}
let sink_ref: &dyn StreamSink = &*sink;
let mut engine = state.engine.lock().await;
let turn_result = engine.run_turn_with_sink(&req.content, sink_ref).await;
let response_text = sink.text.lock().map(|t| t.clone()).unwrap_or_default();
let tools_used = sink.tools.lock().map(|t| t.clone()).unwrap_or_default();
let state_ref = engine.state();
let turn_count = state_ref.turn_count;
let cost_usd = state_ref.total_cost_usd;
sink.send(SseEvent::Done {
response: response_text.clone(),
turn_count,
tools_used: tools_used.clone(),
cost_usd,
});
{
let mut event_tx = state.event_tx.write().await;
*event_tx = None;
}
turn_result.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
Ok(Json(MessageResponse {
response: response_text,
turn_count,
tools_used,
cost_usd,
}))
}
async fn handle_events(
State(state): State<Arc<ServerState>>,
) -> Sse<impl futures::stream::Stream<Item = Result<Event, Infallible>>> {
let rx = {
let event_tx = state.event_tx.read().await;
match &*event_tx {
Some(tx) => tx.subscribe(),
None => {
let (tx, rx) = tokio::sync::broadcast::channel(256);
drop(event_tx);
let mut event_tx = state.event_tx.write().await;
*event_tx = Some(tx);
rx
}
}
};
let stream = BroadcastStream::new(rx).filter_map(|result| {
futures::future::ready(match result {
Ok(event) => {
let data = serde_json::to_string(&event).unwrap_or_default();
Some(Ok(Event::default().data(data)))
}
Err(_) => None, })
});
Sse::new(stream).keep_alive(KeepAlive::default())
}
async fn handle_status(State(state): State<Arc<ServerState>>) -> Json<StatusResponse> {
let engine = state.engine.lock().await;
let s = engine.state();
let streaming = state.event_tx.read().await.is_some();
Json(StatusResponse {
session_id: s.session_id.clone(),
model: s.config.api.model.clone(),
cwd: s.cwd.clone(),
turn_count: s.turn_count,
message_count: s.messages.len(),
cost_usd: s.total_cost_usd,
plan_mode: s.plan_mode,
version: env!("CARGO_PKG_VERSION").to_string(),
streaming,
})
}
async fn handle_messages(State(state): State<Arc<ServerState>>) -> Json<MessagesResponse> {
let engine = state.engine.lock().await;
let messages: Vec<MessageEntry> = engine
.state()
.messages
.iter()
.map(|msg| match msg {
agent_code_lib::llm::message::Message::User(u) => {
let text: String = u
.content
.iter()
.filter_map(|b| match b {
agent_code_lib::llm::message::ContentBlock::Text { text } => {
Some(text.clone())
}
agent_code_lib::llm::message::ContentBlock::ToolResult {
content, ..
} => Some(content.clone()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n");
MessageEntry {
role: "user".into(),
content: text,
tool_calls: 0,
}
}
agent_code_lib::llm::message::Message::Assistant(a) => {
let text: String = a
.content
.iter()
.filter_map(|b| {
if let agent_code_lib::llm::message::ContentBlock::Text { text } = b {
Some(text.as_str())
} else {
None
}
})
.collect::<Vec<_>>()
.join("");
let tc = a
.content
.iter()
.filter(|b| {
matches!(
b,
agent_code_lib::llm::message::ContentBlock::ToolUse { .. }
)
})
.count();
MessageEntry {
role: "assistant".into(),
content: text,
tool_calls: tc,
}
}
_ => MessageEntry {
role: "system".into(),
content: String::new(),
tool_calls: 0,
},
})
.collect();
Json(MessagesResponse { messages })
}
async fn handle_health() -> &'static str {
"ok"
}
async fn handle_ws(
State(state): State<Arc<ServerState>>,
ws: WebSocketUpgrade,
) -> impl axum::response::IntoResponse {
ws.on_upgrade(move |socket| handle_ws_connection(socket, state))
}
async fn handle_ws_connection(mut socket: WebSocket, state: Arc<ServerState>) {
let authed = match socket.recv().await {
Some(Ok(Message::Text(ref text))) => {
if let Ok(json) = serde_json::from_str::<serde_json::Value>(text.as_str()) {
json.get("auth")
.and_then(|v| v.as_str())
.map(|t| t == state.auth_token)
.unwrap_or(false)
} else {
false
}
}
_ => false,
};
if !authed {
let _ = socket
.send(Message::Text(
serde_json::json!({"jsonrpc": "2.0", "error": {"code": -32600, "message": "Unauthorized"}}).to_string().into(),
))
.await;
return;
}
let (mut ws_tx, mut ws_rx) = socket.split();
let (out_tx, mut out_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
let out_task = tokio::spawn(async move {
while let Some(text) = out_rx.recv().await {
if ws_tx.send(Message::Text(text.into())).await.is_err() {
break;
}
}
});
let out_tx = Arc::new(out_tx);
while let Some(Ok(msg)) = ws_rx.next().await {
let text = match msg {
Message::Text(ref t) => t.to_string(),
Message::Close(_) => break,
_ => continue,
};
let json: serde_json::Value = match serde_json::from_str(&text) {
Ok(v) => v,
Err(_) => continue,
};
let has_id = json.get("id").is_some() && !json["id"].is_null();
let has_method = json.get("method").is_some();
if has_id && has_method {
let id = json["id"].clone();
let method = json["method"].as_str().unwrap_or("").to_string();
let params = json
.get("params")
.cloned()
.unwrap_or(serde_json::Value::Null);
let result = handle_ws_request(&state, &method, ¶ms, Arc::clone(&out_tx)).await;
let response = serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"result": result,
});
let _ = out_tx.send(response.to_string());
} else if has_id && !has_method {
let id = json["id"].as_str().unwrap_or("").to_string();
let decision = json
.get("result")
.and_then(|r| r.get("decision"))
.and_then(|d| d.as_str())
.unwrap_or("deny")
.to_string();
let mut pending = state.permission_requests.lock().await;
if let Some(tx) = pending.remove(&id) {
let _ = tx.send(decision);
}
}
}
out_task.abort();
}
async fn handle_ws_request(
state: &Arc<ServerState>,
method: &str,
params: &serde_json::Value,
out_tx: Arc<tokio::sync::mpsc::UnboundedSender<String>>,
) -> serde_json::Value {
match method {
"message" => {
let content = params
.get("content")
.and_then(|c| c.as_str())
.unwrap_or("")
.to_string();
let (sink, _rx) = SseBroadcastSink::new();
let sink = Arc::new(sink);
{
let mut event_tx = state.event_tx.write().await;
*event_tx = Some(sink.tx.clone());
}
let out_tx_clone = Arc::clone(&out_tx);
let mut rx = sink.tx.subscribe();
let forward_task = tokio::spawn(async move {
while let Ok(event) = rx.recv().await {
let method_name = match &event {
SseEvent::TextDelta { .. } => "events/text_delta",
SseEvent::ToolStart { .. } => "events/tool_start",
SseEvent::ToolResult { .. } => "events/tool_result",
SseEvent::Thinking { .. } => "events/thinking",
SseEvent::TurnComplete { .. } => "events/turn_complete",
SseEvent::Usage { .. } => "events/usage",
SseEvent::Error { .. } => "events/error",
SseEvent::Compact { .. } => "events/compact",
SseEvent::Warning { .. } => "events/warning",
SseEvent::Done { .. } => "events/done",
};
let notification = serde_json::json!({
"jsonrpc": "2.0",
"method": method_name,
"params": event,
});
if out_tx_clone.send(notification.to_string()).is_err() {
break;
}
}
});
let sink_ref: &dyn StreamSink = &*sink;
let mut engine = state.engine.lock().await;
let _ = engine.run_turn_with_sink(&content, sink_ref).await;
let response_text = sink.text.lock().map(|t| t.clone()).unwrap_or_default();
let tools_used = sink.tools.lock().map(|t| t.clone()).unwrap_or_default();
let turn_count = engine.state().turn_count;
let cost_usd = engine.state().total_cost_usd;
sink.send(SseEvent::Done {
response: response_text.clone(),
turn_count,
tools_used: tools_used.clone(),
cost_usd,
});
{
let mut event_tx = state.event_tx.write().await;
*event_tx = None;
}
forward_task.abort();
serde_json::json!({
"response": response_text,
"turn_count": turn_count,
"tools_used": tools_used,
"cost_usd": cost_usd,
})
}
"status" => {
let engine = state.engine.lock().await;
let s = engine.state();
serde_json::json!({
"session_id": s.session_id,
"model": s.config.api.model,
"cwd": s.cwd,
"turn_count": s.turn_count,
"message_count": s.messages.len(),
"cost_usd": s.total_cost_usd,
"plan_mode": s.plan_mode,
"version": env!("CARGO_PKG_VERSION"),
})
}
"cancel" => {
let engine = state.engine.lock().await;
engine.cancel();
serde_json::json!({"cancelled": true})
}
_ => {
serde_json::json!({"error": format!("Unknown method: {method}")})
}
}
}
async fn handle_permission(
State(state): State<Arc<ServerState>>,
Json(req): Json<PermissionResponse>,
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
let mut pending = state.permission_requests.lock().await;
if let Some(tx) = pending.remove(&req.id) {
let _ = tx.send(req.decision.clone());
Ok(Json(serde_json::json!({"ok": true})))
} else {
Err((
StatusCode::NOT_FOUND,
format!("No pending permission request with id: {}", req.id),
))
}
}
#[derive(Debug, Deserialize)]
struct PermissionResponse {
id: String,
decision: String,
}
async fn shutdown_signal() {
tokio::signal::ctrl_c()
.await
.expect("failed to listen for ctrl+c");
}