1use futures::{AsyncRead, AsyncWrite, future::LocalBoxFuture};
2use rpc::RpcConnection;
3
4mod agent;
5mod client;
6mod rpc;
7#[cfg(test)]
8mod rpc_tests;
9mod stream_broadcast;
10
11pub use agent::*;
12pub use agent_client_protocol_schema::*;
13pub use client::*;
14pub use rpc::*;
15pub use stream_broadcast::{
16 StreamMessage, StreamMessageContent, StreamMessageDirection, StreamReceiver,
17};
18
19#[derive(Debug)]
30pub struct ClientSideConnection {
31 conn: RpcConnection<ClientSide, AgentSide>,
32}
33
34impl ClientSideConnection {
35 pub fn new(
55 client: impl MessageHandler<ClientSide> + 'static,
56 outgoing_bytes: impl Unpin + AsyncWrite,
57 incoming_bytes: impl Unpin + AsyncRead,
58 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
59 ) -> (Self, impl Future<Output = Result<()>>) {
60 let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
61 (Self { conn }, io_task)
62 }
63
64 pub fn subscribe(&self) -> StreamReceiver {
73 self.conn.subscribe()
74 }
75}
76
77#[async_trait::async_trait(?Send)]
78impl Agent for ClientSideConnection {
79 async fn initialize(&self, args: InitializeRequest) -> Result<InitializeResponse> {
80 self.conn
81 .request(
82 AGENT_METHOD_NAMES.initialize,
83 Some(ClientRequest::InitializeRequest(args)),
84 )
85 .await
86 }
87
88 async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse> {
89 self.conn
90 .request::<Option<_>>(
91 AGENT_METHOD_NAMES.authenticate,
92 Some(ClientRequest::AuthenticateRequest(args)),
93 )
94 .await
95 .map(Option::unwrap_or_default)
96 }
97
98 async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse> {
99 self.conn
100 .request(
101 AGENT_METHOD_NAMES.session_new,
102 Some(ClientRequest::NewSessionRequest(args)),
103 )
104 .await
105 }
106
107 async fn load_session(&self, args: LoadSessionRequest) -> Result<LoadSessionResponse> {
108 self.conn
109 .request::<Option<_>>(
110 AGENT_METHOD_NAMES.session_load,
111 Some(ClientRequest::LoadSessionRequest(args)),
112 )
113 .await
114 .map(Option::unwrap_or_default)
115 }
116
117 async fn set_session_mode(
118 &self,
119 args: SetSessionModeRequest,
120 ) -> Result<SetSessionModeResponse> {
121 self.conn
122 .request(
123 AGENT_METHOD_NAMES.session_set_mode,
124 Some(ClientRequest::SetSessionModeRequest(args)),
125 )
126 .await
127 }
128
129 async fn prompt(&self, args: PromptRequest) -> Result<PromptResponse> {
130 self.conn
131 .request(
132 AGENT_METHOD_NAMES.session_prompt,
133 Some(ClientRequest::PromptRequest(args)),
134 )
135 .await
136 }
137
138 async fn cancel(&self, args: CancelNotification) -> Result<()> {
139 self.conn.notify(
140 AGENT_METHOD_NAMES.session_cancel,
141 Some(ClientNotification::CancelNotification(args)),
142 )
143 }
144
145 #[cfg(feature = "unstable_session_model")]
146 async fn set_session_model(
147 &self,
148 args: SetSessionModelRequest,
149 ) -> Result<SetSessionModelResponse> {
150 self.conn
151 .request(
152 AGENT_METHOD_NAMES.session_set_model,
153 Some(ClientRequest::SetSessionModelRequest(args)),
154 )
155 .await
156 }
157
158 #[cfg(feature = "unstable_session_list")]
159 async fn list_sessions(&self, args: ListSessionsRequest) -> Result<ListSessionsResponse> {
160 self.conn
161 .request(
162 AGENT_METHOD_NAMES.session_list,
163 Some(ClientRequest::ListSessionsRequest(args)),
164 )
165 .await
166 }
167
168 #[cfg(feature = "unstable_session_fork")]
169 async fn fork_session(&self, args: ForkSessionRequest) -> Result<ForkSessionResponse> {
170 self.conn
171 .request(
172 AGENT_METHOD_NAMES.session_fork,
173 Some(ClientRequest::ForkSessionRequest(args)),
174 )
175 .await
176 }
177
178 #[cfg(feature = "unstable_session_resume")]
179 async fn resume_session(&self, args: ResumeSessionRequest) -> Result<ResumeSessionResponse> {
180 self.conn
181 .request(
182 AGENT_METHOD_NAMES.session_resume,
183 Some(ClientRequest::ResumeSessionRequest(args)),
184 )
185 .await
186 }
187
188 async fn set_session_config_option(
189 &self,
190 args: SetSessionConfigOptionRequest,
191 ) -> Result<SetSessionConfigOptionResponse> {
192 self.conn
193 .request(
194 AGENT_METHOD_NAMES.session_set_config_option,
195 Some(ClientRequest::SetSessionConfigOptionRequest(args)),
196 )
197 .await
198 }
199
200 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
201 self.conn
202 .request(
203 format!("_{}", args.method),
204 Some(ClientRequest::ExtMethodRequest(args)),
205 )
206 .await
207 }
208
209 async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
210 self.conn.notify(
211 format!("_{}", args.method),
212 Some(ClientNotification::ExtNotification(args)),
213 )
214 }
215}
216
217#[derive(Clone, Debug)]
224pub struct ClientSide;
225
226impl Side for ClientSide {
227 type InNotification = AgentNotification;
228 type InRequest = AgentRequest;
229 type OutResponse = ClientResponse;
230
231 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
232 let params = params.ok_or_else(Error::invalid_params)?;
233
234 match method {
235 m if m == CLIENT_METHOD_NAMES.session_request_permission => {
236 serde_json::from_str(params.get())
237 .map(AgentRequest::RequestPermissionRequest)
238 .map_err(Into::into)
239 }
240 m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
241 .map(AgentRequest::WriteTextFileRequest)
242 .map_err(Into::into),
243 m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
244 .map(AgentRequest::ReadTextFileRequest)
245 .map_err(Into::into),
246 m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
247 .map(AgentRequest::CreateTerminalRequest)
248 .map_err(Into::into),
249 m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
250 .map(AgentRequest::TerminalOutputRequest)
251 .map_err(Into::into),
252 m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
253 .map(AgentRequest::KillTerminalRequest)
254 .map_err(Into::into),
255 m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
256 .map(AgentRequest::ReleaseTerminalRequest)
257 .map_err(Into::into),
258 m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
259 serde_json::from_str(params.get())
260 .map(AgentRequest::WaitForTerminalExitRequest)
261 .map_err(Into::into)
262 }
263 _ => {
264 if let Some(custom_method) = method.strip_prefix('_') {
265 Ok(AgentRequest::ExtMethodRequest(ExtRequest::new(
266 custom_method,
267 params.to_owned().into(),
268 )))
269 } else {
270 Err(Error::method_not_found())
271 }
272 }
273 }
274 }
275
276 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
277 let params = params.ok_or_else(Error::invalid_params)?;
278
279 match method {
280 m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
281 .map(AgentNotification::SessionNotification)
282 .map_err(Into::into),
283 _ => {
284 if let Some(custom_method) = method.strip_prefix('_') {
285 Ok(AgentNotification::ExtNotification(ExtNotification::new(
286 custom_method,
287 RawValue::from_string(params.get().to_string())?.into(),
288 )))
289 } else {
290 Err(Error::method_not_found())
291 }
292 }
293 }
294 }
295}
296
297impl<T: Client> MessageHandler<ClientSide> for T {
298 async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse> {
299 match request {
300 AgentRequest::RequestPermissionRequest(args) => {
301 let response = self.request_permission(args).await?;
302 Ok(ClientResponse::RequestPermissionResponse(response))
303 }
304 AgentRequest::WriteTextFileRequest(args) => {
305 let response = self.write_text_file(args).await?;
306 Ok(ClientResponse::WriteTextFileResponse(response))
307 }
308 AgentRequest::ReadTextFileRequest(args) => {
309 let response = self.read_text_file(args).await?;
310 Ok(ClientResponse::ReadTextFileResponse(response))
311 }
312 AgentRequest::CreateTerminalRequest(args) => {
313 let response = self.create_terminal(args).await?;
314 Ok(ClientResponse::CreateTerminalResponse(response))
315 }
316 AgentRequest::TerminalOutputRequest(args) => {
317 let response = self.terminal_output(args).await?;
318 Ok(ClientResponse::TerminalOutputResponse(response))
319 }
320 AgentRequest::ReleaseTerminalRequest(args) => {
321 let response = self.release_terminal(args).await?;
322 Ok(ClientResponse::ReleaseTerminalResponse(response))
323 }
324 AgentRequest::WaitForTerminalExitRequest(args) => {
325 let response = self.wait_for_terminal_exit(args).await?;
326 Ok(ClientResponse::WaitForTerminalExitResponse(response))
327 }
328 AgentRequest::KillTerminalRequest(args) => {
329 let response = self.kill_terminal(args).await?;
330 Ok(ClientResponse::KillTerminalResponse(response))
331 }
332 AgentRequest::ExtMethodRequest(args) => {
333 let response = self.ext_method(args).await?;
334 Ok(ClientResponse::ExtMethodResponse(response))
335 }
336 _ => Err(Error::method_not_found()),
337 }
338 }
339
340 async fn handle_notification(&self, notification: AgentNotification) -> Result<()> {
341 match notification {
342 AgentNotification::SessionNotification(args) => {
343 self.session_notification(args).await?;
344 }
345 AgentNotification::ExtNotification(args) => {
346 self.ext_notification(args).await?;
347 }
348 _ => {}
350 }
351 Ok(())
352 }
353}
354
355#[derive(Debug)]
366pub struct AgentSideConnection {
367 conn: RpcConnection<AgentSide, ClientSide>,
368}
369
370impl AgentSideConnection {
371 pub fn new(
391 agent: impl MessageHandler<AgentSide> + 'static,
392 outgoing_bytes: impl Unpin + AsyncWrite,
393 incoming_bytes: impl Unpin + AsyncRead,
394 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
395 ) -> (Self, impl Future<Output = Result<()>>) {
396 let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
397 (Self { conn }, io_task)
398 }
399
400 pub fn subscribe(&self) -> StreamReceiver {
409 self.conn.subscribe()
410 }
411}
412
413#[async_trait::async_trait(?Send)]
414impl Client for AgentSideConnection {
415 async fn request_permission(
416 &self,
417 args: RequestPermissionRequest,
418 ) -> Result<RequestPermissionResponse> {
419 self.conn
420 .request(
421 CLIENT_METHOD_NAMES.session_request_permission,
422 Some(AgentRequest::RequestPermissionRequest(args)),
423 )
424 .await
425 }
426
427 async fn write_text_file(&self, args: WriteTextFileRequest) -> Result<WriteTextFileResponse> {
428 self.conn
429 .request::<Option<_>>(
430 CLIENT_METHOD_NAMES.fs_write_text_file,
431 Some(AgentRequest::WriteTextFileRequest(args)),
432 )
433 .await
434 .map(Option::unwrap_or_default)
435 }
436
437 async fn read_text_file(&self, args: ReadTextFileRequest) -> Result<ReadTextFileResponse> {
438 self.conn
439 .request(
440 CLIENT_METHOD_NAMES.fs_read_text_file,
441 Some(AgentRequest::ReadTextFileRequest(args)),
442 )
443 .await
444 }
445
446 async fn create_terminal(&self, args: CreateTerminalRequest) -> Result<CreateTerminalResponse> {
447 self.conn
448 .request(
449 CLIENT_METHOD_NAMES.terminal_create,
450 Some(AgentRequest::CreateTerminalRequest(args)),
451 )
452 .await
453 }
454
455 async fn terminal_output(&self, args: TerminalOutputRequest) -> Result<TerminalOutputResponse> {
456 self.conn
457 .request(
458 CLIENT_METHOD_NAMES.terminal_output,
459 Some(AgentRequest::TerminalOutputRequest(args)),
460 )
461 .await
462 }
463
464 async fn release_terminal(
465 &self,
466 args: ReleaseTerminalRequest,
467 ) -> Result<ReleaseTerminalResponse> {
468 self.conn
469 .request::<Option<_>>(
470 CLIENT_METHOD_NAMES.terminal_release,
471 Some(AgentRequest::ReleaseTerminalRequest(args)),
472 )
473 .await
474 .map(Option::unwrap_or_default)
475 }
476
477 async fn wait_for_terminal_exit(
478 &self,
479 args: WaitForTerminalExitRequest,
480 ) -> Result<WaitForTerminalExitResponse> {
481 self.conn
482 .request(
483 CLIENT_METHOD_NAMES.terminal_wait_for_exit,
484 Some(AgentRequest::WaitForTerminalExitRequest(args)),
485 )
486 .await
487 }
488
489 async fn kill_terminal(&self, args: KillTerminalRequest) -> Result<KillTerminalResponse> {
490 self.conn
491 .request::<Option<_>>(
492 CLIENT_METHOD_NAMES.terminal_kill,
493 Some(AgentRequest::KillTerminalRequest(args)),
494 )
495 .await
496 .map(Option::unwrap_or_default)
497 }
498
499 async fn session_notification(&self, args: SessionNotification) -> Result<()> {
500 self.conn.notify(
501 CLIENT_METHOD_NAMES.session_update,
502 Some(AgentNotification::SessionNotification(args)),
503 )
504 }
505
506 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
507 self.conn
508 .request(
509 format!("_{}", args.method),
510 Some(AgentRequest::ExtMethodRequest(args)),
511 )
512 .await
513 }
514
515 async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
516 self.conn.notify(
517 format!("_{}", args.method),
518 Some(AgentNotification::ExtNotification(args)),
519 )
520 }
521}
522
523#[derive(Clone, Debug)]
530pub struct AgentSide;
531
532impl Side for AgentSide {
533 type InRequest = ClientRequest;
534 type InNotification = ClientNotification;
535 type OutResponse = AgentResponse;
536
537 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
538 let params = params.ok_or_else(Error::invalid_params)?;
539
540 match method {
541 m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
542 .map(ClientRequest::InitializeRequest)
543 .map_err(Into::into),
544 m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
545 .map(ClientRequest::AuthenticateRequest)
546 .map_err(Into::into),
547 m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
548 .map(ClientRequest::NewSessionRequest)
549 .map_err(Into::into),
550 m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
551 .map(ClientRequest::LoadSessionRequest)
552 .map_err(Into::into),
553 m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
554 .map(ClientRequest::SetSessionModeRequest)
555 .map_err(Into::into),
556 #[cfg(feature = "unstable_session_model")]
557 m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
558 .map(ClientRequest::SetSessionModelRequest)
559 .map_err(Into::into),
560 #[cfg(feature = "unstable_session_list")]
561 m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
562 .map(ClientRequest::ListSessionsRequest)
563 .map_err(Into::into),
564 #[cfg(feature = "unstable_session_fork")]
565 m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get())
566 .map(ClientRequest::ForkSessionRequest)
567 .map_err(Into::into),
568 #[cfg(feature = "unstable_session_resume")]
569 m if m == AGENT_METHOD_NAMES.session_resume => serde_json::from_str(params.get())
570 .map(ClientRequest::ResumeSessionRequest)
571 .map_err(Into::into),
572 m if m == AGENT_METHOD_NAMES.session_set_config_option => {
573 serde_json::from_str(params.get())
574 .map(ClientRequest::SetSessionConfigOptionRequest)
575 .map_err(Into::into)
576 }
577 m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
578 .map(ClientRequest::PromptRequest)
579 .map_err(Into::into),
580 _ => {
581 if let Some(custom_method) = method.strip_prefix('_') {
582 Ok(ClientRequest::ExtMethodRequest(ExtRequest::new(
583 custom_method,
584 params.to_owned().into(),
585 )))
586 } else {
587 Err(Error::method_not_found())
588 }
589 }
590 }
591 }
592
593 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
594 let params = params.ok_or_else(Error::invalid_params)?;
595
596 match method {
597 m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
598 .map(ClientNotification::CancelNotification)
599 .map_err(Into::into),
600 _ => {
601 if let Some(custom_method) = method.strip_prefix('_') {
602 Ok(ClientNotification::ExtNotification(ExtNotification::new(
603 custom_method,
604 RawValue::from_string(params.get().to_string())?.into(),
605 )))
606 } else {
607 Err(Error::method_not_found())
608 }
609 }
610 }
611 }
612}
613
614impl<T: Agent> MessageHandler<AgentSide> for T {
615 async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse> {
616 match request {
617 ClientRequest::InitializeRequest(args) => {
618 let response = self.initialize(args).await?;
619 Ok(AgentResponse::InitializeResponse(response))
620 }
621 ClientRequest::AuthenticateRequest(args) => {
622 let response = self.authenticate(args).await?;
623 Ok(AgentResponse::AuthenticateResponse(response))
624 }
625 ClientRequest::NewSessionRequest(args) => {
626 let response = self.new_session(args).await?;
627 Ok(AgentResponse::NewSessionResponse(response))
628 }
629 ClientRequest::LoadSessionRequest(args) => {
630 let response = self.load_session(args).await?;
631 Ok(AgentResponse::LoadSessionResponse(response))
632 }
633 ClientRequest::PromptRequest(args) => {
634 let response = self.prompt(args).await?;
635 Ok(AgentResponse::PromptResponse(response))
636 }
637 ClientRequest::SetSessionModeRequest(args) => {
638 let response = self.set_session_mode(args).await?;
639 Ok(AgentResponse::SetSessionModeResponse(response))
640 }
641 #[cfg(feature = "unstable_session_model")]
642 ClientRequest::SetSessionModelRequest(args) => {
643 let response = self.set_session_model(args).await?;
644 Ok(AgentResponse::SetSessionModelResponse(response))
645 }
646 #[cfg(feature = "unstable_session_list")]
647 ClientRequest::ListSessionsRequest(args) => {
648 let response = self.list_sessions(args).await?;
649 Ok(AgentResponse::ListSessionsResponse(response))
650 }
651 #[cfg(feature = "unstable_session_fork")]
652 ClientRequest::ForkSessionRequest(args) => {
653 let response = self.fork_session(args).await?;
654 Ok(AgentResponse::ForkSessionResponse(response))
655 }
656 #[cfg(feature = "unstable_session_resume")]
657 ClientRequest::ResumeSessionRequest(args) => {
658 let response = self.resume_session(args).await?;
659 Ok(AgentResponse::ResumeSessionResponse(response))
660 }
661 ClientRequest::SetSessionConfigOptionRequest(args) => {
662 let response = self.set_session_config_option(args).await?;
663 Ok(AgentResponse::SetSessionConfigOptionResponse(response))
664 }
665 ClientRequest::ExtMethodRequest(args) => {
666 let response = self.ext_method(args).await?;
667 Ok(AgentResponse::ExtMethodResponse(response))
668 }
669 _ => Err(Error::method_not_found()),
670 }
671 }
672
673 async fn handle_notification(&self, notification: ClientNotification) -> Result<()> {
674 match notification {
675 ClientNotification::CancelNotification(args) => {
676 self.cancel(args).await?;
677 }
678 ClientNotification::ExtNotification(args) => {
679 self.ext_notification(args).await?;
680 }
681 _ => {}
683 }
684 Ok(())
685 }
686}