agent_client_protocol/
lib.rs

1use futures::{AsyncRead, AsyncWrite, future::LocalBoxFuture};
2use rpc::RpcConnection;
3
4mod agent;
5mod client;
6mod rpc;
7#[cfg(test)]
8mod rpc_tests;
9mod stream_broadcast;
10
11pub use agent::*;
12pub use agent_client_protocol_schema::*;
13pub use client::*;
14pub use rpc::*;
15pub use stream_broadcast::{
16    StreamMessage, StreamMessageContent, StreamMessageDirection, StreamReceiver,
17};
18
19// Client to Agent
20
21/// A client-side connection to an agent.
22///
23/// This struct provides the client's view of an ACP connection, allowing
24/// clients (such as code editors) to communicate with agents. It implements
25/// the [`Agent`] trait to provide methods for initializing sessions, sending
26/// prompts, and managing the agent lifecycle.
27///
28/// See protocol docs: [Client](https://agentclientprotocol.com/protocol/overview#client)
29pub struct ClientSideConnection {
30    conn: RpcConnection<ClientSide, AgentSide>,
31}
32
33impl ClientSideConnection {
34    /// Creates a new client-side connection to an agent.
35    ///
36    /// This establishes the communication channel between a client and agent
37    /// following the ACP specification.
38    ///
39    /// # Arguments
40    ///
41    /// * `client` - A handler that implements the [`Client`] trait to process incoming agent requests
42    /// * `outgoing_bytes` - The stream for sending data to the agent (typically stdout)
43    /// * `incoming_bytes` - The stream for receiving data from the agent (typically stdin)
44    /// * `spawn` - A function to spawn async tasks (e.g., `tokio::spawn`)
45    ///
46    /// # Returns
47    ///
48    /// Returns a tuple containing:
49    /// - The connection instance for making requests to the agent
50    /// - An I/O future that must be spawned to handle the underlying communication
51    ///
52    /// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
53    pub fn new(
54        client: impl MessageHandler<ClientSide> + 'static,
55        outgoing_bytes: impl Unpin + AsyncWrite,
56        incoming_bytes: impl Unpin + AsyncRead,
57        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
58    ) -> (Self, impl Future<Output = Result<()>>) {
59        let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
60        (Self { conn }, io_task)
61    }
62
63    /// Subscribe to receive stream updates from the agent.
64    ///
65    /// This allows the client to receive real-time notifications about
66    /// agent activities, such as tool calls, content updates, and progress reports.
67    ///
68    /// # Returns
69    ///
70    /// A [`StreamReceiver`] that can be used to receive stream messages.
71    pub fn subscribe(&self) -> StreamReceiver {
72        self.conn.subscribe()
73    }
74}
75
76#[async_trait::async_trait(?Send)]
77impl Agent for ClientSideConnection {
78    async fn initialize(&self, args: InitializeRequest) -> Result<InitializeResponse> {
79        self.conn
80            .request(
81                AGENT_METHOD_NAMES.initialize,
82                Some(ClientRequest::InitializeRequest(args)),
83            )
84            .await
85    }
86
87    async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse> {
88        self.conn
89            .request::<Option<_>>(
90                AGENT_METHOD_NAMES.authenticate,
91                Some(ClientRequest::AuthenticateRequest(args)),
92            )
93            .await
94            .map(Option::unwrap_or_default)
95    }
96
97    async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse> {
98        self.conn
99            .request(
100                AGENT_METHOD_NAMES.session_new,
101                Some(ClientRequest::NewSessionRequest(args)),
102            )
103            .await
104    }
105
106    async fn load_session(&self, args: LoadSessionRequest) -> Result<LoadSessionResponse> {
107        self.conn
108            .request::<Option<_>>(
109                AGENT_METHOD_NAMES.session_load,
110                Some(ClientRequest::LoadSessionRequest(args)),
111            )
112            .await
113            .map(Option::unwrap_or_default)
114    }
115
116    async fn set_session_mode(
117        &self,
118        args: SetSessionModeRequest,
119    ) -> Result<SetSessionModeResponse> {
120        self.conn
121            .request(
122                AGENT_METHOD_NAMES.session_set_mode,
123                Some(ClientRequest::SetSessionModeRequest(args)),
124            )
125            .await
126    }
127
128    async fn prompt(&self, args: PromptRequest) -> Result<PromptResponse> {
129        self.conn
130            .request(
131                AGENT_METHOD_NAMES.session_prompt,
132                Some(ClientRequest::PromptRequest(args)),
133            )
134            .await
135    }
136
137    async fn cancel(&self, args: CancelNotification) -> Result<()> {
138        self.conn.notify(
139            AGENT_METHOD_NAMES.session_cancel,
140            Some(ClientNotification::CancelNotification(args)),
141        )
142    }
143
144    #[cfg(feature = "unstable")]
145    async fn set_session_model(
146        &self,
147        args: SetSessionModelRequest,
148    ) -> Result<SetSessionModelResponse> {
149        self.conn
150            .request(
151                AGENT_METHOD_NAMES.session_set_model,
152                Some(ClientRequest::SetSessionModelRequest(args)),
153            )
154            .await
155    }
156
157    async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
158        self.conn
159            .request(
160                format!("_{}", args.method),
161                Some(ClientRequest::ExtMethodRequest(args)),
162            )
163            .await
164    }
165
166    async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
167        self.conn.notify(
168            format!("_{}", args.method),
169            Some(ClientNotification::ExtNotification(args)),
170        )
171    }
172}
173
174/// Marker type representing the client side of an ACP connection.
175///
176/// This type is used by the RPC layer to determine which messages
177/// are incoming vs outgoing from the client's perspective.
178///
179/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
180#[derive(Clone)]
181pub struct ClientSide;
182
183impl Side for ClientSide {
184    type InNotification = AgentNotification;
185    type InRequest = AgentRequest;
186    type OutResponse = ClientResponse;
187
188    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
189        let params = params.ok_or_else(Error::invalid_params)?;
190
191        match method {
192            m if m == CLIENT_METHOD_NAMES.session_request_permission => {
193                serde_json::from_str(params.get())
194                    .map(AgentRequest::RequestPermissionRequest)
195                    .map_err(Into::into)
196            }
197            m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
198                .map(AgentRequest::WriteTextFileRequest)
199                .map_err(Into::into),
200            m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
201                .map(AgentRequest::ReadTextFileRequest)
202                .map_err(Into::into),
203            m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
204                .map(AgentRequest::CreateTerminalRequest)
205                .map_err(Into::into),
206            m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
207                .map(AgentRequest::TerminalOutputRequest)
208                .map_err(Into::into),
209            m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
210                .map(AgentRequest::KillTerminalCommandRequest)
211                .map_err(Into::into),
212            m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
213                .map(AgentRequest::ReleaseTerminalRequest)
214                .map_err(Into::into),
215            m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
216                serde_json::from_str(params.get())
217                    .map(AgentRequest::WaitForTerminalExitRequest)
218                    .map_err(Into::into)
219            }
220            _ => {
221                if let Some(custom_method) = method.strip_prefix('_') {
222                    Ok(AgentRequest::ExtMethodRequest(ExtRequest {
223                        method: custom_method.into(),
224                        params: params.to_owned().into(),
225                    }))
226                } else {
227                    Err(Error::method_not_found())
228                }
229            }
230        }
231    }
232
233    fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
234        let params = params.ok_or_else(Error::invalid_params)?;
235
236        match method {
237            m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
238                .map(AgentNotification::SessionNotification)
239                .map_err(Into::into),
240            _ => {
241                if let Some(custom_method) = method.strip_prefix('_') {
242                    Ok(AgentNotification::ExtNotification(ExtNotification {
243                        method: custom_method.into(),
244                        params: RawValue::from_string(params.get().to_string())?.into(),
245                    }))
246                } else {
247                    Err(Error::method_not_found())
248                }
249            }
250        }
251    }
252}
253
254impl<T: Client> MessageHandler<ClientSide> for T {
255    async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse> {
256        match request {
257            AgentRequest::RequestPermissionRequest(args) => {
258                let response = self.request_permission(args).await?;
259                Ok(ClientResponse::RequestPermissionResponse(response))
260            }
261            AgentRequest::WriteTextFileRequest(args) => {
262                let response = self.write_text_file(args).await?;
263                Ok(ClientResponse::WriteTextFileResponse(response))
264            }
265            AgentRequest::ReadTextFileRequest(args) => {
266                let response = self.read_text_file(args).await?;
267                Ok(ClientResponse::ReadTextFileResponse(response))
268            }
269            AgentRequest::CreateTerminalRequest(args) => {
270                let response = self.create_terminal(args).await?;
271                Ok(ClientResponse::CreateTerminalResponse(response))
272            }
273            AgentRequest::TerminalOutputRequest(args) => {
274                let response = self.terminal_output(args).await?;
275                Ok(ClientResponse::TerminalOutputResponse(response))
276            }
277            AgentRequest::ReleaseTerminalRequest(args) => {
278                let response = self.release_terminal(args).await?;
279                Ok(ClientResponse::ReleaseTerminalResponse(response))
280            }
281            AgentRequest::WaitForTerminalExitRequest(args) => {
282                let response = self.wait_for_terminal_exit(args).await?;
283                Ok(ClientResponse::WaitForTerminalExitResponse(response))
284            }
285            AgentRequest::KillTerminalCommandRequest(args) => {
286                let response = self.kill_terminal_command(args).await?;
287                Ok(ClientResponse::KillTerminalResponse(response))
288            }
289            AgentRequest::ExtMethodRequest(args) => {
290                let response = self.ext_method(args).await?;
291                Ok(ClientResponse::ExtMethodResponse(response))
292            }
293        }
294    }
295
296    async fn handle_notification(&self, notification: AgentNotification) -> Result<()> {
297        match notification {
298            AgentNotification::SessionNotification(args) => {
299                self.session_notification(args).await?;
300            }
301            AgentNotification::ExtNotification(args) => {
302                self.ext_notification(args).await?;
303            }
304        }
305        Ok(())
306    }
307}
308
309// Agent to Client
310
311/// An agent-side connection to a client.
312///
313/// This struct provides the agent's view of an ACP connection, allowing
314/// agents to communicate with clients. It implements the [`Client`] trait
315/// to provide methods for requesting permissions, accessing the file system,
316/// and sending session updates.
317///
318/// See protocol docs: [Agent](https://agentclientprotocol.com/protocol/overview#agent)
319pub struct AgentSideConnection {
320    conn: RpcConnection<AgentSide, ClientSide>,
321}
322
323impl AgentSideConnection {
324    /// Creates a new agent-side connection to a client.
325    ///
326    /// This establishes the communication channel from the agent's perspective
327    /// following the ACP specification.
328    ///
329    /// # Arguments
330    ///
331    /// * `agent` - A handler that implements the [`Agent`] trait to process incoming client requests
332    /// * `outgoing_bytes` - The stream for sending data to the client (typically stdout)
333    /// * `incoming_bytes` - The stream for receiving data from the client (typically stdin)
334    /// * `spawn` - A function to spawn async tasks (e.g., `tokio::spawn`)
335    ///
336    /// # Returns
337    ///
338    /// Returns a tuple containing:
339    /// - The connection instance for making requests to the client
340    /// - An I/O future that must be spawned to handle the underlying communication
341    ///
342    /// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
343    pub fn new(
344        agent: impl MessageHandler<AgentSide> + 'static,
345        outgoing_bytes: impl Unpin + AsyncWrite,
346        incoming_bytes: impl Unpin + AsyncRead,
347        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
348    ) -> (Self, impl Future<Output = Result<()>>) {
349        let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
350        (Self { conn }, io_task)
351    }
352
353    /// Subscribe to receive stream updates from the client.
354    ///
355    /// This allows the agent to receive real-time notifications about
356    /// client activities and cancellation requests.
357    ///
358    /// # Returns
359    ///
360    /// A [`StreamReceiver`] that can be used to receive stream messages.
361    pub fn subscribe(&self) -> StreamReceiver {
362        self.conn.subscribe()
363    }
364}
365
366#[async_trait::async_trait(?Send)]
367impl Client for AgentSideConnection {
368    async fn request_permission(
369        &self,
370        args: RequestPermissionRequest,
371    ) -> Result<RequestPermissionResponse> {
372        self.conn
373            .request(
374                CLIENT_METHOD_NAMES.session_request_permission,
375                Some(AgentRequest::RequestPermissionRequest(args)),
376            )
377            .await
378    }
379
380    async fn write_text_file(&self, args: WriteTextFileRequest) -> Result<WriteTextFileResponse> {
381        self.conn
382            .request::<Option<_>>(
383                CLIENT_METHOD_NAMES.fs_write_text_file,
384                Some(AgentRequest::WriteTextFileRequest(args)),
385            )
386            .await
387            .map(Option::unwrap_or_default)
388    }
389
390    async fn read_text_file(&self, args: ReadTextFileRequest) -> Result<ReadTextFileResponse> {
391        self.conn
392            .request(
393                CLIENT_METHOD_NAMES.fs_read_text_file,
394                Some(AgentRequest::ReadTextFileRequest(args)),
395            )
396            .await
397    }
398
399    async fn create_terminal(&self, args: CreateTerminalRequest) -> Result<CreateTerminalResponse> {
400        self.conn
401            .request(
402                CLIENT_METHOD_NAMES.terminal_create,
403                Some(AgentRequest::CreateTerminalRequest(args)),
404            )
405            .await
406    }
407
408    async fn terminal_output(&self, args: TerminalOutputRequest) -> Result<TerminalOutputResponse> {
409        self.conn
410            .request(
411                CLIENT_METHOD_NAMES.terminal_output,
412                Some(AgentRequest::TerminalOutputRequest(args)),
413            )
414            .await
415    }
416
417    async fn release_terminal(
418        &self,
419        args: ReleaseTerminalRequest,
420    ) -> Result<ReleaseTerminalResponse> {
421        self.conn
422            .request::<Option<_>>(
423                CLIENT_METHOD_NAMES.terminal_release,
424                Some(AgentRequest::ReleaseTerminalRequest(args)),
425            )
426            .await
427            .map(Option::unwrap_or_default)
428    }
429
430    async fn wait_for_terminal_exit(
431        &self,
432        args: WaitForTerminalExitRequest,
433    ) -> Result<WaitForTerminalExitResponse> {
434        self.conn
435            .request(
436                CLIENT_METHOD_NAMES.terminal_wait_for_exit,
437                Some(AgentRequest::WaitForTerminalExitRequest(args)),
438            )
439            .await
440    }
441
442    async fn kill_terminal_command(
443        &self,
444        args: KillTerminalCommandRequest,
445    ) -> Result<KillTerminalCommandResponse> {
446        self.conn
447            .request::<Option<_>>(
448                CLIENT_METHOD_NAMES.terminal_kill,
449                Some(AgentRequest::KillTerminalCommandRequest(args)),
450            )
451            .await
452            .map(Option::unwrap_or_default)
453    }
454
455    async fn session_notification(&self, args: SessionNotification) -> Result<()> {
456        self.conn.notify(
457            CLIENT_METHOD_NAMES.session_update,
458            Some(AgentNotification::SessionNotification(args)),
459        )
460    }
461
462    async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
463        self.conn
464            .request(
465                format!("_{}", args.method),
466                Some(AgentRequest::ExtMethodRequest(args)),
467            )
468            .await
469    }
470
471    async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
472        self.conn.notify(
473            format!("_{}", args.method),
474            Some(AgentNotification::ExtNotification(args)),
475        )
476    }
477}
478
479/// Marker type representing the agent side of an ACP connection.
480///
481/// This type is used by the RPC layer to determine which messages
482/// are incoming vs outgoing from the agent's perspective.
483///
484/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
485#[derive(Clone)]
486pub struct AgentSide;
487
488impl Side for AgentSide {
489    type InRequest = ClientRequest;
490    type InNotification = ClientNotification;
491    type OutResponse = AgentResponse;
492
493    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
494        let params = params.ok_or_else(Error::invalid_params)?;
495
496        match method {
497            m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
498                .map(ClientRequest::InitializeRequest)
499                .map_err(Into::into),
500            m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
501                .map(ClientRequest::AuthenticateRequest)
502                .map_err(Into::into),
503            m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
504                .map(ClientRequest::NewSessionRequest)
505                .map_err(Into::into),
506            m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
507                .map(ClientRequest::LoadSessionRequest)
508                .map_err(Into::into),
509            m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
510                .map(ClientRequest::SetSessionModeRequest)
511                .map_err(Into::into),
512            #[cfg(feature = "unstable")]
513            m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
514                .map(ClientRequest::SetSessionModelRequest)
515                .map_err(Into::into),
516            m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
517                .map(ClientRequest::PromptRequest)
518                .map_err(Into::into),
519            _ => {
520                if let Some(custom_method) = method.strip_prefix('_') {
521                    Ok(ClientRequest::ExtMethodRequest(ExtRequest {
522                        method: custom_method.into(),
523                        params: params.to_owned().into(),
524                    }))
525                } else {
526                    Err(Error::method_not_found())
527                }
528            }
529        }
530    }
531
532    fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
533        let params = params.ok_or_else(Error::invalid_params)?;
534
535        match method {
536            m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
537                .map(ClientNotification::CancelNotification)
538                .map_err(Into::into),
539            _ => {
540                if let Some(custom_method) = method.strip_prefix('_') {
541                    Ok(ClientNotification::ExtNotification(ExtNotification {
542                        method: custom_method.into(),
543                        params: RawValue::from_string(params.get().to_string())?.into(),
544                    }))
545                } else {
546                    Err(Error::method_not_found())
547                }
548            }
549        }
550    }
551}
552
553impl<T: Agent> MessageHandler<AgentSide> for T {
554    async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse> {
555        match request {
556            ClientRequest::InitializeRequest(args) => {
557                let response = self.initialize(args).await?;
558                Ok(AgentResponse::InitializeResponse(response))
559            }
560            ClientRequest::AuthenticateRequest(args) => {
561                let response = self.authenticate(args).await?;
562                Ok(AgentResponse::AuthenticateResponse(response))
563            }
564            ClientRequest::NewSessionRequest(args) => {
565                let response = self.new_session(args).await?;
566                Ok(AgentResponse::NewSessionResponse(response))
567            }
568            ClientRequest::LoadSessionRequest(args) => {
569                let response = self.load_session(args).await?;
570                Ok(AgentResponse::LoadSessionResponse(response))
571            }
572            ClientRequest::PromptRequest(args) => {
573                let response = self.prompt(args).await?;
574                Ok(AgentResponse::PromptResponse(response))
575            }
576            ClientRequest::SetSessionModeRequest(args) => {
577                let response = self.set_session_mode(args).await?;
578                Ok(AgentResponse::SetSessionModeResponse(response))
579            }
580            #[cfg(feature = "unstable")]
581            ClientRequest::SetSessionModelRequest(args) => {
582                let response = self.set_session_model(args).await?;
583                Ok(AgentResponse::SetSessionModelResponse(response))
584            }
585            ClientRequest::ExtMethodRequest(args) => {
586                let response = self.ext_method(args).await?;
587                Ok(AgentResponse::ExtMethodResponse(response))
588            }
589        }
590    }
591
592    async fn handle_notification(&self, notification: ClientNotification) -> Result<()> {
593        match notification {
594            ClientNotification::CancelNotification(args) => {
595                self.cancel(args).await?;
596            }
597            ClientNotification::ExtNotification(args) => {
598                self.ext_notification(args).await?;
599            }
600        }
601        Ok(())
602    }
603}