agentic-coding-protocol 0.0.11

A protocol for standardizing communication between code editors and AI coding agents
Documentation
#[cfg(test)]
mod acp_tests;
mod schema;

use futures::{
    AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
    StreamExt as _,
    channel::{
        mpsc::{self, UnboundedReceiver, UnboundedSender},
        oneshot,
    },
    future::LocalBoxFuture,
    io::BufReader,
    select_biased,
};
use parking_lot::Mutex;
pub use schema::*;
use semver::Comparator;
use serde::{Deserialize, Serialize};
use serde_json::value::RawValue;
use std::{
    collections::HashMap,
    rc::Rc,
    sync::{
        Arc,
        atomic::{AtomicI32, Ordering::SeqCst},
    },
};

/// A connection to a separate agent process over the ACP protocol.
pub struct AgentConnection(Connection<AnyClientRequest, AnyAgentRequest>);

/// A connection to a separate client process over the ACP protocol.
pub struct ClientConnection(Connection<AnyAgentRequest, AnyClientRequest>);

impl AgentConnection {
    /// Connect to an agent process, handling any incoming requests
    /// using the given handler.
    pub fn connect_to_agent<H: 'static + Client>(
        handler: H,
        outgoing_bytes: impl Unpin + AsyncWrite,
        incoming_bytes: impl Unpin + AsyncRead,
        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
    ) -> (Self, impl Future<Output = Result<(), Error>>) {
        let handler = Arc::new(handler);
        let (connection, io_task) = Connection::new(
            Box::new(move |request| {
                let handler = handler.clone();
                async move { handler.call(request).await }.boxed_local()
            }),
            outgoing_bytes,
            incoming_bytes,
            spawn,
        );
        (Self(connection), io_task)
    }

    /// Send a request to the agent and wait for a response.
    pub fn request<R: AgentRequest + 'static>(
        &self,
        params: R,
    ) -> impl Future<Output = Result<R::Response, Error>> {
        let params = params.into_any();
        let result = self.0.request(params.method_name(), params);
        async move {
            let result = result.await?;
            R::response_from_any(result)
        }
    }

    /// Send an untyped request to the agent and wait for a response.
    pub fn request_any(
        &self,
        params: AnyAgentRequest,
    ) -> impl use<> + Future<Output = Result<AnyAgentResult, Error>> {
        self.0.request(params.method_name(), params)
    }

    /// Sends an initialization request to the Agent.
    /// This will error if the server version is incompatible with the client version.
    pub async fn initialize(
        &self,
        context_servers: HashMap<String, ContextServer>,
    ) -> Result<InitializeResponse, Error> {
        let protocol_version = ProtocolVersion::latest();
        // Check that we are on the same major version of the protocol
        let version_requirement = Comparator {
            op: semver::Op::Caret,
            major: protocol_version.major,
            minor: (protocol_version.major == 0).then_some(protocol_version.minor),
            patch: (protocol_version.major == 0 && protocol_version.minor == 0)
                .then_some(protocol_version.patch),
            pre: protocol_version.pre.clone(),
        };
        let response = self
            .request(InitializeParams {
                protocol_version,
                context_servers,
            })
            .await?;

        let server_version = &response.protocol_version;

        if version_requirement.matches(server_version) {
            Ok(response)
        } else {
            Err(Error::invalid_request().with_data(format!(
                "Incompatible versions: Server {server_version} / Client: {version_requirement}"
            )))
        }
    }
}

