Skip to main content

agent_client_protocol/
lib.rs

1use futures::{AsyncRead, AsyncWrite, future::LocalBoxFuture};
2use rpc::RpcConnection;
3
4mod agent;
5mod client;
6mod rpc;
7#[cfg(test)]
8mod rpc_tests;
9mod stream_broadcast;
10
11pub use agent::*;
12pub use agent_client_protocol_schema::*;
13pub use client::*;
14pub use rpc::*;
15pub use stream_broadcast::{
16    StreamMessage, StreamMessageContent, StreamMessageDirection, StreamReceiver,
17};
18
19// Client to Agent
20
21/// A client-side connection to an agent.
22///
23/// This struct provides the client's view of an ACP connection, allowing
24/// clients (such as code editors) to communicate with agents. It implements
25/// the [`Agent`] trait to provide methods for initializing sessions, sending
26/// prompts, and managing the agent lifecycle.
27///
28/// See protocol docs: [Client](https://agentclientprotocol.com/protocol/overview#client)
29#[derive(Debug)]
30pub struct ClientSideConnection {
31    conn: RpcConnection<ClientSide, AgentSide>,
32}
33
34impl ClientSideConnection {
35    /// Creates a new client-side connection to an agent.
36    ///
37    /// This establishes the communication channel between a client and agent
38    /// following the ACP specification.
39    ///
40    /// # Arguments
41    ///
42    /// * `client` - A handler that implements the [`Client`] trait to process incoming agent requests
43    /// * `outgoing_bytes` - The stream for sending data to the agent (typically stdout)
44    /// * `incoming_bytes` - The stream for receiving data from the agent (typically stdin)
45    /// * `spawn` - A function to spawn async tasks (e.g., `tokio::spawn`)
46    ///
47    /// # Returns
48    ///
49    /// Returns a tuple containing:
50    /// - The connection instance for making requests to the agent
51    /// - An I/O future that must be spawned to handle the underlying communication
52    ///
53    /// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
54    pub fn new(
55        client: impl MessageHandler<ClientSide> + 'static,
56        outgoing_bytes: impl Unpin + AsyncWrite,
57        incoming_bytes: impl Unpin + AsyncRead,
58        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
59    ) -> (Self, impl Future<Output = Result<()>>) {
60        let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
61        (Self { conn }, io_task)
62    }
63
64    /// Subscribe to receive stream updates from the agent.
65    ///
66    /// This allows the client to receive real-time notifications about
67    /// agent activities, such as tool calls, content updates, and progress reports.
68    ///
69    /// # Returns
70    ///
71    /// A [`StreamReceiver`] that can be used to receive stream messages.
72    pub fn subscribe(&self) -> StreamReceiver {
73        self.conn.subscribe()
74    }
75}
76
77#[async_trait::async_trait(?Send)]
78impl Agent for ClientSideConnection {
79    async fn initialize(&self, args: InitializeRequest) -> Result<InitializeResponse> {
80        self.conn
81            .request(
82                AGENT_METHOD_NAMES.initialize,
83                Some(ClientRequest::InitializeRequest(args)),
84            )
85            .await
86    }
87
88    async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse> {
89        self.conn
90            .request::<Option<_>>(
91                AGENT_METHOD_NAMES.authenticate,
92                Some(ClientRequest::AuthenticateRequest(args)),
93            )
94            .await
95            .map(Option::unwrap_or_default)
96    }
97
98    #[cfg(feature = "unstable_logout")]
99    async fn logout(&self, args: LogoutRequest) -> Result<LogoutResponse> {
100        self.conn
101            .request::<Option<_>>(
102                AGENT_METHOD_NAMES.logout,
103                Some(ClientRequest::LogoutRequest(args)),
104            )
105            .await
106            .map(Option::unwrap_or_default)
107    }
108
109    async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse> {
110        self.conn
111            .request(
112                AGENT_METHOD_NAMES.session_new,
113                Some(ClientRequest::NewSessionRequest(args)),
114            )
115            .await
116    }
117
118    async fn load_session(&self, args: LoadSessionRequest) -> Result<LoadSessionResponse> {
119        self.conn
120            .request::<Option<_>>(
121                AGENT_METHOD_NAMES.session_load,
122                Some(ClientRequest::LoadSessionRequest(args)),
123            )
124            .await
125            .map(Option::unwrap_or_default)
126    }
127
128    async fn set_session_mode(
129        &self,
130        args: SetSessionModeRequest,
131    ) -> Result<SetSessionModeResponse> {
132        self.conn
133            .request(
134                AGENT_METHOD_NAMES.session_set_mode,
135                Some(ClientRequest::SetSessionModeRequest(args)),
136            )
137            .await
138    }
139
140    async fn prompt(&self, args: PromptRequest) -> Result<PromptResponse> {
141        self.conn
142            .request(
143                AGENT_METHOD_NAMES.session_prompt,
144                Some(ClientRequest::PromptRequest(args)),
145            )
146            .await
147    }
148
149    async fn cancel(&self, args: CancelNotification) -> Result<()> {
150        self.conn.notify(
151            AGENT_METHOD_NAMES.session_cancel,
152            Some(ClientNotification::CancelNotification(args)),
153        )
154    }
155
156    #[cfg(feature = "unstable_session_model")]
157    async fn set_session_model(
158        &self,
159        args: SetSessionModelRequest,
160    ) -> Result<SetSessionModelResponse> {
161        self.conn
162            .request(
163                AGENT_METHOD_NAMES.session_set_model,
164                Some(ClientRequest::SetSessionModelRequest(args)),
165            )
166            .await
167    }
168
169    async fn list_sessions(&self, args: ListSessionsRequest) -> Result<ListSessionsResponse> {
170        self.conn
171            .request(
172                AGENT_METHOD_NAMES.session_list,
173                Some(ClientRequest::ListSessionsRequest(args)),
174            )
175            .await
176    }
177
178    #[cfg(feature = "unstable_session_fork")]
179    async fn fork_session(&self, args: ForkSessionRequest) -> Result<ForkSessionResponse> {
180        self.conn
181            .request(
182                AGENT_METHOD_NAMES.session_fork,
183                Some(ClientRequest::ForkSessionRequest(args)),
184            )
185            .await
186    }
187
188    #[cfg(feature = "unstable_session_resume")]
189    async fn resume_session(&self, args: ResumeSessionRequest) -> Result<ResumeSessionResponse> {
190        self.conn
191            .request(
192                AGENT_METHOD_NAMES.session_resume,
193                Some(ClientRequest::ResumeSessionRequest(args)),
194            )
195            .await
196    }
197
198    #[cfg(feature = "unstable_session_close")]
199    async fn close_session(&self, args: CloseSessionRequest) -> Result<CloseSessionResponse> {
200        self.conn
201            .request::<Option<_>>(
202                AGENT_METHOD_NAMES.session_close,
203                Some(ClientRequest::CloseSessionRequest(args)),
204            )
205            .await
206            .map(Option::unwrap_or_default)
207    }
208
209    async fn set_session_config_option(
210        &self,
211        args: SetSessionConfigOptionRequest,
212    ) -> Result<SetSessionConfigOptionResponse> {
213        self.conn
214            .request(
215                AGENT_METHOD_NAMES.session_set_config_option,
216                Some(ClientRequest::SetSessionConfigOptionRequest(args)),
217            )
218            .await
219    }
220
221    async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
222        self.conn
223            .request(
224                format!("_{}", args.method),
225                Some(ClientRequest::ExtMethodRequest(args)),
226            )
227            .await
228    }
229
230    async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
231        self.conn.notify(
232            format!("_{}", args.method),
233            Some(ClientNotification::ExtNotification(args)),
234        )
235    }
236}
237
238/// Marker type representing the client side of an ACP connection.
239///
240/// This type is used by the RPC layer to determine which messages
241/// are incoming vs outgoing from the client's perspective.
242///
243/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
244#[derive(Clone, Debug)]
245pub struct ClientSide;
246
247impl Side for ClientSide {
248    type InNotification = AgentNotification;
249    type InRequest = AgentRequest;
250    type OutResponse = ClientResponse;
251
252    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
253        let params = params.ok_or_else(Error::invalid_params)?;
254
255        match method {
256            m if m == CLIENT_METHOD_NAMES.session_request_permission => {
257                serde_json::from_str(params.get())
258                    .map(AgentRequest::RequestPermissionRequest)
259                    .map_err(Into::into)
260            }
261            m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
262                .map(AgentRequest::WriteTextFileRequest)
263                .map_err(Into::into),
264            m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
265                .map(AgentRequest::ReadTextFileRequest)
266                .map_err(Into::into),
267            m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
268                .map(AgentRequest::CreateTerminalRequest)
269                .map_err(Into::into),
270            m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
271                .map(AgentRequest::TerminalOutputRequest)
272                .map_err(Into::into),
273            m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
274                .map(AgentRequest::KillTerminalRequest)
275                .map_err(Into::into),
276            m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
277                .map(AgentRequest::ReleaseTerminalRequest)
278                .map_err(Into::into),
279            m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
280                serde_json::from_str(params.get())
281                    .map(AgentRequest::WaitForTerminalExitRequest)
282                    .map_err(Into::into)
283            }
284            _ => {
285                if let Some(custom_method) = method.strip_prefix('_') {
286                    Ok(AgentRequest::ExtMethodRequest(ExtRequest::new(
287                        custom_method,
288                        params.to_owned().into(),
289                    )))
290                } else {
291                    Err(Error::method_not_found())
292                }
293            }
294        }
295    }
296
297    fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
298        let params = params.ok_or_else(Error::invalid_params)?;
299
300        match method {
301            m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
302                .map(AgentNotification::SessionNotification)
303                .map_err(Into::into),
304            _ => {
305                if let Some(custom_method) = method.strip_prefix('_') {
306                    Ok(AgentNotification::ExtNotification(ExtNotification::new(
307                        custom_method,
308                        RawValue::from_string(params.get().to_string())?.into(),
309                    )))
310                } else {
311                    Err(Error::method_not_found())
312                }
313            }
314        }
315    }
316}
317
318impl<T: Client> MessageHandler<ClientSide> for T {
319    async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse> {
320        match request {
321            AgentRequest::RequestPermissionRequest(args) => {
322                let response = self.request_permission(args).await?;
323                Ok(ClientResponse::RequestPermissionResponse(response))
324            }
325            AgentRequest::WriteTextFileRequest(args) => {
326                let response = self.write_text_file(args).await?;
327                Ok(ClientResponse::WriteTextFileResponse(response))
328            }
329            AgentRequest::ReadTextFileRequest(args) => {
330                let response = self.read_text_file(args).await?;
331                Ok(ClientResponse::ReadTextFileResponse(response))
332            }
333            AgentRequest::CreateTerminalRequest(args) => {
334                let response = self.create_terminal(args).await?;
335                Ok(ClientResponse::CreateTerminalResponse(response))
336            }
337            AgentRequest::TerminalOutputRequest(args) => {
338                let response = self.terminal_output(args).await?;
339                Ok(ClientResponse::TerminalOutputResponse(response))
340            }
341            AgentRequest::ReleaseTerminalRequest(args) => {
342                let response = self.release_terminal(args).await?;
343                Ok(ClientResponse::ReleaseTerminalResponse(response))
344            }
345            AgentRequest::WaitForTerminalExitRequest(args) => {
346                let response = self.wait_for_terminal_exit(args).await?;
347                Ok(ClientResponse::WaitForTerminalExitResponse(response))
348            }
349            AgentRequest::KillTerminalRequest(args) => {
350                let response = self.kill_terminal(args).await?;
351                Ok(ClientResponse::KillTerminalResponse(response))
352            }
353            AgentRequest::ExtMethodRequest(args) => {
354                let response = self.ext_method(args).await?;
355                Ok(ClientResponse::ExtMethodResponse(response))
356            }
357            _ => Err(Error::method_not_found()),
358        }
359    }
360
361    async fn handle_notification(&self, notification: AgentNotification) -> Result<()> {
362        match notification {
363            AgentNotification::SessionNotification(args) => {
364                self.session_notification(args).await?;
365            }
366            AgentNotification::ExtNotification(args) => {
367                self.ext_notification(args).await?;
368            }
369            // Ignore unknown notifications
370            _ => {}
371        }
372        Ok(())
373    }
374}
375
376// Agent to Client
377
378/// An agent-side connection to a client.
379///
380/// This struct provides the agent's view of an ACP connection, allowing
381/// agents to communicate with clients. It implements the [`Client`] trait
382/// to provide methods for requesting permissions, accessing the file system,
383/// and sending session updates.
384///
385/// See protocol docs: [Agent](https://agentclientprotocol.com/protocol/overview#agent)
386#[derive(Debug)]
387pub struct AgentSideConnection {
388    conn: RpcConnection<AgentSide, ClientSide>,
389}
390
391impl AgentSideConnection {
392    /// Creates a new agent-side connection to a client.
393    ///
394    /// This establishes the communication channel from the agent's perspective
395    /// following the ACP specification.
396    ///
397    /// # Arguments
398    ///
399    /// * `agent` - A handler that implements the [`Agent`] trait to process incoming client requests
400    /// * `outgoing_bytes` - The stream for sending data to the client (typically stdout)
401    /// * `incoming_bytes` - The stream for receiving data from the client (typically stdin)
402    /// * `spawn` - A function to spawn async tasks (e.g., `tokio::spawn`)
403    ///
404    /// # Returns
405    ///
406    /// Returns a tuple containing:
407    /// - The connection instance for making requests to the client
408    /// - An I/O future that must be spawned to handle the underlying communication
409    ///
410    /// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
411    pub fn new(
412        agent: impl MessageHandler<AgentSide> + 'static,
413        outgoing_bytes: impl Unpin + AsyncWrite,
414        incoming_bytes: impl Unpin + AsyncRead,
415        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
416    ) -> (Self, impl Future<Output = Result<()>>) {
417        let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
418        (Self { conn }, io_task)
419    }
420
421    /// Subscribe to receive stream updates from the client.
422    ///
423    /// This allows the agent to receive real-time notifications about
424    /// client activities and cancellation requests.
425    ///
426    /// # Returns
427    ///
428    /// A [`StreamReceiver`] that can be used to receive stream messages.
429    pub fn subscribe(&self) -> StreamReceiver {
430        self.conn.subscribe()
431    }
432}
433
434#[async_trait::async_trait(?Send)]
435impl Client for AgentSideConnection {
436    async fn request_permission(
437        &self,
438        args: RequestPermissionRequest,
439    ) -> Result<RequestPermissionResponse> {
440        self.conn
441            .request(
442                CLIENT_METHOD_NAMES.session_request_permission,
443                Some(AgentRequest::RequestPermissionRequest(args)),
444            )
445            .await
446    }
447
448    async fn write_text_file(&self, args: WriteTextFileRequest) -> Result<WriteTextFileResponse> {
449        self.conn
450            .request::<Option<_>>(
451                CLIENT_METHOD_NAMES.fs_write_text_file,
452                Some(AgentRequest::WriteTextFileRequest(args)),
453            )
454            .await
455            .map(Option::unwrap_or_default)
456    }
457
458    async fn read_text_file(&self, args: ReadTextFileRequest) -> Result<ReadTextFileResponse> {
459        self.conn
460            .request(
461                CLIENT_METHOD_NAMES.fs_read_text_file,
462                Some(AgentRequest::ReadTextFileRequest(args)),
463            )
464            .await
465    }
466
467    async fn create_terminal(&self, args: CreateTerminalRequest) -> Result<CreateTerminalResponse> {
468        self.conn
469            .request(
470                CLIENT_METHOD_NAMES.terminal_create,
471                Some(AgentRequest::CreateTerminalRequest(args)),
472            )
473            .await
474    }
475
476    async fn terminal_output(&self, args: TerminalOutputRequest) -> Result<TerminalOutputResponse> {
477        self.conn
478            .request(
479                CLIENT_METHOD_NAMES.terminal_output,
480                Some(AgentRequest::TerminalOutputRequest(args)),
481            )
482            .await
483    }
484
485    async fn release_terminal(
486        &self,
487        args: ReleaseTerminalRequest,
488    ) -> Result<ReleaseTerminalResponse> {
489        self.conn
490            .request::<Option<_>>(
491                CLIENT_METHOD_NAMES.terminal_release,
492                Some(AgentRequest::ReleaseTerminalRequest(args)),
493            )
494            .await
495            .map(Option::unwrap_or_default)
496    }
497
498    async fn wait_for_terminal_exit(
499        &self,
500        args: WaitForTerminalExitRequest,
501    ) -> Result<WaitForTerminalExitResponse> {
502        self.conn
503            .request(
504                CLIENT_METHOD_NAMES.terminal_wait_for_exit,
505                Some(AgentRequest::WaitForTerminalExitRequest(args)),
506            )
507            .await
508    }
509
510    async fn kill_terminal(&self, args: KillTerminalRequest) -> Result<KillTerminalResponse> {
511        self.conn
512            .request::<Option<_>>(
513                CLIENT_METHOD_NAMES.terminal_kill,
514                Some(AgentRequest::KillTerminalRequest(args)),
515            )
516            .await
517            .map(Option::unwrap_or_default)
518    }
519
520    async fn session_notification(&self, args: SessionNotification) -> Result<()> {
521        self.conn.notify(
522            CLIENT_METHOD_NAMES.session_update,
523            Some(AgentNotification::SessionNotification(args)),
524        )
525    }
526
527    async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
528        self.conn
529            .request(
530                format!("_{}", args.method),
531                Some(AgentRequest::ExtMethodRequest(args)),
532            )
533            .await
534    }
535
536    async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
537        self.conn.notify(
538            format!("_{}", args.method),
539            Some(AgentNotification::ExtNotification(args)),
540        )
541    }
542}
543
544/// Marker type representing the agent side of an ACP connection.
545///
546/// This type is used by the RPC layer to determine which messages
547/// are incoming vs outgoing from the agent's perspective.
548///
549/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
550#[derive(Clone, Debug)]
551pub struct AgentSide;
552
553impl Side for AgentSide {
554    type InRequest = ClientRequest;
555    type InNotification = ClientNotification;
556    type OutResponse = AgentResponse;
557
558    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
559        let params = params.ok_or_else(Error::invalid_params)?;
560
561        match method {
562            m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
563                .map(ClientRequest::InitializeRequest)
564                .map_err(Into::into),
565            m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
566                .map(ClientRequest::AuthenticateRequest)
567                .map_err(Into::into),
568            #[cfg(feature = "unstable_logout")]
569            m if m == AGENT_METHOD_NAMES.logout => serde_json::from_str(params.get())
570                .map(ClientRequest::LogoutRequest)
571                .map_err(Into::into),
572            m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
573                .map(ClientRequest::NewSessionRequest)
574                .map_err(Into::into),
575            m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
576                .map(ClientRequest::LoadSessionRequest)
577                .map_err(Into::into),
578            m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
579                .map(ClientRequest::SetSessionModeRequest)
580                .map_err(Into::into),
581            #[cfg(feature = "unstable_session_model")]
582            m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
583                .map(ClientRequest::SetSessionModelRequest)
584                .map_err(Into::into),
585            m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
586                .map(ClientRequest::ListSessionsRequest)
587                .map_err(Into::into),
588            #[cfg(feature = "unstable_session_fork")]
589            m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get())
590                .map(ClientRequest::ForkSessionRequest)
591                .map_err(Into::into),
592            #[cfg(feature = "unstable_session_resume")]
593            m if m == AGENT_METHOD_NAMES.session_resume => serde_json::from_str(params.get())
594                .map(ClientRequest::ResumeSessionRequest)
595                .map_err(Into::into),
596            #[cfg(feature = "unstable_session_close")]
597            m if m == AGENT_METHOD_NAMES.session_close => serde_json::from_str(params.get())
598                .map(ClientRequest::CloseSessionRequest)
599                .map_err(Into::into),
600            m if m == AGENT_METHOD_NAMES.session_set_config_option => {
601                serde_json::from_str(params.get())
602                    .map(ClientRequest::SetSessionConfigOptionRequest)
603                    .map_err(Into::into)
604            }
605            m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
606                .map(ClientRequest::PromptRequest)
607                .map_err(Into::into),
608            _ => {
609                if let Some(custom_method) = method.strip_prefix('_') {
610                    Ok(ClientRequest::ExtMethodRequest(ExtRequest::new(
611                        custom_method,
612                        params.to_owned().into(),
613                    )))
614                } else {
615                    Err(Error::method_not_found())
616                }
617            }
618        }
619    }
620
621    fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
622        let params = params.ok_or_else(Error::invalid_params)?;
623
624        match method {
625            m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
626                .map(ClientNotification::CancelNotification)
627                .map_err(Into::into),
628            _ => {
629                if let Some(custom_method) = method.strip_prefix('_') {
630                    Ok(ClientNotification::ExtNotification(ExtNotification::new(
631                        custom_method,
632                        RawValue::from_string(params.get().to_string())?.into(),
633                    )))
634                } else {
635                    Err(Error::method_not_found())
636                }
637            }
638        }
639    }
640}
641
642impl<T: Agent> MessageHandler<AgentSide> for T {
643    async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse> {
644        match request {
645            ClientRequest::InitializeRequest(args) => {
646                let response = self.initialize(args).await?;
647                Ok(AgentResponse::InitializeResponse(response))
648            }
649            ClientRequest::AuthenticateRequest(args) => {
650                let response = self.authenticate(args).await?;
651                Ok(AgentResponse::AuthenticateResponse(response))
652            }
653            #[cfg(feature = "unstable_logout")]
654            ClientRequest::LogoutRequest(args) => {
655                let response = self.logout(args).await?;
656                Ok(AgentResponse::LogoutResponse(response))
657            }
658            ClientRequest::NewSessionRequest(args) => {
659                let response = self.new_session(args).await?;
660                Ok(AgentResponse::NewSessionResponse(response))
661            }
662            ClientRequest::LoadSessionRequest(args) => {
663                let response = self.load_session(args).await?;
664                Ok(AgentResponse::LoadSessionResponse(response))
665            }
666            ClientRequest::PromptRequest(args) => {
667                let response = self.prompt(args).await?;
668                Ok(AgentResponse::PromptResponse(response))
669            }
670            ClientRequest::SetSessionModeRequest(args) => {
671                let response = self.set_session_mode(args).await?;
672                Ok(AgentResponse::SetSessionModeResponse(response))
673            }
674            #[cfg(feature = "unstable_session_model")]
675            ClientRequest::SetSessionModelRequest(args) => {
676                let response = self.set_session_model(args).await?;
677                Ok(AgentResponse::SetSessionModelResponse(response))
678            }
679            ClientRequest::ListSessionsRequest(args) => {
680                let response = self.list_sessions(args).await?;
681                Ok(AgentResponse::ListSessionsResponse(response))
682            }
683            #[cfg(feature = "unstable_session_fork")]
684            ClientRequest::ForkSessionRequest(args) => {
685                let response = self.fork_session(args).await?;
686                Ok(AgentResponse::ForkSessionResponse(response))
687            }
688            #[cfg(feature = "unstable_session_resume")]
689            ClientRequest::ResumeSessionRequest(args) => {
690                let response = self.resume_session(args).await?;
691                Ok(AgentResponse::ResumeSessionResponse(response))
692            }
693            #[cfg(feature = "unstable_session_close")]
694            ClientRequest::CloseSessionRequest(args) => {
695                let response = self.close_session(args).await?;
696                Ok(AgentResponse::CloseSessionResponse(response))
697            }
698            ClientRequest::SetSessionConfigOptionRequest(args) => {
699                let response = self.set_session_config_option(args).await?;
700                Ok(AgentResponse::SetSessionConfigOptionResponse(response))
701            }
702            ClientRequest::ExtMethodRequest(args) => {
703                let response = self.ext_method(args).await?;
704                Ok(AgentResponse::ExtMethodResponse(response))
705            }
706            _ => Err(Error::method_not_found()),
707        }
708    }
709
710    async fn handle_notification(&self, notification: ClientNotification) -> Result<()> {
711        match notification {
712            ClientNotification::CancelNotification(args) => {
713                self.cancel(args).await?;
714            }
715            ClientNotification::ExtNotification(args) => {
716                self.ext_notification(args).await?;
717            }
718            // Ignore unknown notifications
719            _ => {}
720        }
721        Ok(())
722    }
723}