1use futures::{AsyncRead, AsyncWrite, future::LocalBoxFuture};
2use rpc::RpcConnection;
3
4mod agent;
5mod client;
6mod rpc;
7#[cfg(test)]
8mod rpc_tests;
9mod stream_broadcast;
10
11pub use agent::*;
12pub use agent_client_protocol_schema::*;
13pub use client::*;
14pub use rpc::*;
15pub use stream_broadcast::{
16 StreamMessage, StreamMessageContent, StreamMessageDirection, StreamReceiver,
17};
18
19#[derive(Debug)]
30pub struct ClientSideConnection {
31 conn: RpcConnection<ClientSide, AgentSide>,
32}
33
34impl ClientSideConnection {
35 pub fn new(
55 client: impl MessageHandler<ClientSide> + 'static,
56 outgoing_bytes: impl Unpin + AsyncWrite,
57 incoming_bytes: impl Unpin + AsyncRead,
58 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
59 ) -> (Self, impl Future<Output = Result<()>>) {
60 let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
61 (Self { conn }, io_task)
62 }
63
64 pub fn subscribe(&self) -> StreamReceiver {
73 self.conn.subscribe()
74 }
75}
76
77#[async_trait::async_trait(?Send)]
78impl Agent for ClientSideConnection {
79 async fn initialize(&self, args: InitializeRequest) -> Result<InitializeResponse> {
80 self.conn
81 .request(
82 AGENT_METHOD_NAMES.initialize,
83 Some(ClientRequest::InitializeRequest(args)),
84 )
85 .await
86 }
87
88 async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse> {
89 self.conn
90 .request::<Option<_>>(
91 AGENT_METHOD_NAMES.authenticate,
92 Some(ClientRequest::AuthenticateRequest(args)),
93 )
94 .await
95 .map(Option::unwrap_or_default)
96 }
97
98 async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse> {
99 self.conn
100 .request(
101 AGENT_METHOD_NAMES.session_new,
102 Some(ClientRequest::NewSessionRequest(args)),
103 )
104 .await
105 }
106
107 async fn load_session(&self, args: LoadSessionRequest) -> Result<LoadSessionResponse> {
108 self.conn
109 .request::<Option<_>>(
110 AGENT_METHOD_NAMES.session_load,
111 Some(ClientRequest::LoadSessionRequest(args)),
112 )
113 .await
114 .map(Option::unwrap_or_default)
115 }
116
117 async fn set_session_mode(
118 &self,
119 args: SetSessionModeRequest,
120 ) -> Result<SetSessionModeResponse> {
121 self.conn
122 .request(
123 AGENT_METHOD_NAMES.session_set_mode,
124 Some(ClientRequest::SetSessionModeRequest(args)),
125 )
126 .await
127 }
128
129 async fn prompt(&self, args: PromptRequest) -> Result<PromptResponse> {
130 self.conn
131 .request(
132 AGENT_METHOD_NAMES.session_prompt,
133 Some(ClientRequest::PromptRequest(args)),
134 )
135 .await
136 }
137
138 async fn cancel(&self, args: CancelNotification) -> Result<()> {
139 self.conn.notify(
140 AGENT_METHOD_NAMES.session_cancel,
141 Some(ClientNotification::CancelNotification(args)),
142 )
143 }
144
145 #[cfg(feature = "unstable")]
146 async fn set_session_model(
147 &self,
148 args: SetSessionModelRequest,
149 ) -> Result<SetSessionModelResponse> {
150 self.conn
151 .request(
152 AGENT_METHOD_NAMES.session_set_model,
153 Some(ClientRequest::SetSessionModelRequest(args)),
154 )
155 .await
156 }
157
158 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
159 self.conn
160 .request(
161 format!("_{}", args.method),
162 Some(ClientRequest::ExtMethodRequest(args)),
163 )
164 .await
165 }
166
167 async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
168 self.conn.notify(
169 format!("_{}", args.method),
170 Some(ClientNotification::ExtNotification(args)),
171 )
172 }
173}
174
175#[derive(Clone, Debug)]
182pub struct ClientSide;
183
184impl Side for ClientSide {
185 type InNotification = AgentNotification;
186 type InRequest = AgentRequest;
187 type OutResponse = ClientResponse;
188
189 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
190 let params = params.ok_or_else(Error::invalid_params)?;
191
192 match method {
193 m if m == CLIENT_METHOD_NAMES.session_request_permission => {
194 serde_json::from_str(params.get())
195 .map(AgentRequest::RequestPermissionRequest)
196 .map_err(Into::into)
197 }
198 m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
199 .map(AgentRequest::WriteTextFileRequest)
200 .map_err(Into::into),
201 m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
202 .map(AgentRequest::ReadTextFileRequest)
203 .map_err(Into::into),
204 m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
205 .map(AgentRequest::CreateTerminalRequest)
206 .map_err(Into::into),
207 m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
208 .map(AgentRequest::TerminalOutputRequest)
209 .map_err(Into::into),
210 m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
211 .map(AgentRequest::KillTerminalCommandRequest)
212 .map_err(Into::into),
213 m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
214 .map(AgentRequest::ReleaseTerminalRequest)
215 .map_err(Into::into),
216 m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
217 serde_json::from_str(params.get())
218 .map(AgentRequest::WaitForTerminalExitRequest)
219 .map_err(Into::into)
220 }
221 _ => {
222 if let Some(custom_method) = method.strip_prefix('_') {
223 Ok(AgentRequest::ExtMethodRequest(ExtRequest::new(
224 custom_method,
225 params.to_owned().into(),
226 )))
227 } else {
228 Err(Error::method_not_found())
229 }
230 }
231 }
232 }
233
234 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
235 let params = params.ok_or_else(Error::invalid_params)?;
236
237 match method {
238 m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
239 .map(AgentNotification::SessionNotification)
240 .map_err(Into::into),
241 _ => {
242 if let Some(custom_method) = method.strip_prefix('_') {
243 Ok(AgentNotification::ExtNotification(ExtNotification::new(
244 custom_method,
245 RawValue::from_string(params.get().to_string())?.into(),
246 )))
247 } else {
248 Err(Error::method_not_found())
249 }
250 }
251 }
252 }
253}
254
255impl<T: Client> MessageHandler<ClientSide> for T {
256 async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse> {
257 match request {
258 AgentRequest::RequestPermissionRequest(args) => {
259 let response = self.request_permission(args).await?;
260 Ok(ClientResponse::RequestPermissionResponse(response))
261 }
262 AgentRequest::WriteTextFileRequest(args) => {
263 let response = self.write_text_file(args).await?;
264 Ok(ClientResponse::WriteTextFileResponse(response))
265 }
266 AgentRequest::ReadTextFileRequest(args) => {
267 let response = self.read_text_file(args).await?;
268 Ok(ClientResponse::ReadTextFileResponse(response))
269 }
270 AgentRequest::CreateTerminalRequest(args) => {
271 let response = self.create_terminal(args).await?;
272 Ok(ClientResponse::CreateTerminalResponse(response))
273 }
274 AgentRequest::TerminalOutputRequest(args) => {
275 let response = self.terminal_output(args).await?;
276 Ok(ClientResponse::TerminalOutputResponse(response))
277 }
278 AgentRequest::ReleaseTerminalRequest(args) => {
279 let response = self.release_terminal(args).await?;
280 Ok(ClientResponse::ReleaseTerminalResponse(response))
281 }
282 AgentRequest::WaitForTerminalExitRequest(args) => {
283 let response = self.wait_for_terminal_exit(args).await?;
284 Ok(ClientResponse::WaitForTerminalExitResponse(response))
285 }
286 AgentRequest::KillTerminalCommandRequest(args) => {
287 let response = self.kill_terminal_command(args).await?;
288 Ok(ClientResponse::KillTerminalResponse(response))
289 }
290 AgentRequest::ExtMethodRequest(args) => {
291 let response = self.ext_method(args).await?;
292 Ok(ClientResponse::ExtMethodResponse(response))
293 }
294 _ => Err(Error::method_not_found()),
295 }
296 }
297
298 async fn handle_notification(&self, notification: AgentNotification) -> Result<()> {
299 match notification {
300 AgentNotification::SessionNotification(args) => {
301 self.session_notification(args).await?;
302 }
303 AgentNotification::ExtNotification(args) => {
304 self.ext_notification(args).await?;
305 }
306 _ => {}
308 }
309 Ok(())
310 }
311}
312
313#[derive(Debug)]
324pub struct AgentSideConnection {
325 conn: RpcConnection<AgentSide, ClientSide>,
326}
327
328impl AgentSideConnection {
329 pub fn new(
349 agent: impl MessageHandler<AgentSide> + 'static,
350 outgoing_bytes: impl Unpin + AsyncWrite,
351 incoming_bytes: impl Unpin + AsyncRead,
352 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
353 ) -> (Self, impl Future<Output = Result<()>>) {
354 let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
355 (Self { conn }, io_task)
356 }
357
358 pub fn subscribe(&self) -> StreamReceiver {
367 self.conn.subscribe()
368 }
369}
370
371#[async_trait::async_trait(?Send)]
372impl Client for AgentSideConnection {
373 async fn request_permission(
374 &self,
375 args: RequestPermissionRequest,
376 ) -> Result<RequestPermissionResponse> {
377 self.conn
378 .request(
379 CLIENT_METHOD_NAMES.session_request_permission,
380 Some(AgentRequest::RequestPermissionRequest(args)),
381 )
382 .await
383 }
384
385 async fn write_text_file(&self, args: WriteTextFileRequest) -> Result<WriteTextFileResponse> {
386 self.conn
387 .request::<Option<_>>(
388 CLIENT_METHOD_NAMES.fs_write_text_file,
389 Some(AgentRequest::WriteTextFileRequest(args)),
390 )
391 .await
392 .map(Option::unwrap_or_default)
393 }
394
395 async fn read_text_file(&self, args: ReadTextFileRequest) -> Result<ReadTextFileResponse> {
396 self.conn
397 .request(
398 CLIENT_METHOD_NAMES.fs_read_text_file,
399 Some(AgentRequest::ReadTextFileRequest(args)),
400 )
401 .await
402 }
403
404 async fn create_terminal(&self, args: CreateTerminalRequest) -> Result<CreateTerminalResponse> {
405 self.conn
406 .request(
407 CLIENT_METHOD_NAMES.terminal_create,
408 Some(AgentRequest::CreateTerminalRequest(args)),
409 )
410 .await
411 }
412
413 async fn terminal_output(&self, args: TerminalOutputRequest) -> Result<TerminalOutputResponse> {
414 self.conn
415 .request(
416 CLIENT_METHOD_NAMES.terminal_output,
417 Some(AgentRequest::TerminalOutputRequest(args)),
418 )
419 .await
420 }
421
422 async fn release_terminal(
423 &self,
424 args: ReleaseTerminalRequest,
425 ) -> Result<ReleaseTerminalResponse> {
426 self.conn
427 .request::<Option<_>>(
428 CLIENT_METHOD_NAMES.terminal_release,
429 Some(AgentRequest::ReleaseTerminalRequest(args)),
430 )
431 .await
432 .map(Option::unwrap_or_default)
433 }
434
435 async fn wait_for_terminal_exit(
436 &self,
437 args: WaitForTerminalExitRequest,
438 ) -> Result<WaitForTerminalExitResponse> {
439 self.conn
440 .request(
441 CLIENT_METHOD_NAMES.terminal_wait_for_exit,
442 Some(AgentRequest::WaitForTerminalExitRequest(args)),
443 )
444 .await
445 }
446
447 async fn kill_terminal_command(
448 &self,
449 args: KillTerminalCommandRequest,
450 ) -> Result<KillTerminalCommandResponse> {
451 self.conn
452 .request::<Option<_>>(
453 CLIENT_METHOD_NAMES.terminal_kill,
454 Some(AgentRequest::KillTerminalCommandRequest(args)),
455 )
456 .await
457 .map(Option::unwrap_or_default)
458 }
459
460 async fn session_notification(&self, args: SessionNotification) -> Result<()> {
461 self.conn.notify(
462 CLIENT_METHOD_NAMES.session_update,
463 Some(AgentNotification::SessionNotification(args)),
464 )
465 }
466
467 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
468 self.conn
469 .request(
470 format!("_{}", args.method),
471 Some(AgentRequest::ExtMethodRequest(args)),
472 )
473 .await
474 }
475
476 async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
477 self.conn.notify(
478 format!("_{}", args.method),
479 Some(AgentNotification::ExtNotification(args)),
480 )
481 }
482}
483
484#[derive(Clone, Debug)]
491pub struct AgentSide;
492
493impl Side for AgentSide {
494 type InRequest = ClientRequest;
495 type InNotification = ClientNotification;
496 type OutResponse = AgentResponse;
497
498 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
499 let params = params.ok_or_else(Error::invalid_params)?;
500
501 match method {
502 m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
503 .map(ClientRequest::InitializeRequest)
504 .map_err(Into::into),
505 m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
506 .map(ClientRequest::AuthenticateRequest)
507 .map_err(Into::into),
508 m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
509 .map(ClientRequest::NewSessionRequest)
510 .map_err(Into::into),
511 m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
512 .map(ClientRequest::LoadSessionRequest)
513 .map_err(Into::into),
514 m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
515 .map(ClientRequest::SetSessionModeRequest)
516 .map_err(Into::into),
517 #[cfg(feature = "unstable")]
518 m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
519 .map(ClientRequest::SetSessionModelRequest)
520 .map_err(Into::into),
521 m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
522 .map(ClientRequest::PromptRequest)
523 .map_err(Into::into),
524 _ => {
525 if let Some(custom_method) = method.strip_prefix('_') {
526 Ok(ClientRequest::ExtMethodRequest(ExtRequest::new(
527 custom_method,
528 params.to_owned().into(),
529 )))
530 } else {
531 Err(Error::method_not_found())
532 }
533 }
534 }
535 }
536
537 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
538 let params = params.ok_or_else(Error::invalid_params)?;
539
540 match method {
541 m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
542 .map(ClientNotification::CancelNotification)
543 .map_err(Into::into),
544 _ => {
545 if let Some(custom_method) = method.strip_prefix('_') {
546 Ok(ClientNotification::ExtNotification(ExtNotification::new(
547 custom_method,
548 RawValue::from_string(params.get().to_string())?.into(),
549 )))
550 } else {
551 Err(Error::method_not_found())
552 }
553 }
554 }
555 }
556}
557
558impl<T: Agent> MessageHandler<AgentSide> for T {
559 async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse> {
560 match request {
561 ClientRequest::InitializeRequest(args) => {
562 let response = self.initialize(args).await?;
563 Ok(AgentResponse::InitializeResponse(response))
564 }
565 ClientRequest::AuthenticateRequest(args) => {
566 let response = self.authenticate(args).await?;
567 Ok(AgentResponse::AuthenticateResponse(response))
568 }
569 ClientRequest::NewSessionRequest(args) => {
570 let response = self.new_session(args).await?;
571 Ok(AgentResponse::NewSessionResponse(response))
572 }
573 ClientRequest::LoadSessionRequest(args) => {
574 let response = self.load_session(args).await?;
575 Ok(AgentResponse::LoadSessionResponse(response))
576 }
577 ClientRequest::PromptRequest(args) => {
578 let response = self.prompt(args).await?;
579 Ok(AgentResponse::PromptResponse(response))
580 }
581 ClientRequest::SetSessionModeRequest(args) => {
582 let response = self.set_session_mode(args).await?;
583 Ok(AgentResponse::SetSessionModeResponse(response))
584 }
585 #[cfg(feature = "unstable")]
586 ClientRequest::SetSessionModelRequest(args) => {
587 let response = self.set_session_model(args).await?;
588 Ok(AgentResponse::SetSessionModelResponse(response))
589 }
590 ClientRequest::ExtMethodRequest(args) => {
591 let response = self.ext_method(args).await?;
592 Ok(AgentResponse::ExtMethodResponse(response))
593 }
594 _ => Err(Error::method_not_found()),
595 }
596 }
597
598 async fn handle_notification(&self, notification: ClientNotification) -> Result<()> {
599 match notification {
600 ClientNotification::CancelNotification(args) => {
601 self.cancel(args).await?;
602 }
603 ClientNotification::ExtNotification(args) => {
604 self.ext_notification(args).await?;
605 }
606 _ => {}
608 }
609 Ok(())
610 }
611}