1use futures::{AsyncRead, AsyncWrite, future::LocalBoxFuture};
2use rpc::RpcConnection;
3
4mod agent;
5mod client;
6mod rpc;
7#[cfg(test)]
8mod rpc_tests;
9mod stream_broadcast;
10
11pub use agent::*;
12pub use agent_client_protocol_schema::*;
13pub use client::*;
14pub use rpc::*;
15pub use stream_broadcast::{
16 StreamMessage, StreamMessageContent, StreamMessageDirection, StreamReceiver,
17};
18
19#[derive(Debug)]
30pub struct ClientSideConnection {
31 conn: RpcConnection<ClientSide, AgentSide>,
32}
33
34impl ClientSideConnection {
35 pub fn new(
55 client: impl MessageHandler<ClientSide> + 'static,
56 outgoing_bytes: impl Unpin + AsyncWrite,
57 incoming_bytes: impl Unpin + AsyncRead,
58 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
59 ) -> (Self, impl Future<Output = Result<()>>) {
60 let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
61 (Self { conn }, io_task)
62 }
63
64 pub fn subscribe(&self) -> StreamReceiver {
73 self.conn.subscribe()
74 }
75}
76
77#[async_trait::async_trait(?Send)]
78impl Agent for ClientSideConnection {
79 async fn initialize(&self, args: InitializeRequest) -> Result<InitializeResponse> {
80 self.conn
81 .request(
82 AGENT_METHOD_NAMES.initialize,
83 Some(ClientRequest::InitializeRequest(args)),
84 )
85 .await
86 }
87
88 async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse> {
89 self.conn
90 .request::<Option<_>>(
91 AGENT_METHOD_NAMES.authenticate,
92 Some(ClientRequest::AuthenticateRequest(args)),
93 )
94 .await
95 .map(Option::unwrap_or_default)
96 }
97
98 async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse> {
99 self.conn
100 .request(
101 AGENT_METHOD_NAMES.session_new,
102 Some(ClientRequest::NewSessionRequest(args)),
103 )
104 .await
105 }
106
107 async fn load_session(&self, args: LoadSessionRequest) -> Result<LoadSessionResponse> {
108 self.conn
109 .request::<Option<_>>(
110 AGENT_METHOD_NAMES.session_load,
111 Some(ClientRequest::LoadSessionRequest(args)),
112 )
113 .await
114 .map(Option::unwrap_or_default)
115 }
116
117 async fn set_session_mode(
118 &self,
119 args: SetSessionModeRequest,
120 ) -> Result<SetSessionModeResponse> {
121 self.conn
122 .request(
123 AGENT_METHOD_NAMES.session_set_mode,
124 Some(ClientRequest::SetSessionModeRequest(args)),
125 )
126 .await
127 }
128
129 async fn prompt(&self, args: PromptRequest) -> Result<PromptResponse> {
130 self.conn
131 .request(
132 AGENT_METHOD_NAMES.session_prompt,
133 Some(ClientRequest::PromptRequest(args)),
134 )
135 .await
136 }
137
138 async fn cancel(&self, args: CancelNotification) -> Result<()> {
139 self.conn.notify(
140 AGENT_METHOD_NAMES.session_cancel,
141 Some(ClientNotification::CancelNotification(args)),
142 )
143 }
144
145 #[cfg(feature = "unstable_session_model")]
146 async fn set_session_model(
147 &self,
148 args: SetSessionModelRequest,
149 ) -> Result<SetSessionModelResponse> {
150 self.conn
151 .request(
152 AGENT_METHOD_NAMES.session_set_model,
153 Some(ClientRequest::SetSessionModelRequest(args)),
154 )
155 .await
156 }
157
158 #[cfg(feature = "unstable_session_list")]
159 async fn list_sessions(&self, args: ListSessionsRequest) -> Result<ListSessionsResponse> {
160 self.conn
161 .request(
162 AGENT_METHOD_NAMES.session_list,
163 Some(ClientRequest::ListSessionsRequest(args)),
164 )
165 .await
166 }
167
168 #[cfg(feature = "unstable_session_fork")]
169 async fn fork_session(&self, args: ForkSessionRequest) -> Result<ForkSessionResponse> {
170 self.conn
171 .request(
172 AGENT_METHOD_NAMES.session_fork,
173 Some(ClientRequest::ForkSessionRequest(args)),
174 )
175 .await
176 }
177
178 #[cfg(feature = "unstable_session_resume")]
179 async fn resume_session(&self, args: ResumeSessionRequest) -> Result<ResumeSessionResponse> {
180 self.conn
181 .request(
182 AGENT_METHOD_NAMES.session_resume,
183 Some(ClientRequest::ResumeSessionRequest(args)),
184 )
185 .await
186 }
187
188 #[cfg(feature = "unstable_session_config_options")]
189 async fn set_session_config_option(
190 &self,
191 args: SetSessionConfigOptionRequest,
192 ) -> Result<SetSessionConfigOptionResponse> {
193 self.conn
194 .request(
195 AGENT_METHOD_NAMES.session_set_config_option,
196 Some(ClientRequest::SetSessionConfigOptionRequest(args)),
197 )
198 .await
199 }
200
201 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
202 self.conn
203 .request(
204 format!("_{}", args.method),
205 Some(ClientRequest::ExtMethodRequest(args)),
206 )
207 .await
208 }
209
210 async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
211 self.conn.notify(
212 format!("_{}", args.method),
213 Some(ClientNotification::ExtNotification(args)),
214 )
215 }
216}
217
218#[derive(Clone, Debug)]
225pub struct ClientSide;
226
227impl Side for ClientSide {
228 type InNotification = AgentNotification;
229 type InRequest = AgentRequest;
230 type OutResponse = ClientResponse;
231
232 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
233 let params = params.ok_or_else(Error::invalid_params)?;
234
235 match method {
236 m if m == CLIENT_METHOD_NAMES.session_request_permission => {
237 serde_json::from_str(params.get())
238 .map(AgentRequest::RequestPermissionRequest)
239 .map_err(Into::into)
240 }
241 m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
242 .map(AgentRequest::WriteTextFileRequest)
243 .map_err(Into::into),
244 m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
245 .map(AgentRequest::ReadTextFileRequest)
246 .map_err(Into::into),
247 m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
248 .map(AgentRequest::CreateTerminalRequest)
249 .map_err(Into::into),
250 m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
251 .map(AgentRequest::TerminalOutputRequest)
252 .map_err(Into::into),
253 m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
254 .map(AgentRequest::KillTerminalCommandRequest)
255 .map_err(Into::into),
256 m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
257 .map(AgentRequest::ReleaseTerminalRequest)
258 .map_err(Into::into),
259 m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
260 serde_json::from_str(params.get())
261 .map(AgentRequest::WaitForTerminalExitRequest)
262 .map_err(Into::into)
263 }
264 _ => {
265 if let Some(custom_method) = method.strip_prefix('_') {
266 Ok(AgentRequest::ExtMethodRequest(ExtRequest::new(
267 custom_method,
268 params.to_owned().into(),
269 )))
270 } else {
271 Err(Error::method_not_found())
272 }
273 }
274 }
275 }
276
277 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
278 let params = params.ok_or_else(Error::invalid_params)?;
279
280 match method {
281 m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
282 .map(AgentNotification::SessionNotification)
283 .map_err(Into::into),
284 _ => {
285 if let Some(custom_method) = method.strip_prefix('_') {
286 Ok(AgentNotification::ExtNotification(ExtNotification::new(
287 custom_method,
288 RawValue::from_string(params.get().to_string())?.into(),
289 )))
290 } else {
291 Err(Error::method_not_found())
292 }
293 }
294 }
295 }
296}
297
298impl<T: Client> MessageHandler<ClientSide> for T {
299 async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse> {
300 match request {
301 AgentRequest::RequestPermissionRequest(args) => {
302 let response = self.request_permission(args).await?;
303 Ok(ClientResponse::RequestPermissionResponse(response))
304 }
305 AgentRequest::WriteTextFileRequest(args) => {
306 let response = self.write_text_file(args).await?;
307 Ok(ClientResponse::WriteTextFileResponse(response))
308 }
309 AgentRequest::ReadTextFileRequest(args) => {
310 let response = self.read_text_file(args).await?;
311 Ok(ClientResponse::ReadTextFileResponse(response))
312 }
313 AgentRequest::CreateTerminalRequest(args) => {
314 let response = self.create_terminal(args).await?;
315 Ok(ClientResponse::CreateTerminalResponse(response))
316 }
317 AgentRequest::TerminalOutputRequest(args) => {
318 let response = self.terminal_output(args).await?;
319 Ok(ClientResponse::TerminalOutputResponse(response))
320 }
321 AgentRequest::ReleaseTerminalRequest(args) => {
322 let response = self.release_terminal(args).await?;
323 Ok(ClientResponse::ReleaseTerminalResponse(response))
324 }
325 AgentRequest::WaitForTerminalExitRequest(args) => {
326 let response = self.wait_for_terminal_exit(args).await?;
327 Ok(ClientResponse::WaitForTerminalExitResponse(response))
328 }
329 AgentRequest::KillTerminalCommandRequest(args) => {
330 let response = self.kill_terminal_command(args).await?;
331 Ok(ClientResponse::KillTerminalResponse(response))
332 }
333 AgentRequest::ExtMethodRequest(args) => {
334 let response = self.ext_method(args).await?;
335 Ok(ClientResponse::ExtMethodResponse(response))
336 }
337 _ => Err(Error::method_not_found()),
338 }
339 }
340
341 async fn handle_notification(&self, notification: AgentNotification) -> Result<()> {
342 match notification {
343 AgentNotification::SessionNotification(args) => {
344 self.session_notification(args).await?;
345 }
346 AgentNotification::ExtNotification(args) => {
347 self.ext_notification(args).await?;
348 }
349 _ => {}
351 }
352 Ok(())
353 }
354}
355
356#[derive(Debug)]
367pub struct AgentSideConnection {
368 conn: RpcConnection<AgentSide, ClientSide>,
369}
370
371impl AgentSideConnection {
372 pub fn new(
392 agent: impl MessageHandler<AgentSide> + 'static,
393 outgoing_bytes: impl Unpin + AsyncWrite,
394 incoming_bytes: impl Unpin + AsyncRead,
395 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
396 ) -> (Self, impl Future<Output = Result<()>>) {
397 let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
398 (Self { conn }, io_task)
399 }
400
401 pub fn subscribe(&self) -> StreamReceiver {
410 self.conn.subscribe()
411 }
412}
413
414#[async_trait::async_trait(?Send)]
415impl Client for AgentSideConnection {
416 async fn request_permission(
417 &self,
418 args: RequestPermissionRequest,
419 ) -> Result<RequestPermissionResponse> {
420 self.conn
421 .request(
422 CLIENT_METHOD_NAMES.session_request_permission,
423 Some(AgentRequest::RequestPermissionRequest(args)),
424 )
425 .await
426 }
427
428 async fn write_text_file(&self, args: WriteTextFileRequest) -> Result<WriteTextFileResponse> {
429 self.conn
430 .request::<Option<_>>(
431 CLIENT_METHOD_NAMES.fs_write_text_file,
432 Some(AgentRequest::WriteTextFileRequest(args)),
433 )
434 .await
435 .map(Option::unwrap_or_default)
436 }
437
438 async fn read_text_file(&self, args: ReadTextFileRequest) -> Result<ReadTextFileResponse> {
439 self.conn
440 .request(
441 CLIENT_METHOD_NAMES.fs_read_text_file,
442 Some(AgentRequest::ReadTextFileRequest(args)),
443 )
444 .await
445 }
446
447 async fn create_terminal(&self, args: CreateTerminalRequest) -> Result<CreateTerminalResponse> {
448 self.conn
449 .request(
450 CLIENT_METHOD_NAMES.terminal_create,
451 Some(AgentRequest::CreateTerminalRequest(args)),
452 )
453 .await
454 }
455
456 async fn terminal_output(&self, args: TerminalOutputRequest) -> Result<TerminalOutputResponse> {
457 self.conn
458 .request(
459 CLIENT_METHOD_NAMES.terminal_output,
460 Some(AgentRequest::TerminalOutputRequest(args)),
461 )
462 .await
463 }
464
465 async fn release_terminal(
466 &self,
467 args: ReleaseTerminalRequest,
468 ) -> Result<ReleaseTerminalResponse> {
469 self.conn
470 .request::<Option<_>>(
471 CLIENT_METHOD_NAMES.terminal_release,
472 Some(AgentRequest::ReleaseTerminalRequest(args)),
473 )
474 .await
475 .map(Option::unwrap_or_default)
476 }
477
478 async fn wait_for_terminal_exit(
479 &self,
480 args: WaitForTerminalExitRequest,
481 ) -> Result<WaitForTerminalExitResponse> {
482 self.conn
483 .request(
484 CLIENT_METHOD_NAMES.terminal_wait_for_exit,
485 Some(AgentRequest::WaitForTerminalExitRequest(args)),
486 )
487 .await
488 }
489
490 async fn kill_terminal_command(
491 &self,
492 args: KillTerminalCommandRequest,
493 ) -> Result<KillTerminalCommandResponse> {
494 self.conn
495 .request::<Option<_>>(
496 CLIENT_METHOD_NAMES.terminal_kill,
497 Some(AgentRequest::KillTerminalCommandRequest(args)),
498 )
499 .await
500 .map(Option::unwrap_or_default)
501 }
502
503 async fn session_notification(&self, args: SessionNotification) -> Result<()> {
504 self.conn.notify(
505 CLIENT_METHOD_NAMES.session_update,
506 Some(AgentNotification::SessionNotification(args)),
507 )
508 }
509
510 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
511 self.conn
512 .request(
513 format!("_{}", args.method),
514 Some(AgentRequest::ExtMethodRequest(args)),
515 )
516 .await
517 }
518
519 async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
520 self.conn.notify(
521 format!("_{}", args.method),
522 Some(AgentNotification::ExtNotification(args)),
523 )
524 }
525}
526
527#[derive(Clone, Debug)]
534pub struct AgentSide;
535
536impl Side for AgentSide {
537 type InRequest = ClientRequest;
538 type InNotification = ClientNotification;
539 type OutResponse = AgentResponse;
540
541 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
542 let params = params.ok_or_else(Error::invalid_params)?;
543
544 match method {
545 m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
546 .map(ClientRequest::InitializeRequest)
547 .map_err(Into::into),
548 m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
549 .map(ClientRequest::AuthenticateRequest)
550 .map_err(Into::into),
551 m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
552 .map(ClientRequest::NewSessionRequest)
553 .map_err(Into::into),
554 m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
555 .map(ClientRequest::LoadSessionRequest)
556 .map_err(Into::into),
557 m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
558 .map(ClientRequest::SetSessionModeRequest)
559 .map_err(Into::into),
560 #[cfg(feature = "unstable_session_model")]
561 m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
562 .map(ClientRequest::SetSessionModelRequest)
563 .map_err(Into::into),
564 #[cfg(feature = "unstable_session_list")]
565 m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
566 .map(ClientRequest::ListSessionsRequest)
567 .map_err(Into::into),
568 #[cfg(feature = "unstable_session_fork")]
569 m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get())
570 .map(ClientRequest::ForkSessionRequest)
571 .map_err(Into::into),
572 #[cfg(feature = "unstable_session_resume")]
573 m if m == AGENT_METHOD_NAMES.session_resume => serde_json::from_str(params.get())
574 .map(ClientRequest::ResumeSessionRequest)
575 .map_err(Into::into),
576 #[cfg(feature = "unstable_session_config_options")]
577 m if m == AGENT_METHOD_NAMES.session_set_config_option => {
578 serde_json::from_str(params.get())
579 .map(ClientRequest::SetSessionConfigOptionRequest)
580 .map_err(Into::into)
581 }
582 m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
583 .map(ClientRequest::PromptRequest)
584 .map_err(Into::into),
585 _ => {
586 if let Some(custom_method) = method.strip_prefix('_') {
587 Ok(ClientRequest::ExtMethodRequest(ExtRequest::new(
588 custom_method,
589 params.to_owned().into(),
590 )))
591 } else {
592 Err(Error::method_not_found())
593 }
594 }
595 }
596 }
597
598 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
599 let params = params.ok_or_else(Error::invalid_params)?;
600
601 match method {
602 m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
603 .map(ClientNotification::CancelNotification)
604 .map_err(Into::into),
605 _ => {
606 if let Some(custom_method) = method.strip_prefix('_') {
607 Ok(ClientNotification::ExtNotification(ExtNotification::new(
608 custom_method,
609 RawValue::from_string(params.get().to_string())?.into(),
610 )))
611 } else {
612 Err(Error::method_not_found())
613 }
614 }
615 }
616 }
617}
618
619impl<T: Agent> MessageHandler<AgentSide> for T {
620 async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse> {
621 match request {
622 ClientRequest::InitializeRequest(args) => {
623 let response = self.initialize(args).await?;
624 Ok(AgentResponse::InitializeResponse(response))
625 }
626 ClientRequest::AuthenticateRequest(args) => {
627 let response = self.authenticate(args).await?;
628 Ok(AgentResponse::AuthenticateResponse(response))
629 }
630 ClientRequest::NewSessionRequest(args) => {
631 let response = self.new_session(args).await?;
632 Ok(AgentResponse::NewSessionResponse(response))
633 }
634 ClientRequest::LoadSessionRequest(args) => {
635 let response = self.load_session(args).await?;
636 Ok(AgentResponse::LoadSessionResponse(response))
637 }
638 ClientRequest::PromptRequest(args) => {
639 let response = self.prompt(args).await?;
640 Ok(AgentResponse::PromptResponse(response))
641 }
642 ClientRequest::SetSessionModeRequest(args) => {
643 let response = self.set_session_mode(args).await?;
644 Ok(AgentResponse::SetSessionModeResponse(response))
645 }
646 #[cfg(feature = "unstable_session_model")]
647 ClientRequest::SetSessionModelRequest(args) => {
648 let response = self.set_session_model(args).await?;
649 Ok(AgentResponse::SetSessionModelResponse(response))
650 }
651 #[cfg(feature = "unstable_session_list")]
652 ClientRequest::ListSessionsRequest(args) => {
653 let response = self.list_sessions(args).await?;
654 Ok(AgentResponse::ListSessionsResponse(response))
655 }
656 #[cfg(feature = "unstable_session_fork")]
657 ClientRequest::ForkSessionRequest(args) => {
658 let response = self.fork_session(args).await?;
659 Ok(AgentResponse::ForkSessionResponse(response))
660 }
661 #[cfg(feature = "unstable_session_resume")]
662 ClientRequest::ResumeSessionRequest(args) => {
663 let response = self.resume_session(args).await?;
664 Ok(AgentResponse::ResumeSessionResponse(response))
665 }
666 #[cfg(feature = "unstable_session_config_options")]
667 ClientRequest::SetSessionConfigOptionRequest(args) => {
668 let response = self.set_session_config_option(args).await?;
669 Ok(AgentResponse::SetSessionConfigOptionResponse(response))
670 }
671 ClientRequest::ExtMethodRequest(args) => {
672 let response = self.ext_method(args).await?;
673 Ok(AgentResponse::ExtMethodResponse(response))
674 }
675 _ => Err(Error::method_not_found()),
676 }
677 }
678
679 async fn handle_notification(&self, notification: ClientNotification) -> Result<()> {
680 match notification {
681 ClientNotification::CancelNotification(args) => {
682 self.cancel(args).await?;
683 }
684 ClientNotification::ExtNotification(args) => {
685 self.ext_notification(args).await?;
686 }
687 _ => {}
689 }
690 Ok(())
691 }
692}