agentic_coding_protocol/
acp.rs

1#[cfg(test)]
2mod acp_tests;
3mod schema;
4
5use anyhow::{Result, anyhow};
6use futures::{
7    AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
8    StreamExt as _,
9    channel::{
10        mpsc::{self, UnboundedReceiver, UnboundedSender},
11        oneshot,
12    },
13    future::LocalBoxFuture,
14    io::BufReader,
15    select_biased,
16};
17use parking_lot::Mutex;
18pub use schema::*;
19use serde::{Deserialize, Serialize};
20use serde_json::value::RawValue;
21use std::{
22    collections::HashMap,
23    fmt::Display,
24    rc::Rc,
25    sync::{
26        Arc,
27        atomic::{AtomicI32, Ordering::SeqCst},
28    },
29};
30
31/// A connection to a separate agent process over the ACP protocol.
32pub struct AgentConnection(Connection<AnyClientRequest, AnyAgentRequest>);
33
34/// A connection to a separate client process over the ACP protocol.
35pub struct ClientConnection(Connection<AnyAgentRequest, AnyClientRequest>);
36
37impl AgentConnection {
38    /// Connect to an agent process, handling any incoming requests
39    /// using the given handler.
40    pub fn connect_to_agent<H: 'static + Client>(
41        handler: H,
42        outgoing_bytes: impl Unpin + AsyncWrite,
43        incoming_bytes: impl Unpin + AsyncRead,
44        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
45    ) -> (Self, impl Future<Output = Result<()>>) {
46        let handler = Arc::new(handler);
47        let (connection, io_task) = Connection::new(
48            Box::new(move |request| {
49                let handler = handler.clone();
50                async move { handler.call(request).await }.boxed_local()
51            }),
52            outgoing_bytes,
53            incoming_bytes,
54            spawn,
55        );
56        (Self(connection), io_task)
57    }
58
59    /// Send a request to the agent and wait for a response.
60    pub fn request<R: AgentRequest + 'static>(
61        &self,
62        params: R,
63    ) -> impl Future<Output = Result<R::Response>> {
64        let params = params.into_any();
65        let result = self.0.request(params.method_name(), params);
66        async move {
67            let result = result.await?;
68            R::response_from_any(result).ok_or_else(|| {
69                anyhow!(crate::Error {
70                    code: -32700,
71                    message: "Unexpected Response".to_string(),
72                })
73            })
74        }
75    }
76}
77
78impl ClientConnection {
79    pub fn connect_to_client<H: 'static + Agent>(
80        handler: H,
81        outgoing_bytes: impl Unpin + AsyncWrite,
82        incoming_bytes: impl Unpin + AsyncRead,
83        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
84    ) -> (Self, impl Future<Output = Result<()>>) {
85        let handler = Arc::new(handler);
86        let (connection, io_task) = Connection::new(
87            Box::new(move |request| {
88                let handler = handler.clone();
89                async move { handler.call(request).await }.boxed_local()
90            }),
91            outgoing_bytes,
92            incoming_bytes,
93            spawn,
94        );
95        (Self(connection), io_task)
96    }
97
98    pub fn request<R: ClientRequest>(
99        &self,
100        params: R,
101    ) -> impl use<R> + Future<Output = Result<R::Response>> {
102        let params = params.into_any();
103        let result = self.0.request(params.method_name(), params);
104        async move {
105            let result = result.await?;
106            R::response_from_any(result).ok_or_else(|| {
107                anyhow!(Error {
108                    code: -32700,
109                    message: "Could not parse".to_string(),
110                })
111            })
112        }
113    }
114}
115
116struct Connection<In, Out>
117where
118    In: AnyRequest,
119    Out: AnyRequest,
120{
121    outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
122    response_senders: ResponseSenders<Out::Response>,
123    next_id: AtomicI32,
124}
125
126type ResponseSenders<T> =
127    Arc<Mutex<HashMap<i32, (&'static str, oneshot::Sender<Result<T, crate::Error>>)>>>;
128
129#[derive(Debug, Deserialize)]
130struct IncomingMessage<'a> {
131    id: i32,
132    method: Option<&'a str>,
133    params: Option<&'a RawValue>,
134    result: Option<&'a RawValue>,
135    error: Option<Error>,
136}
137
138#[derive(Serialize)]
139#[serde(untagged)]
140enum OutgoingMessage<Req, Resp> {
141    Request {
142        id: i32,
143        method: Box<str>,
144        params: Req,
145    },
146    OkResponse {
147        id: i32,
148        result: Resp,
149    },
150    ErrorResponse {
151        id: i32,
152        error: Error,
153    },
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct Error {
158    pub code: i32,
159    pub message: String,
160}
161
162impl std::error::Error for Error {}
163impl Display for Error {
164    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165        if self.message.is_empty() {
166            write!(f, "{}", self.code)
167        } else {
168            write!(f, "{}", self.message)
169        }
170    }
171}
172
173#[derive(Serialize)]
174pub struct JsonRpcMessage<Req, Resp> {
175    pub jsonrpc: &'static str,
176    #[serde(flatten)]
177    message: OutgoingMessage<Req, Resp>,
178}
179
180impl<In, Out> Connection<In, Out>
181where
182    In: AnyRequest,
183    Out: AnyRequest,
184{
185    fn new(
186        request_handler: Box<dyn 'static + Fn(In) -> LocalBoxFuture<'static, Result<In::Response>>>,
187        outgoing_bytes: impl Unpin + AsyncWrite,
188        incoming_bytes: impl Unpin + AsyncRead,
189        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
190    ) -> (Self, impl Future<Output = Result<()>>) {
191        let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
192        let (incoming_tx, incoming_rx) = mpsc::unbounded();
193        let this = Self {
194            response_senders: ResponseSenders::default(),
195            outgoing_tx: outgoing_tx.clone(),
196            next_id: AtomicI32::new(0),
197        };
198        Self::handle_incoming(outgoing_tx, incoming_rx, request_handler, spawn);
199        let io_task = Self::handle_io(
200            outgoing_rx,
201            incoming_tx,
202            this.response_senders.clone(),
203            outgoing_bytes,
204            incoming_bytes,
205        );
206        (this, io_task)
207    }
208
209    fn request(
210        &self,
211        method: &'static str,
212        params: Out,
213    ) -> impl use<In, Out> + Future<Output = Result<Out::Response>> {
214        let (tx, rx) = oneshot::channel();
215        let id = self.next_id.fetch_add(1, SeqCst);
216        self.response_senders.lock().insert(id, (method, tx));
217        if self
218            .outgoing_tx
219            .unbounded_send(OutgoingMessage::Request {
220                id,
221                method: method.into(),
222                params,
223            })
224            .is_err()
225        {
226            self.response_senders.lock().remove(&id);
227        }
228        async move { rx.await?.map_err(|e| anyhow!(e)) }
229    }
230
231    async fn handle_io(
232        mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Out, In::Response>>,
233        incoming_tx: UnboundedSender<(i32, In)>,
234        response_senders: ResponseSenders<Out::Response>,
235        mut outgoing_bytes: impl Unpin + AsyncWrite,
236        incoming_bytes: impl Unpin + AsyncRead,
237    ) -> Result<()> {
238        let mut output_reader = BufReader::new(incoming_bytes);
239        let mut outgoing_line = Vec::new();
240        let mut incoming_line = String::new();
241        loop {
242            select_biased! {
243                message = outgoing_rx.next() => {
244                    if let Some(message) = message {
245                        outgoing_line.clear();
246                        serde_json::to_writer(&mut outgoing_line, &message)?;
247                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
248                        outgoing_line.push(b'\n');
249                        outgoing_bytes.write_all(&outgoing_line).await.ok();
250                    } else {
251                        break;
252                    }
253                }
254                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
255                    if bytes_read? == 0 {
256                        break
257                    }
258                    log::trace!("recv: {}", &incoming_line);
259                    match serde_json::from_str::<IncomingMessage>(&incoming_line) {
260                        Ok(message) => {
261                            if let Some(method) = message.method {
262                                match In::from_method_and_params(method, message.params.unwrap_or(RawValue::NULL)) {
263                                    Ok(params) => {
264                                        incoming_tx.unbounded_send((message.id, params)).ok();
265                                    }
266                                    Err(error) => {
267                                        log::error!("failed to parse incoming {method} message params: {error}. Raw: {incoming_line}");
268                                    }
269                                }
270                            } else if let Some(error) = message.error {
271                                if let Some((_, tx)) = response_senders.lock().remove(&message.id) {
272                                    tx.send(Err(error)).ok();
273                                }
274                            } else {
275                                let result = message.result.unwrap_or(RawValue::NULL);
276                                if let Some((method, tx)) = response_senders.lock().remove(&message.id) {
277                                    match Out::response_from_method_and_result(method, result) {
278                                        Ok(result) => {
279                                            tx.send(Ok(result)).ok();
280                                        }
281                                        Err(error) => {
282                                            log::error!("failed to parse {method} message result: {error}. Raw: {result}");
283                                        }
284                                    }
285                                }
286                            }
287                        }
288                        Err(error) => {
289                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
290                        }
291                    }
292                    incoming_line.clear();
293                }
294            }
295        }
296        response_senders.lock().clear();
297        Ok(())
298    }
299
300    fn handle_incoming(
301        outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
302        mut incoming_rx: UnboundedReceiver<(i32, In)>,
303        incoming_handler: Box<
304            dyn 'static + Fn(In) -> LocalBoxFuture<'static, Result<In::Response>>,
305        >,
306        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
307    ) {
308        let spawn = Rc::new(spawn);
309        let spawn2 = spawn.clone();
310        spawn(
311            async move {
312                while let Some((id, params)) = incoming_rx.next().await {
313                    let result = incoming_handler(params);
314                    let outgoing_tx = outgoing_tx.clone();
315                    spawn2(
316                        async move {
317                            let result = result.await;
318                            match result {
319                                Ok(result) => {
320                                    outgoing_tx
321                                        .unbounded_send(OutgoingMessage::OkResponse { id, result })
322                                        .ok();
323                                }
324                                Err(error) => {
325                                    outgoing_tx
326                                        .unbounded_send(OutgoingMessage::ErrorResponse {
327                                            id,
328                                            error: Error {
329                                                code: -32603,
330                                                message: error.to_string(),
331                                            },
332                                        })
333                                        .ok();
334                                }
335                            }
336                        }
337                        .boxed_local(),
338                    )
339                }
340            }
341            .boxed_local(),
342        )
343    }
344}