agent_client_protocol/
lib.rs

1use futures::{AsyncRead, AsyncWrite, future::LocalBoxFuture};
2use rpc::RpcConnection;
3
4mod agent;
5mod client;
6mod rpc;
7#[cfg(test)]
8mod rpc_tests;
9mod stream_broadcast;
10
11pub use agent::*;
12pub use agent_client_protocol_schema::*;
13pub use client::*;
14pub use rpc::*;
15pub use stream_broadcast::{
16    StreamMessage, StreamMessageContent, StreamMessageDirection, StreamReceiver,
17};
18
19// Client to Agent
20
21/// A client-side connection to an agent.
22///
23/// This struct provides the client's view of an ACP connection, allowing
24/// clients (such as code editors) to communicate with agents. It implements
25/// the [`Agent`] trait to provide methods for initializing sessions, sending
26/// prompts, and managing the agent lifecycle.
27///
28/// See protocol docs: [Client](https://agentclientprotocol.com/protocol/overview#client)
29#[derive(Debug)]
30pub struct ClientSideConnection {
31    conn: RpcConnection<ClientSide, AgentSide>,
32}
33
34impl ClientSideConnection {
35    /// Creates a new client-side connection to an agent.
36    ///
37    /// This establishes the communication channel between a client and agent
38    /// following the ACP specification.
39    ///
40    /// # Arguments
41    ///
42    /// * `client` - A handler that implements the [`Client`] trait to process incoming agent requests
43    /// * `outgoing_bytes` - The stream for sending data to the agent (typically stdout)
44    /// * `incoming_bytes` - The stream for receiving data from the agent (typically stdin)
45    /// * `spawn` - A function to spawn async tasks (e.g., `tokio::spawn`)
46    ///
47    /// # Returns
48    ///
49    /// Returns a tuple containing:
50    /// - The connection instance for making requests to the agent
51    /// - An I/O future that must be spawned to handle the underlying communication
52    ///
53    /// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
54    pub fn new(
55        client: impl MessageHandler<ClientSide> + 'static,
56        outgoing_bytes: impl Unpin + AsyncWrite,
57        incoming_bytes: impl Unpin + AsyncRead,
58        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
59    ) -> (Self, impl Future<Output = Result<()>>) {
60        let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
61        (Self { conn }, io_task)
62    }
63
64    /// Subscribe to receive stream updates from the agent.
65    ///
66    /// This allows the client to receive real-time notifications about
67    /// agent activities, such as tool calls, content updates, and progress reports.
68    ///
69    /// # Returns
70    ///
71    /// A [`StreamReceiver`] that can be used to receive stream messages.
72    pub fn subscribe(&self) -> StreamReceiver {
73        self.conn.subscribe()
74    }
75}
76
77#[async_trait::async_trait(?Send)]
78impl Agent for ClientSideConnection {
79    async fn initialize(&self, args: InitializeRequest) -> Result<InitializeResponse> {
80        self.conn
81            .request(
82                AGENT_METHOD_NAMES.initialize,
83                Some(ClientRequest::InitializeRequest(args)),
84            )
85            .await
86    }
87
88    async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse> {
89        self.conn
90            .request::<Option<_>>(
91                AGENT_METHOD_NAMES.authenticate,
92                Some(ClientRequest::AuthenticateRequest(args)),
93            )
94            .await
95            .map(Option::unwrap_or_default)
96    }
97
98    async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse> {
99        self.conn
100            .request(
101                AGENT_METHOD_NAMES.session_new,
102                Some(ClientRequest::NewSessionRequest(args)),
103            )
104            .await
105    }
106
107    async fn load_session(&self, args: LoadSessionRequest) -> Result<LoadSessionResponse> {
108        self.conn
109            .request::<Option<_>>(
110                AGENT_METHOD_NAMES.session_load,
111                Some(ClientRequest::LoadSessionRequest(args)),
112            )
113            .await
114            .map(Option::unwrap_or_default)
115    }
116
117    async fn set_session_mode(
118        &self,
119        args: SetSessionModeRequest,
120    ) -> Result<SetSessionModeResponse> {
121        self.conn
122            .request(
123                AGENT_METHOD_NAMES.session_set_mode,
124                Some(ClientRequest::SetSessionModeRequest(args)),
125            )
126            .await
127    }
128
129    async fn prompt(&self, args: PromptRequest) -> Result<PromptResponse> {
130        self.conn
131            .request(
132                AGENT_METHOD_NAMES.session_prompt,
133                Some(ClientRequest::PromptRequest(args)),
134            )
135            .await
136    }
137
138    async fn cancel(&self, args: CancelNotification) -> Result<()> {
139        self.conn.notify(
140            AGENT_METHOD_NAMES.session_cancel,
141            Some(ClientNotification::CancelNotification(args)),
142        )
143    }
144
145    #[cfg(feature = "unstable")]
146    async fn set_session_model(
147        &self,
148        args: SetSessionModelRequest,
149    ) -> Result<SetSessionModelResponse> {
150        self.conn
151            .request(
152                AGENT_METHOD_NAMES.session_set_model,
153                Some(ClientRequest::SetSessionModelRequest(args)),
154            )
155            .await
156    }
157
158    async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
159        self.conn
160            .request(
161                format!("_{}", args.method),
162                Some(ClientRequest::ExtMethodRequest(args)),
163            )
164            .await
165    }
166
167    async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
168        self.conn.notify(
169            format!("_{}", args.method),
170            Some(ClientNotification::ExtNotification(args)),
171        )
172    }
173}
174
175/// Marker type representing the client side of an ACP connection.
176///
177/// This type is used by the RPC layer to determine which messages
178/// are incoming vs outgoing from the client's perspective.
179///
180/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
181#[derive(Clone, Debug)]
182pub struct ClientSide;
183
184impl Side for ClientSide {
185    type InNotification = AgentNotification;
186    type InRequest = AgentRequest;
187    type OutResponse = ClientResponse;
188
189    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
190        let params = params.ok_or_else(Error::invalid_params)?;
191
192        match method {
193            m if m == CLIENT_METHOD_NAMES.session_request_permission => {
194                serde_json::from_str(params.get())
195                    .map(AgentRequest::RequestPermissionRequest)
196                    .map_err(Into::into)
197            }
198            m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
199                .map(AgentRequest::WriteTextFileRequest)
200                .map_err(Into::into),
201            m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
202                .map(AgentRequest::ReadTextFileRequest)
203                .map_err(Into::into),
204            m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
205                .map(AgentRequest::CreateTerminalRequest)
206                .map_err(Into::into),
207            m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
208                .map(AgentRequest::TerminalOutputRequest)
209                .map_err(Into::into),
210            m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
211                .map(AgentRequest::KillTerminalCommandRequest)
212                .map_err(Into::into),
213            m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
214                .map(AgentRequest::ReleaseTerminalRequest)
215                .map_err(Into::into),
216            m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
217                serde_json::from_str(params.get())
218                    .map(AgentRequest::WaitForTerminalExitRequest)
219                    .map_err(Into::into)
220            }
221            _ => {
222                if let Some(custom_method) = method.strip_prefix('_') {
223                    Ok(AgentRequest::ExtMethodRequest(ExtRequest::new(
224                        custom_method,
225                        params.to_owned().into(),
226                    )))
227                } else {
228                    Err(Error::method_not_found())
229                }
230            }
231        }
232    }
233
234    fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
235        let params = params.ok_or_else(Error::invalid_params)?;
236
237        match method {
238            m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
239                .map(AgentNotification::SessionNotification)
240                .map_err(Into::into),
241            _ => {
242                if let Some(custom_method) = method.strip_prefix('_') {
243                    Ok(AgentNotification::ExtNotification(ExtNotification::new(
244                        custom_method,
245                        RawValue::from_string(params.get().to_string())?.into(),
246                    )))
247                } else {
248                    Err(Error::method_not_found())
249                }
250            }
251        }
252    }
253}
254
255impl<T: Client> MessageHandler<ClientSide> for T {
256    async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse> {
257        match request {
258            AgentRequest::RequestPermissionRequest(args) => {
259                let response = self.request_permission(args).await?;
260                Ok(ClientResponse::RequestPermissionResponse(response))
261            }
262            AgentRequest::WriteTextFileRequest(args) => {
263                let response = self.write_text_file(args).await?;
264                Ok(ClientResponse::WriteTextFileResponse(response))
265            }
266            AgentRequest::ReadTextFileRequest(args) => {
267                let response = self.read_text_file(args).await?;
268                Ok(ClientResponse::ReadTextFileResponse(response))
269            }
270            AgentRequest::CreateTerminalRequest(args) => {
271                let response = self.create_terminal(args).await?;
272                Ok(ClientResponse::CreateTerminalResponse(response))
273            }
274            AgentRequest::TerminalOutputRequest(args) => {
275                let response = self.terminal_output(args).await?;
276                Ok(ClientResponse::TerminalOutputResponse(response))
277            }
278            AgentRequest::ReleaseTerminalRequest(args) => {
279                let response = self.release_terminal(args).await?;
280                Ok(ClientResponse::ReleaseTerminalResponse(response))
281            }
282            AgentRequest::WaitForTerminalExitRequest(args) => {
283                let response = self.wait_for_terminal_exit(args).await?;
284                Ok(ClientResponse::WaitForTerminalExitResponse(response))
285            }
286            AgentRequest::KillTerminalCommandRequest(args) => {
287                let response = self.kill_terminal_command(args).await?;
288                Ok(ClientResponse::KillTerminalResponse(response))
289            }
290            AgentRequest::ExtMethodRequest(args) => {
291                let response = self.ext_method(args).await?;
292                Ok(ClientResponse::ExtMethodResponse(response))
293            }
294            _ => Err(Error::method_not_found()),
295        }
296    }
297
298    async fn handle_notification(&self, notification: AgentNotification) -> Result<()> {
299        match notification {
300            AgentNotification::SessionNotification(args) => {
301                self.session_notification(args).await?;
302            }
303            AgentNotification::ExtNotification(args) => {
304                self.ext_notification(args).await?;
305            }
306            // Ignore unknown notifications
307            _ => {}
308        }
309        Ok(())
310    }
311}
312
313// Agent to Client
314
315/// An agent-side connection to a client.
316///
317/// This struct provides the agent's view of an ACP connection, allowing
318/// agents to communicate with clients. It implements the [`Client`] trait
319/// to provide methods for requesting permissions, accessing the file system,
320/// and sending session updates.
321///
322/// See protocol docs: [Agent](https://agentclientprotocol.com/protocol/overview#agent)
323#[derive(Debug)]
324pub struct AgentSideConnection {
325    conn: RpcConnection<AgentSide, ClientSide>,
326}
327
328impl AgentSideConnection {
329    /// Creates a new agent-side connection to a client.
330    ///
331    /// This establishes the communication channel from the agent's perspective
332    /// following the ACP specification.
333    ///
334    /// # Arguments
335    ///
336    /// * `agent` - A handler that implements the [`Agent`] trait to process incoming client requests
337    /// * `outgoing_bytes` - The stream for sending data to the client (typically stdout)
338    /// * `incoming_bytes` - The stream for receiving data from the client (typically stdin)
339    /// * `spawn` - A function to spawn async tasks (e.g., `tokio::spawn`)
340    ///
341    /// # Returns
342    ///
343    /// Returns a tuple containing:
344    /// - The connection instance for making requests to the client
345    /// - An I/O future that must be spawned to handle the underlying communication
346    ///
347    /// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
348    pub fn new(
349        agent: impl MessageHandler<AgentSide> + 'static,
350        outgoing_bytes: impl Unpin + AsyncWrite,
351        incoming_bytes: impl Unpin + AsyncRead,
352        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
353    ) -> (Self, impl Future<Output = Result<()>>) {
354        let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
355        (Self { conn }, io_task)
356    }
357
358    /// Subscribe to receive stream updates from the client.
359    ///
360    /// This allows the agent to receive real-time notifications about
361    /// client activities and cancellation requests.
362    ///
363    /// # Returns
364    ///
365    /// A [`StreamReceiver`] that can be used to receive stream messages.
366    pub fn subscribe(&self) -> StreamReceiver {
367        self.conn.subscribe()
368    }
369}
370
371#[async_trait::async_trait(?Send)]
372impl Client for AgentSideConnection {
373    async fn request_permission(
374        &self,
375        args: RequestPermissionRequest,
376    ) -> Result<RequestPermissionResponse> {
377        self.conn
378            .request(
379                CLIENT_METHOD_NAMES.session_request_permission,
380                Some(AgentRequest::RequestPermissionRequest(args)),
381            )
382            .await
383    }
384
385    async fn write_text_file(&self, args: WriteTextFileRequest) -> Result<WriteTextFileResponse> {
386        self.conn
387            .request::<Option<_>>(
388                CLIENT_METHOD_NAMES.fs_write_text_file,
389                Some(AgentRequest::WriteTextFileRequest(args)),
390            )
391            .await
392            .map(Option::unwrap_or_default)
393    }
394
395    async fn read_text_file(&self, args: ReadTextFileRequest) -> Result<ReadTextFileResponse> {
396        self.conn
397            .request(
398                CLIENT_METHOD_NAMES.fs_read_text_file,
399                Some(AgentRequest::ReadTextFileRequest(args)),
400            )
401            .await
402    }
403
404    async fn create_terminal(&self, args: CreateTerminalRequest) -> Result<CreateTerminalResponse> {
405        self.conn
406            .request(
407                CLIENT_METHOD_NAMES.terminal_create,
408                Some(AgentRequest::CreateTerminalRequest(args)),
409            )
410            .await
411    }
412
413    async fn terminal_output(&self, args: TerminalOutputRequest) -> Result<TerminalOutputResponse> {
414        self.conn
415            .request(
416                CLIENT_METHOD_NAMES.terminal_output,
417                Some(AgentRequest::TerminalOutputRequest(args)),
418            )
419            .await
420    }
421
422    async fn release_terminal(
423        &self,
424        args: ReleaseTerminalRequest,
425    ) -> Result<ReleaseTerminalResponse> {
426        self.conn
427            .request::<Option<_>>(
428                CLIENT_METHOD_NAMES.terminal_release,
429                Some(AgentRequest::ReleaseTerminalRequest(args)),
430            )
431            .await
432            .map(Option::unwrap_or_default)
433    }
434
435    async fn wait_for_terminal_exit(
436        &self,
437        args: WaitForTerminalExitRequest,
438    ) -> Result<WaitForTerminalExitResponse> {
439        self.conn
440            .request(
441                CLIENT_METHOD_NAMES.terminal_wait_for_exit,
442                Some(AgentRequest::WaitForTerminalExitRequest(args)),
443            )
444            .await
445    }
446
447    async fn kill_terminal_command(
448        &self,
449        args: KillTerminalCommandRequest,
450    ) -> Result<KillTerminalCommandResponse> {
451        self.conn
452            .request::<Option<_>>(
453                CLIENT_METHOD_NAMES.terminal_kill,
454                Some(AgentRequest::KillTerminalCommandRequest(args)),
455            )
456            .await
457            .map(Option::unwrap_or_default)
458    }
459
460    async fn session_notification(&self, args: SessionNotification) -> Result<()> {
461        self.conn.notify(
462            CLIENT_METHOD_NAMES.session_update,
463            Some(AgentNotification::SessionNotification(args)),
464        )
465    }
466
467    async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
468        self.conn
469            .request(
470                format!("_{}", args.method),
471                Some(AgentRequest::ExtMethodRequest(args)),
472            )
473            .await
474    }
475
476    async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
477        self.conn.notify(
478            format!("_{}", args.method),
479            Some(AgentNotification::ExtNotification(args)),
480        )
481    }
482}
483
484/// Marker type representing the agent side of an ACP connection.
485///
486/// This type is used by the RPC layer to determine which messages
487/// are incoming vs outgoing from the agent's perspective.
488///
489/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
490#[derive(Clone, Debug)]
491pub struct AgentSide;
492
493impl Side for AgentSide {
494    type InRequest = ClientRequest;
495    type InNotification = ClientNotification;
496    type OutResponse = AgentResponse;
497
498    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
499        let params = params.ok_or_else(Error::invalid_params)?;
500
501        match method {
502            m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
503                .map(ClientRequest::InitializeRequest)
504                .map_err(Into::into),
505            m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
506                .map(ClientRequest::AuthenticateRequest)
507                .map_err(Into::into),
508            m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
509                .map(ClientRequest::NewSessionRequest)
510                .map_err(Into::into),
511            m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
512                .map(ClientRequest::LoadSessionRequest)
513                .map_err(Into::into),
514            m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
515                .map(ClientRequest::SetSessionModeRequest)
516                .map_err(Into::into),
517            #[cfg(feature = "unstable")]
518            m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
519                .map(ClientRequest::SetSessionModelRequest)
520                .map_err(Into::into),
521            m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
522                .map(ClientRequest::PromptRequest)
523                .map_err(Into::into),
524            _ => {
525                if let Some(custom_method) = method.strip_prefix('_') {
526                    Ok(ClientRequest::ExtMethodRequest(ExtRequest::new(
527                        custom_method,
528                        params.to_owned().into(),
529                    )))
530                } else {
531                    Err(Error::method_not_found())
532                }
533            }
534        }
535    }
536
537    fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
538        let params = params.ok_or_else(Error::invalid_params)?;
539
540        match method {
541            m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
542                .map(ClientNotification::CancelNotification)
543                .map_err(Into::into),
544            _ => {
545                if let Some(custom_method) = method.strip_prefix('_') {
546                    Ok(ClientNotification::ExtNotification(ExtNotification::new(
547                        custom_method,
548                        RawValue::from_string(params.get().to_string())?.into(),
549                    )))
550                } else {
551                    Err(Error::method_not_found())
552                }
553            }
554        }
555    }
556}
557
558impl<T: Agent> MessageHandler<AgentSide> for T {
559    async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse> {
560        match request {
561            ClientRequest::InitializeRequest(args) => {
562                let response = self.initialize(args).await?;
563                Ok(AgentResponse::InitializeResponse(response))
564            }
565            ClientRequest::AuthenticateRequest(args) => {
566                let response = self.authenticate(args).await?;
567                Ok(AgentResponse::AuthenticateResponse(response))
568            }
569            ClientRequest::NewSessionRequest(args) => {
570                let response = self.new_session(args).await?;
571                Ok(AgentResponse::NewSessionResponse(response))
572            }
573            ClientRequest::LoadSessionRequest(args) => {
574                let response = self.load_session(args).await?;
575                Ok(AgentResponse::LoadSessionResponse(response))
576            }
577            ClientRequest::PromptRequest(args) => {
578                let response = self.prompt(args).await?;
579                Ok(AgentResponse::PromptResponse(response))
580            }
581            ClientRequest::SetSessionModeRequest(args) => {
582                let response = self.set_session_mode(args).await?;
583                Ok(AgentResponse::SetSessionModeResponse(response))
584            }
585            #[cfg(feature = "unstable")]
586            ClientRequest::SetSessionModelRequest(args) => {
587                let response = self.set_session_model(args).await?;
588                Ok(AgentResponse::SetSessionModelResponse(response))
589            }
590            ClientRequest::ExtMethodRequest(args) => {
591                let response = self.ext_method(args).await?;
592                Ok(AgentResponse::ExtMethodResponse(response))
593            }
594            _ => Err(Error::method_not_found()),
595        }
596    }
597
598    async fn handle_notification(&self, notification: ClientNotification) -> Result<()> {
599        match notification {
600            ClientNotification::CancelNotification(args) => {
601                self.cancel(args).await?;
602            }
603            ClientNotification::ExtNotification(args) => {
604                self.ext_notification(args).await?;
605            }
606            // Ignore unknown notifications
607            _ => {}
608        }
609        Ok(())
610    }
611}