impl ClientConnection {
    pub fn connect_to_client<H: 'static + Agent>(
        handler: H,
        outgoing_bytes: impl Unpin + AsyncWrite,
        incoming_bytes: impl Unpin + AsyncRead,
        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
    ) -> (Self, impl Future<Output = Result<(), Error>>) {
        let handler = Arc::new(handler);
        let (connection, io_task) = Connection::new(
            Box::new(move |request| {
                let handler = handler.clone();
                async move { handler.call(request).await }.boxed_local()
            }),
            outgoing_bytes,
            incoming_bytes,
            spawn,
        );
        (Self(connection), io_task)
    }

    pub fn request<R: ClientRequest>(
        &self,
        params: R,
    ) -> impl use<R> + Future<Output = Result<R::Response, Error>> {
        let params = params.into_any();
        let result = self.0.request(params.method_name(), params);
        async move {
            let result = result.await?;
            R::response_from_any(result)
        }
    }

    /// Send an untyped request to the client and wait for a response.
    pub fn request_any(
        &self,
        method: &'static str,
        params: AnyClientRequest,
    ) -> impl Future<Output = Result<AnyClientResult, Error>> {
        self.0.request(method, params)
    }
}

struct Connection<In, Out>
where
    In: AnyRequest,
    Out: AnyRequest,
{
    outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
    response_senders: ResponseSenders<Out::Response>,
    next_id: AtomicI32,
}

type ResponseSenders<T> =
    Arc<Mutex<HashMap<i32, (&'static str, oneshot::Sender<Result<T, Error>>)>>>;

#[derive(Debug, Deserialize)]
struct IncomingMessage<'a> {
    id: i32,
    method: Option<&'a str>,
    params: Option<&'a RawValue>,
    result: Option<&'a RawValue>,
    error: Option<Error>,
}

#[derive(Serialize)]
#[serde(untagged)]
enum OutgoingMessage<Req, Resp> {
    Request {
        id: i32,
        method: Box<str>,
        #[serde(skip_serializing_if = "is_none_or_null")]
        params: Option<Req>,
    },
    OkResponse {
        id: i32,
        result: Resp,
    },
    ErrorResponse {
        id: i32,
        error: Error,
    },
}

fn is_none_or_null<T: Serialize>(opt: &Option<T>) -> bool {
    match opt {
        None => true,
        Some(value) => {
            matches!(serde_json::to_value(value), Ok(serde_json::Value::Null))
        }
    }
}

#[derive(Debug, Deserialize, Serialize)]
enum JsonSchemaVersion {
    #[serde(rename = "2.0")]
    V2,
}

#[derive(Serialize)]
struct OutJsonRpcMessage<Req, Resp> {
    jsonrpc: JsonSchemaVersion,
    #[serde(flatten)]
    message: OutgoingMessage<Req, Resp>,
}

type ResponseHandler<In, Resp> =
    Box<dyn 'static + Fn(In) -> LocalBoxFuture<'static, Result<Resp, Error>>>;

