1use futures::{AsyncRead, AsyncWrite, future::LocalBoxFuture};
2use rpc::RpcConnection;
3
4mod agent;
5mod client;
6mod rpc;
7#[cfg(test)]
8mod rpc_tests;
9mod stream_broadcast;
10
11pub use agent::*;
12pub use agent_client_protocol_schema::*;
13pub use client::*;
14pub use rpc::*;
15pub use stream_broadcast::{
16 StreamMessage, StreamMessageContent, StreamMessageDirection, StreamReceiver,
17};
18
19#[derive(Debug)]
30pub struct ClientSideConnection {
31 conn: RpcConnection<ClientSide, AgentSide>,
32}
33
34impl ClientSideConnection {
35 pub fn new(
55 client: impl MessageHandler<ClientSide> + 'static,
56 outgoing_bytes: impl Unpin + AsyncWrite,
57 incoming_bytes: impl Unpin + AsyncRead,
58 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
59 ) -> (Self, impl Future<Output = Result<()>>) {
60 let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
61 (Self { conn }, io_task)
62 }
63
64 pub fn subscribe(&self) -> StreamReceiver {
73 self.conn.subscribe()
74 }
75}
76
77#[async_trait::async_trait(?Send)]
78impl Agent for ClientSideConnection {
79 async fn initialize(&self, args: InitializeRequest) -> Result<InitializeResponse> {
80 self.conn
81 .request(
82 AGENT_METHOD_NAMES.initialize,
83 Some(ClientRequest::InitializeRequest(args)),
84 )
85 .await
86 }
87
88 async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse> {
89 self.conn
90 .request::<Option<_>>(
91 AGENT_METHOD_NAMES.authenticate,
92 Some(ClientRequest::AuthenticateRequest(args)),
93 )
94 .await
95 .map(Option::unwrap_or_default)
96 }
97
98 #[cfg(feature = "unstable_logout")]
99 async fn logout(&self, args: LogoutRequest) -> Result<LogoutResponse> {
100 self.conn
101 .request::<Option<_>>(
102 AGENT_METHOD_NAMES.logout,
103 Some(ClientRequest::LogoutRequest(args)),
104 )
105 .await
106 .map(Option::unwrap_or_default)
107 }
108
109 async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse> {
110 self.conn
111 .request(
112 AGENT_METHOD_NAMES.session_new,
113 Some(ClientRequest::NewSessionRequest(args)),
114 )
115 .await
116 }
117
118 async fn load_session(&self, args: LoadSessionRequest) -> Result<LoadSessionResponse> {
119 self.conn
120 .request::<Option<_>>(
121 AGENT_METHOD_NAMES.session_load,
122 Some(ClientRequest::LoadSessionRequest(args)),
123 )
124 .await
125 .map(Option::unwrap_or_default)
126 }
127
128 async fn set_session_mode(
129 &self,
130 args: SetSessionModeRequest,
131 ) -> Result<SetSessionModeResponse> {
132 self.conn
133 .request(
134 AGENT_METHOD_NAMES.session_set_mode,
135 Some(ClientRequest::SetSessionModeRequest(args)),
136 )
137 .await
138 }
139
140 async fn prompt(&self, args: PromptRequest) -> Result<PromptResponse> {
141 self.conn
142 .request(
143 AGENT_METHOD_NAMES.session_prompt,
144 Some(ClientRequest::PromptRequest(args)),
145 )
146 .await
147 }
148
149 async fn cancel(&self, args: CancelNotification) -> Result<()> {
150 self.conn.notify(
151 AGENT_METHOD_NAMES.session_cancel,
152 Some(ClientNotification::CancelNotification(args)),
153 )
154 }
155
156 #[cfg(feature = "unstable_session_model")]
157 async fn set_session_model(
158 &self,
159 args: SetSessionModelRequest,
160 ) -> Result<SetSessionModelResponse> {
161 self.conn
162 .request(
163 AGENT_METHOD_NAMES.session_set_model,
164 Some(ClientRequest::SetSessionModelRequest(args)),
165 )
166 .await
167 }
168
169 async fn list_sessions(&self, args: ListSessionsRequest) -> Result<ListSessionsResponse> {
170 self.conn
171 .request(
172 AGENT_METHOD_NAMES.session_list,
173 Some(ClientRequest::ListSessionsRequest(args)),
174 )
175 .await
176 }
177
178 #[cfg(feature = "unstable_session_fork")]
179 async fn fork_session(&self, args: ForkSessionRequest) -> Result<ForkSessionResponse> {
180 self.conn
181 .request(
182 AGENT_METHOD_NAMES.session_fork,
183 Some(ClientRequest::ForkSessionRequest(args)),
184 )
185 .await
186 }
187
188 #[cfg(feature = "unstable_session_resume")]
189 async fn resume_session(&self, args: ResumeSessionRequest) -> Result<ResumeSessionResponse> {
190 self.conn
191 .request(
192 AGENT_METHOD_NAMES.session_resume,
193 Some(ClientRequest::ResumeSessionRequest(args)),
194 )
195 .await
196 }
197
198 #[cfg(feature = "unstable_session_close")]
199 async fn close_session(&self, args: CloseSessionRequest) -> Result<CloseSessionResponse> {
200 self.conn
201 .request::<Option<_>>(
202 AGENT_METHOD_NAMES.session_close,
203 Some(ClientRequest::CloseSessionRequest(args)),
204 )
205 .await
206 .map(Option::unwrap_or_default)
207 }
208
209 async fn set_session_config_option(
210 &self,
211 args: SetSessionConfigOptionRequest,
212 ) -> Result<SetSessionConfigOptionResponse> {
213 self.conn
214 .request(
215 AGENT_METHOD_NAMES.session_set_config_option,
216 Some(ClientRequest::SetSessionConfigOptionRequest(args)),
217 )
218 .await
219 }
220
221 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
222 self.conn
223 .request(
224 format!("_{}", args.method),
225 Some(ClientRequest::ExtMethodRequest(args)),
226 )
227 .await
228 }
229
230 async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
231 self.conn.notify(
232 format!("_{}", args.method),
233 Some(ClientNotification::ExtNotification(args)),
234 )
235 }
236}
237
238#[derive(Clone, Debug)]
245pub struct ClientSide;
246
247impl Side for ClientSide {
248 type InNotification = AgentNotification;
249 type InRequest = AgentRequest;
250 type OutResponse = ClientResponse;
251
252 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
253 let params = params.ok_or_else(Error::invalid_params)?;
254
255 match method {
256 m if m == CLIENT_METHOD_NAMES.session_request_permission => {
257 serde_json::from_str(params.get())
258 .map(AgentRequest::RequestPermissionRequest)
259 .map_err(Into::into)
260 }
261 m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
262 .map(AgentRequest::WriteTextFileRequest)
263 .map_err(Into::into),
264 m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
265 .map(AgentRequest::ReadTextFileRequest)
266 .map_err(Into::into),
267 m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
268 .map(AgentRequest::CreateTerminalRequest)
269 .map_err(Into::into),
270 m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
271 .map(AgentRequest::TerminalOutputRequest)
272 .map_err(Into::into),
273 m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
274 .map(AgentRequest::KillTerminalRequest)
275 .map_err(Into::into),
276 m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
277 .map(AgentRequest::ReleaseTerminalRequest)
278 .map_err(Into::into),
279 m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
280 serde_json::from_str(params.get())
281 .map(AgentRequest::WaitForTerminalExitRequest)
282 .map_err(Into::into)
283 }
284 _ => {
285 if let Some(custom_method) = method.strip_prefix('_') {
286 Ok(AgentRequest::ExtMethodRequest(ExtRequest::new(
287 custom_method,
288 params.to_owned().into(),
289 )))
290 } else {
291 Err(Error::method_not_found())
292 }
293 }
294 }
295 }
296
297 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
298 let params = params.ok_or_else(Error::invalid_params)?;
299
300 match method {
301 m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
302 .map(AgentNotification::SessionNotification)
303 .map_err(Into::into),
304 _ => {
305 if let Some(custom_method) = method.strip_prefix('_') {
306 Ok(AgentNotification::ExtNotification(ExtNotification::new(
307 custom_method,
308 RawValue::from_string(params.get().to_string())?.into(),
309 )))
310 } else {
311 Err(Error::method_not_found())
312 }
313 }
314 }
315 }
316}
317
318impl<T: Client> MessageHandler<ClientSide> for T {
319 async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse> {
320 match request {
321 AgentRequest::RequestPermissionRequest(args) => {
322 let response = self.request_permission(args).await?;
323 Ok(ClientResponse::RequestPermissionResponse(response))
324 }
325 AgentRequest::WriteTextFileRequest(args) => {
326 let response = self.write_text_file(args).await?;
327 Ok(ClientResponse::WriteTextFileResponse(response))
328 }
329 AgentRequest::ReadTextFileRequest(args) => {
330 let response = self.read_text_file(args).await?;
331 Ok(ClientResponse::ReadTextFileResponse(response))
332 }
333 AgentRequest::CreateTerminalRequest(args) => {
334 let response = self.create_terminal(args).await?;
335 Ok(ClientResponse::CreateTerminalResponse(response))
336 }
337 AgentRequest::TerminalOutputRequest(args) => {
338 let response = self.terminal_output(args).await?;
339 Ok(ClientResponse::TerminalOutputResponse(response))
340 }
341 AgentRequest::ReleaseTerminalRequest(args) => {
342 let response = self.release_terminal(args).await?;
343 Ok(ClientResponse::ReleaseTerminalResponse(response))
344 }
345 AgentRequest::WaitForTerminalExitRequest(args) => {
346 let response = self.wait_for_terminal_exit(args).await?;
347 Ok(ClientResponse::WaitForTerminalExitResponse(response))
348 }
349 AgentRequest::KillTerminalRequest(args) => {
350 let response = self.kill_terminal(args).await?;
351 Ok(ClientResponse::KillTerminalResponse(response))
352 }
353 AgentRequest::ExtMethodRequest(args) => {
354 let response = self.ext_method(args).await?;
355 Ok(ClientResponse::ExtMethodResponse(response))
356 }
357 _ => Err(Error::method_not_found()),
358 }
359 }
360
361 async fn handle_notification(&self, notification: AgentNotification) -> Result<()> {
362 match notification {
363 AgentNotification::SessionNotification(args) => {
364 self.session_notification(args).await?;
365 }
366 AgentNotification::ExtNotification(args) => {
367 self.ext_notification(args).await?;
368 }
369 _ => {}
371 }
372 Ok(())
373 }
374}
375
376#[derive(Debug)]
387pub struct AgentSideConnection {
388 conn: RpcConnection<AgentSide, ClientSide>,
389}
390
391impl AgentSideConnection {
392 pub fn new(
412 agent: impl MessageHandler<AgentSide> + 'static,
413 outgoing_bytes: impl Unpin + AsyncWrite,
414 incoming_bytes: impl Unpin + AsyncRead,
415 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
416 ) -> (Self, impl Future<Output = Result<()>>) {
417 let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
418 (Self { conn }, io_task)
419 }
420
421 pub fn subscribe(&self) -> StreamReceiver {
430 self.conn.subscribe()
431 }
432}
433
434#[async_trait::async_trait(?Send)]
435impl Client for AgentSideConnection {
436 async fn request_permission(
437 &self,
438 args: RequestPermissionRequest,
439 ) -> Result<RequestPermissionResponse> {
440 self.conn
441 .request(
442 CLIENT_METHOD_NAMES.session_request_permission,
443 Some(AgentRequest::RequestPermissionRequest(args)),
444 )
445 .await
446 }
447
448 async fn write_text_file(&self, args: WriteTextFileRequest) -> Result<WriteTextFileResponse> {
449 self.conn
450 .request::<Option<_>>(
451 CLIENT_METHOD_NAMES.fs_write_text_file,
452 Some(AgentRequest::WriteTextFileRequest(args)),
453 )
454 .await
455 .map(Option::unwrap_or_default)
456 }
457
458 async fn read_text_file(&self, args: ReadTextFileRequest) -> Result<ReadTextFileResponse> {
459 self.conn
460 .request(
461 CLIENT_METHOD_NAMES.fs_read_text_file,
462 Some(AgentRequest::ReadTextFileRequest(args)),
463 )
464 .await
465 }
466
467 async fn create_terminal(&self, args: CreateTerminalRequest) -> Result<CreateTerminalResponse> {
468 self.conn
469 .request(
470 CLIENT_METHOD_NAMES.terminal_create,
471 Some(AgentRequest::CreateTerminalRequest(args)),
472 )
473 .await
474 }
475
476 async fn terminal_output(&self, args: TerminalOutputRequest) -> Result<TerminalOutputResponse> {
477 self.conn
478 .request(
479 CLIENT_METHOD_NAMES.terminal_output,
480 Some(AgentRequest::TerminalOutputRequest(args)),
481 )
482 .await
483 }
484
485 async fn release_terminal(
486 &self,
487 args: ReleaseTerminalRequest,
488 ) -> Result<ReleaseTerminalResponse> {
489 self.conn
490 .request::<Option<_>>(
491 CLIENT_METHOD_NAMES.terminal_release,
492 Some(AgentRequest::ReleaseTerminalRequest(args)),
493 )
494 .await
495 .map(Option::unwrap_or_default)
496 }
497
498 async fn wait_for_terminal_exit(
499 &self,
500 args: WaitForTerminalExitRequest,
501 ) -> Result<WaitForTerminalExitResponse> {
502 self.conn
503 .request(
504 CLIENT_METHOD_NAMES.terminal_wait_for_exit,
505 Some(AgentRequest::WaitForTerminalExitRequest(args)),
506 )
507 .await
508 }
509
510 async fn kill_terminal(&self, args: KillTerminalRequest) -> Result<KillTerminalResponse> {
511 self.conn
512 .request::<Option<_>>(
513 CLIENT_METHOD_NAMES.terminal_kill,
514 Some(AgentRequest::KillTerminalRequest(args)),
515 )
516 .await
517 .map(Option::unwrap_or_default)
518 }
519
520 async fn session_notification(&self, args: SessionNotification) -> Result<()> {
521 self.conn.notify(
522 CLIENT_METHOD_NAMES.session_update,
523 Some(AgentNotification::SessionNotification(args)),
524 )
525 }
526
527 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
528 self.conn
529 .request(
530 format!("_{}", args.method),
531 Some(AgentRequest::ExtMethodRequest(args)),
532 )
533 .await
534 }
535
536 async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
537 self.conn.notify(
538 format!("_{}", args.method),
539 Some(AgentNotification::ExtNotification(args)),
540 )
541 }
542}
543
544#[derive(Clone, Debug)]
551pub struct AgentSide;
552
553impl Side for AgentSide {
554 type InRequest = ClientRequest;
555 type InNotification = ClientNotification;
556 type OutResponse = AgentResponse;
557
558 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
559 let params = params.ok_or_else(Error::invalid_params)?;
560
561 match method {
562 m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
563 .map(ClientRequest::InitializeRequest)
564 .map_err(Into::into),
565 m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
566 .map(ClientRequest::AuthenticateRequest)
567 .map_err(Into::into),
568 #[cfg(feature = "unstable_logout")]
569 m if m == AGENT_METHOD_NAMES.logout => serde_json::from_str(params.get())
570 .map(ClientRequest::LogoutRequest)
571 .map_err(Into::into),
572 m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
573 .map(ClientRequest::NewSessionRequest)
574 .map_err(Into::into),
575 m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
576 .map(ClientRequest::LoadSessionRequest)
577 .map_err(Into::into),
578 m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
579 .map(ClientRequest::SetSessionModeRequest)
580 .map_err(Into::into),
581 #[cfg(feature = "unstable_session_model")]
582 m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
583 .map(ClientRequest::SetSessionModelRequest)
584 .map_err(Into::into),
585 m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
586 .map(ClientRequest::ListSessionsRequest)
587 .map_err(Into::into),
588 #[cfg(feature = "unstable_session_fork")]
589 m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get())
590 .map(ClientRequest::ForkSessionRequest)
591 .map_err(Into::into),
592 #[cfg(feature = "unstable_session_resume")]
593 m if m == AGENT_METHOD_NAMES.session_resume => serde_json::from_str(params.get())
594 .map(ClientRequest::ResumeSessionRequest)
595 .map_err(Into::into),
596 #[cfg(feature = "unstable_session_close")]
597 m if m == AGENT_METHOD_NAMES.session_close => serde_json::from_str(params.get())
598 .map(ClientRequest::CloseSessionRequest)
599 .map_err(Into::into),
600 m if m == AGENT_METHOD_NAMES.session_set_config_option => {
601 serde_json::from_str(params.get())
602 .map(ClientRequest::SetSessionConfigOptionRequest)
603 .map_err(Into::into)
604 }
605 m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
606 .map(ClientRequest::PromptRequest)
607 .map_err(Into::into),
608 _ => {
609 if let Some(custom_method) = method.strip_prefix('_') {
610 Ok(ClientRequest::ExtMethodRequest(ExtRequest::new(
611 custom_method,
612 params.to_owned().into(),
613 )))
614 } else {
615 Err(Error::method_not_found())
616 }
617 }
618 }
619 }
620
621 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
622 let params = params.ok_or_else(Error::invalid_params)?;
623
624 match method {
625 m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
626 .map(ClientNotification::CancelNotification)
627 .map_err(Into::into),
628 _ => {
629 if let Some(custom_method) = method.strip_prefix('_') {
630 Ok(ClientNotification::ExtNotification(ExtNotification::new(
631 custom_method,
632 RawValue::from_string(params.get().to_string())?.into(),
633 )))
634 } else {
635 Err(Error::method_not_found())
636 }
637 }
638 }
639 }
640}
641
642impl<T: Agent> MessageHandler<AgentSide> for T {
643 async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse> {
644 match request {
645 ClientRequest::InitializeRequest(args) => {
646 let response = self.initialize(args).await?;
647 Ok(AgentResponse::InitializeResponse(response))
648 }
649 ClientRequest::AuthenticateRequest(args) => {
650 let response = self.authenticate(args).await?;
651 Ok(AgentResponse::AuthenticateResponse(response))
652 }
653 #[cfg(feature = "unstable_logout")]
654 ClientRequest::LogoutRequest(args) => {
655 let response = self.logout(args).await?;
656 Ok(AgentResponse::LogoutResponse(response))
657 }
658 ClientRequest::NewSessionRequest(args) => {
659 let response = self.new_session(args).await?;
660 Ok(AgentResponse::NewSessionResponse(response))
661 }
662 ClientRequest::LoadSessionRequest(args) => {
663 let response = self.load_session(args).await?;
664 Ok(AgentResponse::LoadSessionResponse(response))
665 }
666 ClientRequest::PromptRequest(args) => {
667 let response = self.prompt(args).await?;
668 Ok(AgentResponse::PromptResponse(response))
669 }
670 ClientRequest::SetSessionModeRequest(args) => {
671 let response = self.set_session_mode(args).await?;
672 Ok(AgentResponse::SetSessionModeResponse(response))
673 }
674 #[cfg(feature = "unstable_session_model")]
675 ClientRequest::SetSessionModelRequest(args) => {
676 let response = self.set_session_model(args).await?;
677 Ok(AgentResponse::SetSessionModelResponse(response))
678 }
679 ClientRequest::ListSessionsRequest(args) => {
680 let response = self.list_sessions(args).await?;
681 Ok(AgentResponse::ListSessionsResponse(response))
682 }
683 #[cfg(feature = "unstable_session_fork")]
684 ClientRequest::ForkSessionRequest(args) => {
685 let response = self.fork_session(args).await?;
686 Ok(AgentResponse::ForkSessionResponse(response))
687 }
688 #[cfg(feature = "unstable_session_resume")]
689 ClientRequest::ResumeSessionRequest(args) => {
690 let response = self.resume_session(args).await?;
691 Ok(AgentResponse::ResumeSessionResponse(response))
692 }
693 #[cfg(feature = "unstable_session_close")]
694 ClientRequest::CloseSessionRequest(args) => {
695 let response = self.close_session(args).await?;
696 Ok(AgentResponse::CloseSessionResponse(response))
697 }
698 ClientRequest::SetSessionConfigOptionRequest(args) => {
699 let response = self.set_session_config_option(args).await?;
700 Ok(AgentResponse::SetSessionConfigOptionResponse(response))
701 }
702 ClientRequest::ExtMethodRequest(args) => {
703 let response = self.ext_method(args).await?;
704 Ok(AgentResponse::ExtMethodResponse(response))
705 }
706 _ => Err(Error::method_not_found()),
707 }
708 }
709
710 async fn handle_notification(&self, notification: ClientNotification) -> Result<()> {
711 match notification {
712 ClientNotification::CancelNotification(args) => {
713 self.cancel(args).await?;
714 }
715 ClientNotification::ExtNotification(args) => {
716 self.ext_notification(args).await?;
717 }
718 _ => {}
720 }
721 Ok(())
722 }
723}