agent_client_protocol/
acp.rs

1mod agent;
2mod client;
3mod content;
4mod error;
5mod plan;
6mod rpc;
7#[cfg(test)]
8mod rpc_tests;
9mod tool_call;
10mod version;
11
12pub use agent::*;
13pub use client::*;
14pub use content::*;
15pub use error::*;
16pub use plan::*;
17pub use tool_call::*;
18pub use version::*;
19
20use anyhow::Result;
21use futures::{AsyncRead, AsyncWrite, Future, future::LocalBoxFuture};
22use schemars::JsonSchema;
23use serde::{Deserialize, Serialize};
24use serde_json::value::RawValue;
25use std::{fmt, sync::Arc};
26
27use crate::rpc::{MessageHandler, RpcConnection, Side};
28
29#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
30#[serde(transparent)]
31pub struct SessionId(pub Arc<str>);
32
33impl fmt::Display for SessionId {
34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35        write!(f, "{}", self.0)
36    }
37}
38
39// Client to Agent
40
41pub struct ClientSideConnection {
42    conn: RpcConnection<ClientSide, AgentSide>,
43}
44
45impl ClientSideConnection {
46    pub fn new(
47        client: impl MessageHandler<ClientSide> + 'static,
48        outgoing_bytes: impl Unpin + AsyncWrite,
49        incoming_bytes: impl Unpin + AsyncRead,
50        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
51    ) -> (Self, impl Future<Output = Result<()>>) {
52        let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
53        (Self { conn }, io_task)
54    }
55}
56
57impl Agent for ClientSideConnection {
58    async fn initialize(&self, arguments: InitializeRequest) -> Result<InitializeResponse, Error> {
59        self.conn
60            .request(
61                INITIALIZE_METHOD_NAME,
62                Some(ClientRequest::InitializeRequest(arguments)),
63            )
64            .await
65    }
66
67    async fn authenticate(&self, arguments: AuthenticateRequest) -> Result<(), Error> {
68        self.conn
69            .request(
70                AUTHENTICATE_METHOD_NAME,
71                Some(ClientRequest::AuthenticateRequest(arguments)),
72            )
73            .await
74    }
75
76    async fn new_session(&self, arguments: NewSessionRequest) -> Result<NewSessionResponse, Error> {
77        self.conn
78            .request(
79                SESSION_NEW_METHOD_NAME,
80                Some(ClientRequest::NewSessionRequest(arguments)),
81            )
82            .await
83    }
84
85    async fn load_session(&self, arguments: LoadSessionRequest) -> Result<(), Error> {
86        self.conn
87            .request(
88                SESSION_LOAD_METHOD_NAME,
89                Some(ClientRequest::LoadSessionRequest(arguments)),
90            )
91            .await
92    }
93
94    async fn prompt(&self, arguments: PromptRequest) -> Result<PromptResponse, Error> {
95        self.conn
96            .request(
97                SESSION_PROMPT_METHOD_NAME,
98                Some(ClientRequest::PromptRequest(arguments)),
99            )
100            .await
101    }
102
103    async fn cancel(&self, notification: CancelNotification) -> Result<(), Error> {
104        self.conn.notify(
105            SESSION_CANCEL_METHOD_NAME,
106            Some(ClientNotification::CancelNotification(notification)),
107        )
108    }
109}
110
111pub struct ClientSide;
112
113impl Side for ClientSide {
114    type InNotification = AgentNotification;
115    type InRequest = AgentRequest;
116    type OutResponse = ClientResponse;
117
118    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest, Error> {
119        let params = params.ok_or_else(Error::invalid_params)?;
120
121        match method {
122            SESSION_REQUEST_PERMISSION_METHOD_NAME => serde_json::from_str(params.get())
123                .map(AgentRequest::RequestPermissionRequest)
124                .map_err(Into::into),
125            FS_WRITE_TEXT_FILE_METHOD_NAME => serde_json::from_str(params.get())
126                .map(AgentRequest::WriteTextFileRequest)
127                .map_err(Into::into),
128            FS_READ_TEXT_FILE_METHOD_NAME => serde_json::from_str(params.get())
129                .map(AgentRequest::ReadTextFileRequest)
130                .map_err(Into::into),
131            _ => Err(Error::method_not_found()),
132        }
133    }
134
135    fn decode_notification(
136        method: &str,
137        params: Option<&RawValue>,
138    ) -> Result<AgentNotification, Error> {
139        let params = params.ok_or_else(Error::invalid_params)?;
140
141        match method {
142            SESSION_UPDATE_NOTIFICATION => serde_json::from_str(params.get())
143                .map(AgentNotification::SessionNotification)
144                .map_err(Into::into),
145            _ => Err(Error::method_not_found()),
146        }
147    }
148}
149
150impl<T: Client> MessageHandler<ClientSide> for T {
151    async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse, Error> {
152        match request {
153            AgentRequest::RequestPermissionRequest(args) => {
154                let response = self.request_permission(args).await?;
155                Ok(ClientResponse::RequestPermissionResponse(response))
156            }
157            AgentRequest::WriteTextFileRequest(args) => {
158                self.write_text_file(args).await?;
159                Ok(ClientResponse::WriteTextFileResponse)
160            }
161            AgentRequest::ReadTextFileRequest(args) => {
162                let response = self.read_text_file(args).await?;
163                Ok(ClientResponse::ReadTextFileResponse(response))
164            }
165        }
166    }
167
168    async fn handle_notification(&self, notification: AgentNotification) -> Result<(), Error> {
169        match notification {
170            AgentNotification::SessionNotification(notification) => {
171                self.session_notification(notification).await?;
172            }
173        }
174        Ok(())
175    }
176}
177
178// Agent to Client
179
180pub struct AgentSideConnection {
181    conn: RpcConnection<AgentSide, ClientSide>,
182}
183
184impl AgentSideConnection {
185    pub fn new(
186        agent: impl MessageHandler<AgentSide> + 'static,
187        outgoing_bytes: impl Unpin + AsyncWrite,
188        incoming_bytes: impl Unpin + AsyncRead,
189        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
190    ) -> (Self, impl Future<Output = Result<()>>) {
191        let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
192        (Self { conn }, io_task)
193    }
194}
195
196impl Client for AgentSideConnection {
197    async fn request_permission(
198        &self,
199        arguments: RequestPermissionRequest,
200    ) -> Result<RequestPermissionResponse, Error> {
201        self.conn
202            .request(
203                SESSION_REQUEST_PERMISSION_METHOD_NAME,
204                Some(AgentRequest::RequestPermissionRequest(arguments)),
205            )
206            .await
207    }
208
209    async fn write_text_file(&self, arguments: WriteTextFileRequest) -> Result<(), Error> {
210        self.conn
211            .request(
212                FS_WRITE_TEXT_FILE_METHOD_NAME,
213                Some(AgentRequest::WriteTextFileRequest(arguments)),
214            )
215            .await
216    }
217
218    async fn read_text_file(
219        &self,
220        arguments: ReadTextFileRequest,
221    ) -> Result<ReadTextFileResponse, Error> {
222        self.conn
223            .request(
224                FS_READ_TEXT_FILE_METHOD_NAME,
225                Some(AgentRequest::ReadTextFileRequest(arguments)),
226            )
227            .await
228    }
229
230    async fn session_notification(&self, notification: SessionNotification) -> Result<(), Error> {
231        self.conn.notify(
232            SESSION_UPDATE_NOTIFICATION,
233            Some(AgentNotification::SessionNotification(notification)),
234        )
235    }
236}
237
238pub struct AgentSide;
239
240impl Side for AgentSide {
241    type InRequest = ClientRequest;
242    type InNotification = ClientNotification;
243    type OutResponse = AgentResponse;
244
245    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest, Error> {
246        let params = params.ok_or_else(Error::invalid_params)?;
247
248        match method {
249            INITIALIZE_METHOD_NAME => serde_json::from_str(params.get())
250                .map(ClientRequest::InitializeRequest)
251                .map_err(Into::into),
252            AUTHENTICATE_METHOD_NAME => serde_json::from_str(params.get())
253                .map(ClientRequest::AuthenticateRequest)
254                .map_err(Into::into),
255            SESSION_NEW_METHOD_NAME => serde_json::from_str(params.get())
256                .map(ClientRequest::NewSessionRequest)
257                .map_err(Into::into),
258            SESSION_LOAD_METHOD_NAME => serde_json::from_str(params.get())
259                .map(ClientRequest::LoadSessionRequest)
260                .map_err(Into::into),
261            SESSION_PROMPT_METHOD_NAME => serde_json::from_str(params.get())
262                .map(ClientRequest::PromptRequest)
263                .map_err(Into::into),
264            _ => Err(Error::method_not_found()),
265        }
266    }
267
268    fn decode_notification(
269        method: &str,
270        params: Option<&RawValue>,
271    ) -> Result<ClientNotification, Error> {
272        let params = params.ok_or_else(Error::invalid_params)?;
273
274        match method {
275            SESSION_CANCEL_METHOD_NAME => serde_json::from_str(params.get())
276                .map(ClientNotification::CancelNotification)
277                .map_err(Into::into),
278            _ => Err(Error::method_not_found()),
279        }
280    }
281}
282
283impl<T: Agent> MessageHandler<AgentSide> for T {
284    async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse, Error> {
285        match request {
286            ClientRequest::InitializeRequest(args) => {
287                let response = self.initialize(args).await?;
288                Ok(AgentResponse::InitializeResponse(response))
289            }
290            ClientRequest::AuthenticateRequest(args) => {
291                self.authenticate(args).await?;
292                Ok(AgentResponse::AuthenticateResponse)
293            }
294            ClientRequest::NewSessionRequest(args) => {
295                let response = self.new_session(args).await?;
296                Ok(AgentResponse::NewSessionResponse(response))
297            }
298            ClientRequest::LoadSessionRequest(args) => {
299                self.load_session(args).await?;
300                Ok(AgentResponse::LoadSessionResponse)
301            }
302            ClientRequest::PromptRequest(args) => {
303                let response = self.prompt(args).await?;
304                Ok(AgentResponse::PromptResponse(response))
305            }
306        }
307    }
308
309    async fn handle_notification(&self, notification: ClientNotification) -> Result<(), Error> {
310        match notification {
311            ClientNotification::CancelNotification(notification) => {
312                self.cancel(notification).await?;
313            }
314        }
315        Ok(())
316    }
317}