agentic_coding_protocol/
acp.rs

1#[cfg(test)]
2mod acp_tests;
3mod schema;
4
5use futures::{
6    AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
7    StreamExt as _,
8    channel::{
9        mpsc::{self, UnboundedReceiver, UnboundedSender},
10        oneshot,
11    },
12    future::LocalBoxFuture,
13    io::BufReader,
14    select_biased,
15};
16use parking_lot::Mutex;
17pub use schema::*;
18use semver::Comparator;
19use serde::{Deserialize, Serialize};
20use serde_json::value::RawValue;
21use std::{
22    collections::HashMap,
23    rc::Rc,
24    sync::{
25        Arc,
26        atomic::{AtomicI32, Ordering::SeqCst},
27    },
28};
29
30/// A connection to a separate agent process over the ACP protocol.
31pub struct AgentConnection(Connection<AnyClientRequest, AnyAgentRequest>);
32
33/// A connection to a separate client process over the ACP protocol.
34pub struct ClientConnection(Connection<AnyAgentRequest, AnyClientRequest>);
35
36impl AgentConnection {
37    /// Connect to an agent process, handling any incoming requests
38    /// using the given handler.
39    pub fn connect_to_agent<H: 'static + Client>(
40        handler: H,
41        outgoing_bytes: impl Unpin + AsyncWrite,
42        incoming_bytes: impl Unpin + AsyncRead,
43        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
44    ) -> (Self, impl Future<Output = Result<(), Error>>) {
45        let handler = Arc::new(handler);
46        let (connection, io_task) = Connection::new(
47            Box::new(move |request| {
48                let handler = handler.clone();
49                async move { handler.call(request).await }.boxed_local()
50            }),
51            outgoing_bytes,
52            incoming_bytes,
53            spawn,
54        );
55        (Self(connection), io_task)
56    }
57
58    /// Send a request to the agent and wait for a response.
59    pub fn request<R: AgentRequest + 'static>(
60        &self,
61        params: R,
62    ) -> impl Future<Output = Result<R::Response, Error>> {
63        let params = params.into_any();
64        let result = self.0.request(params.method_name(), params);
65        async move {
66            let result = result.await?;
67            R::response_from_any(result)
68        }
69    }
70
71    /// Send an untyped request to the agent and wait for a response.
72    pub fn request_any(
73        &self,
74        params: AnyAgentRequest,
75    ) -> impl use<> + Future<Output = Result<AnyAgentResult, Error>> {
76        self.0.request(params.method_name(), params)
77    }
78
79    /// Sends an initialization request to the Agent.
80    /// This will error if the server version is incompatible with the client version.
81    pub async fn initialize(&self) -> Result<InitializeResponse, Error> {
82        let protocol_version = ProtocolVersion::latest();
83        // Check that we are on the same major version of the protocol
84        let version_requirement = Comparator {
85            op: semver::Op::Caret,
86            major: protocol_version.major,
87            minor: (protocol_version.major == 0).then_some(protocol_version.minor),
88            patch: (protocol_version.major == 0 && protocol_version.minor == 0)
89                .then_some(protocol_version.patch),
90            pre: protocol_version.pre.clone(),
91        };
92        let response = self.request(InitializeParams { protocol_version }).await?;
93
94        let server_version = &response.protocol_version;
95
96        if version_requirement.matches(server_version) {
97            Ok(response)
98        } else {
99            Err(Error::invalid_request().with_data(format!(
100                "Incompatible versions: Server {server_version} / Client: {version_requirement}"
101            )))
102        }
103    }
104}
105
106impl ClientConnection {
107    pub fn connect_to_client<H: 'static + Agent>(
108        handler: H,
109        outgoing_bytes: impl Unpin + AsyncWrite,
110        incoming_bytes: impl Unpin + AsyncRead,
111        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
112    ) -> (Self, impl Future<Output = Result<(), Error>>) {
113        let handler = Arc::new(handler);
114        let (connection, io_task) = Connection::new(
115            Box::new(move |request| {
116                let handler = handler.clone();
117                async move { handler.call(request).await }.boxed_local()
118            }),
119            outgoing_bytes,
120            incoming_bytes,
121            spawn,
122        );
123        (Self(connection), io_task)
124    }
125
126    pub fn request<R: ClientRequest>(
127        &self,
128        params: R,
129    ) -> impl use<R> + Future<Output = Result<R::Response, Error>> {
130        let params = params.into_any();
131        let result = self.0.request(params.method_name(), params);
132        async move {
133            let result = result.await?;
134            R::response_from_any(result)
135        }
136    }
137
138    /// Send an untyped request to the client and wait for a response.
139    pub fn request_any(
140        &self,
141        method: &'static str,
142        params: AnyClientRequest,
143    ) -> impl Future<Output = Result<AnyClientResult, Error>> {
144        self.0.request(method, params)
145    }
146}
147
148struct Connection<In, Out>
149where
150    In: AnyRequest,
151    Out: AnyRequest,
152{
153    outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
154    response_senders: ResponseSenders<Out::Response>,
155    next_id: AtomicI32,
156}
157
158type ResponseSenders<T> =
159    Arc<Mutex<HashMap<i32, (&'static str, oneshot::Sender<Result<T, Error>>)>>>;
160
161#[derive(Debug, Deserialize)]
162struct IncomingMessage<'a> {
163    id: i32,
164    method: Option<&'a str>,
165    params: Option<&'a RawValue>,
166    result: Option<&'a RawValue>,
167    error: Option<Error>,
168}
169
170#[derive(Serialize)]
171#[serde(untagged)]
172enum OutgoingMessage<Req, Resp> {
173    Request {
174        id: i32,
175        method: Box<str>,
176        #[serde(skip_serializing_if = "is_none_or_null")]
177        params: Option<Req>,
178    },
179    OkResponse {
180        id: i32,
181        result: Resp,
182    },
183    ErrorResponse {
184        id: i32,
185        error: Error,
186    },
187}
188
189fn is_none_or_null<T: Serialize>(opt: &Option<T>) -> bool {
190    match opt {
191        None => true,
192        Some(value) => {
193            matches!(serde_json::to_value(value), Ok(serde_json::Value::Null))
194        }
195    }
196}
197
198#[derive(Debug, Deserialize, Serialize)]
199enum JsonSchemaVersion {
200    #[serde(rename = "2.0")]
201    V2,
202}
203
204#[derive(Serialize)]
205struct OutJsonRpcMessage<Req, Resp> {
206    jsonrpc: JsonSchemaVersion,
207    #[serde(flatten)]
208    message: OutgoingMessage<Req, Resp>,
209}
210
211type ResponseHandler<In, Resp> =
212    Box<dyn 'static + Fn(In) -> LocalBoxFuture<'static, Result<Resp, Error>>>;
213
214impl<In, Out> Connection<In, Out>
215where
216    In: AnyRequest,
217    Out: AnyRequest,
218{
219    fn new(
220        request_handler: ResponseHandler<In, In::Response>,
221        outgoing_bytes: impl Unpin + AsyncWrite,
222        incoming_bytes: impl Unpin + AsyncRead,
223        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
224    ) -> (Self, impl Future<Output = Result<(), Error>>) {
225        let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
226        let (incoming_tx, incoming_rx) = mpsc::unbounded();
227        let this = Self {
228            response_senders: ResponseSenders::default(),
229            outgoing_tx: outgoing_tx.clone(),
230            next_id: AtomicI32::new(0),
231        };
232        Self::handle_incoming(outgoing_tx, incoming_rx, request_handler, spawn);
233        let io_task = Self::handle_io(
234            outgoing_rx,
235            incoming_tx,
236            this.response_senders.clone(),
237            outgoing_bytes,
238            incoming_bytes,
239        );
240        (this, io_task)
241    }
242
243    fn request(
244        &self,
245        method: &'static str,
246        params: Out,
247    ) -> impl use<In, Out> + Future<Output = Result<Out::Response, Error>> {
248        let (tx, rx) = oneshot::channel();
249        let id = self.next_id.fetch_add(1, SeqCst);
250        self.response_senders.lock().insert(id, (method, tx));
251        if self
252            .outgoing_tx
253            .unbounded_send(OutgoingMessage::Request {
254                id,
255                method: method.into(),
256                params: Some(params),
257            })
258            .is_err()
259        {
260            self.response_senders.lock().remove(&id);
261        }
262        async move {
263            rx.await
264                .map_err(|e| Error::internal_error().with_data(e.to_string()))?
265        }
266    }
267
268    async fn handle_io(
269        mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Out, In::Response>>,
270        incoming_tx: UnboundedSender<(i32, In)>,
271        response_senders: ResponseSenders<Out::Response>,
272        mut outgoing_bytes: impl Unpin + AsyncWrite,
273        incoming_bytes: impl Unpin + AsyncRead,
274    ) -> Result<(), Error> {
275        let mut output_reader = BufReader::new(incoming_bytes);
276        let mut outgoing_line = Vec::new();
277        let mut incoming_line = String::new();
278        loop {
279            select_biased! {
280                message = outgoing_rx.next() => {
281                    if let Some(message) = message {
282                        let message = OutJsonRpcMessage {
283                            jsonrpc: JsonSchemaVersion::V2,
284                            message,
285                        };
286                        outgoing_line.clear();
287                        serde_json::to_writer(&mut outgoing_line, &message).map_err(Error::into_internal_error)?;
288                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
289                        outgoing_line.push(b'\n');
290                        outgoing_bytes.write_all(&outgoing_line).await.ok();
291                    } else {
292                        break;
293                    }
294                }
295                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
296                    if bytes_read.map_err(Error::into_internal_error)? == 0 {
297                        break
298                    }
299                    log::trace!("recv: {}", &incoming_line);
300                    match serde_json::from_str::<IncomingMessage>(&incoming_line) {
301                        Ok(IncomingMessage { id, method, params, result, error }) => {
302                            if let Some(method) = method {
303                                match In::from_method_and_params(method, params.unwrap_or(RawValue::NULL)) {
304                                    Ok(params) => {
305                                        incoming_tx.unbounded_send((id, params)).ok();
306                                    }
307                                    Err(error) => {
308                                        log::error!("failed to parse incoming {method} message params: {error}. Raw: {incoming_line}");
309                                    }
310                                }
311                            } else if let Some(error) = error {
312                                if let Some((_, tx)) = response_senders.lock().remove(&id) {
313                                    tx.send(Err(error)).ok();
314                                }
315                            } else {
316                                let result = result.unwrap_or(RawValue::NULL);
317                                if let Some((method, tx)) = response_senders.lock().remove(&id) {
318                                    match Out::response_from_method_and_result(method, result) {
319                                        Ok(result) => {
320                                            tx.send(Ok(result)).ok();
321                                        }
322                                        Err(error) => {
323                                            log::error!("failed to parse {method} message result: {error}. Raw: {result}");
324                                        }
325                                    }
326                                }
327                            }
328                        }
329                        Err(error) => {
330                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
331                        }
332                    }
333                    incoming_line.clear();
334                }
335            }
336        }
337        response_senders.lock().clear();
338        Ok(())
339    }
340
341    fn handle_incoming(
342        outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
343        mut incoming_rx: UnboundedReceiver<(i32, In)>,
344        incoming_handler: ResponseHandler<In, In::Response>,
345        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
346    ) {
347        let spawn = Rc::new(spawn);
348        let spawn2 = spawn.clone();
349        spawn(
350            async move {
351                while let Some((id, params)) = incoming_rx.next().await {
352                    let result = incoming_handler(params);
353                    let outgoing_tx = outgoing_tx.clone();
354                    spawn2(
355                        async move {
356                            let result = result.await;
357                            match result {
358                                Ok(result) => {
359                                    outgoing_tx
360                                        .unbounded_send(OutgoingMessage::OkResponse { id, result })
361                                        .ok();
362                                }
363                                Err(error) => {
364                                    outgoing_tx
365                                        .unbounded_send(OutgoingMessage::ErrorResponse {
366                                            id,
367                                            error: Error::into_internal_error(error),
368                                        })
369                                        .ok();
370                                }
371                            }
372                        }
373                        .boxed_local(),
374                    )
375                }
376            }
377            .boxed_local(),
378        )
379    }
380}