impl<In, Out> Connection<In, Out>
where
    In: AnyRequest,
    Out: AnyRequest,
{
    fn new(
        request_handler: ResponseHandler<In, In::Response>,
        outgoing_bytes: impl Unpin + AsyncWrite,
        incoming_bytes: impl Unpin + AsyncRead,
        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
    ) -> (Self, impl Future<Output = Result<(), Error>>) {
        let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
        let (incoming_tx, incoming_rx) = mpsc::unbounded();
        let this = Self {
            response_senders: ResponseSenders::default(),
            outgoing_tx: outgoing_tx.clone(),
            next_id: AtomicI32::new(0),
        };
        Self::handle_incoming(outgoing_tx, incoming_rx, request_handler, spawn);
        let io_task = Self::handle_io(
            outgoing_rx,
            incoming_tx,
            this.response_senders.clone(),
            outgoing_bytes,
            incoming_bytes,
        );
        (this, io_task)
    }

    fn request(
        &self,
        method: &'static str,
        params: Out,
    ) -> impl use<In, Out> + Future<Output = Result<Out::Response, Error>> {
        let (tx, rx) = oneshot::channel();
        let id = self.next_id.fetch_add(1, SeqCst);
        self.response_senders.lock().insert(id, (method, tx));
        if self
            .outgoing_tx
            .unbounded_send(OutgoingMessage::Request {
                id,
                method: method.into(),
                params: Some(params),
            })
            .is_err()
        {
            self.response_senders.lock().remove(&id);
        }
        async move {
            rx.await
                .map_err(|e| Error::internal_error().with_data(e.to_string()))?
        }
    }

    async fn handle_io(
        mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Out, In::Response>>,
        incoming_tx: UnboundedSender<(i32, In)>,
        response_senders: ResponseSenders<Out::Response>,
        mut outgoing_bytes: impl Unpin + AsyncWrite,
        incoming_bytes: impl Unpin + AsyncRead,
    ) -> Result<(), Error> {
        let mut output_reader = BufReader::new(incoming_bytes);
        let mut outgoing_line = Vec::new();
        let mut incoming_line = String::new();
        loop {
            select_biased! {
                message = outgoing_rx.next() => {
                    if let Some(message) = message {
                        let message = OutJsonRpcMessage {
                            jsonrpc: JsonSchemaVersion::V2,
                            message,
                        };
                        outgoing_line.clear();
                        serde_json::to_writer(&mut outgoing_line, &message).map_err(Error::into_internal_error)?;
                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
                        outgoing_line.push(b'\n');
                        outgoing_bytes.write_all(&outgoing_line).await.ok();
                    } else {
                        break;
                    }
                }
                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
                    if bytes_read.map_err(Error::into_internal_error)? == 0 {
                        break
                    }
                    log::trace!("recv: {}", &incoming_line);
                    match serde_json::from_str::<IncomingMessage>(&incoming_line) {
                        Ok(IncomingMessage { id, method, params, result, error }) => {
                            if let Some(method) = method {
                                match In::from_method_and_params(method, params.unwrap_or(RawValue::NULL)) {
                                    Ok(params) => {
                                        incoming_tx.unbounded_send((id, params)).ok();
                                    }
                                    Err(error) => {
                                        log::error!("failed to parse incoming {method} message params: {error}. Raw: {incoming_line}");
                                    }
                                }
                            } else if let Some(error) = error {
                                if let Some((_, tx)) = response_senders.lock().remove(&id) {
                                    tx.send(Err(error)).ok();
                                }
                            } else {
                                let result = result.unwrap_or(RawValue::NULL);
                                if let Some((method, tx)) = response_senders.lock().remove(&id) {
                                    match Out::response_from_method_and_result(method, result) {
                                        Ok(result) => {
                                            tx.send(Ok(result)).ok();
                                        }
                                        Err(error) => {
                                            log::error!("failed to parse {method} message result: {error}. Raw: {result}");
                                        }
                                    }
                                }
                            }
                        }
                        Err(error) => {
                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
                        }
                    }
                    incoming_line.clear();
                }
            }
        }
        response_senders.lock().clear();
        Ok(())
    }

    fn handle_incoming(
        outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
        mut incoming_rx: UnboundedReceiver<(i32, In)>,
        incoming_handler: ResponseHandler<In, In::Response>,
        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
    ) {
        let spawn = Rc::new(spawn);
        let spawn2 = spawn.clone();
        spawn(
            async move {
                while let Some((id, params)) = incoming_rx.next().await {
                    let result = incoming_handler(params);
                    let outgoing_tx = outgoing_tx.clone();
                    spawn2(
                        async move {
                            let result = result.await;
                            match result {
                                Ok(result) => {
                                    outgoing_tx
                                        .unbounded_send(OutgoingMessage::OkResponse { id, result })
                                        .ok();
                                }
                                Err(error) => {
                                    outgoing_tx
                                        .unbounded_send(OutgoingMessage::ErrorResponse {
                                            id,
                                            error: Error::into_internal_error(error),
                                        })
                                        .ok();
                                }
                            }
                        }
                        .boxed_local(),
                    )
                }
            }
            .boxed_local(),
        )
    }
}