Skip to main content

vtcode_acp/
transport.rs

1//! Generic JSON-RPC-over-stdio transport for subprocess agents.
2//!
3//! [`StdioTransport`] handles the low-level framing of newline-delimited JSON
4//! over a child process's stdin/stdout pair. It is intentionally protocol-agnostic:
5//! it knows nothing about Copilot, ACP sessions, or any other higher-level concept.
6//!
7//! ## Message routing
8//!
9//! The internal reader task inspects each incoming line and dispatches it as follows:
10//!
11//! - **Response** (has `result` or `error` field with a numeric `id`): looked up in the
12//!   pending table populated by [`StdioTransport::call`] and delivered to the waiting
13//!   caller via a [`tokio::sync::oneshot`] channel.
14//! - **Request / notification** (anything else): forwarded to the closure registered
15//!   via [`StdioTransport::set_notification_handler`].
16//!
17//! Stderr lines are forwarded to `tracing::debug!` under the
18//! `vtcode.stdio_transport.stderr` target.
19
20use std::fmt;
21use std::sync::atomic::{AtomicI64, Ordering};
22use std::sync::{Arc, Mutex as StdMutex};
23
24use hashbrown::HashMap;
25use std::time::Duration;
26
27use serde_json::Value;
28use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
29use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout};
30use tokio::sync::{mpsc, oneshot};
31use tokio::time::timeout;
32
33use crate::error::{AcpError, AcpResult};
34
35/// Callback type for incoming server→client requests and notifications.
36///
37/// The handler receives the raw JSON-RPC message value. It should return
38/// `Ok(())` on success; errors are logged as warnings by the transport.
39type NotificationHandler = Arc<dyn Fn(Value) -> anyhow::Result<()> + Send + Sync>;
40
41#[derive(Debug, Clone, Copy)]
42pub struct StdioTransportOptions {
43    pub include_jsonrpc_version: bool,
44}
45
46impl Default for StdioTransportOptions {
47    fn default() -> Self {
48        Self {
49            include_jsonrpc_version: true,
50        }
51    }
52}
53
54/// Generic JSON-RPC-over-stdio transport for local subprocess agents.
55///
56/// Wraps a child process and provides:
57/// - [`call`](Self::call): send a request and await its response.
58/// - [`notify`](Self::notify): send a fire-and-forget notification.
59/// - [`respond`](Self::respond) / [`respond_error`](Self::respond_error): reply to
60///   incoming server-initiated requests.
61/// - [`set_notification_handler`](Self::set_notification_handler): register the handler
62///   that receives all incoming server→client messages.
63///
64/// The child process is killed when this struct is dropped.
65pub struct StdioTransport {
66    write_tx: mpsc::UnboundedSender<String>,
67    pending: Arc<StdMutex<HashMap<String, oneshot::Sender<AcpResult<Value>>>>>,
68    request_counter: AtomicI64,
69    notification_handler: Arc<StdMutex<Option<NotificationHandler>>>,
70    child: StdMutex<Option<Child>>,
71    rpc_timeout: Duration,
72    options: StdioTransportOptions,
73}
74
75impl StdioTransport {
76    /// Wire up transport from a spawned subprocess's stdin/stdout/stderr.
77    ///
78    /// Spawns background tasks for the writer (stdin), stderr logger, and the
79    /// reader (stdout) that dispatches JSON-RPC messages.
80    pub fn from_child(
81        child: Child,
82        stdin: ChildStdin,
83        stdout: ChildStdout,
84        stderr: ChildStderr,
85        rpc_timeout: Duration,
86    ) -> Self {
87        Self::from_child_with_options(
88            child,
89            stdin,
90            stdout,
91            stderr,
92            rpc_timeout,
93            StdioTransportOptions::default(),
94        )
95    }
96
97    pub fn from_child_with_options(
98        child: Child,
99        stdin: ChildStdin,
100        stdout: ChildStdout,
101        stderr: ChildStderr,
102        rpc_timeout: Duration,
103        options: StdioTransportOptions,
104    ) -> Self {
105        let (write_tx, write_rx) = mpsc::unbounded_channel();
106        let pending = Arc::new(StdMutex::new(HashMap::new()));
107        let notification_handler = Arc::new(StdMutex::new(None));
108
109        spawn_writer(write_rx, stdin);
110        spawn_stderr_logger(stderr);
111        spawn_reader(
112            stdout,
113            Arc::clone(&pending),
114            Arc::clone(&notification_handler),
115        );
116
117        Self {
118            write_tx,
119            pending,
120            request_counter: AtomicI64::new(1),
121            notification_handler,
122            child: StdMutex::new(Some(child)),
123            rpc_timeout,
124            options,
125        }
126    }
127
128    /// Construct a transport with a pre-wired channel for unit tests.
129    ///
130    /// No subprocess is spawned and no background tasks are started. The caller
131    /// can drive the mock by reading from the paired receiver.
132    #[cfg(test)]
133    pub fn new_for_testing(write_tx: mpsc::UnboundedSender<String>, rpc_timeout: Duration) -> Self {
134        Self::new_for_testing_with_options(write_tx, rpc_timeout, StdioTransportOptions::default())
135    }
136
137    #[cfg(test)]
138    pub fn new_for_testing_with_options(
139        write_tx: mpsc::UnboundedSender<String>,
140        rpc_timeout: Duration,
141        options: StdioTransportOptions,
142    ) -> Self {
143        Self {
144            write_tx,
145            pending: Arc::new(StdMutex::new(HashMap::new())),
146            request_counter: AtomicI64::new(1),
147            notification_handler: Arc::new(StdMutex::new(None)),
148            child: StdMutex::new(None),
149            rpc_timeout,
150            options,
151        }
152    }
153
154    /// Register a handler for incoming server→client requests and notifications.
155    ///
156    /// Must be called once after construction. Subsequent calls overwrite the
157    /// previous handler. The handler receives the raw JSON message value for
158    /// every incoming message that is **not** a response to a pending [`call`](Self::call).
159    pub fn set_notification_handler(&self, handler: NotificationHandler) {
160        if let Ok(mut guard) = self.notification_handler.lock() {
161            *guard = Some(handler);
162        }
163    }
164
165    /// Send a JSON-RPC request and wait for its response.
166    ///
167    /// Assigns a monotonically increasing `id`, inserts it into the pending
168    /// table, serialises the message, and awaits the reply up to `rpc_timeout`.
169    ///
170    /// # Errors
171    ///
172    /// Returns [`AcpError::Timeout`] if the peer does not reply in time, or
173    /// [`AcpError::Internal`] if the transport is shut down.
174    pub async fn call(&self, method: &str, params: Value) -> AcpResult<Value> {
175        let id = self.request_counter.fetch_add(1, Ordering::Relaxed);
176        let id_value = Value::from(id);
177        let pending_key = response_id_key(&id_value);
178        let (tx, rx) = oneshot::channel();
179        self.pending
180            .lock()
181            .map_err(|_err| AcpError::Internal("stdio transport pending mutex poisoned".into()))?
182            .insert(pending_key.clone(), tx);
183
184        let mut payload = serde_json::json!({
185            "jsonrpc": "2.0",
186            "id": id,
187            "method": method,
188            "params": params,
189        });
190        maybe_strip_jsonrpc_field(&mut payload, self.options);
191        if let Err(e) = self.send_raw(payload) {
192            // Clean up the pending entry so it doesn't linger until timeout.
193            self.pending.lock().ok().map(|mut g| g.remove(&pending_key));
194            return Err(e);
195        }
196
197        timeout(self.rpc_timeout, rx)
198            .await
199            .map_err(|_err| AcpError::Timeout(format!("{method} timed out")))?
200            .map_err(|_err| AcpError::Internal(format!("{method} response channel closed")))
201            .and_then(|r| r)
202    }
203
204    /// Send a JSON-RPC notification (no response expected).
205    ///
206    /// # Errors
207    ///
208    /// Returns an error if serialisation fails or the writer task has shut down.
209    pub fn notify(&self, method: &str, params: Value) -> AcpResult<()> {
210        let mut payload = serde_json::json!({
211            "jsonrpc": "2.0",
212            "method": method,
213            "params": params,
214        });
215        maybe_strip_jsonrpc_field(&mut payload, self.options);
216        self.send_raw(payload)
217    }
218
219    /// Send a JSON-RPC success response to an incoming server request.
220    ///
221    /// Use this to reply to messages received by the notification handler when
222    /// they carry an `id` field (i.e. they expect a response).
223    ///
224    /// # Errors
225    ///
226    /// Returns an error if serialisation fails or the writer task has shut down.
227    pub fn respond(&self, id: i64, result: Value) -> AcpResult<()> {
228        self.respond_value(Value::from(id), result)
229    }
230
231    pub fn respond_value(&self, id: Value, result: Value) -> AcpResult<()> {
232        let mut payload = serde_json::json!({
233            "jsonrpc": "2.0",
234            "id": id,
235            "result": result,
236        });
237        maybe_strip_jsonrpc_field(&mut payload, self.options);
238        self.send_raw(payload)
239    }
240
241    /// Send a JSON-RPC error response to an incoming server request.
242    ///
243    /// # Errors
244    ///
245    /// Returns an error if serialisation fails or the writer task has shut down.
246    pub fn respond_error(&self, id: i64, code: i32, message: impl Into<String>) -> AcpResult<()> {
247        self.respond_error_value(Value::from(id), code, message)
248    }
249
250    pub fn respond_error_value(
251        &self,
252        id: Value,
253        code: i32,
254        message: impl Into<String>,
255    ) -> AcpResult<()> {
256        let mut payload = serde_json::json!({
257            "jsonrpc": "2.0",
258            "id": id,
259            "error": {
260                "code": code,
261                "message": message.into(),
262            },
263        });
264        maybe_strip_jsonrpc_field(&mut payload, self.options);
265        self.send_raw(payload)
266    }
267
268    fn send_raw(&self, payload: Value) -> AcpResult<()> {
269        let text = serde_json::to_string(&payload)?;
270        self.write_tx
271            .send(text)
272            .map_err(|_err| AcpError::Internal("stdio transport writer channel closed".into()))
273    }
274}
275
276impl fmt::Debug for StdioTransport {
277    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278        f.debug_struct("StdioTransport")
279            .field(
280                "request_counter",
281                &self.request_counter.load(Ordering::Relaxed),
282            )
283            .field("rpc_timeout", &self.rpc_timeout)
284            .finish_non_exhaustive()
285    }
286}
287
288impl Drop for StdioTransport {
289    fn drop(&mut self) {
290        if let Ok(mut child) = self.child.lock()
291            && let Some(child) = child.as_mut()
292        {
293            let _ = child.start_kill();
294        }
295    }
296}
297
298// ============================================================================
299// Background tasks
300// ============================================================================
301
302fn spawn_writer(mut write_rx: mpsc::UnboundedReceiver<String>, mut stdin: ChildStdin) {
303    tokio::spawn(async move {
304        while let Some(payload) = write_rx.recv().await {
305            if stdin.write_all(payload.as_bytes()).await.is_err()
306                || stdin.write_all(b"\n").await.is_err()
307                || stdin.flush().await.is_err()
308            {
309                tracing::warn!(
310                    target: "vtcode.stdio_transport",
311                    "stdin write failed; writer task exiting"
312                );
313                break;
314            }
315        }
316    });
317}
318
319fn spawn_stderr_logger(stderr: ChildStderr) {
320    tokio::spawn(async move {
321        let mut reader = BufReader::new(stderr);
322        let mut line = String::new();
323        loop {
324            line.clear();
325            match reader.read_line(&mut line).await {
326                Ok(0) | Err(_) => break,
327                Ok(_) => {
328                    tracing::debug!(target: "vtcode.stdio_transport.stderr", "{}", line.trim_end())
329                }
330            }
331        }
332    });
333}
334
335fn spawn_reader(
336    stdout: ChildStdout,
337    pending: Arc<StdMutex<HashMap<String, oneshot::Sender<AcpResult<Value>>>>>,
338    notification_handler: Arc<StdMutex<Option<NotificationHandler>>>,
339) {
340    tokio::spawn(async move {
341        let mut reader = BufReader::new(stdout).lines();
342        while let Ok(Some(line)) = reader.next_line().await {
343            if line.trim().is_empty() {
344                continue;
345            }
346            let message: Value = match serde_json::from_str(&line) {
347                Ok(v) => v,
348                Err(e) => {
349                    tracing::warn!("stdio transport: JSON decode failed: {e}");
350                    continue;
351                }
352            };
353
354            // Dispatch JSON-RPC responses to pending callers.
355            // Extract tx before releasing the lock so `tx.send` runs lock-free.
356            if let Some(id) = response_id(&message) {
357                let result = extract_rpc_result(&message);
358                let tx = pending
359                    .lock()
360                    .ok()
361                    .and_then(|mut g| g.remove(&response_id_key(&id)));
362                if let Some(tx) = tx {
363                    let _ = tx.send(result);
364                }
365                continue;
366            }
367
368            // Clone the handler Arc out of the lock so the lock is released
369            // before the handler runs (prevents re-entrancy / call-site latency).
370            if let Some(handler) = notification_handler
371                .lock()
372                .ok()
373                .and_then(|g| g.as_ref().cloned())
374                && let Err(e) = handler(message)
375            {
376                tracing::warn!("stdio transport: notification handler error: {e}");
377            }
378        }
379    });
380}
381
382// ============================================================================
383// Helpers
384// ============================================================================
385
386/// Returns the `id` if the message is a JSON-RPC *response* (has `result` or `error`).
387fn response_id(message: &Value) -> Option<Value> {
388    if message.get("result").is_some() || message.get("error").is_some() {
389        message.get("id").cloned()
390    } else {
391        None
392    }
393}
394
395fn response_id_key(id: &Value) -> String {
396    serde_json::to_string(id).unwrap_or_else(|_| "null".to_string())
397}
398
399fn maybe_strip_jsonrpc_field(payload: &mut Value, options: StdioTransportOptions) {
400    if options.include_jsonrpc_version {
401        return;
402    }
403
404    if let Some(object) = payload.as_object_mut() {
405        object.remove("jsonrpc");
406    }
407}
408
409fn extract_rpc_result(message: &Value) -> AcpResult<Value> {
410    if let Some(error) = message.get("error") {
411        let code = error
412            .get("code")
413            .and_then(Value::as_i64)
414            .unwrap_or_default();
415        let detail = error
416            .get("message")
417            .and_then(Value::as_str)
418            .unwrap_or("unknown error");
419        Err(AcpError::RemoteError {
420            agent_id: "stdio".into(),
421            message: format!("rpc error {code}: {detail}"),
422            code: Some(code as i32),
423        })
424    } else {
425        Ok(message.get("result").cloned().unwrap_or(Value::Null))
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    #[test]
434    fn response_id_requires_result_or_error() {
435        // Pure notification: no result/error
436        assert!(
437            response_id(&serde_json::json!({
438                "jsonrpc": "2.0",
439                "method": "some/notification",
440                "params": {}
441            }))
442            .is_none()
443        );
444
445        // Server-initiated request with id but no result
446        assert!(
447            response_id(&serde_json::json!({
448                "jsonrpc": "2.0",
449                "id": 7,
450                "method": "permission.request",
451                "params": {}
452            }))
453            .is_none()
454        );
455
456        // Response has result
457        assert_eq!(
458            response_id(&serde_json::json!({
459                "jsonrpc": "2.0",
460                "id": 3,
461                "result": { "ok": true }
462            })),
463            Some(Value::from(3))
464        );
465
466        // Error response
467        assert_eq!(
468            response_id(&serde_json::json!({
469                "jsonrpc": "2.0",
470                "id": 5,
471                "error": { "code": -32601, "message": "method not found" }
472            })),
473            Some(Value::from(5))
474        );
475    }
476
477    #[test]
478    fn extract_rpc_result_propagates_error() {
479        let result = extract_rpc_result(&serde_json::json!({
480            "jsonrpc": "2.0",
481            "id": 1,
482            "error": { "code": -32600, "message": "invalid request" }
483        }));
484        assert!(result.is_err());
485        let err = result.unwrap_err().to_string();
486        assert!(err.contains("invalid request"));
487    }
488
489    #[test]
490    fn extract_rpc_result_returns_result_value() {
491        let result = extract_rpc_result(&serde_json::json!({
492            "jsonrpc": "2.0",
493            "id": 1,
494            "result": { "sessionId": "abc" }
495        }))
496        .unwrap();
497        assert_eq!(result["sessionId"], "abc");
498    }
499
500    #[test]
501    fn notify_serialises_payload_to_write_channel() {
502        let (tx, mut rx) = mpsc::unbounded_channel();
503        let transport = StdioTransport::new_for_testing(tx, Duration::from_secs(5));
504
505        transport
506            .notify("session/cancel", serde_json::json!({ "sessionId": "s1" }))
507            .unwrap();
508
509        let raw = rx.try_recv().expect("notification payload");
510        let payload: Value = serde_json::from_str(&raw).unwrap();
511        assert_eq!(payload["method"], "session/cancel");
512        assert_eq!(payload["params"]["sessionId"], "s1");
513        assert!(
514            payload.get("id").is_none(),
515            "notifications must not have id"
516        );
517    }
518
519    #[test]
520    fn respond_writes_jsonrpc_result() {
521        let (tx, mut rx) = mpsc::unbounded_channel();
522        let transport = StdioTransport::new_for_testing(tx, Duration::from_secs(5));
523
524        transport
525            .respond(42, serde_json::json!({ "ok": true }))
526            .unwrap();
527
528        let raw = rx.try_recv().unwrap();
529        let payload: Value = serde_json::from_str(&raw).unwrap();
530        assert_eq!(payload["jsonrpc"], "2.0");
531        assert_eq!(payload["id"], 42);
532        assert_eq!(payload["result"]["ok"], true);
533    }
534
535    #[test]
536    fn respond_error_writes_jsonrpc_error() {
537        let (tx, mut rx) = mpsc::unbounded_channel();
538        let transport = StdioTransport::new_for_testing(tx, Duration::from_secs(5));
539
540        transport
541            .respond_error(9, -32601, "method not found")
542            .unwrap();
543
544        let raw = rx.try_recv().unwrap();
545        let payload: Value = serde_json::from_str(&raw).unwrap();
546        assert_eq!(payload["id"], 9);
547        assert_eq!(payload["error"]["code"], -32601);
548        assert_eq!(payload["error"]["message"], "method not found");
549    }
550
551    #[test]
552    fn respond_value_supports_string_ids() {
553        let (tx, mut rx) = mpsc::unbounded_channel();
554        let transport = StdioTransport::new_for_testing(tx, Duration::from_secs(5));
555
556        transport
557            .respond_value(
558                Value::String("request-1".to_string()),
559                serde_json::json!({ "ok": true }),
560            )
561            .unwrap();
562
563        let raw = rx.try_recv().unwrap();
564        let payload: Value = serde_json::from_str(&raw).unwrap();
565        assert_eq!(payload["id"], "request-1");
566        assert_eq!(payload["result"]["ok"], true);
567    }
568
569    #[test]
570    fn can_omit_jsonrpc_field_for_codex_mode() {
571        let (tx, mut rx) = mpsc::unbounded_channel();
572        let transport = StdioTransport::new_for_testing_with_options(
573            tx,
574            Duration::from_secs(5),
575            StdioTransportOptions {
576                include_jsonrpc_version: false,
577            },
578        );
579
580        transport
581            .notify("initialized", serde_json::json!({}))
582            .unwrap();
583
584        let raw = rx.try_recv().unwrap();
585        let payload: Value = serde_json::from_str(&raw).unwrap();
586        assert!(payload.get("jsonrpc").is_none());
587        assert_eq!(payload["method"], "initialized");
588    }
589}