agentic_coding_protocol/
acp.rs

1#[cfg(test)]
2mod acp_tests;
3mod schema;
4
5use futures::{
6    AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
7    StreamExt as _,
8    channel::{
9        mpsc::{self, UnboundedReceiver, UnboundedSender},
10        oneshot,
11    },
12    future::LocalBoxFuture,
13    io::BufReader,
14    select_biased,
15};
16use parking_lot::Mutex;
17pub use schema::*;
18use serde::{Deserialize, Serialize};
19use serde_json::value::RawValue;
20use std::{
21    collections::HashMap,
22    rc::Rc,
23    sync::{
24        Arc,
25        atomic::{AtomicI32, Ordering::SeqCst},
26    },
27};
28
29/// A connection to a separate agent process over the ACP protocol.
30pub struct AgentConnection(Connection<AnyClientRequest, AnyAgentRequest>);
31
32/// A connection to a separate client process over the ACP protocol.
33pub struct ClientConnection(Connection<AnyAgentRequest, AnyClientRequest>);
34
35impl AgentConnection {
36    /// Connect to an agent process, handling any incoming requests
37    /// using the given handler.
38    pub fn connect_to_agent<H: 'static + Client>(
39        handler: H,
40        outgoing_bytes: impl Unpin + AsyncWrite,
41        incoming_bytes: impl Unpin + AsyncRead,
42        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
43    ) -> (Self, impl Future<Output = Result<(), Error>>) {
44        let handler = Arc::new(handler);
45        let (connection, io_task) = Connection::new(
46            Box::new(move |request| {
47                let handler = handler.clone();
48                async move { handler.call(request).await }.boxed_local()
49            }),
50            outgoing_bytes,
51            incoming_bytes,
52            spawn,
53        );
54        (Self(connection), io_task)
55    }
56
57    /// Send a request to the agent and wait for a response.
58    pub fn request<R: AgentRequest + 'static>(
59        &self,
60        params: R,
61    ) -> impl Future<Output = Result<R::Response, Error>> {
62        let params = params.into_any();
63        let result = self.0.request(params.method_name(), params);
64        async move {
65            let result = result.await?;
66            R::response_from_any(result)
67        }
68    }
69}
70
71impl ClientConnection {
72    pub fn connect_to_client<H: 'static + Agent>(
73        handler: H,
74        outgoing_bytes: impl Unpin + AsyncWrite,
75        incoming_bytes: impl Unpin + AsyncRead,
76        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
77    ) -> (Self, impl Future<Output = Result<(), Error>>) {
78        let handler = Arc::new(handler);
79        let (connection, io_task) = Connection::new(
80            Box::new(move |request| {
81                let handler = handler.clone();
82                async move { handler.call(request).await }.boxed_local()
83            }),
84            outgoing_bytes,
85            incoming_bytes,
86            spawn,
87        );
88        (Self(connection), io_task)
89    }
90
91    pub fn request<R: ClientRequest>(
92        &self,
93        params: R,
94    ) -> impl use<R> + Future<Output = Result<R::Response, Error>> {
95        let params = params.into_any();
96        let result = self.0.request(params.method_name(), params);
97        async move {
98            let result = result.await?;
99            R::response_from_any(result)
100        }
101    }
102}
103
104struct Connection<In, Out>
105where
106    In: AnyRequest,
107    Out: AnyRequest,
108{
109    outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
110    response_senders: ResponseSenders<Out::Response>,
111    next_id: AtomicI32,
112}
113
114type ResponseSenders<T> =
115    Arc<Mutex<HashMap<i32, (&'static str, oneshot::Sender<Result<T, Error>>)>>>;
116
117#[derive(Debug, Deserialize)]
118struct IncomingMessage<'a> {
119    id: i32,
120    method: Option<&'a str>,
121    params: Option<&'a RawValue>,
122    result: Option<&'a RawValue>,
123    error: Option<Error>,
124}
125
126#[derive(Serialize)]
127#[serde(untagged)]
128enum OutgoingMessage<Req, Resp> {
129    Request {
130        id: i32,
131        method: Box<str>,
132        params: Req,
133    },
134    OkResponse {
135        id: i32,
136        result: Resp,
137    },
138    ErrorResponse {
139        id: i32,
140        error: Error,
141    },
142}
143
144#[derive(Serialize)]
145pub struct JsonRpcMessage<Req, Resp> {
146    pub jsonrpc: &'static str,
147    #[serde(flatten)]
148    message: OutgoingMessage<Req, Resp>,
149}
150
151type ResponseHandler<In, Resp> =
152    Box<dyn 'static + Fn(In) -> LocalBoxFuture<'static, Result<Resp, Error>>>;
153
154impl<In, Out> Connection<In, Out>
155where
156    In: AnyRequest,
157    Out: AnyRequest,
158{
159    fn new(
160        request_handler: ResponseHandler<In, In::Response>,
161        outgoing_bytes: impl Unpin + AsyncWrite,
162        incoming_bytes: impl Unpin + AsyncRead,
163        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
164    ) -> (Self, impl Future<Output = Result<(), Error>>) {
165        let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
166        let (incoming_tx, incoming_rx) = mpsc::unbounded();
167        let this = Self {
168            response_senders: ResponseSenders::default(),
169            outgoing_tx: outgoing_tx.clone(),
170            next_id: AtomicI32::new(0),
171        };
172        Self::handle_incoming(outgoing_tx, incoming_rx, request_handler, spawn);
173        let io_task = Self::handle_io(
174            outgoing_rx,
175            incoming_tx,
176            this.response_senders.clone(),
177            outgoing_bytes,
178            incoming_bytes,
179        );
180        (this, io_task)
181    }
182
183    fn request(
184        &self,
185        method: &'static str,
186        params: Out,
187    ) -> impl use<In, Out> + Future<Output = Result<Out::Response, Error>> {
188        let (tx, rx) = oneshot::channel();
189        let id = self.next_id.fetch_add(1, SeqCst);
190        self.response_senders.lock().insert(id, (method, tx));
191        if self
192            .outgoing_tx
193            .unbounded_send(OutgoingMessage::Request {
194                id,
195                method: method.into(),
196                params,
197            })
198            .is_err()
199        {
200            self.response_senders.lock().remove(&id);
201        }
202        async move {
203            rx.await
204                .map_err(|e| Error::internal_error().with_data(e.to_string()))?
205        }
206    }
207
208    async fn handle_io(
209        mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Out, In::Response>>,
210        incoming_tx: UnboundedSender<(i32, In)>,
211        response_senders: ResponseSenders<Out::Response>,
212        mut outgoing_bytes: impl Unpin + AsyncWrite,
213        incoming_bytes: impl Unpin + AsyncRead,
214    ) -> Result<(), Error> {
215        let mut output_reader = BufReader::new(incoming_bytes);
216        let mut outgoing_line = Vec::new();
217        let mut incoming_line = String::new();
218        loop {
219            select_biased! {
220                message = outgoing_rx.next() => {
221                    if let Some(message) = message {
222                        outgoing_line.clear();
223                        serde_json::to_writer(&mut outgoing_line, &message).map_err(Error::into_internal_error)?;
224                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
225                        outgoing_line.push(b'\n');
226                        outgoing_bytes.write_all(&outgoing_line).await.ok();
227                    } else {
228                        break;
229                    }
230                }
231                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
232                    if bytes_read.map_err(Error::into_internal_error)? == 0 {
233                        break
234                    }
235                    log::trace!("recv: {}", &incoming_line);
236                    match serde_json::from_str::<IncomingMessage>(&incoming_line) {
237                        Ok(message) => {
238                            if let Some(method) = message.method {
239                                match In::from_method_and_params(method, message.params.unwrap_or(RawValue::NULL)) {
240                                    Ok(params) => {
241                                        incoming_tx.unbounded_send((message.id, params)).ok();
242                                    }
243                                    Err(error) => {
244                                        log::error!("failed to parse incoming {method} message params: {error}. Raw: {incoming_line}");
245                                    }
246                                }
247                            } else if let Some(error) = message.error {
248                                if let Some((_, tx)) = response_senders.lock().remove(&message.id) {
249                                    tx.send(Err(error)).ok();
250                                }
251                            } else {
252                                let result = message.result.unwrap_or(RawValue::NULL);
253                                if let Some((method, tx)) = response_senders.lock().remove(&message.id) {
254                                    match Out::response_from_method_and_result(method, result) {
255                                        Ok(result) => {
256                                            tx.send(Ok(result)).ok();
257                                        }
258                                        Err(error) => {
259                                            log::error!("failed to parse {method} message result: {error}. Raw: {result}");
260                                        }
261                                    }
262                                }
263                            }
264                        }
265                        Err(error) => {
266                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
267                        }
268                    }
269                    incoming_line.clear();
270                }
271            }
272        }
273        response_senders.lock().clear();
274        Ok(())
275    }
276
277    fn handle_incoming(
278        outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
279        mut incoming_rx: UnboundedReceiver<(i32, In)>,
280        incoming_handler: ResponseHandler<In, In::Response>,
281        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
282    ) {
283        let spawn = Rc::new(spawn);
284        let spawn2 = spawn.clone();
285        spawn(
286            async move {
287                while let Some((id, params)) = incoming_rx.next().await {
288                    let result = incoming_handler(params);
289                    let outgoing_tx = outgoing_tx.clone();
290                    spawn2(
291                        async move {
292                            let result = result.await;
293                            match result {
294                                Ok(result) => {
295                                    outgoing_tx
296                                        .unbounded_send(OutgoingMessage::OkResponse { id, result })
297                                        .ok();
298                                }
299                                Err(error) => {
300                                    outgoing_tx
301                                        .unbounded_send(OutgoingMessage::ErrorResponse {
302                                            id,
303                                            error: Error::into_internal_error(error),
304                                        })
305                                        .ok();
306                                }
307                            }
308                        }
309                        .boxed_local(),
310                    )
311                }
312            }
313            .boxed_local(),
314        )
315    }
316}