1use futures::{AsyncRead, AsyncWrite, future::LocalBoxFuture};
2use rpc::RpcConnection;
3
4mod agent;
5mod client;
6mod rpc;
7#[cfg(test)]
8mod rpc_tests;
9mod stream_broadcast;
10
11pub use agent::*;
12pub use agent_client_protocol_schema::*;
13pub use client::*;
14pub use rpc::*;
15pub use stream_broadcast::{
16 StreamMessage, StreamMessageContent, StreamMessageDirection, StreamReceiver,
17};
18
19#[derive(Debug)]
30pub struct ClientSideConnection {
31 conn: RpcConnection<ClientSide, AgentSide>,
32}
33
34impl ClientSideConnection {
35 pub fn new(
55 client: impl MessageHandler<ClientSide> + 'static,
56 outgoing_bytes: impl Unpin + AsyncWrite,
57 incoming_bytes: impl Unpin + AsyncRead,
58 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
59 ) -> (Self, impl Future<Output = Result<()>>) {
60 let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
61 (Self { conn }, io_task)
62 }
63
64 pub fn subscribe(&self) -> StreamReceiver {
73 self.conn.subscribe()
74 }
75}
76
77#[async_trait::async_trait(?Send)]
78impl Agent for ClientSideConnection {
79 async fn initialize(&self, args: InitializeRequest) -> Result<InitializeResponse> {
80 self.conn
81 .request(
82 AGENT_METHOD_NAMES.initialize,
83 Some(ClientRequest::InitializeRequest(args)),
84 )
85 .await
86 }
87
88 async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse> {
89 self.conn
90 .request::<Option<_>>(
91 AGENT_METHOD_NAMES.authenticate,
92 Some(ClientRequest::AuthenticateRequest(args)),
93 )
94 .await
95 .map(Option::unwrap_or_default)
96 }
97
98 async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse> {
99 self.conn
100 .request(
101 AGENT_METHOD_NAMES.session_new,
102 Some(ClientRequest::NewSessionRequest(args)),
103 )
104 .await
105 }
106
107 async fn load_session(&self, args: LoadSessionRequest) -> Result<LoadSessionResponse> {
108 self.conn
109 .request::<Option<_>>(
110 AGENT_METHOD_NAMES.session_load,
111 Some(ClientRequest::LoadSessionRequest(args)),
112 )
113 .await
114 .map(Option::unwrap_or_default)
115 }
116
117 async fn set_session_mode(
118 &self,
119 args: SetSessionModeRequest,
120 ) -> Result<SetSessionModeResponse> {
121 self.conn
122 .request(
123 AGENT_METHOD_NAMES.session_set_mode,
124 Some(ClientRequest::SetSessionModeRequest(args)),
125 )
126 .await
127 }
128
129 async fn prompt(&self, args: PromptRequest) -> Result<PromptResponse> {
130 self.conn
131 .request(
132 AGENT_METHOD_NAMES.session_prompt,
133 Some(ClientRequest::PromptRequest(args)),
134 )
135 .await
136 }
137
138 async fn cancel(&self, args: CancelNotification) -> Result<()> {
139 self.conn.notify(
140 AGENT_METHOD_NAMES.session_cancel,
141 Some(ClientNotification::CancelNotification(args)),
142 )
143 }
144
145 #[cfg(feature = "unstable_session_model")]
146 async fn set_session_model(
147 &self,
148 args: SetSessionModelRequest,
149 ) -> Result<SetSessionModelResponse> {
150 self.conn
151 .request(
152 AGENT_METHOD_NAMES.session_set_model,
153 Some(ClientRequest::SetSessionModelRequest(args)),
154 )
155 .await
156 }
157
158 async fn list_sessions(&self, args: ListSessionsRequest) -> Result<ListSessionsResponse> {
159 self.conn
160 .request(
161 AGENT_METHOD_NAMES.session_list,
162 Some(ClientRequest::ListSessionsRequest(args)),
163 )
164 .await
165 }
166
167 #[cfg(feature = "unstable_session_fork")]
168 async fn fork_session(&self, args: ForkSessionRequest) -> Result<ForkSessionResponse> {
169 self.conn
170 .request(
171 AGENT_METHOD_NAMES.session_fork,
172 Some(ClientRequest::ForkSessionRequest(args)),
173 )
174 .await
175 }
176
177 #[cfg(feature = "unstable_session_resume")]
178 async fn resume_session(&self, args: ResumeSessionRequest) -> Result<ResumeSessionResponse> {
179 self.conn
180 .request(
181 AGENT_METHOD_NAMES.session_resume,
182 Some(ClientRequest::ResumeSessionRequest(args)),
183 )
184 .await
185 }
186
187 async fn set_session_config_option(
188 &self,
189 args: SetSessionConfigOptionRequest,
190 ) -> Result<SetSessionConfigOptionResponse> {
191 self.conn
192 .request(
193 AGENT_METHOD_NAMES.session_set_config_option,
194 Some(ClientRequest::SetSessionConfigOptionRequest(args)),
195 )
196 .await
197 }
198
199 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
200 self.conn
201 .request(
202 format!("_{}", args.method),
203 Some(ClientRequest::ExtMethodRequest(args)),
204 )
205 .await
206 }
207
208 async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
209 self.conn.notify(
210 format!("_{}", args.method),
211 Some(ClientNotification::ExtNotification(args)),
212 )
213 }
214}
215
216#[derive(Clone, Debug)]
223pub struct ClientSide;
224
225impl Side for ClientSide {
226 type InNotification = AgentNotification;
227 type InRequest = AgentRequest;
228 type OutResponse = ClientResponse;
229
230 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
231 let params = params.ok_or_else(Error::invalid_params)?;
232
233 match method {
234 m if m == CLIENT_METHOD_NAMES.session_request_permission => {
235 serde_json::from_str(params.get())
236 .map(AgentRequest::RequestPermissionRequest)
237 .map_err(Into::into)
238 }
239 m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
240 .map(AgentRequest::WriteTextFileRequest)
241 .map_err(Into::into),
242 m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
243 .map(AgentRequest::ReadTextFileRequest)
244 .map_err(Into::into),
245 m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
246 .map(AgentRequest::CreateTerminalRequest)
247 .map_err(Into::into),
248 m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
249 .map(AgentRequest::TerminalOutputRequest)
250 .map_err(Into::into),
251 m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
252 .map(AgentRequest::KillTerminalRequest)
253 .map_err(Into::into),
254 m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
255 .map(AgentRequest::ReleaseTerminalRequest)
256 .map_err(Into::into),
257 m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
258 serde_json::from_str(params.get())
259 .map(AgentRequest::WaitForTerminalExitRequest)
260 .map_err(Into::into)
261 }
262 _ => {
263 if let Some(custom_method) = method.strip_prefix('_') {
264 Ok(AgentRequest::ExtMethodRequest(ExtRequest::new(
265 custom_method,
266 params.to_owned().into(),
267 )))
268 } else {
269 Err(Error::method_not_found())
270 }
271 }
272 }
273 }
274
275 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
276 let params = params.ok_or_else(Error::invalid_params)?;
277
278 match method {
279 m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
280 .map(AgentNotification::SessionNotification)
281 .map_err(Into::into),
282 _ => {
283 if let Some(custom_method) = method.strip_prefix('_') {
284 Ok(AgentNotification::ExtNotification(ExtNotification::new(
285 custom_method,
286 RawValue::from_string(params.get().to_string())?.into(),
287 )))
288 } else {
289 Err(Error::method_not_found())
290 }
291 }
292 }
293 }
294}
295
296impl<T: Client> MessageHandler<ClientSide> for T {
297 async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse> {
298 match request {
299 AgentRequest::RequestPermissionRequest(args) => {
300 let response = self.request_permission(args).await?;
301 Ok(ClientResponse::RequestPermissionResponse(response))
302 }
303 AgentRequest::WriteTextFileRequest(args) => {
304 let response = self.write_text_file(args).await?;
305 Ok(ClientResponse::WriteTextFileResponse(response))
306 }
307 AgentRequest::ReadTextFileRequest(args) => {
308 let response = self.read_text_file(args).await?;
309 Ok(ClientResponse::ReadTextFileResponse(response))
310 }
311 AgentRequest::CreateTerminalRequest(args) => {
312 let response = self.create_terminal(args).await?;
313 Ok(ClientResponse::CreateTerminalResponse(response))
314 }
315 AgentRequest::TerminalOutputRequest(args) => {
316 let response = self.terminal_output(args).await?;
317 Ok(ClientResponse::TerminalOutputResponse(response))
318 }
319 AgentRequest::ReleaseTerminalRequest(args) => {
320 let response = self.release_terminal(args).await?;
321 Ok(ClientResponse::ReleaseTerminalResponse(response))
322 }
323 AgentRequest::WaitForTerminalExitRequest(args) => {
324 let response = self.wait_for_terminal_exit(args).await?;
325 Ok(ClientResponse::WaitForTerminalExitResponse(response))
326 }
327 AgentRequest::KillTerminalRequest(args) => {
328 let response = self.kill_terminal(args).await?;
329 Ok(ClientResponse::KillTerminalResponse(response))
330 }
331 AgentRequest::ExtMethodRequest(args) => {
332 let response = self.ext_method(args).await?;
333 Ok(ClientResponse::ExtMethodResponse(response))
334 }
335 _ => Err(Error::method_not_found()),
336 }
337 }
338
339 async fn handle_notification(&self, notification: AgentNotification) -> Result<()> {
340 match notification {
341 AgentNotification::SessionNotification(args) => {
342 self.session_notification(args).await?;
343 }
344 AgentNotification::ExtNotification(args) => {
345 self.ext_notification(args).await?;
346 }
347 _ => {}
349 }
350 Ok(())
351 }
352}
353
354#[derive(Debug)]
365pub struct AgentSideConnection {
366 conn: RpcConnection<AgentSide, ClientSide>,
367}
368
369impl AgentSideConnection {
370 pub fn new(
390 agent: impl MessageHandler<AgentSide> + 'static,
391 outgoing_bytes: impl Unpin + AsyncWrite,
392 incoming_bytes: impl Unpin + AsyncRead,
393 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
394 ) -> (Self, impl Future<Output = Result<()>>) {
395 let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
396 (Self { conn }, io_task)
397 }
398
399 pub fn subscribe(&self) -> StreamReceiver {
408 self.conn.subscribe()
409 }
410}
411
412#[async_trait::async_trait(?Send)]
413impl Client for AgentSideConnection {
414 async fn request_permission(
415 &self,
416 args: RequestPermissionRequest,
417 ) -> Result<RequestPermissionResponse> {
418 self.conn
419 .request(
420 CLIENT_METHOD_NAMES.session_request_permission,
421 Some(AgentRequest::RequestPermissionRequest(args)),
422 )
423 .await
424 }
425
426 async fn write_text_file(&self, args: WriteTextFileRequest) -> Result<WriteTextFileResponse> {
427 self.conn
428 .request::<Option<_>>(
429 CLIENT_METHOD_NAMES.fs_write_text_file,
430 Some(AgentRequest::WriteTextFileRequest(args)),
431 )
432 .await
433 .map(Option::unwrap_or_default)
434 }
435
436 async fn read_text_file(&self, args: ReadTextFileRequest) -> Result<ReadTextFileResponse> {
437 self.conn
438 .request(
439 CLIENT_METHOD_NAMES.fs_read_text_file,
440 Some(AgentRequest::ReadTextFileRequest(args)),
441 )
442 .await
443 }
444
445 async fn create_terminal(&self, args: CreateTerminalRequest) -> Result<CreateTerminalResponse> {
446 self.conn
447 .request(
448 CLIENT_METHOD_NAMES.terminal_create,
449 Some(AgentRequest::CreateTerminalRequest(args)),
450 )
451 .await
452 }
453
454 async fn terminal_output(&self, args: TerminalOutputRequest) -> Result<TerminalOutputResponse> {
455 self.conn
456 .request(
457 CLIENT_METHOD_NAMES.terminal_output,
458 Some(AgentRequest::TerminalOutputRequest(args)),
459 )
460 .await
461 }
462
463 async fn release_terminal(
464 &self,
465 args: ReleaseTerminalRequest,
466 ) -> Result<ReleaseTerminalResponse> {
467 self.conn
468 .request::<Option<_>>(
469 CLIENT_METHOD_NAMES.terminal_release,
470 Some(AgentRequest::ReleaseTerminalRequest(args)),
471 )
472 .await
473 .map(Option::unwrap_or_default)
474 }
475
476 async fn wait_for_terminal_exit(
477 &self,
478 args: WaitForTerminalExitRequest,
479 ) -> Result<WaitForTerminalExitResponse> {
480 self.conn
481 .request(
482 CLIENT_METHOD_NAMES.terminal_wait_for_exit,
483 Some(AgentRequest::WaitForTerminalExitRequest(args)),
484 )
485 .await
486 }
487
488 async fn kill_terminal(&self, args: KillTerminalRequest) -> Result<KillTerminalResponse> {
489 self.conn
490 .request::<Option<_>>(
491 CLIENT_METHOD_NAMES.terminal_kill,
492 Some(AgentRequest::KillTerminalRequest(args)),
493 )
494 .await
495 .map(Option::unwrap_or_default)
496 }
497
498 async fn session_notification(&self, args: SessionNotification) -> Result<()> {
499 self.conn.notify(
500 CLIENT_METHOD_NAMES.session_update,
501 Some(AgentNotification::SessionNotification(args)),
502 )
503 }
504
505 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
506 self.conn
507 .request(
508 format!("_{}", args.method),
509 Some(AgentRequest::ExtMethodRequest(args)),
510 )
511 .await
512 }
513
514 async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
515 self.conn.notify(
516 format!("_{}", args.method),
517 Some(AgentNotification::ExtNotification(args)),
518 )
519 }
520}
521
522#[derive(Clone, Debug)]
529pub struct AgentSide;
530
531impl Side for AgentSide {
532 type InRequest = ClientRequest;
533 type InNotification = ClientNotification;
534 type OutResponse = AgentResponse;
535
536 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
537 let params = params.ok_or_else(Error::invalid_params)?;
538
539 match method {
540 m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
541 .map(ClientRequest::InitializeRequest)
542 .map_err(Into::into),
543 m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
544 .map(ClientRequest::AuthenticateRequest)
545 .map_err(Into::into),
546 m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
547 .map(ClientRequest::NewSessionRequest)
548 .map_err(Into::into),
549 m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
550 .map(ClientRequest::LoadSessionRequest)
551 .map_err(Into::into),
552 m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
553 .map(ClientRequest::SetSessionModeRequest)
554 .map_err(Into::into),
555 #[cfg(feature = "unstable_session_model")]
556 m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
557 .map(ClientRequest::SetSessionModelRequest)
558 .map_err(Into::into),
559 m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
560 .map(ClientRequest::ListSessionsRequest)
561 .map_err(Into::into),
562 #[cfg(feature = "unstable_session_fork")]
563 m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get())
564 .map(ClientRequest::ForkSessionRequest)
565 .map_err(Into::into),
566 #[cfg(feature = "unstable_session_resume")]
567 m if m == AGENT_METHOD_NAMES.session_resume => serde_json::from_str(params.get())
568 .map(ClientRequest::ResumeSessionRequest)
569 .map_err(Into::into),
570 m if m == AGENT_METHOD_NAMES.session_set_config_option => {
571 serde_json::from_str(params.get())
572 .map(ClientRequest::SetSessionConfigOptionRequest)
573 .map_err(Into::into)
574 }
575 m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
576 .map(ClientRequest::PromptRequest)
577 .map_err(Into::into),
578 _ => {
579 if let Some(custom_method) = method.strip_prefix('_') {
580 Ok(ClientRequest::ExtMethodRequest(ExtRequest::new(
581 custom_method,
582 params.to_owned().into(),
583 )))
584 } else {
585 Err(Error::method_not_found())
586 }
587 }
588 }
589 }
590
591 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
592 let params = params.ok_or_else(Error::invalid_params)?;
593
594 match method {
595 m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
596 .map(ClientNotification::CancelNotification)
597 .map_err(Into::into),
598 _ => {
599 if let Some(custom_method) = method.strip_prefix('_') {
600 Ok(ClientNotification::ExtNotification(ExtNotification::new(
601 custom_method,
602 RawValue::from_string(params.get().to_string())?.into(),
603 )))
604 } else {
605 Err(Error::method_not_found())
606 }
607 }
608 }
609 }
610}
611
612impl<T: Agent> MessageHandler<AgentSide> for T {
613 async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse> {
614 match request {
615 ClientRequest::InitializeRequest(args) => {
616 let response = self.initialize(args).await?;
617 Ok(AgentResponse::InitializeResponse(response))
618 }
619 ClientRequest::AuthenticateRequest(args) => {
620 let response = self.authenticate(args).await?;
621 Ok(AgentResponse::AuthenticateResponse(response))
622 }
623 ClientRequest::NewSessionRequest(args) => {
624 let response = self.new_session(args).await?;
625 Ok(AgentResponse::NewSessionResponse(response))
626 }
627 ClientRequest::LoadSessionRequest(args) => {
628 let response = self.load_session(args).await?;
629 Ok(AgentResponse::LoadSessionResponse(response))
630 }
631 ClientRequest::PromptRequest(args) => {
632 let response = self.prompt(args).await?;
633 Ok(AgentResponse::PromptResponse(response))
634 }
635 ClientRequest::SetSessionModeRequest(args) => {
636 let response = self.set_session_mode(args).await?;
637 Ok(AgentResponse::SetSessionModeResponse(response))
638 }
639 #[cfg(feature = "unstable_session_model")]
640 ClientRequest::SetSessionModelRequest(args) => {
641 let response = self.set_session_model(args).await?;
642 Ok(AgentResponse::SetSessionModelResponse(response))
643 }
644 ClientRequest::ListSessionsRequest(args) => {
645 let response = self.list_sessions(args).await?;
646 Ok(AgentResponse::ListSessionsResponse(response))
647 }
648 #[cfg(feature = "unstable_session_fork")]
649 ClientRequest::ForkSessionRequest(args) => {
650 let response = self.fork_session(args).await?;
651 Ok(AgentResponse::ForkSessionResponse(response))
652 }
653 #[cfg(feature = "unstable_session_resume")]
654 ClientRequest::ResumeSessionRequest(args) => {
655 let response = self.resume_session(args).await?;
656 Ok(AgentResponse::ResumeSessionResponse(response))
657 }
658 ClientRequest::SetSessionConfigOptionRequest(args) => {
659 let response = self.set_session_config_option(args).await?;
660 Ok(AgentResponse::SetSessionConfigOptionResponse(response))
661 }
662 ClientRequest::ExtMethodRequest(args) => {
663 let response = self.ext_method(args).await?;
664 Ok(AgentResponse::ExtMethodResponse(response))
665 }
666 _ => Err(Error::method_not_found()),
667 }
668 }
669
670 async fn handle_notification(&self, notification: ClientNotification) -> Result<()> {
671 match notification {
672 ClientNotification::CancelNotification(args) => {
673 self.cancel(args).await?;
674 }
675 ClientNotification::ExtNotification(args) => {
676 self.ext_notification(args).await?;
677 }
678 _ => {}
680 }
681 Ok(())
682 }
683}