agent_client_protocol/
acp.rs

1mod agent;
2mod client;
3mod content;
4mod error;
5mod plan;
6mod rpc;
7#[cfg(test)]
8mod rpc_tests;
9mod tool_call;
10mod version;
11
12pub use agent::*;
13pub use client::*;
14pub use content::*;
15pub use error::*;
16pub use plan::*;
17pub use tool_call::*;
18pub use version::*;
19
20use anyhow::Result;
21use futures::{AsyncRead, AsyncWrite, Future, future::LocalBoxFuture};
22use schemars::JsonSchema;
23use serde::{Deserialize, Serialize};
24use serde_json::value::RawValue;
25use std::{fmt, sync::Arc};
26
27use crate::rpc::{MessageHandler, RpcConnection, Side};
28
29#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
30#[serde(transparent)]
31pub struct SessionId(pub Arc<str>);
32
33impl fmt::Display for SessionId {
34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35        write!(f, "{}", self.0)
36    }
37}
38
39// Client to Agent
40
41pub struct ClientSideConnection {
42    conn: RpcConnection<ClientSide, AgentSide>,
43}
44
45impl ClientSideConnection {
46    pub fn new(
47        client: impl MessageHandler<ClientSide> + 'static,
48        outgoing_bytes: impl Unpin + AsyncWrite,
49        incoming_bytes: impl Unpin + AsyncRead,
50        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
51    ) -> (Self, impl Future<Output = Result<()>>) {
52        let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
53        (Self { conn }, io_task)
54    }
55}
56
57impl Agent for ClientSideConnection {
58    async fn initialize(&self, arguments: InitializeRequest) -> Result<InitializeResponse, Error> {
59        self.conn
60            .request(
61                INITIALIZE_METHOD_NAME,
62                Some(ClientRequest::InitializeRequest(arguments)),
63            )
64            .await
65    }
66
67    async fn authenticate(&self, arguments: AuthenticateRequest) -> Result<(), Error> {
68        self.conn
69            .request(
70                AUTHENTICATE_METHOD_NAME,
71                Some(ClientRequest::AuthenticateRequest(arguments)),
72            )
73            .await
74    }
75
76    async fn new_session(&self, arguments: NewSessionRequest) -> Result<NewSessionResponse, Error> {
77        self.conn
78            .request(
79                SESSION_NEW_METHOD_NAME,
80                Some(ClientRequest::NewSessionRequest(arguments)),
81            )
82            .await
83    }
84
85    async fn load_session(
86        &self,
87        arguments: LoadSessionRequest,
88    ) -> Result<LoadSessionResponse, Error> {
89        self.conn
90            .request(
91                SESSION_LOAD_METHOD_NAME,
92                Some(ClientRequest::LoadSessionRequest(arguments)),
93            )
94            .await
95    }
96
97    async fn prompt(&self, arguments: PromptRequest) -> Result<(), Error> {
98        self.conn
99            .request(
100                SESSION_PROMPT_METHOD_NAME,
101                Some(ClientRequest::PromptRequest(arguments)),
102            )
103            .await
104    }
105
106    async fn cancelled(&self, notification: CancelledNotification) -> Result<(), Error> {
107        self.conn.notify(
108            SESSION_CANCELLED_METHOD_NAME,
109            Some(ClientNotification::CancelledNotification(notification)),
110        )
111    }
112}
113
114pub struct ClientSide;
115
116impl Side for ClientSide {
117    type InNotification = AgentNotification;
118    type InRequest = AgentRequest;
119    type OutResponse = ClientResponse;
120
121    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest, Error> {
122        let params = params.ok_or_else(Error::invalid_params)?;
123
124        match method {
125            SESSION_REQUEST_PERMISSION_METHOD_NAME => serde_json::from_str(params.get())
126                .map(AgentRequest::RequestPermissionRequest)
127                .map_err(Into::into),
128            FS_WRITE_TEXT_FILE_METHOD_NAME => serde_json::from_str(params.get())
129                .map(AgentRequest::WriteTextFileRequest)
130                .map_err(Into::into),
131            FS_READ_TEXT_FILE_METHOD_NAME => serde_json::from_str(params.get())
132                .map(AgentRequest::ReadTextFileRequest)
133                .map_err(Into::into),
134            _ => Err(Error::method_not_found()),
135        }
136    }
137
138    fn decode_notification(
139        method: &str,
140        params: Option<&RawValue>,
141    ) -> Result<AgentNotification, Error> {
142        let params = params.ok_or_else(Error::invalid_params)?;
143
144        match method {
145            SESSION_UPDATE_NOTIFICATION => serde_json::from_str(params.get())
146                .map(AgentNotification::SessionNotification)
147                .map_err(Into::into),
148            _ => Err(Error::method_not_found()),
149        }
150    }
151}
152
153impl<T: Client> MessageHandler<ClientSide> for T {
154    async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse, Error> {
155        match request {
156            AgentRequest::RequestPermissionRequest(args) => {
157                let response = self.request_permission(args).await?;
158                Ok(ClientResponse::RequestPermissionResponse(response))
159            }
160            AgentRequest::WriteTextFileRequest(args) => {
161                self.write_text_file(args).await?;
162                Ok(ClientResponse::WriteTextFileResponse)
163            }
164            AgentRequest::ReadTextFileRequest(args) => {
165                let response = self.read_text_file(args).await?;
166                Ok(ClientResponse::ReadTextFileResponse(response))
167            }
168        }
169    }
170
171    async fn handle_notification(&self, notification: AgentNotification) -> Result<(), Error> {
172        match notification {
173            AgentNotification::SessionNotification(notification) => {
174                self.session_notification(notification).await?;
175            }
176        }
177        Ok(())
178    }
179}
180
181// Agent to Client
182
183pub struct AgentSideConnection {
184    conn: RpcConnection<AgentSide, ClientSide>,
185}
186
187impl AgentSideConnection {
188    pub fn new(
189        agent: impl MessageHandler<AgentSide> + 'static,
190        outgoing_bytes: impl Unpin + AsyncWrite,
191        incoming_bytes: impl Unpin + AsyncRead,
192        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
193    ) -> (Self, impl Future<Output = Result<()>>) {
194        let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
195        (Self { conn }, io_task)
196    }
197}
198
199impl Client for AgentSideConnection {
200    async fn request_permission(
201        &self,
202        arguments: RequestPermissionRequest,
203    ) -> Result<RequestPermissionResponse, Error> {
204        self.conn
205            .request(
206                SESSION_REQUEST_PERMISSION_METHOD_NAME,
207                Some(AgentRequest::RequestPermissionRequest(arguments)),
208            )
209            .await
210    }
211
212    async fn write_text_file(&self, arguments: WriteTextFileRequest) -> Result<(), Error> {
213        self.conn
214            .request(
215                FS_WRITE_TEXT_FILE_METHOD_NAME,
216                Some(AgentRequest::WriteTextFileRequest(arguments)),
217            )
218            .await
219    }
220
221    async fn read_text_file(
222        &self,
223        arguments: ReadTextFileRequest,
224    ) -> Result<ReadTextFileResponse, Error> {
225        self.conn
226            .request(
227                FS_READ_TEXT_FILE_METHOD_NAME,
228                Some(AgentRequest::ReadTextFileRequest(arguments)),
229            )
230            .await
231    }
232
233    async fn session_notification(&self, notification: SessionNotification) -> Result<(), Error> {
234        self.conn.notify(
235            SESSION_UPDATE_NOTIFICATION,
236            Some(AgentNotification::SessionNotification(notification)),
237        )
238    }
239}
240
241pub struct AgentSide;
242
243impl Side for AgentSide {
244    type InRequest = ClientRequest;
245    type InNotification = ClientNotification;
246    type OutResponse = AgentResponse;
247
248    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest, Error> {
249        let params = params.ok_or_else(Error::invalid_params)?;
250
251        match method {
252            INITIALIZE_METHOD_NAME => serde_json::from_str(params.get())
253                .map(ClientRequest::InitializeRequest)
254                .map_err(Into::into),
255            AUTHENTICATE_METHOD_NAME => serde_json::from_str(params.get())
256                .map(ClientRequest::AuthenticateRequest)
257                .map_err(Into::into),
258            SESSION_NEW_METHOD_NAME => serde_json::from_str(params.get())
259                .map(ClientRequest::NewSessionRequest)
260                .map_err(Into::into),
261            SESSION_LOAD_METHOD_NAME => serde_json::from_str(params.get())
262                .map(ClientRequest::LoadSessionRequest)
263                .map_err(Into::into),
264            SESSION_PROMPT_METHOD_NAME => serde_json::from_str(params.get())
265                .map(ClientRequest::PromptRequest)
266                .map_err(Into::into),
267            _ => Err(Error::method_not_found()),
268        }
269    }
270
271    fn decode_notification(
272        method: &str,
273        params: Option<&RawValue>,
274    ) -> Result<ClientNotification, Error> {
275        let params = params.ok_or_else(Error::invalid_params)?;
276
277        match method {
278            SESSION_CANCELLED_METHOD_NAME => serde_json::from_str(params.get())
279                .map(ClientNotification::CancelledNotification)
280                .map_err(Into::into),
281            _ => Err(Error::method_not_found()),
282        }
283    }
284}
285
286impl<T: Agent> MessageHandler<AgentSide> for T {
287    async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse, Error> {
288        match request {
289            ClientRequest::InitializeRequest(args) => {
290                let response = self.initialize(args).await?;
291                Ok(AgentResponse::InitializeResponse(response))
292            }
293            ClientRequest::AuthenticateRequest(args) => {
294                self.authenticate(args).await?;
295                Ok(AgentResponse::AuthenticateResponse)
296            }
297            ClientRequest::NewSessionRequest(args) => {
298                let response = self.new_session(args).await?;
299                Ok(AgentResponse::NewSessionResponse(response))
300            }
301            ClientRequest::LoadSessionRequest(args) => {
302                let response = self.load_session(args).await?;
303                Ok(AgentResponse::LoadSessionResponse(response))
304            }
305            ClientRequest::PromptRequest(args) => {
306                self.prompt(args).await?;
307                Ok(AgentResponse::PromptResponse)
308            }
309        }
310    }
311
312    async fn handle_notification(&self, notification: ClientNotification) -> Result<(), Error> {
313        match notification {
314            ClientNotification::CancelledNotification(notification) => {
315                self.cancelled(notification).await?;
316            }
317        }
318        Ok(())
319    }
320}