agentic_coding_protocol/
acp.rs

1#[cfg(test)]
2mod acp_tests;
3mod schema;
4
5use futures::{
6    AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
7    StreamExt as _,
8    channel::{
9        mpsc::{self, UnboundedReceiver, UnboundedSender},
10        oneshot,
11    },
12    future::LocalBoxFuture,
13    io::BufReader,
14    select_biased,
15};
16use parking_lot::Mutex;
17pub use schema::*;
18use serde::{Deserialize, Serialize};
19use serde_json::value::RawValue;
20use std::{
21    collections::HashMap,
22    fmt::Display,
23    rc::Rc,
24    sync::{
25        Arc,
26        atomic::{AtomicI32, Ordering::SeqCst},
27    },
28};
29
30/// A connection to a separate agent process over the ACP protocol.
31pub struct AgentConnection(Connection<AnyClientRequest, AnyAgentRequest>);
32
33/// A connection to a separate client process over the ACP protocol.
34pub struct ClientConnection(Connection<AnyAgentRequest, AnyClientRequest>);
35
36impl AgentConnection {
37    /// Connect to an agent process, handling any incoming requests
38    /// using the given handler.
39    pub fn connect_to_agent<H: 'static + Client>(
40        handler: H,
41        outgoing_bytes: impl Unpin + AsyncWrite,
42        incoming_bytes: impl Unpin + AsyncRead,
43        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
44    ) -> (Self, impl Future<Output = Result<(), Error>>) {
45        let handler = Arc::new(handler);
46        let (connection, io_task) = Connection::new(
47            Box::new(move |request| {
48                let handler = handler.clone();
49                async move { handler.call(request).await }.boxed_local()
50            }),
51            outgoing_bytes,
52            incoming_bytes,
53            spawn,
54        );
55        (Self(connection), io_task)
56    }
57
58    /// Send a request to the agent and wait for a response.
59    pub fn request<R: AgentRequest + 'static>(
60        &self,
61        params: R,
62    ) -> impl Future<Output = Result<R::Response, Error>> {
63        let params = params.into_any();
64        let result = self.0.request(params.method_name(), params);
65        async move {
66            let result = result.await?;
67            R::response_from_any(result)
68        }
69    }
70}
71
72impl ClientConnection {
73    pub fn connect_to_client<H: 'static + Agent>(
74        handler: H,
75        outgoing_bytes: impl Unpin + AsyncWrite,
76        incoming_bytes: impl Unpin + AsyncRead,
77        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
78    ) -> (Self, impl Future<Output = Result<(), Error>>) {
79        let handler = Arc::new(handler);
80        let (connection, io_task) = Connection::new(
81            Box::new(move |request| {
82                let handler = handler.clone();
83                async move { handler.call(request).await }.boxed_local()
84            }),
85            outgoing_bytes,
86            incoming_bytes,
87            spawn,
88        );
89        (Self(connection), io_task)
90    }
91
92    pub fn request<R: ClientRequest>(
93        &self,
94        params: R,
95    ) -> impl use<R> + Future<Output = Result<R::Response, Error>> {
96        let params = params.into_any();
97        let result = self.0.request(params.method_name(), params);
98        async move {
99            let result = result.await?;
100            R::response_from_any(result)
101        }
102    }
103}
104
105struct Connection<In, Out>
106where
107    In: AnyRequest,
108    Out: AnyRequest,
109{
110    outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
111    response_senders: ResponseSenders<Out::Response>,
112    next_id: AtomicI32,
113}
114
115type ResponseSenders<T> =
116    Arc<Mutex<HashMap<i32, (&'static str, oneshot::Sender<Result<T, Error>>)>>>;
117
118#[derive(Debug, Deserialize)]
119struct IncomingMessage<'a> {
120    id: i32,
121    method: Option<&'a str>,
122    params: Option<&'a RawValue>,
123    result: Option<&'a RawValue>,
124    error: Option<Error>,
125}
126
127#[derive(Serialize)]
128#[serde(untagged)]
129enum OutgoingMessage<Req, Resp> {
130    Request {
131        id: i32,
132        method: Box<str>,
133        params: Req,
134    },
135    OkResponse {
136        id: i32,
137        result: Resp,
138    },
139    ErrorResponse {
140        id: i32,
141        error: Error,
142    },
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct Error {
147    pub code: i32,
148    pub message: String,
149    #[serde(skip_serializing_if = "Option::is_none")]
150    pub data: Option<ErrorData>,
151}
152
153impl Error {
154    pub fn new(code: i32, message: impl Into<String>) -> Self {
155        Error {
156            code,
157            message: message.into(),
158            data: None,
159        }
160    }
161
162    pub fn with_details(mut self, details: impl Into<String>) -> Self {
163        self.data = Some(ErrorData::new(details));
164        self
165    }
166
167    /// Invalid JSON was received by the server. An error occurred on the server while parsing the JSON text.
168    pub fn parse_error() -> Self {
169        Error::new(-32700, "Parse error")
170    }
171
172    /// The JSON sent is not a valid Request object.
173    pub fn invalid_request() -> Self {
174        Error::new(-32600, "Invalid Request")
175    }
176
177    /// The method does not exist / is not available.
178    pub fn method_not_found() -> Self {
179        Error::new(-32601, "Method not found")
180    }
181
182    /// Invalid method parameter(s).
183    pub fn invalid_params() -> Self {
184        Error::new(-32602, "Invalid params")
185    }
186
187    /// Internal JSON-RPC error.
188    pub fn internal_error() -> Self {
189        Error::new(-32603, "Internal error")
190    }
191}
192
193impl std::error::Error for Error {}
194impl Display for Error {
195    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196        if self.message.is_empty() {
197            write!(f, "{}", self.code)?;
198        } else {
199            write!(f, "{}", self.message)?;
200        }
201
202        if let Some(data) = &self.data {
203            write!(f, ": {}", data.details)?;
204        }
205
206        Ok(())
207    }
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct ErrorData {
212    pub details: String,
213}
214
215impl ErrorData {
216    pub fn new(details: impl Into<String>) -> Self {
217        ErrorData {
218            details: details.into(),
219        }
220    }
221}
222
223#[derive(Serialize)]
224pub struct JsonRpcMessage<Req, Resp> {
225    pub jsonrpc: &'static str,
226    #[serde(flatten)]
227    message: OutgoingMessage<Req, Resp>,
228}
229
230type ResponseHandler<In, Resp> =
231    Box<dyn 'static + Fn(In) -> LocalBoxFuture<'static, Result<Resp, Error>>>;
232
233impl<In, Out> Connection<In, Out>
234where
235    In: AnyRequest,
236    Out: AnyRequest,
237{
238    fn new(
239        request_handler: ResponseHandler<In, In::Response>,
240        outgoing_bytes: impl Unpin + AsyncWrite,
241        incoming_bytes: impl Unpin + AsyncRead,
242        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
243    ) -> (Self, impl Future<Output = Result<(), Error>>) {
244        let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
245        let (incoming_tx, incoming_rx) = mpsc::unbounded();
246        let this = Self {
247            response_senders: ResponseSenders::default(),
248            outgoing_tx: outgoing_tx.clone(),
249            next_id: AtomicI32::new(0),
250        };
251        Self::handle_incoming(outgoing_tx, incoming_rx, request_handler, spawn);
252        let io_task = Self::handle_io(
253            outgoing_rx,
254            incoming_tx,
255            this.response_senders.clone(),
256            outgoing_bytes,
257            incoming_bytes,
258        );
259        (this, io_task)
260    }
261
262    fn request(
263        &self,
264        method: &'static str,
265        params: Out,
266    ) -> impl use<In, Out> + Future<Output = Result<Out::Response, Error>> {
267        let (tx, rx) = oneshot::channel();
268        let id = self.next_id.fetch_add(1, SeqCst);
269        self.response_senders.lock().insert(id, (method, tx));
270        if self
271            .outgoing_tx
272            .unbounded_send(OutgoingMessage::Request {
273                id,
274                method: method.into(),
275                params,
276            })
277            .is_err()
278        {
279            self.response_senders.lock().remove(&id);
280        }
281        async move {
282            rx.await
283                .map_err(|e| Error::internal_error().with_details(e.to_string()))?
284        }
285    }
286
287    async fn handle_io(
288        mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Out, In::Response>>,
289        incoming_tx: UnboundedSender<(i32, In)>,
290        response_senders: ResponseSenders<Out::Response>,
291        mut outgoing_bytes: impl Unpin + AsyncWrite,
292        incoming_bytes: impl Unpin + AsyncRead,
293    ) -> Result<(), Error> {
294        let mut output_reader = BufReader::new(incoming_bytes);
295        let mut outgoing_line = Vec::new();
296        let mut incoming_line = String::new();
297        loop {
298            select_biased! {
299                message = outgoing_rx.next() => {
300                    if let Some(message) = message {
301                        outgoing_line.clear();
302                        serde_json::to_writer(&mut outgoing_line, &message).map_err(|e| Error::internal_error()
303                            .with_details(e.to_string()))?;
304                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
305                        outgoing_line.push(b'\n');
306                        outgoing_bytes.write_all(&outgoing_line).await.ok();
307                    } else {
308                        break;
309                    }
310                }
311                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
312                    if bytes_read.map_err(|e| Error::internal_error().with_details(e.to_string()))? == 0 {
313                        break
314                    }
315                    log::trace!("recv: {}", &incoming_line);
316                    match serde_json::from_str::<IncomingMessage>(&incoming_line) {
317                        Ok(message) => {
318                            if let Some(method) = message.method {
319                                match In::from_method_and_params(method, message.params.unwrap_or(RawValue::NULL)) {
320                                    Ok(params) => {
321                                        incoming_tx.unbounded_send((message.id, params)).ok();
322                                    }
323                                    Err(error) => {
324                                        log::error!("failed to parse incoming {method} message params: {error}. Raw: {incoming_line}");
325                                    }
326                                }
327                            } else if let Some(error) = message.error {
328                                if let Some((_, tx)) = response_senders.lock().remove(&message.id) {
329                                    tx.send(Err(error)).ok();
330                                }
331                            } else {
332                                let result = message.result.unwrap_or(RawValue::NULL);
333                                if let Some((method, tx)) = response_senders.lock().remove(&message.id) {
334                                    match Out::response_from_method_and_result(method, result) {
335                                        Ok(result) => {
336                                            tx.send(Ok(result)).ok();
337                                        }
338                                        Err(error) => {
339                                            log::error!("failed to parse {method} message result: {error}. Raw: {result}");
340                                        }
341                                    }
342                                }
343                            }
344                        }
345                        Err(error) => {
346                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
347                        }
348                    }
349                    incoming_line.clear();
350                }
351            }
352        }
353        response_senders.lock().clear();
354        Ok(())
355    }
356
357    fn handle_incoming(
358        outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
359        mut incoming_rx: UnboundedReceiver<(i32, In)>,
360        incoming_handler: ResponseHandler<In, In::Response>,
361        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
362    ) {
363        let spawn = Rc::new(spawn);
364        let spawn2 = spawn.clone();
365        spawn(
366            async move {
367                while let Some((id, params)) = incoming_rx.next().await {
368                    let result = incoming_handler(params);
369                    let outgoing_tx = outgoing_tx.clone();
370                    spawn2(
371                        async move {
372                            let result = result.await;
373                            match result {
374                                Ok(result) => {
375                                    outgoing_tx
376                                        .unbounded_send(OutgoingMessage::OkResponse { id, result })
377                                        .ok();
378                                }
379                                Err(error) => {
380                                    outgoing_tx
381                                        .unbounded_send(OutgoingMessage::ErrorResponse {
382                                            id,
383                                            error: Error::internal_error()
384                                                .with_details(error.to_string()),
385                                        })
386                                        .ok();
387                                }
388                            }
389                        }
390                        .boxed_local(),
391                    )
392                }
393            }
394            .boxed_local(),
395        )
396    }
397}