1use anyhow::Result;
2use futures::{AsyncRead, AsyncWrite, future::LocalBoxFuture};
3use rpc::{MessageHandler, RpcConnection, Side};
4
5mod agent;
6mod client;
7mod rpc;
8#[cfg(test)]
9mod rpc_tests;
10mod stream_broadcast;
11
12pub use agent::*;
13pub use agent_client_protocol_schema::*;
14pub use client::*;
15pub use stream_broadcast::{
16 StreamMessage, StreamMessageContent, StreamMessageDirection, StreamReceiver,
17};
18
19pub struct ClientSideConnection {
30 conn: RpcConnection<ClientSide, AgentSide>,
31}
32
33impl ClientSideConnection {
34 pub fn new(
54 client: impl MessageHandler<ClientSide> + 'static,
55 outgoing_bytes: impl Unpin + AsyncWrite,
56 incoming_bytes: impl Unpin + AsyncRead,
57 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
58 ) -> (Self, impl Future<Output = Result<()>>) {
59 let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
60 (Self { conn }, io_task)
61 }
62
63 pub fn subscribe(&self) -> StreamReceiver {
72 self.conn.subscribe()
73 }
74}
75
76#[async_trait::async_trait(?Send)]
77impl Agent for ClientSideConnection {
78 async fn initialize(&self, args: InitializeRequest) -> Result<InitializeResponse, Error> {
79 self.conn
80 .request(
81 AGENT_METHOD_NAMES.initialize,
82 Some(ClientRequest::InitializeRequest(args)),
83 )
84 .await
85 }
86
87 async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse, Error> {
88 self.conn
89 .request::<Option<_>>(
90 AGENT_METHOD_NAMES.authenticate,
91 Some(ClientRequest::AuthenticateRequest(args)),
92 )
93 .await
94 .map(Option::unwrap_or_default)
95 }
96
97 async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse, Error> {
98 self.conn
99 .request(
100 AGENT_METHOD_NAMES.session_new,
101 Some(ClientRequest::NewSessionRequest(args)),
102 )
103 .await
104 }
105
106 async fn load_session(&self, args: LoadSessionRequest) -> Result<LoadSessionResponse, Error> {
107 self.conn
108 .request::<Option<_>>(
109 AGENT_METHOD_NAMES.session_load,
110 Some(ClientRequest::LoadSessionRequest(args)),
111 )
112 .await
113 .map(Option::unwrap_or_default)
114 }
115
116 async fn set_session_mode(
117 &self,
118 args: SetSessionModeRequest,
119 ) -> Result<SetSessionModeResponse, Error> {
120 self.conn
121 .request(
122 AGENT_METHOD_NAMES.session_set_mode,
123 Some(ClientRequest::SetSessionModeRequest(args)),
124 )
125 .await
126 }
127
128 async fn prompt(&self, args: PromptRequest) -> Result<PromptResponse, Error> {
129 self.conn
130 .request(
131 AGENT_METHOD_NAMES.session_prompt,
132 Some(ClientRequest::PromptRequest(args)),
133 )
134 .await
135 }
136
137 async fn cancel(&self, args: CancelNotification) -> Result<(), Error> {
138 self.conn.notify(
139 AGENT_METHOD_NAMES.session_cancel,
140 Some(ClientNotification::CancelNotification(args)),
141 )
142 }
143
144 #[cfg(feature = "unstable")]
145 async fn set_session_model(
146 &self,
147 args: SetSessionModelRequest,
148 ) -> Result<SetSessionModelResponse, Error> {
149 self.conn
150 .request(
151 AGENT_METHOD_NAMES.session_set_model,
152 Some(ClientRequest::SetSessionModelRequest(args)),
153 )
154 .await
155 }
156
157 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse, Error> {
158 self.conn
159 .request(
160 format!("_{}", args.method),
161 Some(ClientRequest::ExtMethodRequest(args)),
162 )
163 .await
164 }
165
166 async fn ext_notification(&self, args: ExtNotification) -> Result<(), Error> {
167 self.conn.notify(
168 format!("_{}", args.method),
169 Some(ClientNotification::ExtNotification(args)),
170 )
171 }
172}
173
174#[derive(Clone)]
181pub struct ClientSide;
182
183impl Side for ClientSide {
184 type InNotification = AgentNotification;
185 type InRequest = AgentRequest;
186 type OutResponse = ClientResponse;
187
188 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest, Error> {
189 let params = params.ok_or_else(Error::invalid_params)?;
190
191 match method {
192 m if m == CLIENT_METHOD_NAMES.session_request_permission => {
193 serde_json::from_str(params.get())
194 .map(AgentRequest::RequestPermissionRequest)
195 .map_err(Into::into)
196 }
197 m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
198 .map(AgentRequest::WriteTextFileRequest)
199 .map_err(Into::into),
200 m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
201 .map(AgentRequest::ReadTextFileRequest)
202 .map_err(Into::into),
203 m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
204 .map(AgentRequest::CreateTerminalRequest)
205 .map_err(Into::into),
206 m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
207 .map(AgentRequest::TerminalOutputRequest)
208 .map_err(Into::into),
209 m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
210 .map(AgentRequest::KillTerminalCommandRequest)
211 .map_err(Into::into),
212 m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
213 .map(AgentRequest::ReleaseTerminalRequest)
214 .map_err(Into::into),
215 m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
216 serde_json::from_str(params.get())
217 .map(AgentRequest::WaitForTerminalExitRequest)
218 .map_err(Into::into)
219 }
220 _ => {
221 if let Some(custom_method) = method.strip_prefix('_') {
222 Ok(AgentRequest::ExtMethodRequest(ExtRequest {
223 method: custom_method.into(),
224 params: RawValue::from_string(params.get().to_string())?.into(),
225 }))
226 } else {
227 Err(Error::method_not_found())
228 }
229 }
230 }
231 }
232
233 fn decode_notification(
234 method: &str,
235 params: Option<&RawValue>,
236 ) -> Result<AgentNotification, Error> {
237 let params = params.ok_or_else(Error::invalid_params)?;
238
239 match method {
240 m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
241 .map(AgentNotification::SessionNotification)
242 .map_err(Into::into),
243 _ => {
244 if let Some(custom_method) = method.strip_prefix('_') {
245 Ok(AgentNotification::ExtNotification(ExtNotification {
246 method: custom_method.into(),
247 params: RawValue::from_string(params.get().to_string())?.into(),
248 }))
249 } else {
250 Err(Error::method_not_found())
251 }
252 }
253 }
254 }
255}
256
257impl<T: Client> MessageHandler<ClientSide> for T {
258 async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse, Error> {
259 match request {
260 AgentRequest::RequestPermissionRequest(args) => {
261 let response = self.request_permission(args).await?;
262 Ok(ClientResponse::RequestPermissionResponse(response))
263 }
264 AgentRequest::WriteTextFileRequest(args) => {
265 let response = self.write_text_file(args).await?;
266 Ok(ClientResponse::WriteTextFileResponse(response))
267 }
268 AgentRequest::ReadTextFileRequest(args) => {
269 let response = self.read_text_file(args).await?;
270 Ok(ClientResponse::ReadTextFileResponse(response))
271 }
272 AgentRequest::CreateTerminalRequest(args) => {
273 let response = self.create_terminal(args).await?;
274 Ok(ClientResponse::CreateTerminalResponse(response))
275 }
276 AgentRequest::TerminalOutputRequest(args) => {
277 let response = self.terminal_output(args).await?;
278 Ok(ClientResponse::TerminalOutputResponse(response))
279 }
280 AgentRequest::ReleaseTerminalRequest(args) => {
281 let response = self.release_terminal(args).await?;
282 Ok(ClientResponse::ReleaseTerminalResponse(response))
283 }
284 AgentRequest::WaitForTerminalExitRequest(args) => {
285 let response = self.wait_for_terminal_exit(args).await?;
286 Ok(ClientResponse::WaitForTerminalExitResponse(response))
287 }
288 AgentRequest::KillTerminalCommandRequest(args) => {
289 let response = self.kill_terminal_command(args).await?;
290 Ok(ClientResponse::KillTerminalResponse(response))
291 }
292 AgentRequest::ExtMethodRequest(args) => {
293 let response = self.ext_method(args).await?;
294 Ok(ClientResponse::ExtMethodResponse(response))
295 }
296 }
297 }
298
299 async fn handle_notification(&self, notification: AgentNotification) -> Result<(), Error> {
300 match notification {
301 AgentNotification::SessionNotification(args) => {
302 self.session_notification(args).await?;
303 }
304 AgentNotification::ExtNotification(args) => {
305 self.ext_notification(args).await?;
306 }
307 }
308 Ok(())
309 }
310}
311
312pub struct AgentSideConnection {
323 conn: RpcConnection<AgentSide, ClientSide>,
324}
325
326impl AgentSideConnection {
327 pub fn new(
347 agent: impl MessageHandler<AgentSide> + 'static,
348 outgoing_bytes: impl Unpin + AsyncWrite,
349 incoming_bytes: impl Unpin + AsyncRead,
350 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
351 ) -> (Self, impl Future<Output = Result<()>>) {
352 let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
353 (Self { conn }, io_task)
354 }
355
356 pub fn subscribe(&self) -> StreamReceiver {
365 self.conn.subscribe()
366 }
367}
368
369#[async_trait::async_trait(?Send)]
370impl Client for AgentSideConnection {
371 async fn request_permission(
372 &self,
373 args: RequestPermissionRequest,
374 ) -> Result<RequestPermissionResponse, Error> {
375 self.conn
376 .request(
377 CLIENT_METHOD_NAMES.session_request_permission,
378 Some(AgentRequest::RequestPermissionRequest(args)),
379 )
380 .await
381 }
382
383 async fn write_text_file(
384 &self,
385 args: WriteTextFileRequest,
386 ) -> Result<WriteTextFileResponse, Error> {
387 self.conn
388 .request::<Option<_>>(
389 CLIENT_METHOD_NAMES.fs_write_text_file,
390 Some(AgentRequest::WriteTextFileRequest(args)),
391 )
392 .await
393 .map(Option::unwrap_or_default)
394 }
395
396 async fn read_text_file(
397 &self,
398 args: ReadTextFileRequest,
399 ) -> Result<ReadTextFileResponse, Error> {
400 self.conn
401 .request(
402 CLIENT_METHOD_NAMES.fs_read_text_file,
403 Some(AgentRequest::ReadTextFileRequest(args)),
404 )
405 .await
406 }
407
408 async fn create_terminal(
409 &self,
410 args: CreateTerminalRequest,
411 ) -> Result<CreateTerminalResponse, Error> {
412 self.conn
413 .request(
414 CLIENT_METHOD_NAMES.terminal_create,
415 Some(AgentRequest::CreateTerminalRequest(args)),
416 )
417 .await
418 }
419
420 async fn terminal_output(
421 &self,
422 args: TerminalOutputRequest,
423 ) -> Result<TerminalOutputResponse, Error> {
424 self.conn
425 .request(
426 CLIENT_METHOD_NAMES.terminal_output,
427 Some(AgentRequest::TerminalOutputRequest(args)),
428 )
429 .await
430 }
431
432 async fn release_terminal(
433 &self,
434 args: ReleaseTerminalRequest,
435 ) -> Result<ReleaseTerminalResponse, Error> {
436 self.conn
437 .request::<Option<_>>(
438 CLIENT_METHOD_NAMES.terminal_release,
439 Some(AgentRequest::ReleaseTerminalRequest(args)),
440 )
441 .await
442 .map(Option::unwrap_or_default)
443 }
444
445 async fn wait_for_terminal_exit(
446 &self,
447 args: WaitForTerminalExitRequest,
448 ) -> Result<WaitForTerminalExitResponse, Error> {
449 self.conn
450 .request(
451 CLIENT_METHOD_NAMES.terminal_wait_for_exit,
452 Some(AgentRequest::WaitForTerminalExitRequest(args)),
453 )
454 .await
455 }
456
457 async fn kill_terminal_command(
458 &self,
459 args: KillTerminalCommandRequest,
460 ) -> Result<KillTerminalCommandResponse, Error> {
461 self.conn
462 .request::<Option<_>>(
463 CLIENT_METHOD_NAMES.terminal_kill,
464 Some(AgentRequest::KillTerminalCommandRequest(args)),
465 )
466 .await
467 .map(Option::unwrap_or_default)
468 }
469
470 async fn session_notification(&self, args: SessionNotification) -> Result<(), Error> {
471 self.conn.notify(
472 CLIENT_METHOD_NAMES.session_update,
473 Some(AgentNotification::SessionNotification(args)),
474 )
475 }
476
477 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse, Error> {
478 self.conn
479 .request(
480 format!("_{}", args.method),
481 Some(AgentRequest::ExtMethodRequest(args)),
482 )
483 .await
484 }
485
486 async fn ext_notification(&self, args: ExtNotification) -> Result<(), Error> {
487 self.conn.notify(
488 format!("_{}", args.method),
489 Some(AgentNotification::ExtNotification(args)),
490 )
491 }
492}
493
494#[derive(Clone)]
501pub struct AgentSide;
502
503impl Side for AgentSide {
504 type InRequest = ClientRequest;
505 type InNotification = ClientNotification;
506 type OutResponse = AgentResponse;
507
508 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest, Error> {
509 let params = params.ok_or_else(Error::invalid_params)?;
510
511 match method {
512 m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
513 .map(ClientRequest::InitializeRequest)
514 .map_err(Into::into),
515 m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
516 .map(ClientRequest::AuthenticateRequest)
517 .map_err(Into::into),
518 m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
519 .map(ClientRequest::NewSessionRequest)
520 .map_err(Into::into),
521 m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
522 .map(ClientRequest::LoadSessionRequest)
523 .map_err(Into::into),
524 m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
525 .map(ClientRequest::SetSessionModeRequest)
526 .map_err(Into::into),
527 #[cfg(feature = "unstable")]
528 m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
529 .map(ClientRequest::SetSessionModelRequest)
530 .map_err(Into::into),
531 m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
532 .map(ClientRequest::PromptRequest)
533 .map_err(Into::into),
534 _ => {
535 if let Some(custom_method) = method.strip_prefix('_') {
536 Ok(ClientRequest::ExtMethodRequest(ExtRequest {
537 method: custom_method.into(),
538 params: RawValue::from_string(params.get().to_string())?.into(),
539 }))
540 } else {
541 Err(Error::method_not_found())
542 }
543 }
544 }
545 }
546
547 fn decode_notification(
548 method: &str,
549 params: Option<&RawValue>,
550 ) -> Result<ClientNotification, Error> {
551 let params = params.ok_or_else(Error::invalid_params)?;
552
553 match method {
554 m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
555 .map(ClientNotification::CancelNotification)
556 .map_err(Into::into),
557 _ => {
558 if let Some(custom_method) = method.strip_prefix('_') {
559 Ok(ClientNotification::ExtNotification(ExtNotification {
560 method: custom_method.into(),
561 params: RawValue::from_string(params.get().to_string())?.into(),
562 }))
563 } else {
564 Err(Error::method_not_found())
565 }
566 }
567 }
568 }
569}
570
571impl<T: Agent> MessageHandler<AgentSide> for T {
572 async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse, Error> {
573 match request {
574 ClientRequest::InitializeRequest(args) => {
575 let response = self.initialize(args).await?;
576 Ok(AgentResponse::InitializeResponse(response))
577 }
578 ClientRequest::AuthenticateRequest(args) => {
579 let response = self.authenticate(args).await?;
580 Ok(AgentResponse::AuthenticateResponse(response))
581 }
582 ClientRequest::NewSessionRequest(args) => {
583 let response = self.new_session(args).await?;
584 Ok(AgentResponse::NewSessionResponse(response))
585 }
586 ClientRequest::LoadSessionRequest(args) => {
587 let response = self.load_session(args).await?;
588 Ok(AgentResponse::LoadSessionResponse(response))
589 }
590 ClientRequest::PromptRequest(args) => {
591 let response = self.prompt(args).await?;
592 Ok(AgentResponse::PromptResponse(response))
593 }
594 ClientRequest::SetSessionModeRequest(args) => {
595 let response = self.set_session_mode(args).await?;
596 Ok(AgentResponse::SetSessionModeResponse(response))
597 }
598 #[cfg(feature = "unstable")]
599 ClientRequest::SetSessionModelRequest(args) => {
600 let response = self.set_session_model(args).await?;
601 Ok(AgentResponse::SetSessionModelResponse(response))
602 }
603 ClientRequest::ExtMethodRequest(args) => {
604 let response = self.ext_method(args).await?;
605 Ok(AgentResponse::ExtMethodResponse(response))
606 }
607 }
608 }
609
610 async fn handle_notification(&self, notification: ClientNotification) -> Result<(), Error> {
611 match notification {
612 ClientNotification::CancelNotification(args) => {
613 self.cancel(args).await?;
614 }
615 ClientNotification::ExtNotification(args) => {
616 self.ext_notification(args).await?;
617 }
618 }
619 Ok(())
620 }
621}