agentic_coding_protocol/
acp.rs

1#[cfg(test)]
2mod acp_tests;
3mod schema;
4
5use futures::{
6    AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
7    StreamExt as _,
8    channel::{
9        mpsc::{self, UnboundedReceiver, UnboundedSender},
10        oneshot,
11    },
12    future::LocalBoxFuture,
13    io::BufReader,
14    select_biased,
15};
16use parking_lot::Mutex;
17pub use schema::*;
18use semver::Comparator;
19use serde::{Deserialize, Serialize};
20use serde_json::value::RawValue;
21use std::{
22    collections::HashMap,
23    rc::Rc,
24    sync::{
25        Arc,
26        atomic::{AtomicI32, Ordering::SeqCst},
27    },
28};
29
30/// A connection to a separate agent process over the ACP protocol.
31pub struct AgentConnection(Connection<AnyClientRequest, AnyAgentRequest>);
32
33/// A connection to a separate client process over the ACP protocol.
34pub struct ClientConnection(Connection<AnyAgentRequest, AnyClientRequest>);
35
36impl AgentConnection {
37    /// Connect to an agent process, handling any incoming requests
38    /// using the given handler.
39    pub fn connect_to_agent<H: 'static + Client>(
40        handler: H,
41        outgoing_bytes: impl Unpin + AsyncWrite,
42        incoming_bytes: impl Unpin + AsyncRead,
43        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
44    ) -> (Self, impl Future<Output = Result<(), Error>>) {
45        let handler = Arc::new(handler);
46        let (connection, io_task) = Connection::new(
47            Box::new(move |request| {
48                let handler = handler.clone();
49                async move { handler.call(request).await }.boxed_local()
50            }),
51            outgoing_bytes,
52            incoming_bytes,
53            spawn,
54        );
55        (Self(connection), io_task)
56    }
57
58    /// Send a request to the agent and wait for a response.
59    pub fn request<R: AgentRequest + 'static>(
60        &self,
61        params: R,
62    ) -> impl Future<Output = Result<R::Response, Error>> {
63        let params = params.into_any();
64        let result = self.0.request(params.method_name(), params);
65        async move {
66            let result = result.await?;
67            R::response_from_any(result)
68        }
69    }
70
71    /// Send an untyped request to the agent and wait for a response.
72    pub fn request_any(
73        &self,
74        params: AnyAgentRequest,
75    ) -> impl use<> + Future<Output = Result<AnyAgentResult, Error>> {
76        self.0.request(params.method_name(), params)
77    }
78
79    /// Sends an initialization request to the Agent.
80    /// This will error if the server version is incompatible with the client version.
81    pub async fn initialize(&self) -> Result<InitializeResponse, Error> {
82        let protocol_version = ProtocolVersion::latest();
83        let version_requirement = Comparator {
84            op: semver::Op::Caret,
85            major: protocol_version.major,
86            minor: Some(protocol_version.minor),
87            patch: Some(protocol_version.patch),
88            pre: protocol_version.pre.clone(),
89        };
90        let response = self.request(InitializeParams { protocol_version }).await?;
91
92        let server_version = &response.protocol_version;
93
94        if version_requirement.matches(server_version) {
95            Ok(response)
96        } else {
97            Err(Error::invalid_request().with_data(format!(
98                "Incompatible versions: Server {server_version} / Client: {version_requirement}"
99            )))
100        }
101    }
102}
103
104impl ClientConnection {
105    pub fn connect_to_client<H: 'static + Agent>(
106        handler: H,
107        outgoing_bytes: impl Unpin + AsyncWrite,
108        incoming_bytes: impl Unpin + AsyncRead,
109        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
110    ) -> (Self, impl Future<Output = Result<(), Error>>) {
111        let handler = Arc::new(handler);
112        let (connection, io_task) = Connection::new(
113            Box::new(move |request| {
114                let handler = handler.clone();
115                async move { handler.call(request).await }.boxed_local()
116            }),
117            outgoing_bytes,
118            incoming_bytes,
119            spawn,
120        );
121        (Self(connection), io_task)
122    }
123
124    pub fn request<R: ClientRequest>(
125        &self,
126        params: R,
127    ) -> impl use<R> + Future<Output = Result<R::Response, Error>> {
128        let params = params.into_any();
129        let result = self.0.request(params.method_name(), params);
130        async move {
131            let result = result.await?;
132            R::response_from_any(result)
133        }
134    }
135
136    /// Send an untyped request to the client and wait for a response.
137    pub fn request_any(
138        &self,
139        method: &'static str,
140        params: AnyClientRequest,
141    ) -> impl Future<Output = Result<AnyClientResult, Error>> {
142        self.0.request(method, params)
143    }
144}
145
146struct Connection<In, Out>
147where
148    In: AnyRequest,
149    Out: AnyRequest,
150{
151    outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
152    response_senders: ResponseSenders<Out::Response>,
153    next_id: AtomicI32,
154}
155
156type ResponseSenders<T> =
157    Arc<Mutex<HashMap<i32, (&'static str, oneshot::Sender<Result<T, Error>>)>>>;
158
159#[derive(Debug, Deserialize)]
160struct IncomingMessage<'a> {
161    id: i32,
162    method: Option<&'a str>,
163    params: Option<&'a RawValue>,
164    result: Option<&'a RawValue>,
165    error: Option<Error>,
166}
167
168#[derive(Serialize)]
169#[serde(untagged)]
170enum OutgoingMessage<Req, Resp> {
171    Request {
172        id: i32,
173        method: Box<str>,
174        #[serde(skip_serializing_if = "is_none_or_null")]
175        params: Option<Req>,
176    },
177    OkResponse {
178        id: i32,
179        result: Resp,
180    },
181    ErrorResponse {
182        id: i32,
183        error: Error,
184    },
185}
186
187fn is_none_or_null<T: Serialize>(opt: &Option<T>) -> bool {
188    match opt {
189        None => true,
190        Some(value) => {
191            matches!(serde_json::to_value(value), Ok(serde_json::Value::Null))
192        }
193    }
194}
195
196#[derive(Debug, Deserialize, Serialize)]
197enum JsonSchemaVersion {
198    #[serde(rename = "2.0")]
199    V2,
200}
201
202#[derive(Serialize)]
203struct OutJsonRpcMessage<Req, Resp> {
204    jsonrpc: JsonSchemaVersion,
205    #[serde(flatten)]
206    message: OutgoingMessage<Req, Resp>,
207}
208
209type ResponseHandler<In, Resp> =
210    Box<dyn 'static + Fn(In) -> LocalBoxFuture<'static, Result<Resp, Error>>>;
211
212impl<In, Out> Connection<In, Out>
213where
214    In: AnyRequest,
215    Out: AnyRequest,
216{
217    fn new(
218        request_handler: ResponseHandler<In, In::Response>,
219        outgoing_bytes: impl Unpin + AsyncWrite,
220        incoming_bytes: impl Unpin + AsyncRead,
221        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
222    ) -> (Self, impl Future<Output = Result<(), Error>>) {
223        let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
224        let (incoming_tx, incoming_rx) = mpsc::unbounded();
225        let this = Self {
226            response_senders: ResponseSenders::default(),
227            outgoing_tx: outgoing_tx.clone(),
228            next_id: AtomicI32::new(0),
229        };
230        Self::handle_incoming(outgoing_tx, incoming_rx, request_handler, spawn);
231        let io_task = Self::handle_io(
232            outgoing_rx,
233            incoming_tx,
234            this.response_senders.clone(),
235            outgoing_bytes,
236            incoming_bytes,
237        );
238        (this, io_task)
239    }
240
241    fn request(
242        &self,
243        method: &'static str,
244        params: Out,
245    ) -> impl use<In, Out> + Future<Output = Result<Out::Response, Error>> {
246        let (tx, rx) = oneshot::channel();
247        let id = self.next_id.fetch_add(1, SeqCst);
248        self.response_senders.lock().insert(id, (method, tx));
249        if self
250            .outgoing_tx
251            .unbounded_send(OutgoingMessage::Request {
252                id,
253                method: method.into(),
254                params: Some(params),
255            })
256            .is_err()
257        {
258            self.response_senders.lock().remove(&id);
259        }
260        async move {
261            rx.await
262                .map_err(|e| Error::internal_error().with_data(e.to_string()))?
263        }
264    }
265
266    async fn handle_io(
267        mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Out, In::Response>>,
268        incoming_tx: UnboundedSender<(i32, In)>,
269        response_senders: ResponseSenders<Out::Response>,
270        mut outgoing_bytes: impl Unpin + AsyncWrite,
271        incoming_bytes: impl Unpin + AsyncRead,
272    ) -> Result<(), Error> {
273        let mut output_reader = BufReader::new(incoming_bytes);
274        let mut outgoing_line = Vec::new();
275        let mut incoming_line = String::new();
276        loop {
277            select_biased! {
278                message = outgoing_rx.next() => {
279                    if let Some(message) = message {
280                        let message = OutJsonRpcMessage {
281                            jsonrpc: JsonSchemaVersion::V2,
282                            message,
283                        };
284                        outgoing_line.clear();
285                        serde_json::to_writer(&mut outgoing_line, &message).map_err(Error::into_internal_error)?;
286                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
287                        outgoing_line.push(b'\n');
288                        outgoing_bytes.write_all(&outgoing_line).await.ok();
289                    } else {
290                        break;
291                    }
292                }
293                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
294                    if bytes_read.map_err(Error::into_internal_error)? == 0 {
295                        break
296                    }
297                    log::trace!("recv: {}", &incoming_line);
298                    match serde_json::from_str::<IncomingMessage>(&incoming_line) {
299                        Ok(IncomingMessage { id, method, params, result, error }) => {
300                            if let Some(method) = method {
301                                match In::from_method_and_params(method, params.unwrap_or(RawValue::NULL)) {
302                                    Ok(params) => {
303                                        incoming_tx.unbounded_send((id, params)).ok();
304                                    }
305                                    Err(error) => {
306                                        log::error!("failed to parse incoming {method} message params: {error}. Raw: {incoming_line}");
307                                    }
308                                }
309                            } else if let Some(error) = error {
310                                if let Some((_, tx)) = response_senders.lock().remove(&id) {
311                                    tx.send(Err(error)).ok();
312                                }
313                            } else {
314                                let result = result.unwrap_or(RawValue::NULL);
315                                if let Some((method, tx)) = response_senders.lock().remove(&id) {
316                                    match Out::response_from_method_and_result(method, result) {
317                                        Ok(result) => {
318                                            tx.send(Ok(result)).ok();
319                                        }
320                                        Err(error) => {
321                                            log::error!("failed to parse {method} message result: {error}. Raw: {result}");
322                                        }
323                                    }
324                                }
325                            }
326                        }
327                        Err(error) => {
328                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
329                        }
330                    }
331                    incoming_line.clear();
332                }
333            }
334        }
335        response_senders.lock().clear();
336        Ok(())
337    }
338
339    fn handle_incoming(
340        outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
341        mut incoming_rx: UnboundedReceiver<(i32, In)>,
342        incoming_handler: ResponseHandler<In, In::Response>,
343        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
344    ) {
345        let spawn = Rc::new(spawn);
346        let spawn2 = spawn.clone();
347        spawn(
348            async move {
349                while let Some((id, params)) = incoming_rx.next().await {
350                    let result = incoming_handler(params);
351                    let outgoing_tx = outgoing_tx.clone();
352                    spawn2(
353                        async move {
354                            let result = result.await;
355                            match result {
356                                Ok(result) => {
357                                    outgoing_tx
358                                        .unbounded_send(OutgoingMessage::OkResponse { id, result })
359                                        .ok();
360                                }
361                                Err(error) => {
362                                    outgoing_tx
363                                        .unbounded_send(OutgoingMessage::ErrorResponse {
364                                            id,
365                                            error: Error::into_internal_error(error),
366                                        })
367                                        .ok();
368                                }
369                            }
370                        }
371                        .boxed_local(),
372                    )
373                }
374            }
375            .boxed_local(),
376        )
377    }
378}