1use futures::{AsyncRead, AsyncWrite, future::LocalBoxFuture};
2use rpc::RpcConnection;
3
4mod agent;
5mod client;
6mod rpc;
7#[cfg(test)]
8mod rpc_tests;
9mod stream_broadcast;
10
11pub use agent::*;
12pub use agent_client_protocol_schema::*;
13pub use client::*;
14pub use rpc::*;
15pub use stream_broadcast::{
16 StreamMessage, StreamMessageContent, StreamMessageDirection, StreamReceiver,
17};
18
19#[derive(Debug)]
30pub struct ClientSideConnection {
31 conn: RpcConnection<ClientSide, AgentSide>,
32}
33
34impl ClientSideConnection {
35 pub fn new(
55 client: impl MessageHandler<ClientSide> + 'static,
56 outgoing_bytes: impl Unpin + AsyncWrite,
57 incoming_bytes: impl Unpin + AsyncRead,
58 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
59 ) -> (Self, impl Future<Output = Result<()>>) {
60 let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
61 (Self { conn }, io_task)
62 }
63
64 pub fn subscribe(&self) -> StreamReceiver {
73 self.conn.subscribe()
74 }
75}
76
77#[async_trait::async_trait(?Send)]
78impl Agent for ClientSideConnection {
79 async fn initialize(&self, args: InitializeRequest) -> Result<InitializeResponse> {
80 self.conn
81 .request(
82 AGENT_METHOD_NAMES.initialize,
83 Some(ClientRequest::InitializeRequest(args)),
84 )
85 .await
86 }
87
88 async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse> {
89 self.conn
90 .request::<Option<_>>(
91 AGENT_METHOD_NAMES.authenticate,
92 Some(ClientRequest::AuthenticateRequest(args)),
93 )
94 .await
95 .map(Option::unwrap_or_default)
96 }
97
98 async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse> {
99 self.conn
100 .request(
101 AGENT_METHOD_NAMES.session_new,
102 Some(ClientRequest::NewSessionRequest(args)),
103 )
104 .await
105 }
106
107 async fn load_session(&self, args: LoadSessionRequest) -> Result<LoadSessionResponse> {
108 self.conn
109 .request::<Option<_>>(
110 AGENT_METHOD_NAMES.session_load,
111 Some(ClientRequest::LoadSessionRequest(args)),
112 )
113 .await
114 .map(Option::unwrap_or_default)
115 }
116
117 async fn set_session_mode(
118 &self,
119 args: SetSessionModeRequest,
120 ) -> Result<SetSessionModeResponse> {
121 self.conn
122 .request(
123 AGENT_METHOD_NAMES.session_set_mode,
124 Some(ClientRequest::SetSessionModeRequest(args)),
125 )
126 .await
127 }
128
129 async fn prompt(&self, args: PromptRequest) -> Result<PromptResponse> {
130 self.conn
131 .request(
132 AGENT_METHOD_NAMES.session_prompt,
133 Some(ClientRequest::PromptRequest(args)),
134 )
135 .await
136 }
137
138 async fn cancel(&self, args: CancelNotification) -> Result<()> {
139 self.conn.notify(
140 AGENT_METHOD_NAMES.session_cancel,
141 Some(ClientNotification::CancelNotification(args)),
142 )
143 }
144
145 #[cfg(feature = "unstable_session_model")]
146 async fn set_session_model(
147 &self,
148 args: SetSessionModelRequest,
149 ) -> Result<SetSessionModelResponse> {
150 self.conn
151 .request(
152 AGENT_METHOD_NAMES.session_set_model,
153 Some(ClientRequest::SetSessionModelRequest(args)),
154 )
155 .await
156 }
157
158 #[cfg(feature = "unstable_session_list")]
159 async fn list_sessions(&self, args: ListSessionsRequest) -> Result<ListSessionsResponse> {
160 self.conn
161 .request(
162 AGENT_METHOD_NAMES.session_list,
163 Some(ClientRequest::ListSessionsRequest(args)),
164 )
165 .await
166 }
167
168 #[cfg(feature = "unstable_session_fork")]
169 async fn fork_session(&self, args: ForkSessionRequest) -> Result<ForkSessionResponse> {
170 self.conn
171 .request(
172 AGENT_METHOD_NAMES.session_fork,
173 Some(ClientRequest::ForkSessionRequest(args)),
174 )
175 .await
176 }
177
178 #[cfg(feature = "unstable_session_resume")]
179 async fn resume_session(&self, args: ResumeSessionRequest) -> Result<ResumeSessionResponse> {
180 self.conn
181 .request(
182 AGENT_METHOD_NAMES.session_resume,
183 Some(ClientRequest::ResumeSessionRequest(args)),
184 )
185 .await
186 }
187
188 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
189 self.conn
190 .request(
191 format!("_{}", args.method),
192 Some(ClientRequest::ExtMethodRequest(args)),
193 )
194 .await
195 }
196
197 async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
198 self.conn.notify(
199 format!("_{}", args.method),
200 Some(ClientNotification::ExtNotification(args)),
201 )
202 }
203}
204
205#[derive(Clone, Debug)]
212pub struct ClientSide;
213
214impl Side for ClientSide {
215 type InNotification = AgentNotification;
216 type InRequest = AgentRequest;
217 type OutResponse = ClientResponse;
218
219 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
220 let params = params.ok_or_else(Error::invalid_params)?;
221
222 match method {
223 m if m == CLIENT_METHOD_NAMES.session_request_permission => {
224 serde_json::from_str(params.get())
225 .map(AgentRequest::RequestPermissionRequest)
226 .map_err(Into::into)
227 }
228 m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
229 .map(AgentRequest::WriteTextFileRequest)
230 .map_err(Into::into),
231 m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
232 .map(AgentRequest::ReadTextFileRequest)
233 .map_err(Into::into),
234 m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
235 .map(AgentRequest::CreateTerminalRequest)
236 .map_err(Into::into),
237 m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
238 .map(AgentRequest::TerminalOutputRequest)
239 .map_err(Into::into),
240 m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
241 .map(AgentRequest::KillTerminalCommandRequest)
242 .map_err(Into::into),
243 m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
244 .map(AgentRequest::ReleaseTerminalRequest)
245 .map_err(Into::into),
246 m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
247 serde_json::from_str(params.get())
248 .map(AgentRequest::WaitForTerminalExitRequest)
249 .map_err(Into::into)
250 }
251 _ => {
252 if let Some(custom_method) = method.strip_prefix('_') {
253 Ok(AgentRequest::ExtMethodRequest(ExtRequest::new(
254 custom_method,
255 params.to_owned().into(),
256 )))
257 } else {
258 Err(Error::method_not_found())
259 }
260 }
261 }
262 }
263
264 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
265 let params = params.ok_or_else(Error::invalid_params)?;
266
267 match method {
268 m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
269 .map(AgentNotification::SessionNotification)
270 .map_err(Into::into),
271 _ => {
272 if let Some(custom_method) = method.strip_prefix('_') {
273 Ok(AgentNotification::ExtNotification(ExtNotification::new(
274 custom_method,
275 RawValue::from_string(params.get().to_string())?.into(),
276 )))
277 } else {
278 Err(Error::method_not_found())
279 }
280 }
281 }
282 }
283}
284
285impl<T: Client> MessageHandler<ClientSide> for T {
286 async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse> {
287 match request {
288 AgentRequest::RequestPermissionRequest(args) => {
289 let response = self.request_permission(args).await?;
290 Ok(ClientResponse::RequestPermissionResponse(response))
291 }
292 AgentRequest::WriteTextFileRequest(args) => {
293 let response = self.write_text_file(args).await?;
294 Ok(ClientResponse::WriteTextFileResponse(response))
295 }
296 AgentRequest::ReadTextFileRequest(args) => {
297 let response = self.read_text_file(args).await?;
298 Ok(ClientResponse::ReadTextFileResponse(response))
299 }
300 AgentRequest::CreateTerminalRequest(args) => {
301 let response = self.create_terminal(args).await?;
302 Ok(ClientResponse::CreateTerminalResponse(response))
303 }
304 AgentRequest::TerminalOutputRequest(args) => {
305 let response = self.terminal_output(args).await?;
306 Ok(ClientResponse::TerminalOutputResponse(response))
307 }
308 AgentRequest::ReleaseTerminalRequest(args) => {
309 let response = self.release_terminal(args).await?;
310 Ok(ClientResponse::ReleaseTerminalResponse(response))
311 }
312 AgentRequest::WaitForTerminalExitRequest(args) => {
313 let response = self.wait_for_terminal_exit(args).await?;
314 Ok(ClientResponse::WaitForTerminalExitResponse(response))
315 }
316 AgentRequest::KillTerminalCommandRequest(args) => {
317 let response = self.kill_terminal_command(args).await?;
318 Ok(ClientResponse::KillTerminalResponse(response))
319 }
320 AgentRequest::ExtMethodRequest(args) => {
321 let response = self.ext_method(args).await?;
322 Ok(ClientResponse::ExtMethodResponse(response))
323 }
324 _ => Err(Error::method_not_found()),
325 }
326 }
327
328 async fn handle_notification(&self, notification: AgentNotification) -> Result<()> {
329 match notification {
330 AgentNotification::SessionNotification(args) => {
331 self.session_notification(args).await?;
332 }
333 AgentNotification::ExtNotification(args) => {
334 self.ext_notification(args).await?;
335 }
336 _ => {}
338 }
339 Ok(())
340 }
341}
342
343#[derive(Debug)]
354pub struct AgentSideConnection {
355 conn: RpcConnection<AgentSide, ClientSide>,
356}
357
358impl AgentSideConnection {
359 pub fn new(
379 agent: impl MessageHandler<AgentSide> + 'static,
380 outgoing_bytes: impl Unpin + AsyncWrite,
381 incoming_bytes: impl Unpin + AsyncRead,
382 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
383 ) -> (Self, impl Future<Output = Result<()>>) {
384 let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
385 (Self { conn }, io_task)
386 }
387
388 pub fn subscribe(&self) -> StreamReceiver {
397 self.conn.subscribe()
398 }
399}
400
401#[async_trait::async_trait(?Send)]
402impl Client for AgentSideConnection {
403 async fn request_permission(
404 &self,
405 args: RequestPermissionRequest,
406 ) -> Result<RequestPermissionResponse> {
407 self.conn
408 .request(
409 CLIENT_METHOD_NAMES.session_request_permission,
410 Some(AgentRequest::RequestPermissionRequest(args)),
411 )
412 .await
413 }
414
415 async fn write_text_file(&self, args: WriteTextFileRequest) -> Result<WriteTextFileResponse> {
416 self.conn
417 .request::<Option<_>>(
418 CLIENT_METHOD_NAMES.fs_write_text_file,
419 Some(AgentRequest::WriteTextFileRequest(args)),
420 )
421 .await
422 .map(Option::unwrap_or_default)
423 }
424
425 async fn read_text_file(&self, args: ReadTextFileRequest) -> Result<ReadTextFileResponse> {
426 self.conn
427 .request(
428 CLIENT_METHOD_NAMES.fs_read_text_file,
429 Some(AgentRequest::ReadTextFileRequest(args)),
430 )
431 .await
432 }
433
434 async fn create_terminal(&self, args: CreateTerminalRequest) -> Result<CreateTerminalResponse> {
435 self.conn
436 .request(
437 CLIENT_METHOD_NAMES.terminal_create,
438 Some(AgentRequest::CreateTerminalRequest(args)),
439 )
440 .await
441 }
442
443 async fn terminal_output(&self, args: TerminalOutputRequest) -> Result<TerminalOutputResponse> {
444 self.conn
445 .request(
446 CLIENT_METHOD_NAMES.terminal_output,
447 Some(AgentRequest::TerminalOutputRequest(args)),
448 )
449 .await
450 }
451
452 async fn release_terminal(
453 &self,
454 args: ReleaseTerminalRequest,
455 ) -> Result<ReleaseTerminalResponse> {
456 self.conn
457 .request::<Option<_>>(
458 CLIENT_METHOD_NAMES.terminal_release,
459 Some(AgentRequest::ReleaseTerminalRequest(args)),
460 )
461 .await
462 .map(Option::unwrap_or_default)
463 }
464
465 async fn wait_for_terminal_exit(
466 &self,
467 args: WaitForTerminalExitRequest,
468 ) -> Result<WaitForTerminalExitResponse> {
469 self.conn
470 .request(
471 CLIENT_METHOD_NAMES.terminal_wait_for_exit,
472 Some(AgentRequest::WaitForTerminalExitRequest(args)),
473 )
474 .await
475 }
476
477 async fn kill_terminal_command(
478 &self,
479 args: KillTerminalCommandRequest,
480 ) -> Result<KillTerminalCommandResponse> {
481 self.conn
482 .request::<Option<_>>(
483 CLIENT_METHOD_NAMES.terminal_kill,
484 Some(AgentRequest::KillTerminalCommandRequest(args)),
485 )
486 .await
487 .map(Option::unwrap_or_default)
488 }
489
490 async fn session_notification(&self, args: SessionNotification) -> Result<()> {
491 self.conn.notify(
492 CLIENT_METHOD_NAMES.session_update,
493 Some(AgentNotification::SessionNotification(args)),
494 )
495 }
496
497 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
498 self.conn
499 .request(
500 format!("_{}", args.method),
501 Some(AgentRequest::ExtMethodRequest(args)),
502 )
503 .await
504 }
505
506 async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
507 self.conn.notify(
508 format!("_{}", args.method),
509 Some(AgentNotification::ExtNotification(args)),
510 )
511 }
512}
513
514#[derive(Clone, Debug)]
521pub struct AgentSide;
522
523impl Side for AgentSide {
524 type InRequest = ClientRequest;
525 type InNotification = ClientNotification;
526 type OutResponse = AgentResponse;
527
528 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
529 let params = params.ok_or_else(Error::invalid_params)?;
530
531 match method {
532 m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
533 .map(ClientRequest::InitializeRequest)
534 .map_err(Into::into),
535 m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
536 .map(ClientRequest::AuthenticateRequest)
537 .map_err(Into::into),
538 m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
539 .map(ClientRequest::NewSessionRequest)
540 .map_err(Into::into),
541 m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
542 .map(ClientRequest::LoadSessionRequest)
543 .map_err(Into::into),
544 m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
545 .map(ClientRequest::SetSessionModeRequest)
546 .map_err(Into::into),
547 #[cfg(feature = "unstable_session_model")]
548 m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
549 .map(ClientRequest::SetSessionModelRequest)
550 .map_err(Into::into),
551 #[cfg(feature = "unstable_session_list")]
552 m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
553 .map(ClientRequest::ListSessionsRequest)
554 .map_err(Into::into),
555 #[cfg(feature = "unstable_session_fork")]
556 m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get())
557 .map(ClientRequest::ForkSessionRequest)
558 .map_err(Into::into),
559 #[cfg(feature = "unstable_session_resume")]
560 m if m == AGENT_METHOD_NAMES.session_resume => serde_json::from_str(params.get())
561 .map(ClientRequest::ResumeSessionRequest)
562 .map_err(Into::into),
563 m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
564 .map(ClientRequest::PromptRequest)
565 .map_err(Into::into),
566 _ => {
567 if let Some(custom_method) = method.strip_prefix('_') {
568 Ok(ClientRequest::ExtMethodRequest(ExtRequest::new(
569 custom_method,
570 params.to_owned().into(),
571 )))
572 } else {
573 Err(Error::method_not_found())
574 }
575 }
576 }
577 }
578
579 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
580 let params = params.ok_or_else(Error::invalid_params)?;
581
582 match method {
583 m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
584 .map(ClientNotification::CancelNotification)
585 .map_err(Into::into),
586 _ => {
587 if let Some(custom_method) = method.strip_prefix('_') {
588 Ok(ClientNotification::ExtNotification(ExtNotification::new(
589 custom_method,
590 RawValue::from_string(params.get().to_string())?.into(),
591 )))
592 } else {
593 Err(Error::method_not_found())
594 }
595 }
596 }
597 }
598}
599
600impl<T: Agent> MessageHandler<AgentSide> for T {
601 async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse> {
602 match request {
603 ClientRequest::InitializeRequest(args) => {
604 let response = self.initialize(args).await?;
605 Ok(AgentResponse::InitializeResponse(response))
606 }
607 ClientRequest::AuthenticateRequest(args) => {
608 let response = self.authenticate(args).await?;
609 Ok(AgentResponse::AuthenticateResponse(response))
610 }
611 ClientRequest::NewSessionRequest(args) => {
612 let response = self.new_session(args).await?;
613 Ok(AgentResponse::NewSessionResponse(response))
614 }
615 ClientRequest::LoadSessionRequest(args) => {
616 let response = self.load_session(args).await?;
617 Ok(AgentResponse::LoadSessionResponse(response))
618 }
619 ClientRequest::PromptRequest(args) => {
620 let response = self.prompt(args).await?;
621 Ok(AgentResponse::PromptResponse(response))
622 }
623 ClientRequest::SetSessionModeRequest(args) => {
624 let response = self.set_session_mode(args).await?;
625 Ok(AgentResponse::SetSessionModeResponse(response))
626 }
627 #[cfg(feature = "unstable_session_model")]
628 ClientRequest::SetSessionModelRequest(args) => {
629 let response = self.set_session_model(args).await?;
630 Ok(AgentResponse::SetSessionModelResponse(response))
631 }
632 #[cfg(feature = "unstable_session_list")]
633 ClientRequest::ListSessionsRequest(args) => {
634 let response = self.list_sessions(args).await?;
635 Ok(AgentResponse::ListSessionsResponse(response))
636 }
637 #[cfg(feature = "unstable_session_fork")]
638 ClientRequest::ForkSessionRequest(args) => {
639 let response = self.fork_session(args).await?;
640 Ok(AgentResponse::ForkSessionResponse(response))
641 }
642 #[cfg(feature = "unstable_session_resume")]
643 ClientRequest::ResumeSessionRequest(args) => {
644 let response = self.resume_session(args).await?;
645 Ok(AgentResponse::ResumeSessionResponse(response))
646 }
647 ClientRequest::ExtMethodRequest(args) => {
648 let response = self.ext_method(args).await?;
649 Ok(AgentResponse::ExtMethodResponse(response))
650 }
651 _ => Err(Error::method_not_found()),
652 }
653 }
654
655 async fn handle_notification(&self, notification: ClientNotification) -> Result<()> {
656 match notification {
657 ClientNotification::CancelNotification(args) => {
658 self.cancel(args).await?;
659 }
660 ClientNotification::ExtNotification(args) => {
661 self.ext_notification(args).await?;
662 }
663 _ => {}
665 }
666 Ok(())
667 }
668}