1use futures::{AsyncRead, AsyncWrite, future::LocalBoxFuture};
2use rpc::RpcConnection;
3
4mod agent;
5mod client;
6mod rpc;
7#[cfg(test)]
8mod rpc_tests;
9mod stream_broadcast;
10
11pub use agent::*;
12pub use agent_client_protocol_schema::*;
13pub use client::*;
14pub use rpc::*;
15pub use stream_broadcast::{
16 StreamMessage, StreamMessageContent, StreamMessageDirection, StreamReceiver,
17};
18
19#[derive(Debug)]
30pub struct ClientSideConnection {
31 conn: RpcConnection<ClientSide, AgentSide>,
32}
33
34impl ClientSideConnection {
35 pub fn new(
55 client: impl MessageHandler<ClientSide> + 'static,
56 outgoing_bytes: impl Unpin + AsyncWrite,
57 incoming_bytes: impl Unpin + AsyncRead,
58 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
59 ) -> (Self, impl Future<Output = Result<()>>) {
60 let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
61 (Self { conn }, io_task)
62 }
63
64 pub fn subscribe(&self) -> StreamReceiver {
73 self.conn.subscribe()
74 }
75}
76
77#[async_trait::async_trait(?Send)]
78impl Agent for ClientSideConnection {
79 async fn initialize(&self, args: InitializeRequest) -> Result<InitializeResponse> {
80 self.conn
81 .request(
82 AGENT_METHOD_NAMES.initialize,
83 Some(ClientRequest::InitializeRequest(args)),
84 )
85 .await
86 }
87
88 async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse> {
89 self.conn
90 .request::<Option<_>>(
91 AGENT_METHOD_NAMES.authenticate,
92 Some(ClientRequest::AuthenticateRequest(args)),
93 )
94 .await
95 .map(Option::unwrap_or_default)
96 }
97
98 async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse> {
99 self.conn
100 .request(
101 AGENT_METHOD_NAMES.session_new,
102 Some(ClientRequest::NewSessionRequest(args)),
103 )
104 .await
105 }
106
107 async fn load_session(&self, args: LoadSessionRequest) -> Result<LoadSessionResponse> {
108 self.conn
109 .request::<Option<_>>(
110 AGENT_METHOD_NAMES.session_load,
111 Some(ClientRequest::LoadSessionRequest(args)),
112 )
113 .await
114 .map(Option::unwrap_or_default)
115 }
116
117 async fn set_session_mode(
118 &self,
119 args: SetSessionModeRequest,
120 ) -> Result<SetSessionModeResponse> {
121 self.conn
122 .request(
123 AGENT_METHOD_NAMES.session_set_mode,
124 Some(ClientRequest::SetSessionModeRequest(args)),
125 )
126 .await
127 }
128
129 async fn prompt(&self, args: PromptRequest) -> Result<PromptResponse> {
130 self.conn
131 .request(
132 AGENT_METHOD_NAMES.session_prompt,
133 Some(ClientRequest::PromptRequest(args)),
134 )
135 .await
136 }
137
138 async fn cancel(&self, args: CancelNotification) -> Result<()> {
139 self.conn.notify(
140 AGENT_METHOD_NAMES.session_cancel,
141 Some(ClientNotification::CancelNotification(args)),
142 )
143 }
144
145 #[cfg(feature = "unstable_session_model")]
146 async fn set_session_model(
147 &self,
148 args: SetSessionModelRequest,
149 ) -> Result<SetSessionModelResponse> {
150 self.conn
151 .request(
152 AGENT_METHOD_NAMES.session_set_model,
153 Some(ClientRequest::SetSessionModelRequest(args)),
154 )
155 .await
156 }
157
158 async fn list_sessions(&self, args: ListSessionsRequest) -> Result<ListSessionsResponse> {
159 self.conn
160 .request(
161 AGENT_METHOD_NAMES.session_list,
162 Some(ClientRequest::ListSessionsRequest(args)),
163 )
164 .await
165 }
166
167 #[cfg(feature = "unstable_session_fork")]
168 async fn fork_session(&self, args: ForkSessionRequest) -> Result<ForkSessionResponse> {
169 self.conn
170 .request(
171 AGENT_METHOD_NAMES.session_fork,
172 Some(ClientRequest::ForkSessionRequest(args)),
173 )
174 .await
175 }
176
177 #[cfg(feature = "unstable_session_resume")]
178 async fn resume_session(&self, args: ResumeSessionRequest) -> Result<ResumeSessionResponse> {
179 self.conn
180 .request(
181 AGENT_METHOD_NAMES.session_resume,
182 Some(ClientRequest::ResumeSessionRequest(args)),
183 )
184 .await
185 }
186
187 #[cfg(feature = "unstable_session_close")]
188 async fn close_session(&self, args: CloseSessionRequest) -> Result<CloseSessionResponse> {
189 self.conn
190 .request::<Option<_>>(
191 AGENT_METHOD_NAMES.session_close,
192 Some(ClientRequest::CloseSessionRequest(args)),
193 )
194 .await
195 .map(Option::unwrap_or_default)
196 }
197
198 async fn set_session_config_option(
199 &self,
200 args: SetSessionConfigOptionRequest,
201 ) -> Result<SetSessionConfigOptionResponse> {
202 self.conn
203 .request(
204 AGENT_METHOD_NAMES.session_set_config_option,
205 Some(ClientRequest::SetSessionConfigOptionRequest(args)),
206 )
207 .await
208 }
209
210 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
211 self.conn
212 .request(
213 format!("_{}", args.method),
214 Some(ClientRequest::ExtMethodRequest(args)),
215 )
216 .await
217 }
218
219 async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
220 self.conn.notify(
221 format!("_{}", args.method),
222 Some(ClientNotification::ExtNotification(args)),
223 )
224 }
225}
226
227#[derive(Clone, Debug)]
234pub struct ClientSide;
235
236impl Side for ClientSide {
237 type InNotification = AgentNotification;
238 type InRequest = AgentRequest;
239 type OutResponse = ClientResponse;
240
241 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
242 let params = params.ok_or_else(Error::invalid_params)?;
243
244 match method {
245 m if m == CLIENT_METHOD_NAMES.session_request_permission => {
246 serde_json::from_str(params.get())
247 .map(AgentRequest::RequestPermissionRequest)
248 .map_err(Into::into)
249 }
250 m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
251 .map(AgentRequest::WriteTextFileRequest)
252 .map_err(Into::into),
253 m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
254 .map(AgentRequest::ReadTextFileRequest)
255 .map_err(Into::into),
256 m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
257 .map(AgentRequest::CreateTerminalRequest)
258 .map_err(Into::into),
259 m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
260 .map(AgentRequest::TerminalOutputRequest)
261 .map_err(Into::into),
262 m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
263 .map(AgentRequest::KillTerminalRequest)
264 .map_err(Into::into),
265 m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
266 .map(AgentRequest::ReleaseTerminalRequest)
267 .map_err(Into::into),
268 m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
269 serde_json::from_str(params.get())
270 .map(AgentRequest::WaitForTerminalExitRequest)
271 .map_err(Into::into)
272 }
273 _ => {
274 if let Some(custom_method) = method.strip_prefix('_') {
275 Ok(AgentRequest::ExtMethodRequest(ExtRequest::new(
276 custom_method,
277 params.to_owned().into(),
278 )))
279 } else {
280 Err(Error::method_not_found())
281 }
282 }
283 }
284 }
285
286 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
287 let params = params.ok_or_else(Error::invalid_params)?;
288
289 match method {
290 m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
291 .map(AgentNotification::SessionNotification)
292 .map_err(Into::into),
293 _ => {
294 if let Some(custom_method) = method.strip_prefix('_') {
295 Ok(AgentNotification::ExtNotification(ExtNotification::new(
296 custom_method,
297 RawValue::from_string(params.get().to_string())?.into(),
298 )))
299 } else {
300 Err(Error::method_not_found())
301 }
302 }
303 }
304 }
305}
306
307impl<T: Client> MessageHandler<ClientSide> for T {
308 async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse> {
309 match request {
310 AgentRequest::RequestPermissionRequest(args) => {
311 let response = self.request_permission(args).await?;
312 Ok(ClientResponse::RequestPermissionResponse(response))
313 }
314 AgentRequest::WriteTextFileRequest(args) => {
315 let response = self.write_text_file(args).await?;
316 Ok(ClientResponse::WriteTextFileResponse(response))
317 }
318 AgentRequest::ReadTextFileRequest(args) => {
319 let response = self.read_text_file(args).await?;
320 Ok(ClientResponse::ReadTextFileResponse(response))
321 }
322 AgentRequest::CreateTerminalRequest(args) => {
323 let response = self.create_terminal(args).await?;
324 Ok(ClientResponse::CreateTerminalResponse(response))
325 }
326 AgentRequest::TerminalOutputRequest(args) => {
327 let response = self.terminal_output(args).await?;
328 Ok(ClientResponse::TerminalOutputResponse(response))
329 }
330 AgentRequest::ReleaseTerminalRequest(args) => {
331 let response = self.release_terminal(args).await?;
332 Ok(ClientResponse::ReleaseTerminalResponse(response))
333 }
334 AgentRequest::WaitForTerminalExitRequest(args) => {
335 let response = self.wait_for_terminal_exit(args).await?;
336 Ok(ClientResponse::WaitForTerminalExitResponse(response))
337 }
338 AgentRequest::KillTerminalRequest(args) => {
339 let response = self.kill_terminal(args).await?;
340 Ok(ClientResponse::KillTerminalResponse(response))
341 }
342 AgentRequest::ExtMethodRequest(args) => {
343 let response = self.ext_method(args).await?;
344 Ok(ClientResponse::ExtMethodResponse(response))
345 }
346 _ => Err(Error::method_not_found()),
347 }
348 }
349
350 async fn handle_notification(&self, notification: AgentNotification) -> Result<()> {
351 match notification {
352 AgentNotification::SessionNotification(args) => {
353 self.session_notification(args).await?;
354 }
355 AgentNotification::ExtNotification(args) => {
356 self.ext_notification(args).await?;
357 }
358 _ => {}
360 }
361 Ok(())
362 }
363}
364
365#[derive(Debug)]
376pub struct AgentSideConnection {
377 conn: RpcConnection<AgentSide, ClientSide>,
378}
379
380impl AgentSideConnection {
381 pub fn new(
401 agent: impl MessageHandler<AgentSide> + 'static,
402 outgoing_bytes: impl Unpin + AsyncWrite,
403 incoming_bytes: impl Unpin + AsyncRead,
404 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
405 ) -> (Self, impl Future<Output = Result<()>>) {
406 let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
407 (Self { conn }, io_task)
408 }
409
410 pub fn subscribe(&self) -> StreamReceiver {
419 self.conn.subscribe()
420 }
421}
422
423#[async_trait::async_trait(?Send)]
424impl Client for AgentSideConnection {
425 async fn request_permission(
426 &self,
427 args: RequestPermissionRequest,
428 ) -> Result<RequestPermissionResponse> {
429 self.conn
430 .request(
431 CLIENT_METHOD_NAMES.session_request_permission,
432 Some(AgentRequest::RequestPermissionRequest(args)),
433 )
434 .await
435 }
436
437 async fn write_text_file(&self, args: WriteTextFileRequest) -> Result<WriteTextFileResponse> {
438 self.conn
439 .request::<Option<_>>(
440 CLIENT_METHOD_NAMES.fs_write_text_file,
441 Some(AgentRequest::WriteTextFileRequest(args)),
442 )
443 .await
444 .map(Option::unwrap_or_default)
445 }
446
447 async fn read_text_file(&self, args: ReadTextFileRequest) -> Result<ReadTextFileResponse> {
448 self.conn
449 .request(
450 CLIENT_METHOD_NAMES.fs_read_text_file,
451 Some(AgentRequest::ReadTextFileRequest(args)),
452 )
453 .await
454 }
455
456 async fn create_terminal(&self, args: CreateTerminalRequest) -> Result<CreateTerminalResponse> {
457 self.conn
458 .request(
459 CLIENT_METHOD_NAMES.terminal_create,
460 Some(AgentRequest::CreateTerminalRequest(args)),
461 )
462 .await
463 }
464
465 async fn terminal_output(&self, args: TerminalOutputRequest) -> Result<TerminalOutputResponse> {
466 self.conn
467 .request(
468 CLIENT_METHOD_NAMES.terminal_output,
469 Some(AgentRequest::TerminalOutputRequest(args)),
470 )
471 .await
472 }
473
474 async fn release_terminal(
475 &self,
476 args: ReleaseTerminalRequest,
477 ) -> Result<ReleaseTerminalResponse> {
478 self.conn
479 .request::<Option<_>>(
480 CLIENT_METHOD_NAMES.terminal_release,
481 Some(AgentRequest::ReleaseTerminalRequest(args)),
482 )
483 .await
484 .map(Option::unwrap_or_default)
485 }
486
487 async fn wait_for_terminal_exit(
488 &self,
489 args: WaitForTerminalExitRequest,
490 ) -> Result<WaitForTerminalExitResponse> {
491 self.conn
492 .request(
493 CLIENT_METHOD_NAMES.terminal_wait_for_exit,
494 Some(AgentRequest::WaitForTerminalExitRequest(args)),
495 )
496 .await
497 }
498
499 async fn kill_terminal(&self, args: KillTerminalRequest) -> Result<KillTerminalResponse> {
500 self.conn
501 .request::<Option<_>>(
502 CLIENT_METHOD_NAMES.terminal_kill,
503 Some(AgentRequest::KillTerminalRequest(args)),
504 )
505 .await
506 .map(Option::unwrap_or_default)
507 }
508
509 async fn session_notification(&self, args: SessionNotification) -> Result<()> {
510 self.conn.notify(
511 CLIENT_METHOD_NAMES.session_update,
512 Some(AgentNotification::SessionNotification(args)),
513 )
514 }
515
516 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
517 self.conn
518 .request(
519 format!("_{}", args.method),
520 Some(AgentRequest::ExtMethodRequest(args)),
521 )
522 .await
523 }
524
525 async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
526 self.conn.notify(
527 format!("_{}", args.method),
528 Some(AgentNotification::ExtNotification(args)),
529 )
530 }
531}
532
533#[derive(Clone, Debug)]
540pub struct AgentSide;
541
542impl Side for AgentSide {
543 type InRequest = ClientRequest;
544 type InNotification = ClientNotification;
545 type OutResponse = AgentResponse;
546
547 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
548 let params = params.ok_or_else(Error::invalid_params)?;
549
550 match method {
551 m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
552 .map(ClientRequest::InitializeRequest)
553 .map_err(Into::into),
554 m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
555 .map(ClientRequest::AuthenticateRequest)
556 .map_err(Into::into),
557 m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
558 .map(ClientRequest::NewSessionRequest)
559 .map_err(Into::into),
560 m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
561 .map(ClientRequest::LoadSessionRequest)
562 .map_err(Into::into),
563 m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
564 .map(ClientRequest::SetSessionModeRequest)
565 .map_err(Into::into),
566 #[cfg(feature = "unstable_session_model")]
567 m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
568 .map(ClientRequest::SetSessionModelRequest)
569 .map_err(Into::into),
570 m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
571 .map(ClientRequest::ListSessionsRequest)
572 .map_err(Into::into),
573 #[cfg(feature = "unstable_session_fork")]
574 m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get())
575 .map(ClientRequest::ForkSessionRequest)
576 .map_err(Into::into),
577 #[cfg(feature = "unstable_session_resume")]
578 m if m == AGENT_METHOD_NAMES.session_resume => serde_json::from_str(params.get())
579 .map(ClientRequest::ResumeSessionRequest)
580 .map_err(Into::into),
581 #[cfg(feature = "unstable_session_close")]
582 m if m == AGENT_METHOD_NAMES.session_close => serde_json::from_str(params.get())
583 .map(ClientRequest::CloseSessionRequest)
584 .map_err(Into::into),
585 m if m == AGENT_METHOD_NAMES.session_set_config_option => {
586 serde_json::from_str(params.get())
587 .map(ClientRequest::SetSessionConfigOptionRequest)
588 .map_err(Into::into)
589 }
590 m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
591 .map(ClientRequest::PromptRequest)
592 .map_err(Into::into),
593 _ => {
594 if let Some(custom_method) = method.strip_prefix('_') {
595 Ok(ClientRequest::ExtMethodRequest(ExtRequest::new(
596 custom_method,
597 params.to_owned().into(),
598 )))
599 } else {
600 Err(Error::method_not_found())
601 }
602 }
603 }
604 }
605
606 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
607 let params = params.ok_or_else(Error::invalid_params)?;
608
609 match method {
610 m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
611 .map(ClientNotification::CancelNotification)
612 .map_err(Into::into),
613 _ => {
614 if let Some(custom_method) = method.strip_prefix('_') {
615 Ok(ClientNotification::ExtNotification(ExtNotification::new(
616 custom_method,
617 RawValue::from_string(params.get().to_string())?.into(),
618 )))
619 } else {
620 Err(Error::method_not_found())
621 }
622 }
623 }
624 }
625}
626
627impl<T: Agent> MessageHandler<AgentSide> for T {
628 async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse> {
629 match request {
630 ClientRequest::InitializeRequest(args) => {
631 let response = self.initialize(args).await?;
632 Ok(AgentResponse::InitializeResponse(response))
633 }
634 ClientRequest::AuthenticateRequest(args) => {
635 let response = self.authenticate(args).await?;
636 Ok(AgentResponse::AuthenticateResponse(response))
637 }
638 ClientRequest::NewSessionRequest(args) => {
639 let response = self.new_session(args).await?;
640 Ok(AgentResponse::NewSessionResponse(response))
641 }
642 ClientRequest::LoadSessionRequest(args) => {
643 let response = self.load_session(args).await?;
644 Ok(AgentResponse::LoadSessionResponse(response))
645 }
646 ClientRequest::PromptRequest(args) => {
647 let response = self.prompt(args).await?;
648 Ok(AgentResponse::PromptResponse(response))
649 }
650 ClientRequest::SetSessionModeRequest(args) => {
651 let response = self.set_session_mode(args).await?;
652 Ok(AgentResponse::SetSessionModeResponse(response))
653 }
654 #[cfg(feature = "unstable_session_model")]
655 ClientRequest::SetSessionModelRequest(args) => {
656 let response = self.set_session_model(args).await?;
657 Ok(AgentResponse::SetSessionModelResponse(response))
658 }
659 ClientRequest::ListSessionsRequest(args) => {
660 let response = self.list_sessions(args).await?;
661 Ok(AgentResponse::ListSessionsResponse(response))
662 }
663 #[cfg(feature = "unstable_session_fork")]
664 ClientRequest::ForkSessionRequest(args) => {
665 let response = self.fork_session(args).await?;
666 Ok(AgentResponse::ForkSessionResponse(response))
667 }
668 #[cfg(feature = "unstable_session_resume")]
669 ClientRequest::ResumeSessionRequest(args) => {
670 let response = self.resume_session(args).await?;
671 Ok(AgentResponse::ResumeSessionResponse(response))
672 }
673 #[cfg(feature = "unstable_session_close")]
674 ClientRequest::CloseSessionRequest(args) => {
675 let response = self.close_session(args).await?;
676 Ok(AgentResponse::CloseSessionResponse(response))
677 }
678 ClientRequest::SetSessionConfigOptionRequest(args) => {
679 let response = self.set_session_config_option(args).await?;
680 Ok(AgentResponse::SetSessionConfigOptionResponse(response))
681 }
682 ClientRequest::ExtMethodRequest(args) => {
683 let response = self.ext_method(args).await?;
684 Ok(AgentResponse::ExtMethodResponse(response))
685 }
686 _ => Err(Error::method_not_found()),
687 }
688 }
689
690 async fn handle_notification(&self, notification: ClientNotification) -> Result<()> {
691 match notification {
692 ClientNotification::CancelNotification(args) => {
693 self.cancel(args).await?;
694 }
695 ClientNotification::ExtNotification(args) => {
696 self.ext_notification(args).await?;
697 }
698 _ => {}
700 }
701 Ok(())
702 }
703}