agentic_coding_protocol/
acp.rs

1#[cfg(test)]
2mod acp_tests;
3mod schema;
4
5use futures::{
6    AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
7    StreamExt as _,
8    channel::{
9        mpsc::{self, UnboundedReceiver, UnboundedSender},
10        oneshot,
11    },
12    future::LocalBoxFuture,
13    io::BufReader,
14    select_biased,
15};
16use parking_lot::Mutex;
17pub use schema::*;
18use semver::Comparator;
19use serde::{Deserialize, Serialize};
20use serde_json::value::RawValue;
21use std::{
22    collections::HashMap,
23    rc::Rc,
24    sync::{
25        Arc,
26        atomic::{AtomicI32, Ordering::SeqCst},
27    },
28};
29
30/// A connection to a separate agent process over the ACP protocol.
31pub struct AgentConnection(Connection<AnyClientRequest, AnyAgentRequest>);
32
33/// A connection to a separate client process over the ACP protocol.
34pub struct ClientConnection(Connection<AnyAgentRequest, AnyClientRequest>);
35
36impl AgentConnection {
37    /// Connect to an agent process, handling any incoming requests
38    /// using the given handler.
39    pub fn connect_to_agent<H: 'static + Client>(
40        handler: H,
41        outgoing_bytes: impl Unpin + AsyncWrite,
42        incoming_bytes: impl Unpin + AsyncRead,
43        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
44    ) -> (Self, impl Future<Output = Result<(), Error>>) {
45        let handler = Arc::new(handler);
46        let (connection, io_task) = Connection::new(
47            Box::new(move |request| {
48                let handler = handler.clone();
49                async move { handler.call(request).await }.boxed_local()
50            }),
51            outgoing_bytes,
52            incoming_bytes,
53            spawn,
54        );
55        (Self(connection), io_task)
56    }
57
58    /// Send a request to the agent and wait for a response.
59    pub fn request<R: AgentRequest + 'static>(
60        &self,
61        params: R,
62    ) -> impl Future<Output = Result<R::Response, Error>> {
63        let params = params.into_any();
64        let result = self.0.request(params.method_name(), params);
65        async move {
66            let result = result.await?;
67            R::response_from_any(result)
68        }
69    }
70
71    /// Send an untyped request to the agent and wait for a response.
72    pub fn request_any(
73        &self,
74        params: AnyAgentRequest,
75    ) -> impl use<> + Future<Output = Result<AnyAgentResult, Error>> {
76        self.0.request(params.method_name(), params)
77    }
78
79    /// Sends an initialization request to the Agent.
80    /// This will error if the server version is incompatible with the client version.
81    pub async fn initialize(
82        &self,
83        context_servers: HashMap<String, ContextServer>,
84    ) -> Result<InitializeResponse, Error> {
85        let protocol_version = ProtocolVersion::latest();
86        // Check that we are on the same major version of the protocol
87        let version_requirement = Comparator {
88            op: semver::Op::Caret,
89            major: protocol_version.major,
90            minor: (protocol_version.major == 0).then_some(protocol_version.minor),
91            patch: (protocol_version.major == 0 && protocol_version.minor == 0)
92                .then_some(protocol_version.patch),
93            pre: protocol_version.pre.clone(),
94        };
95        let response = self
96            .request(InitializeParams {
97                protocol_version,
98                context_servers,
99            })
100            .await?;
101
102        let server_version = &response.protocol_version;
103
104        if version_requirement.matches(server_version) {
105            Ok(response)
106        } else {
107            Err(Error::invalid_request().with_data(format!(
108                "Incompatible versions: Server {server_version} / Client: {version_requirement}"
109            )))
110        }
111    }
112}
113
114impl ClientConnection {
115    pub fn connect_to_client<H: 'static + Agent>(
116        handler: H,
117        outgoing_bytes: impl Unpin + AsyncWrite,
118        incoming_bytes: impl Unpin + AsyncRead,
119        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
120    ) -> (Self, impl Future<Output = Result<(), Error>>) {
121        let handler = Arc::new(handler);
122        let (connection, io_task) = Connection::new(
123            Box::new(move |request| {
124                let handler = handler.clone();
125                async move { handler.call(request).await }.boxed_local()
126            }),
127            outgoing_bytes,
128            incoming_bytes,
129            spawn,
130        );
131        (Self(connection), io_task)
132    }
133
134    pub fn request<R: ClientRequest>(
135        &self,
136        params: R,
137    ) -> impl use<R> + Future<Output = Result<R::Response, Error>> {
138        let params = params.into_any();
139        let result = self.0.request(params.method_name(), params);
140        async move {
141            let result = result.await?;
142            R::response_from_any(result)
143        }
144    }
145
146    /// Send an untyped request to the client and wait for a response.
147    pub fn request_any(
148        &self,
149        method: &'static str,
150        params: AnyClientRequest,
151    ) -> impl Future<Output = Result<AnyClientResult, Error>> {
152        self.0.request(method, params)
153    }
154}
155
156struct Connection<In, Out>
157where
158    In: AnyRequest,
159    Out: AnyRequest,
160{
161    outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
162    response_senders: ResponseSenders<Out::Response>,
163    next_id: AtomicI32,
164}
165
166type ResponseSenders<T> =
167    Arc<Mutex<HashMap<i32, (&'static str, oneshot::Sender<Result<T, Error>>)>>>;
168
169#[derive(Debug, Deserialize)]
170struct IncomingMessage<'a> {
171    id: i32,
172    method: Option<&'a str>,
173    params: Option<&'a RawValue>,
174    result: Option<&'a RawValue>,
175    error: Option<Error>,
176}
177
178#[derive(Serialize)]
179#[serde(untagged)]
180enum OutgoingMessage<Req, Resp> {
181    Request {
182        id: i32,
183        method: Box<str>,
184        #[serde(skip_serializing_if = "is_none_or_null")]
185        params: Option<Req>,
186    },
187    OkResponse {
188        id: i32,
189        result: Resp,
190    },
191    ErrorResponse {
192        id: i32,
193        error: Error,
194    },
195}
196
197fn is_none_or_null<T: Serialize>(opt: &Option<T>) -> bool {
198    match opt {
199        None => true,
200        Some(value) => {
201            matches!(serde_json::to_value(value), Ok(serde_json::Value::Null))
202        }
203    }
204}
205
206#[derive(Debug, Deserialize, Serialize)]
207enum JsonSchemaVersion {
208    #[serde(rename = "2.0")]
209    V2,
210}
211
212#[derive(Serialize)]
213struct OutJsonRpcMessage<Req, Resp> {
214    jsonrpc: JsonSchemaVersion,
215    #[serde(flatten)]
216    message: OutgoingMessage<Req, Resp>,
217}
218
219type ResponseHandler<In, Resp> =
220    Box<dyn 'static + Fn(In) -> LocalBoxFuture<'static, Result<Resp, Error>>>;
221
222impl<In, Out> Connection<In, Out>
223where
224    In: AnyRequest,
225    Out: AnyRequest,
226{
227    fn new(
228        request_handler: ResponseHandler<In, In::Response>,
229        outgoing_bytes: impl Unpin + AsyncWrite,
230        incoming_bytes: impl Unpin + AsyncRead,
231        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
232    ) -> (Self, impl Future<Output = Result<(), Error>>) {
233        let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
234        let (incoming_tx, incoming_rx) = mpsc::unbounded();
235        let this = Self {
236            response_senders: ResponseSenders::default(),
237            outgoing_tx: outgoing_tx.clone(),
238            next_id: AtomicI32::new(0),
239        };
240        Self::handle_incoming(outgoing_tx, incoming_rx, request_handler, spawn);
241        let io_task = Self::handle_io(
242            outgoing_rx,
243            incoming_tx,
244            this.response_senders.clone(),
245            outgoing_bytes,
246            incoming_bytes,
247        );
248        (this, io_task)
249    }
250
251    fn request(
252        &self,
253        method: &'static str,
254        params: Out,
255    ) -> impl use<In, Out> + Future<Output = Result<Out::Response, Error>> {
256        let (tx, rx) = oneshot::channel();
257        let id = self.next_id.fetch_add(1, SeqCst);
258        self.response_senders.lock().insert(id, (method, tx));
259        if self
260            .outgoing_tx
261            .unbounded_send(OutgoingMessage::Request {
262                id,
263                method: method.into(),
264                params: Some(params),
265            })
266            .is_err()
267        {
268            self.response_senders.lock().remove(&id);
269        }
270        async move {
271            rx.await
272                .map_err(|e| Error::internal_error().with_data(e.to_string()))?
273        }
274    }
275
276    async fn handle_io(
277        mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Out, In::Response>>,
278        incoming_tx: UnboundedSender<(i32, In)>,
279        response_senders: ResponseSenders<Out::Response>,
280        mut outgoing_bytes: impl Unpin + AsyncWrite,
281        incoming_bytes: impl Unpin + AsyncRead,
282    ) -> Result<(), Error> {
283        let mut output_reader = BufReader::new(incoming_bytes);
284        let mut outgoing_line = Vec::new();
285        let mut incoming_line = String::new();
286        loop {
287            select_biased! {
288                message = outgoing_rx.next() => {
289                    if let Some(message) = message {
290                        let message = OutJsonRpcMessage {
291                            jsonrpc: JsonSchemaVersion::V2,
292                            message,
293                        };
294                        outgoing_line.clear();
295                        serde_json::to_writer(&mut outgoing_line, &message).map_err(Error::into_internal_error)?;
296                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
297                        outgoing_line.push(b'\n');
298                        outgoing_bytes.write_all(&outgoing_line).await.ok();
299                    } else {
300                        break;
301                    }
302                }
303                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
304                    if bytes_read.map_err(Error::into_internal_error)? == 0 {
305                        break
306                    }
307                    log::trace!("recv: {}", &incoming_line);
308                    match serde_json::from_str::<IncomingMessage>(&incoming_line) {
309                        Ok(IncomingMessage { id, method, params, result, error }) => {
310                            if let Some(method) = method {
311                                match In::from_method_and_params(method, params.unwrap_or(RawValue::NULL)) {
312                                    Ok(params) => {
313                                        incoming_tx.unbounded_send((id, params)).ok();
314                                    }
315                                    Err(error) => {
316                                        log::error!("failed to parse incoming {method} message params: {error}. Raw: {incoming_line}");
317                                    }
318                                }
319                            } else if let Some(error) = error {
320                                if let Some((_, tx)) = response_senders.lock().remove(&id) {
321                                    tx.send(Err(error)).ok();
322                                }
323                            } else {
324                                let result = result.unwrap_or(RawValue::NULL);
325                                if let Some((method, tx)) = response_senders.lock().remove(&id) {
326                                    match Out::response_from_method_and_result(method, result) {
327                                        Ok(result) => {
328                                            tx.send(Ok(result)).ok();
329                                        }
330                                        Err(error) => {
331                                            log::error!("failed to parse {method} message result: {error}. Raw: {result}");
332                                        }
333                                    }
334                                }
335                            }
336                        }
337                        Err(error) => {
338                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
339                        }
340                    }
341                    incoming_line.clear();
342                }
343            }
344        }
345        response_senders.lock().clear();
346        Ok(())
347    }
348
349    fn handle_incoming(
350        outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
351        mut incoming_rx: UnboundedReceiver<(i32, In)>,
352        incoming_handler: ResponseHandler<In, In::Response>,
353        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
354    ) {
355        let spawn = Rc::new(spawn);
356        let spawn2 = spawn.clone();
357        spawn(
358            async move {
359                while let Some((id, params)) = incoming_rx.next().await {
360                    let result = incoming_handler(params);
361                    let outgoing_tx = outgoing_tx.clone();
362                    spawn2(
363                        async move {
364                            let result = result.await;
365                            match result {
366                                Ok(result) => {
367                                    outgoing_tx
368                                        .unbounded_send(OutgoingMessage::OkResponse { id, result })
369                                        .ok();
370                                }
371                                Err(error) => {
372                                    outgoing_tx
373                                        .unbounded_send(OutgoingMessage::ErrorResponse {
374                                            id,
375                                            error: Error::into_internal_error(error),
376                                        })
377                                        .ok();
378                                }
379                            }
380                        }
381                        .boxed_local(),
382                    )
383                }
384            }
385            .boxed_local(),
386        )
387    }
388}