1use futures::{AsyncRead, AsyncWrite, future::LocalBoxFuture};
2use rpc::RpcConnection;
3
4mod agent;
5mod client;
6mod rpc;
7#[cfg(test)]
8mod rpc_tests;
9mod stream_broadcast;
10
11pub use agent::*;
12pub use agent_client_protocol_schema::*;
13pub use client::*;
14pub use rpc::*;
15pub use stream_broadcast::{
16 StreamMessage, StreamMessageContent, StreamMessageDirection, StreamReceiver,
17};
18
19#[derive(Debug)]
30pub struct ClientSideConnection {
31 conn: RpcConnection<ClientSide, AgentSide>,
32}
33
34impl ClientSideConnection {
35 pub fn new(
55 client: impl MessageHandler<ClientSide> + 'static,
56 outgoing_bytes: impl Unpin + AsyncWrite,
57 incoming_bytes: impl Unpin + AsyncRead,
58 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
59 ) -> (Self, impl Future<Output = Result<()>>) {
60 let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
61 (Self { conn }, io_task)
62 }
63
64 pub fn subscribe(&self) -> StreamReceiver {
73 self.conn.subscribe()
74 }
75}
76
77#[async_trait::async_trait(?Send)]
78impl Agent for ClientSideConnection {
79 async fn initialize(&self, args: InitializeRequest) -> Result<InitializeResponse> {
80 self.conn
81 .request(
82 AGENT_METHOD_NAMES.initialize,
83 Some(ClientRequest::InitializeRequest(args)),
84 )
85 .await
86 }
87
88 async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse> {
89 self.conn
90 .request::<Option<_>>(
91 AGENT_METHOD_NAMES.authenticate,
92 Some(ClientRequest::AuthenticateRequest(args)),
93 )
94 .await
95 .map(Option::unwrap_or_default)
96 }
97
98 async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse> {
99 self.conn
100 .request(
101 AGENT_METHOD_NAMES.session_new,
102 Some(ClientRequest::NewSessionRequest(args)),
103 )
104 .await
105 }
106
107 async fn load_session(&self, args: LoadSessionRequest) -> Result<LoadSessionResponse> {
108 self.conn
109 .request::<Option<_>>(
110 AGENT_METHOD_NAMES.session_load,
111 Some(ClientRequest::LoadSessionRequest(args)),
112 )
113 .await
114 .map(Option::unwrap_or_default)
115 }
116
117 async fn set_session_mode(
118 &self,
119 args: SetSessionModeRequest,
120 ) -> Result<SetSessionModeResponse> {
121 self.conn
122 .request(
123 AGENT_METHOD_NAMES.session_set_mode,
124 Some(ClientRequest::SetSessionModeRequest(args)),
125 )
126 .await
127 }
128
129 async fn prompt(&self, args: PromptRequest) -> Result<PromptResponse> {
130 self.conn
131 .request(
132 AGENT_METHOD_NAMES.session_prompt,
133 Some(ClientRequest::PromptRequest(args)),
134 )
135 .await
136 }
137
138 async fn cancel(&self, args: CancelNotification) -> Result<()> {
139 self.conn.notify(
140 AGENT_METHOD_NAMES.session_cancel,
141 Some(ClientNotification::CancelNotification(args)),
142 )
143 }
144
145 #[cfg(feature = "unstable_session_model")]
146 async fn set_session_model(
147 &self,
148 args: SetSessionModelRequest,
149 ) -> Result<SetSessionModelResponse> {
150 self.conn
151 .request(
152 AGENT_METHOD_NAMES.session_set_model,
153 Some(ClientRequest::SetSessionModelRequest(args)),
154 )
155 .await
156 }
157
158 #[cfg(feature = "unstable_session_list")]
159 async fn list_sessions(&self, args: ListSessionsRequest) -> Result<ListSessionsResponse> {
160 self.conn
161 .request(
162 AGENT_METHOD_NAMES.session_list,
163 Some(ClientRequest::ListSessionsRequest(args)),
164 )
165 .await
166 }
167
168 #[cfg(feature = "unstable_session_fork")]
169 async fn fork_session(&self, args: ForkSessionRequest) -> Result<ForkSessionResponse> {
170 self.conn
171 .request(
172 AGENT_METHOD_NAMES.session_fork,
173 Some(ClientRequest::ForkSessionRequest(args)),
174 )
175 .await
176 }
177
178 #[cfg(feature = "unstable_session_resume")]
179 async fn resume_session(&self, args: ResumeSessionRequest) -> Result<ResumeSessionResponse> {
180 self.conn
181 .request(
182 AGENT_METHOD_NAMES.session_resume,
183 Some(ClientRequest::ResumeSessionRequest(args)),
184 )
185 .await
186 }
187
188 async fn set_session_config_option(
189 &self,
190 args: SetSessionConfigOptionRequest,
191 ) -> Result<SetSessionConfigOptionResponse> {
192 self.conn
193 .request(
194 AGENT_METHOD_NAMES.session_set_config_option,
195 Some(ClientRequest::SetSessionConfigOptionRequest(args)),
196 )
197 .await
198 }
199
200 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
201 self.conn
202 .request(
203 format!("_{}", args.method),
204 Some(ClientRequest::ExtMethodRequest(args)),
205 )
206 .await
207 }
208
209 async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
210 self.conn.notify(
211 format!("_{}", args.method),
212 Some(ClientNotification::ExtNotification(args)),
213 )
214 }
215}
216
217#[derive(Clone, Debug)]
224pub struct ClientSide;
225
226impl Side for ClientSide {
227 type InNotification = AgentNotification;
228 type InRequest = AgentRequest;
229 type OutResponse = ClientResponse;
230
231 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
232 let params = params.ok_or_else(Error::invalid_params)?;
233
234 match method {
235 m if m == CLIENT_METHOD_NAMES.session_request_permission => {
236 serde_json::from_str(params.get())
237 .map(AgentRequest::RequestPermissionRequest)
238 .map_err(Into::into)
239 }
240 m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
241 .map(AgentRequest::WriteTextFileRequest)
242 .map_err(Into::into),
243 m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
244 .map(AgentRequest::ReadTextFileRequest)
245 .map_err(Into::into),
246 m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
247 .map(AgentRequest::CreateTerminalRequest)
248 .map_err(Into::into),
249 m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
250 .map(AgentRequest::TerminalOutputRequest)
251 .map_err(Into::into),
252 m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
253 .map(AgentRequest::KillTerminalCommandRequest)
254 .map_err(Into::into),
255 m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
256 .map(AgentRequest::ReleaseTerminalRequest)
257 .map_err(Into::into),
258 m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
259 serde_json::from_str(params.get())
260 .map(AgentRequest::WaitForTerminalExitRequest)
261 .map_err(Into::into)
262 }
263 _ => {
264 if let Some(custom_method) = method.strip_prefix('_') {
265 Ok(AgentRequest::ExtMethodRequest(ExtRequest::new(
266 custom_method,
267 params.to_owned().into(),
268 )))
269 } else {
270 Err(Error::method_not_found())
271 }
272 }
273 }
274 }
275
276 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
277 let params = params.ok_or_else(Error::invalid_params)?;
278
279 match method {
280 m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
281 .map(AgentNotification::SessionNotification)
282 .map_err(Into::into),
283 _ => {
284 if let Some(custom_method) = method.strip_prefix('_') {
285 Ok(AgentNotification::ExtNotification(ExtNotification::new(
286 custom_method,
287 RawValue::from_string(params.get().to_string())?.into(),
288 )))
289 } else {
290 Err(Error::method_not_found())
291 }
292 }
293 }
294 }
295}
296
297impl<T: Client> MessageHandler<ClientSide> for T {
298 async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse> {
299 match request {
300 AgentRequest::RequestPermissionRequest(args) => {
301 let response = self.request_permission(args).await?;
302 Ok(ClientResponse::RequestPermissionResponse(response))
303 }
304 AgentRequest::WriteTextFileRequest(args) => {
305 let response = self.write_text_file(args).await?;
306 Ok(ClientResponse::WriteTextFileResponse(response))
307 }
308 AgentRequest::ReadTextFileRequest(args) => {
309 let response = self.read_text_file(args).await?;
310 Ok(ClientResponse::ReadTextFileResponse(response))
311 }
312 AgentRequest::CreateTerminalRequest(args) => {
313 let response = self.create_terminal(args).await?;
314 Ok(ClientResponse::CreateTerminalResponse(response))
315 }
316 AgentRequest::TerminalOutputRequest(args) => {
317 let response = self.terminal_output(args).await?;
318 Ok(ClientResponse::TerminalOutputResponse(response))
319 }
320 AgentRequest::ReleaseTerminalRequest(args) => {
321 let response = self.release_terminal(args).await?;
322 Ok(ClientResponse::ReleaseTerminalResponse(response))
323 }
324 AgentRequest::WaitForTerminalExitRequest(args) => {
325 let response = self.wait_for_terminal_exit(args).await?;
326 Ok(ClientResponse::WaitForTerminalExitResponse(response))
327 }
328 AgentRequest::KillTerminalCommandRequest(args) => {
329 let response = self.kill_terminal_command(args).await?;
330 Ok(ClientResponse::KillTerminalResponse(response))
331 }
332 AgentRequest::ExtMethodRequest(args) => {
333 let response = self.ext_method(args).await?;
334 Ok(ClientResponse::ExtMethodResponse(response))
335 }
336 _ => Err(Error::method_not_found()),
337 }
338 }
339
340 async fn handle_notification(&self, notification: AgentNotification) -> Result<()> {
341 match notification {
342 AgentNotification::SessionNotification(args) => {
343 self.session_notification(args).await?;
344 }
345 AgentNotification::ExtNotification(args) => {
346 self.ext_notification(args).await?;
347 }
348 _ => {}
350 }
351 Ok(())
352 }
353}
354
355#[derive(Debug)]
366pub struct AgentSideConnection {
367 conn: RpcConnection<AgentSide, ClientSide>,
368}
369
370impl AgentSideConnection {
371 pub fn new(
391 agent: impl MessageHandler<AgentSide> + 'static,
392 outgoing_bytes: impl Unpin + AsyncWrite,
393 incoming_bytes: impl Unpin + AsyncRead,
394 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
395 ) -> (Self, impl Future<Output = Result<()>>) {
396 let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
397 (Self { conn }, io_task)
398 }
399
400 pub fn subscribe(&self) -> StreamReceiver {
409 self.conn.subscribe()
410 }
411}
412
413#[async_trait::async_trait(?Send)]
414impl Client for AgentSideConnection {
415 async fn request_permission(
416 &self,
417 args: RequestPermissionRequest,
418 ) -> Result<RequestPermissionResponse> {
419 self.conn
420 .request(
421 CLIENT_METHOD_NAMES.session_request_permission,
422 Some(AgentRequest::RequestPermissionRequest(args)),
423 )
424 .await
425 }
426
427 async fn write_text_file(&self, args: WriteTextFileRequest) -> Result<WriteTextFileResponse> {
428 self.conn
429 .request::<Option<_>>(
430 CLIENT_METHOD_NAMES.fs_write_text_file,
431 Some(AgentRequest::WriteTextFileRequest(args)),
432 )
433 .await
434 .map(Option::unwrap_or_default)
435 }
436
437 async fn read_text_file(&self, args: ReadTextFileRequest) -> Result<ReadTextFileResponse> {
438 self.conn
439 .request(
440 CLIENT_METHOD_NAMES.fs_read_text_file,
441 Some(AgentRequest::ReadTextFileRequest(args)),
442 )
443 .await
444 }
445
446 async fn create_terminal(&self, args: CreateTerminalRequest) -> Result<CreateTerminalResponse> {
447 self.conn
448 .request(
449 CLIENT_METHOD_NAMES.terminal_create,
450 Some(AgentRequest::CreateTerminalRequest(args)),
451 )
452 .await
453 }
454
455 async fn terminal_output(&self, args: TerminalOutputRequest) -> Result<TerminalOutputResponse> {
456 self.conn
457 .request(
458 CLIENT_METHOD_NAMES.terminal_output,
459 Some(AgentRequest::TerminalOutputRequest(args)),
460 )
461 .await
462 }
463
464 async fn release_terminal(
465 &self,
466 args: ReleaseTerminalRequest,
467 ) -> Result<ReleaseTerminalResponse> {
468 self.conn
469 .request::<Option<_>>(
470 CLIENT_METHOD_NAMES.terminal_release,
471 Some(AgentRequest::ReleaseTerminalRequest(args)),
472 )
473 .await
474 .map(Option::unwrap_or_default)
475 }
476
477 async fn wait_for_terminal_exit(
478 &self,
479 args: WaitForTerminalExitRequest,
480 ) -> Result<WaitForTerminalExitResponse> {
481 self.conn
482 .request(
483 CLIENT_METHOD_NAMES.terminal_wait_for_exit,
484 Some(AgentRequest::WaitForTerminalExitRequest(args)),
485 )
486 .await
487 }
488
489 async fn kill_terminal_command(
490 &self,
491 args: KillTerminalCommandRequest,
492 ) -> Result<KillTerminalCommandResponse> {
493 self.conn
494 .request::<Option<_>>(
495 CLIENT_METHOD_NAMES.terminal_kill,
496 Some(AgentRequest::KillTerminalCommandRequest(args)),
497 )
498 .await
499 .map(Option::unwrap_or_default)
500 }
501
502 async fn session_notification(&self, args: SessionNotification) -> Result<()> {
503 self.conn.notify(
504 CLIENT_METHOD_NAMES.session_update,
505 Some(AgentNotification::SessionNotification(args)),
506 )
507 }
508
509 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
510 self.conn
511 .request(
512 format!("_{}", args.method),
513 Some(AgentRequest::ExtMethodRequest(args)),
514 )
515 .await
516 }
517
518 async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
519 self.conn.notify(
520 format!("_{}", args.method),
521 Some(AgentNotification::ExtNotification(args)),
522 )
523 }
524}
525
526#[derive(Clone, Debug)]
533pub struct AgentSide;
534
535impl Side for AgentSide {
536 type InRequest = ClientRequest;
537 type InNotification = ClientNotification;
538 type OutResponse = AgentResponse;
539
540 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
541 let params = params.ok_or_else(Error::invalid_params)?;
542
543 match method {
544 m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
545 .map(ClientRequest::InitializeRequest)
546 .map_err(Into::into),
547 m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
548 .map(ClientRequest::AuthenticateRequest)
549 .map_err(Into::into),
550 m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
551 .map(ClientRequest::NewSessionRequest)
552 .map_err(Into::into),
553 m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
554 .map(ClientRequest::LoadSessionRequest)
555 .map_err(Into::into),
556 m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
557 .map(ClientRequest::SetSessionModeRequest)
558 .map_err(Into::into),
559 #[cfg(feature = "unstable_session_model")]
560 m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
561 .map(ClientRequest::SetSessionModelRequest)
562 .map_err(Into::into),
563 #[cfg(feature = "unstable_session_list")]
564 m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
565 .map(ClientRequest::ListSessionsRequest)
566 .map_err(Into::into),
567 #[cfg(feature = "unstable_session_fork")]
568 m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get())
569 .map(ClientRequest::ForkSessionRequest)
570 .map_err(Into::into),
571 #[cfg(feature = "unstable_session_resume")]
572 m if m == AGENT_METHOD_NAMES.session_resume => serde_json::from_str(params.get())
573 .map(ClientRequest::ResumeSessionRequest)
574 .map_err(Into::into),
575 m if m == AGENT_METHOD_NAMES.session_set_config_option => {
576 serde_json::from_str(params.get())
577 .map(ClientRequest::SetSessionConfigOptionRequest)
578 .map_err(Into::into)
579 }
580 m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
581 .map(ClientRequest::PromptRequest)
582 .map_err(Into::into),
583 _ => {
584 if let Some(custom_method) = method.strip_prefix('_') {
585 Ok(ClientRequest::ExtMethodRequest(ExtRequest::new(
586 custom_method,
587 params.to_owned().into(),
588 )))
589 } else {
590 Err(Error::method_not_found())
591 }
592 }
593 }
594 }
595
596 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
597 let params = params.ok_or_else(Error::invalid_params)?;
598
599 match method {
600 m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
601 .map(ClientNotification::CancelNotification)
602 .map_err(Into::into),
603 _ => {
604 if let Some(custom_method) = method.strip_prefix('_') {
605 Ok(ClientNotification::ExtNotification(ExtNotification::new(
606 custom_method,
607 RawValue::from_string(params.get().to_string())?.into(),
608 )))
609 } else {
610 Err(Error::method_not_found())
611 }
612 }
613 }
614 }
615}
616
617impl<T: Agent> MessageHandler<AgentSide> for T {
618 async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse> {
619 match request {
620 ClientRequest::InitializeRequest(args) => {
621 let response = self.initialize(args).await?;
622 Ok(AgentResponse::InitializeResponse(response))
623 }
624 ClientRequest::AuthenticateRequest(args) => {
625 let response = self.authenticate(args).await?;
626 Ok(AgentResponse::AuthenticateResponse(response))
627 }
628 ClientRequest::NewSessionRequest(args) => {
629 let response = self.new_session(args).await?;
630 Ok(AgentResponse::NewSessionResponse(response))
631 }
632 ClientRequest::LoadSessionRequest(args) => {
633 let response = self.load_session(args).await?;
634 Ok(AgentResponse::LoadSessionResponse(response))
635 }
636 ClientRequest::PromptRequest(args) => {
637 let response = self.prompt(args).await?;
638 Ok(AgentResponse::PromptResponse(response))
639 }
640 ClientRequest::SetSessionModeRequest(args) => {
641 let response = self.set_session_mode(args).await?;
642 Ok(AgentResponse::SetSessionModeResponse(response))
643 }
644 #[cfg(feature = "unstable_session_model")]
645 ClientRequest::SetSessionModelRequest(args) => {
646 let response = self.set_session_model(args).await?;
647 Ok(AgentResponse::SetSessionModelResponse(response))
648 }
649 #[cfg(feature = "unstable_session_list")]
650 ClientRequest::ListSessionsRequest(args) => {
651 let response = self.list_sessions(args).await?;
652 Ok(AgentResponse::ListSessionsResponse(response))
653 }
654 #[cfg(feature = "unstable_session_fork")]
655 ClientRequest::ForkSessionRequest(args) => {
656 let response = self.fork_session(args).await?;
657 Ok(AgentResponse::ForkSessionResponse(response))
658 }
659 #[cfg(feature = "unstable_session_resume")]
660 ClientRequest::ResumeSessionRequest(args) => {
661 let response = self.resume_session(args).await?;
662 Ok(AgentResponse::ResumeSessionResponse(response))
663 }
664 ClientRequest::SetSessionConfigOptionRequest(args) => {
665 let response = self.set_session_config_option(args).await?;
666 Ok(AgentResponse::SetSessionConfigOptionResponse(response))
667 }
668 ClientRequest::ExtMethodRequest(args) => {
669 let response = self.ext_method(args).await?;
670 Ok(AgentResponse::ExtMethodResponse(response))
671 }
672 _ => Err(Error::method_not_found()),
673 }
674 }
675
676 async fn handle_notification(&self, notification: ClientNotification) -> Result<()> {
677 match notification {
678 ClientNotification::CancelNotification(args) => {
679 self.cancel(args).await?;
680 }
681 ClientNotification::ExtNotification(args) => {
682 self.ext_notification(args).await?;
683 }
684 _ => {}
686 }
687 Ok(())
688 }
689}