agent_client_protocol/
lib.rs

1use anyhow::Result;
2use futures::{AsyncRead, AsyncWrite, future::LocalBoxFuture};
3use rpc::{MessageHandler, RpcConnection, Side};
4
5mod agent;
6mod client;
7mod rpc;
8#[cfg(test)]
9mod rpc_tests;
10mod stream_broadcast;
11
12pub use agent::*;
13pub use agent_client_protocol_schema::*;
14pub use client::*;
15pub use stream_broadcast::{
16    StreamMessage, StreamMessageContent, StreamMessageDirection, StreamReceiver,
17};
18
19// Client to Agent
20
21/// A client-side connection to an agent.
22///
23/// This struct provides the client's view of an ACP connection, allowing
24/// clients (such as code editors) to communicate with agents. It implements
25/// the [`Agent`] trait to provide methods for initializing sessions, sending
26/// prompts, and managing the agent lifecycle.
27///
28/// See protocol docs: [Client](https://agentclientprotocol.com/protocol/overview#client)
29pub struct ClientSideConnection {
30    conn: RpcConnection<ClientSide, AgentSide>,
31}
32
33impl ClientSideConnection {
34    /// Creates a new client-side connection to an agent.
35    ///
36    /// This establishes the communication channel between a client and agent
37    /// following the ACP specification.
38    ///
39    /// # Arguments
40    ///
41    /// * `client` - A handler that implements the [`Client`] trait to process incoming agent requests
42    /// * `outgoing_bytes` - The stream for sending data to the agent (typically stdout)
43    /// * `incoming_bytes` - The stream for receiving data from the agent (typically stdin)
44    /// * `spawn` - A function to spawn async tasks (e.g., `tokio::spawn`)
45    ///
46    /// # Returns
47    ///
48    /// Returns a tuple containing:
49    /// - The connection instance for making requests to the agent
50    /// - An I/O future that must be spawned to handle the underlying communication
51    ///
52    /// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
53    pub fn new(
54        client: impl MessageHandler<ClientSide> + 'static,
55        outgoing_bytes: impl Unpin + AsyncWrite,
56        incoming_bytes: impl Unpin + AsyncRead,
57        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
58    ) -> (Self, impl Future<Output = Result<()>>) {
59        let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
60        (Self { conn }, io_task)
61    }
62
63    /// Subscribe to receive stream updates from the agent.
64    ///
65    /// This allows the client to receive real-time notifications about
66    /// agent activities, such as tool calls, content updates, and progress reports.
67    ///
68    /// # Returns
69    ///
70    /// A [`StreamReceiver`] that can be used to receive stream messages.
71    pub fn subscribe(&self) -> StreamReceiver {
72        self.conn.subscribe()
73    }
74}
75
76#[async_trait::async_trait(?Send)]
77impl Agent for ClientSideConnection {
78    async fn initialize(&self, args: InitializeRequest) -> Result<InitializeResponse, Error> {
79        self.conn
80            .request(
81                AGENT_METHOD_NAMES.initialize,
82                Some(ClientRequest::InitializeRequest(args)),
83            )
84            .await
85    }
86
87    async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse, Error> {
88        self.conn
89            .request::<Option<_>>(
90                AGENT_METHOD_NAMES.authenticate,
91                Some(ClientRequest::AuthenticateRequest(args)),
92            )
93            .await
94            .map(Option::unwrap_or_default)
95    }
96
97    async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse, Error> {
98        self.conn
99            .request(
100                AGENT_METHOD_NAMES.session_new,
101                Some(ClientRequest::NewSessionRequest(args)),
102            )
103            .await
104    }
105
106    async fn load_session(&self, args: LoadSessionRequest) -> Result<LoadSessionResponse, Error> {
107        self.conn
108            .request::<Option<_>>(
109                AGENT_METHOD_NAMES.session_load,
110                Some(ClientRequest::LoadSessionRequest(args)),
111            )
112            .await
113            .map(Option::unwrap_or_default)
114    }
115
116    async fn set_session_mode(
117        &self,
118        args: SetSessionModeRequest,
119    ) -> Result<SetSessionModeResponse, Error> {
120        self.conn
121            .request(
122                AGENT_METHOD_NAMES.session_set_mode,
123                Some(ClientRequest::SetSessionModeRequest(args)),
124            )
125            .await
126    }
127
128    async fn prompt(&self, args: PromptRequest) -> Result<PromptResponse, Error> {
129        self.conn
130            .request(
131                AGENT_METHOD_NAMES.session_prompt,
132                Some(ClientRequest::PromptRequest(args)),
133            )
134            .await
135    }
136
137    async fn cancel(&self, args: CancelNotification) -> Result<(), Error> {
138        self.conn.notify(
139            AGENT_METHOD_NAMES.session_cancel,
140            Some(ClientNotification::CancelNotification(args)),
141        )
142    }
143
144    #[cfg(feature = "unstable")]
145    async fn set_session_model(
146        &self,
147        args: SetSessionModelRequest,
148    ) -> Result<SetSessionModelResponse, Error> {
149        self.conn
150            .request(
151                AGENT_METHOD_NAMES.session_set_model,
152                Some(ClientRequest::SetSessionModelRequest(args)),
153            )
154            .await
155    }
156
157    async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse, Error> {
158        self.conn
159            .request(
160                format!("_{}", args.method),
161                Some(ClientRequest::ExtMethodRequest(args)),
162            )
163            .await
164    }
165
166    async fn ext_notification(&self, args: ExtNotification) -> Result<(), Error> {
167        self.conn.notify(
168            format!("_{}", args.method),
169            Some(ClientNotification::ExtNotification(args)),
170        )
171    }
172}
173
174/// Marker type representing the client side of an ACP connection.
175///
176/// This type is used by the RPC layer to determine which messages
177/// are incoming vs outgoing from the client's perspective.
178///
179/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
180#[derive(Clone)]
181pub struct ClientSide;
182
183impl Side for ClientSide {
184    type InNotification = AgentNotification;
185    type InRequest = AgentRequest;
186    type OutResponse = ClientResponse;
187
188    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest, Error> {
189        let params = params.ok_or_else(Error::invalid_params)?;
190
191        match method {
192            m if m == CLIENT_METHOD_NAMES.session_request_permission => {
193                serde_json::from_str(params.get())
194                    .map(AgentRequest::RequestPermissionRequest)
195                    .map_err(Into::into)
196            }
197            m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
198                .map(AgentRequest::WriteTextFileRequest)
199                .map_err(Into::into),
200            m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
201                .map(AgentRequest::ReadTextFileRequest)
202                .map_err(Into::into),
203            m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
204                .map(AgentRequest::CreateTerminalRequest)
205                .map_err(Into::into),
206            m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
207                .map(AgentRequest::TerminalOutputRequest)
208                .map_err(Into::into),
209            m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
210                .map(AgentRequest::KillTerminalCommandRequest)
211                .map_err(Into::into),
212            m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
213                .map(AgentRequest::ReleaseTerminalRequest)
214                .map_err(Into::into),
215            m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
216                serde_json::from_str(params.get())
217                    .map(AgentRequest::WaitForTerminalExitRequest)
218                    .map_err(Into::into)
219            }
220            _ => {
221                if let Some(custom_method) = method.strip_prefix('_') {
222                    Ok(AgentRequest::ExtMethodRequest(ExtRequest {
223                        method: custom_method.into(),
224                        params: RawValue::from_string(params.get().to_string())?.into(),
225                    }))
226                } else {
227                    Err(Error::method_not_found())
228                }
229            }
230        }
231    }
232
233    fn decode_notification(
234        method: &str,
235        params: Option<&RawValue>,
236    ) -> Result<AgentNotification, Error> {
237        let params = params.ok_or_else(Error::invalid_params)?;
238
239        match method {
240            m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
241                .map(AgentNotification::SessionNotification)
242                .map_err(Into::into),
243            _ => {
244                if let Some(custom_method) = method.strip_prefix('_') {
245                    Ok(AgentNotification::ExtNotification(ExtNotification {
246                        method: custom_method.into(),
247                        params: RawValue::from_string(params.get().to_string())?.into(),
248                    }))
249                } else {
250                    Err(Error::method_not_found())
251                }
252            }
253        }
254    }
255}
256
257impl<T: Client> MessageHandler<ClientSide> for T {
258    async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse, Error> {
259        match request {
260            AgentRequest::RequestPermissionRequest(args) => {
261                let response = self.request_permission(args).await?;
262                Ok(ClientResponse::RequestPermissionResponse(response))
263            }
264            AgentRequest::WriteTextFileRequest(args) => {
265                let response = self.write_text_file(args).await?;
266                Ok(ClientResponse::WriteTextFileResponse(response))
267            }
268            AgentRequest::ReadTextFileRequest(args) => {
269                let response = self.read_text_file(args).await?;
270                Ok(ClientResponse::ReadTextFileResponse(response))
271            }
272            AgentRequest::CreateTerminalRequest(args) => {
273                let response = self.create_terminal(args).await?;
274                Ok(ClientResponse::CreateTerminalResponse(response))
275            }
276            AgentRequest::TerminalOutputRequest(args) => {
277                let response = self.terminal_output(args).await?;
278                Ok(ClientResponse::TerminalOutputResponse(response))
279            }
280            AgentRequest::ReleaseTerminalRequest(args) => {
281                let response = self.release_terminal(args).await?;
282                Ok(ClientResponse::ReleaseTerminalResponse(response))
283            }
284            AgentRequest::WaitForTerminalExitRequest(args) => {
285                let response = self.wait_for_terminal_exit(args).await?;
286                Ok(ClientResponse::WaitForTerminalExitResponse(response))
287            }
288            AgentRequest::KillTerminalCommandRequest(args) => {
289                let response = self.kill_terminal_command(args).await?;
290                Ok(ClientResponse::KillTerminalResponse(response))
291            }
292            AgentRequest::ExtMethodRequest(args) => {
293                let response = self.ext_method(args).await?;
294                Ok(ClientResponse::ExtMethodResponse(response))
295            }
296        }
297    }
298
299    async fn handle_notification(&self, notification: AgentNotification) -> Result<(), Error> {
300        match notification {
301            AgentNotification::SessionNotification(args) => {
302                self.session_notification(args).await?;
303            }
304            AgentNotification::ExtNotification(args) => {
305                self.ext_notification(args).await?;
306            }
307        }
308        Ok(())
309    }
310}
311
312// Agent to Client
313
314/// An agent-side connection to a client.
315///
316/// This struct provides the agent's view of an ACP connection, allowing
317/// agents to communicate with clients. It implements the [`Client`] trait
318/// to provide methods for requesting permissions, accessing the file system,
319/// and sending session updates.
320///
321/// See protocol docs: [Agent](https://agentclientprotocol.com/protocol/overview#agent)
322pub struct AgentSideConnection {
323    conn: RpcConnection<AgentSide, ClientSide>,
324}
325
326impl AgentSideConnection {
327    /// Creates a new agent-side connection to a client.
328    ///
329    /// This establishes the communication channel from the agent's perspective
330    /// following the ACP specification.
331    ///
332    /// # Arguments
333    ///
334    /// * `agent` - A handler that implements the [`Agent`] trait to process incoming client requests
335    /// * `outgoing_bytes` - The stream for sending data to the client (typically stdout)
336    /// * `incoming_bytes` - The stream for receiving data from the client (typically stdin)
337    /// * `spawn` - A function to spawn async tasks (e.g., `tokio::spawn`)
338    ///
339    /// # Returns
340    ///
341    /// Returns a tuple containing:
342    /// - The connection instance for making requests to the client
343    /// - An I/O future that must be spawned to handle the underlying communication
344    ///
345    /// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
346    pub fn new(
347        agent: impl MessageHandler<AgentSide> + 'static,
348        outgoing_bytes: impl Unpin + AsyncWrite,
349        incoming_bytes: impl Unpin + AsyncRead,
350        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
351    ) -> (Self, impl Future<Output = Result<()>>) {
352        let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
353        (Self { conn }, io_task)
354    }
355
356    /// Subscribe to receive stream updates from the client.
357    ///
358    /// This allows the agent to receive real-time notifications about
359    /// client activities and cancellation requests.
360    ///
361    /// # Returns
362    ///
363    /// A [`StreamReceiver`] that can be used to receive stream messages.
364    pub fn subscribe(&self) -> StreamReceiver {
365        self.conn.subscribe()
366    }
367}
368
369#[async_trait::async_trait(?Send)]
370impl Client for AgentSideConnection {
371    async fn request_permission(
372        &self,
373        args: RequestPermissionRequest,
374    ) -> Result<RequestPermissionResponse, Error> {
375        self.conn
376            .request(
377                CLIENT_METHOD_NAMES.session_request_permission,
378                Some(AgentRequest::RequestPermissionRequest(args)),
379            )
380            .await
381    }
382
383    async fn write_text_file(
384        &self,
385        args: WriteTextFileRequest,
386    ) -> Result<WriteTextFileResponse, Error> {
387        self.conn
388            .request::<Option<_>>(
389                CLIENT_METHOD_NAMES.fs_write_text_file,
390                Some(AgentRequest::WriteTextFileRequest(args)),
391            )
392            .await
393            .map(Option::unwrap_or_default)
394    }
395
396    async fn read_text_file(
397        &self,
398        args: ReadTextFileRequest,
399    ) -> Result<ReadTextFileResponse, Error> {
400        self.conn
401            .request(
402                CLIENT_METHOD_NAMES.fs_read_text_file,
403                Some(AgentRequest::ReadTextFileRequest(args)),
404            )
405            .await
406    }
407
408    async fn create_terminal(
409        &self,
410        args: CreateTerminalRequest,
411    ) -> Result<CreateTerminalResponse, Error> {
412        self.conn
413            .request(
414                CLIENT_METHOD_NAMES.terminal_create,
415                Some(AgentRequest::CreateTerminalRequest(args)),
416            )
417            .await
418    }
419
420    async fn terminal_output(
421        &self,
422        args: TerminalOutputRequest,
423    ) -> Result<TerminalOutputResponse, Error> {
424        self.conn
425            .request(
426                CLIENT_METHOD_NAMES.terminal_output,
427                Some(AgentRequest::TerminalOutputRequest(args)),
428            )
429            .await
430    }
431
432    async fn release_terminal(
433        &self,
434        args: ReleaseTerminalRequest,
435    ) -> Result<ReleaseTerminalResponse, Error> {
436        self.conn
437            .request::<Option<_>>(
438                CLIENT_METHOD_NAMES.terminal_release,
439                Some(AgentRequest::ReleaseTerminalRequest(args)),
440            )
441            .await
442            .map(Option::unwrap_or_default)
443    }
444
445    async fn wait_for_terminal_exit(
446        &self,
447        args: WaitForTerminalExitRequest,
448    ) -> Result<WaitForTerminalExitResponse, Error> {
449        self.conn
450            .request(
451                CLIENT_METHOD_NAMES.terminal_wait_for_exit,
452                Some(AgentRequest::WaitForTerminalExitRequest(args)),
453            )
454            .await
455    }
456
457    async fn kill_terminal_command(
458        &self,
459        args: KillTerminalCommandRequest,
460    ) -> Result<KillTerminalCommandResponse, Error> {
461        self.conn
462            .request::<Option<_>>(
463                CLIENT_METHOD_NAMES.terminal_kill,
464                Some(AgentRequest::KillTerminalCommandRequest(args)),
465            )
466            .await
467            .map(Option::unwrap_or_default)
468    }
469
470    async fn session_notification(&self, args: SessionNotification) -> Result<(), Error> {
471        self.conn.notify(
472            CLIENT_METHOD_NAMES.session_update,
473            Some(AgentNotification::SessionNotification(args)),
474        )
475    }
476
477    async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse, Error> {
478        self.conn
479            .request(
480                format!("_{}", args.method),
481                Some(AgentRequest::ExtMethodRequest(args)),
482            )
483            .await
484    }
485
486    async fn ext_notification(&self, args: ExtNotification) -> Result<(), Error> {
487        self.conn.notify(
488            format!("_{}", args.method),
489            Some(AgentNotification::ExtNotification(args)),
490        )
491    }
492}
493
494/// Marker type representing the agent side of an ACP connection.
495///
496/// This type is used by the RPC layer to determine which messages
497/// are incoming vs outgoing from the agent's perspective.
498///
499/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
500#[derive(Clone)]
501pub struct AgentSide;
502
503impl Side for AgentSide {
504    type InRequest = ClientRequest;
505    type InNotification = ClientNotification;
506    type OutResponse = AgentResponse;
507
508    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest, Error> {
509        let params = params.ok_or_else(Error::invalid_params)?;
510
511        match method {
512            m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
513                .map(ClientRequest::InitializeRequest)
514                .map_err(Into::into),
515            m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
516                .map(ClientRequest::AuthenticateRequest)
517                .map_err(Into::into),
518            m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
519                .map(ClientRequest::NewSessionRequest)
520                .map_err(Into::into),
521            m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
522                .map(ClientRequest::LoadSessionRequest)
523                .map_err(Into::into),
524            m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
525                .map(ClientRequest::SetSessionModeRequest)
526                .map_err(Into::into),
527            #[cfg(feature = "unstable")]
528            m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
529                .map(ClientRequest::SetSessionModelRequest)
530                .map_err(Into::into),
531            m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
532                .map(ClientRequest::PromptRequest)
533                .map_err(Into::into),
534            _ => {
535                if let Some(custom_method) = method.strip_prefix('_') {
536                    Ok(ClientRequest::ExtMethodRequest(ExtRequest {
537                        method: custom_method.into(),
538                        params: RawValue::from_string(params.get().to_string())?.into(),
539                    }))
540                } else {
541                    Err(Error::method_not_found())
542                }
543            }
544        }
545    }
546
547    fn decode_notification(
548        method: &str,
549        params: Option<&RawValue>,
550    ) -> Result<ClientNotification, Error> {
551        let params = params.ok_or_else(Error::invalid_params)?;
552
553        match method {
554            m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
555                .map(ClientNotification::CancelNotification)
556                .map_err(Into::into),
557            _ => {
558                if let Some(custom_method) = method.strip_prefix('_') {
559                    Ok(ClientNotification::ExtNotification(ExtNotification {
560                        method: custom_method.into(),
561                        params: RawValue::from_string(params.get().to_string())?.into(),
562                    }))
563                } else {
564                    Err(Error::method_not_found())
565                }
566            }
567        }
568    }
569}
570
571impl<T: Agent> MessageHandler<AgentSide> for T {
572    async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse, Error> {
573        match request {
574            ClientRequest::InitializeRequest(args) => {
575                let response = self.initialize(args).await?;
576                Ok(AgentResponse::InitializeResponse(response))
577            }
578            ClientRequest::AuthenticateRequest(args) => {
579                let response = self.authenticate(args).await?;
580                Ok(AgentResponse::AuthenticateResponse(response))
581            }
582            ClientRequest::NewSessionRequest(args) => {
583                let response = self.new_session(args).await?;
584                Ok(AgentResponse::NewSessionResponse(response))
585            }
586            ClientRequest::LoadSessionRequest(args) => {
587                let response = self.load_session(args).await?;
588                Ok(AgentResponse::LoadSessionResponse(response))
589            }
590            ClientRequest::PromptRequest(args) => {
591                let response = self.prompt(args).await?;
592                Ok(AgentResponse::PromptResponse(response))
593            }
594            ClientRequest::SetSessionModeRequest(args) => {
595                let response = self.set_session_mode(args).await?;
596                Ok(AgentResponse::SetSessionModeResponse(response))
597            }
598            #[cfg(feature = "unstable")]
599            ClientRequest::SetSessionModelRequest(args) => {
600                let response = self.set_session_model(args).await?;
601                Ok(AgentResponse::SetSessionModelResponse(response))
602            }
603            ClientRequest::ExtMethodRequest(args) => {
604                let response = self.ext_method(args).await?;
605                Ok(AgentResponse::ExtMethodResponse(response))
606            }
607        }
608    }
609
610    async fn handle_notification(&self, notification: ClientNotification) -> Result<(), Error> {
611        match notification {
612            ClientNotification::CancelNotification(args) => {
613                self.cancel(args).await?;
614            }
615            ClientNotification::ExtNotification(args) => {
616                self.ext_notification(args).await?;
617            }
618        }
619        Ok(())
620    }
621}