Skip to main content

koda_cli/
server.rs

1//! ACP server over stdio JSON-RPC.
2//!
3//! Reads newline-delimited JSON from stdin, writes JSON-RPC messages to stdout.
4//! Implements the ACP lifecycle: Initialize → Authenticate → NewSession → Prompt → Cancel.
5
6use crate::acp_adapter::{self, AcpOutgoing, PendingApproval};
7use acp::Side;
8use agent_client_protocol_schema as acp;
9use anyhow::Result;
10use koda_core::agent::KodaAgent;
11use koda_core::approval::ApprovalMode;
12use koda_core::config::KodaConfig;
13use koda_core::db::{Database, Role};
14use koda_core::engine::EngineCommand;
15use koda_core::persistence::Persistence;
16use koda_core::session::KodaSession;
17use std::collections::HashMap;
18use std::path::PathBuf;
19use std::sync::atomic::AtomicI64;
20use std::sync::{Arc, Mutex};
21use tokio::io::{AsyncBufReadExt, BufReader};
22use tokio::sync::mpsc;
23use tokio_util::sync::CancellationToken;
24
25/// An active prompt session with its running task handle.
26struct ActiveSession {
27    session: KodaSession,
28    cmd_tx: mpsc::Sender<EngineCommand>,
29    cancel: CancellationToken,
30}
31
32/// Server state shared across the event loop.
33struct ServerState {
34    agent: Arc<KodaAgent>,
35    config: KodaConfig,
36    db: Database,
37    project_root: PathBuf,
38    active: Option<ActiveSession>,
39    /// Maps outgoing JSON-RPC request IDs to engine approval IDs.
40    pending_approvals: Arc<Mutex<HashMap<acp::RequestId, PendingApproval>>>,
41    /// Counter for outgoing JSON-RPC request IDs.
42    next_rpc_id: Arc<AtomicI64>,
43}
44
45/// Run the ACP server over stdio.
46///
47/// Reads newline-delimited JSON-RPC from stdin, dispatches to handlers,
48/// and writes JSON-RPC responses/notifications to stdout.
49pub async fn run_stdio_server(project_root: PathBuf, mut config: KodaConfig) -> Result<()> {
50    // Initialize database
51    let db = Database::init(&koda_core::db::config_dir()?).await?;
52
53    // Query actual model capabilities before building agent
54    let tmp_provider = koda_core::providers::create_provider(&config);
55    config
56        .query_and_apply_capabilities(tmp_provider.as_ref())
57        .await;
58
59    // Build agent (tools, system prompt)
60    let agent = Arc::new(KodaAgent::new(&config, project_root.clone(), &[]).await?);
61
62    let pending_approvals = Arc::new(Mutex::new(HashMap::new()));
63    let next_rpc_id = Arc::new(AtomicI64::new(1));
64
65    let mut state = ServerState {
66        agent,
67        config,
68        db,
69        project_root,
70        active: None,
71        pending_approvals,
72        next_rpc_id,
73    };
74
75    // Channel for outgoing messages → stdout writer task
76    let (out_tx, mut out_rx) = mpsc::channel::<String>(256);
77
78    // Spawn stdout writer task
79    tokio::spawn(async move {
80        use tokio::io::AsyncWriteExt;
81        let mut stdout = tokio::io::stdout();
82        while let Some(line) = out_rx.recv().await {
83            if stdout.write_all(line.as_bytes()).await.is_err() {
84                break;
85            }
86            if stdout.write_all(b"\n").await.is_err() {
87                break;
88            }
89            let _ = stdout.flush().await;
90        }
91    });
92
93    // Read stdin line by line
94    let stdin = tokio::io::stdin();
95    let mut reader = BufReader::new(stdin);
96    let mut line = String::new();
97
98    loop {
99        line.clear();
100        let n = reader.read_line(&mut line).await?;
101        if n == 0 {
102            // EOF — client disconnected
103            break;
104        }
105
106        let trimmed = line.trim();
107        if trimmed.is_empty() {
108            continue;
109        }
110
111        // Parse raw JSON to determine message type
112        let raw: serde_json::Value = match serde_json::from_str(trimmed) {
113            Ok(v) => v,
114            Err(e) => {
115                let err =
116                    make_error_response(acp::RequestId::Null, -32700, &format!("Parse error: {e}"));
117                send_json(&out_tx, &err).await;
118                continue;
119            }
120        };
121
122        let has_method = raw.get("method").and_then(|m| m.as_str()).is_some();
123        let has_id = raw.get("id").is_some();
124        let has_result = raw.get("result").is_some();
125        let has_error = raw.get("error").is_some();
126
127        if has_method && has_id {
128            // Request from client
129            handle_request(&raw, &mut state, &out_tx).await;
130        } else if has_method && !has_id {
131            // Notification from client
132            handle_notification(&raw, &mut state).await;
133        } else if has_id && (has_result || has_error) {
134            // Response to our outgoing request (permission response)
135            handle_response(&raw, &mut state).await;
136        } else {
137            let err = make_error_response(acp::RequestId::Null, -32600, "Invalid JSON-RPC message");
138            send_json(&out_tx, &err).await;
139        }
140    }
141
142    Ok(())
143}
144
145/// Handle an incoming JSON-RPC request.
146async fn handle_request(
147    raw: &serde_json::Value,
148    state: &mut ServerState,
149    out_tx: &mpsc::Sender<String>,
150) {
151    let id = parse_request_id(raw);
152    let method = raw["method"].as_str().unwrap_or("");
153
154    // Extract params as RawValue for ACP decoder
155    let params_raw = raw
156        .get("params")
157        .map(|p| serde_json::value::to_raw_value(p).unwrap());
158
159    let decoded = acp::AgentSide::decode_request(method, params_raw.as_deref());
160
161    let request = match decoded {
162        Ok(r) => r,
163        Err(e) => {
164            let err = make_error_response(id, -32601, &format!("Unknown method '{method}': {e}"));
165            send_json(out_tx, &err).await;
166            return;
167        }
168    };
169
170    match request {
171        acp::ClientRequest::InitializeRequest(req) => {
172            handle_initialize(id, req, out_tx).await;
173        }
174        acp::ClientRequest::AuthenticateRequest(_req) => {
175            handle_authenticate(id, out_tx).await;
176        }
177        acp::ClientRequest::NewSessionRequest(req) => {
178            handle_new_session(id, req, state, out_tx).await;
179        }
180        acp::ClientRequest::PromptRequest(req) => {
181            handle_prompt(id, req, state, out_tx).await;
182        }
183        _ => {
184            let err = make_error_response(
185                id,
186                -32601,
187                &format!("Method '{method}' not yet implemented"),
188            );
189            send_json(out_tx, &err).await;
190        }
191    }
192}
193
194/// Handle an incoming JSON-RPC notification (no response expected).
195async fn handle_notification(raw: &serde_json::Value, state: &mut ServerState) {
196    let method = raw["method"].as_str().unwrap_or("");
197    let params_raw = raw
198        .get("params")
199        .map(|p| serde_json::value::to_raw_value(p).unwrap());
200
201    let decoded = acp::AgentSide::decode_notification(method, params_raw.as_deref());
202
203    if let Ok(acp::ClientNotification::CancelNotification(_cancel)) = decoded
204        && let Some(ref active) = state.active
205    {
206        active.cancel.cancel();
207    }
208}
209
210/// Handle a JSON-RPC response (to our outgoing permission request).
211async fn handle_response(raw: &serde_json::Value, state: &mut ServerState) {
212    let rpc_id = parse_request_id(raw);
213
214    // Check if this is a permission response
215    if let Some(result) = raw.get("result")
216        && let Ok(perm_resp) =
217            serde_json::from_value::<acp::RequestPermissionResponse>(result.clone())
218        && let Some(ref active) = state.active
219    {
220        acp_adapter::resolve_permission_response(
221            &state.pending_approvals,
222            &rpc_id,
223            &perm_resp.outcome,
224            &active.cmd_tx,
225        );
226    }
227}
228
229/// Handle `initialize` request.
230async fn handle_initialize(
231    id: acp::RequestId,
232    req: acp::InitializeRequest,
233    out_tx: &mpsc::Sender<String>,
234) {
235    let response = acp::InitializeResponse::new(req.protocol_version)
236        .agent_info(acp::Implementation::new("koda", env!("CARGO_PKG_VERSION")));
237
238    let resp = wrap_response(id, acp::AgentResponse::InitializeResponse(response));
239    send_json(out_tx, &resp).await;
240}
241
242/// Handle `authenticate` request (no-op for local agent).
243async fn handle_authenticate(id: acp::RequestId, out_tx: &mpsc::Sender<String>) {
244    let response = acp::AuthenticateResponse::default();
245    let resp = wrap_response(id, acp::AgentResponse::AuthenticateResponse(response));
246    send_json(out_tx, &resp).await;
247}
248
249/// Handle `session/new` request.
250async fn handle_new_session(
251    id: acp::RequestId,
252    _req: acp::NewSessionRequest,
253    state: &mut ServerState,
254    out_tx: &mpsc::Sender<String>,
255) {
256    let session_id = match state
257        .db
258        .create_session(&state.config.agent_name, &state.project_root)
259        .await
260    {
261        Ok(sid) => sid,
262        Err(e) => {
263            let err = make_error_response(id, -32000, &format!("Failed to create session: {e}"));
264            send_json(out_tx, &err).await;
265            return;
266        }
267    };
268
269    let (cmd_tx, _cmd_rx) = mpsc::channel::<EngineCommand>(32);
270    let cancel = CancellationToken::new();
271
272    let session = KodaSession::new(
273        session_id.clone(),
274        state.agent.clone(),
275        state.db.clone(),
276        &state.config,
277        ApprovalMode::Auto,
278    )
279    .await;
280
281    state.active = Some(ActiveSession {
282        session,
283        cmd_tx,
284        cancel,
285    });
286
287    let response = acp::NewSessionResponse::new(session_id);
288    let resp = wrap_response(id, acp::AgentResponse::NewSessionResponse(response));
289    send_json(out_tx, &resp).await;
290}
291
292/// Handle `session/prompt` request.
293async fn handle_prompt(
294    id: acp::RequestId,
295    req: acp::PromptRequest,
296    state: &mut ServerState,
297    out_tx: &mpsc::Sender<String>,
298) {
299    // Extract text from prompt content blocks
300    let mut text_parts = Vec::new();
301    for block in &req.prompt {
302        if let acp::ContentBlock::Text(tc) = block {
303            text_parts.push(tc.text.clone());
304        }
305    }
306    let user_text = text_parts.join("\n");
307
308    // Ensure we have an active session
309    let active = match state.active.as_mut() {
310        Some(a) => a,
311        None => {
312            let err = make_error_response(id, -32000, "No active session. Call session/new first.");
313            send_json(out_tx, &err).await;
314            return;
315        }
316    };
317
318    let session_id = active.session.id.clone();
319
320    // Insert user message into DB
321    if let Err(e) = active
322        .session
323        .db
324        .insert_message(&session_id, &Role::User, Some(&user_text), None, None, None)
325        .await
326    {
327        let err = make_error_response(id, -32000, &format!("Failed to insert message: {e}"));
328        send_json(out_tx, &err).await;
329        return;
330    }
331
332    // Create a fresh cancel token for this prompt
333    active.cancel = CancellationToken::new();
334    active.session.cancel = active.cancel.clone();
335
336    // Create new cmd channel for this prompt
337    let (cmd_tx, mut cmd_rx) = mpsc::channel::<EngineCommand>(32);
338    active.cmd_tx = cmd_tx.clone();
339
340    // Create AcpSink
341    let (acp_tx, mut acp_rx) = mpsc::channel::<AcpOutgoing>(256);
342    let sink = acp_adapter::AcpSink::new(
343        session_id,
344        acp_tx,
345        cmd_tx,
346        state.pending_approvals.clone(),
347        state.next_rpc_id.clone(),
348    );
349
350    // Spawn background task to stream ACP events to stdout
351    let out_tx_events = out_tx.clone();
352    let streaming_task = tokio::spawn(async move {
353        while let Some(outgoing) = acp_rx.recv().await {
354            let json = match &outgoing {
355                AcpOutgoing::Notification(notification) => {
356                    let msg = acp::OutgoingMessage::<acp::AgentSide, acp::ClientSide>::Notification(
357                        acp::Notification {
358                            method: "session/update".into(),
359                            params: Some(acp::AgentNotification::SessionNotification(
360                                notification.clone(),
361                            )),
362                        },
363                    );
364                    let wrapped = acp::JsonRpcMessage::wrap(msg);
365                    serde_json::to_string(&wrapped).ok()
366                }
367                AcpOutgoing::PermissionRequest { rpc_id, request } => {
368                    let msg = acp::OutgoingMessage::<acp::AgentSide, acp::ClientSide>::Request(
369                        acp::Request {
370                            id: rpc_id.clone(),
371                            method: "session/request_permission".into(),
372                            params: Some(acp::AgentRequest::RequestPermissionRequest(
373                                request.clone(),
374                            )),
375                        },
376                    );
377                    let wrapped = acp::JsonRpcMessage::wrap(msg);
378                    serde_json::to_string(&wrapped).ok()
379                }
380            };
381            if let Some(json) = json {
382                let _ = out_tx_events.send(json).await;
383            }
384        }
385    });
386
387    // Run inference on the current task (blocks stdin reading, but that's fine
388    // for the initial single-session implementation)
389    let active = state.active.as_mut().unwrap();
390    let config = state.config.clone();
391    let result = active
392        .session
393        .run_turn(&config, None, &sink, &mut cmd_rx)
394        .await;
395
396    // Drop the sink so the streaming task finishes
397    drop(sink);
398    let _ = streaming_task.await;
399
400    // Determine stop reason
401    let stop_reason = match result {
402        Ok(()) => acp::StopReason::EndTurn,
403        Err(_) => acp::StopReason::EndTurn,
404    };
405
406    let response = acp::PromptResponse::new(stop_reason);
407    let resp = wrap_response(id, acp::AgentResponse::PromptResponse(response));
408    send_json(out_tx, &resp).await;
409}
410
411// ── Helpers ─────────────────────────────────────────────────
412
413/// Parse a JSON-RPC request ID from a raw JSON value.
414fn parse_request_id(raw: &serde_json::Value) -> acp::RequestId {
415    match raw.get("id") {
416        Some(serde_json::Value::Number(n)) => acp::RequestId::Number(n.as_i64().unwrap_or(0)),
417        Some(serde_json::Value::String(s)) => acp::RequestId::Str(s.clone()),
418        Some(serde_json::Value::Null) | None => acp::RequestId::Null,
419        _ => acp::RequestId::Null,
420    }
421}
422
423/// Send a JSON string over the output channel.
424async fn send_json(out_tx: &mpsc::Sender<String>, value: &serde_json::Value) {
425    if let Ok(json) = serde_json::to_string(value) {
426        let _ = out_tx.send(json).await;
427    }
428}
429
430/// Wrap an ACP agent response into a JSON-RPC response value.
431fn wrap_response(id: acp::RequestId, response: acp::AgentResponse) -> serde_json::Value {
432    serde_json::json!({
433        "jsonrpc": "2.0",
434        "id": id,
435        "result": response,
436    })
437}
438
439/// Create a JSON-RPC error response.
440fn make_error_response(id: acp::RequestId, code: i32, message: &str) -> serde_json::Value {
441    serde_json::json!({
442        "jsonrpc": "2.0",
443        "id": id,
444        "error": {
445            "code": code,
446            "message": message,
447        },
448    })
449}