1use futures::{AsyncRead, AsyncWrite, future::LocalBoxFuture};
2use rpc::RpcConnection;
3
4mod agent;
5mod client;
6mod rpc;
7#[cfg(test)]
8mod rpc_tests;
9mod stream_broadcast;
10
11pub use agent::*;
12pub use agent_client_protocol_schema::*;
13pub use client::*;
14pub use rpc::*;
15pub use stream_broadcast::{
16 StreamMessage, StreamMessageContent, StreamMessageDirection, StreamReceiver,
17};
18
19pub struct ClientSideConnection {
30 conn: RpcConnection<ClientSide, AgentSide>,
31}
32
33impl ClientSideConnection {
34 pub fn new(
54 client: impl MessageHandler<ClientSide> + 'static,
55 outgoing_bytes: impl Unpin + AsyncWrite,
56 incoming_bytes: impl Unpin + AsyncRead,
57 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
58 ) -> (Self, impl Future<Output = Result<()>>) {
59 let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
60 (Self { conn }, io_task)
61 }
62
63 pub fn subscribe(&self) -> StreamReceiver {
72 self.conn.subscribe()
73 }
74}
75
76#[async_trait::async_trait(?Send)]
77impl Agent for ClientSideConnection {
78 async fn initialize(&self, args: InitializeRequest) -> Result<InitializeResponse> {
79 self.conn
80 .request(
81 AGENT_METHOD_NAMES.initialize,
82 Some(ClientRequest::InitializeRequest(args)),
83 )
84 .await
85 }
86
87 async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse> {
88 self.conn
89 .request::<Option<_>>(
90 AGENT_METHOD_NAMES.authenticate,
91 Some(ClientRequest::AuthenticateRequest(args)),
92 )
93 .await
94 .map(Option::unwrap_or_default)
95 }
96
97 async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse> {
98 self.conn
99 .request(
100 AGENT_METHOD_NAMES.session_new,
101 Some(ClientRequest::NewSessionRequest(args)),
102 )
103 .await
104 }
105
106 async fn load_session(&self, args: LoadSessionRequest) -> Result<LoadSessionResponse> {
107 self.conn
108 .request::<Option<_>>(
109 AGENT_METHOD_NAMES.session_load,
110 Some(ClientRequest::LoadSessionRequest(args)),
111 )
112 .await
113 .map(Option::unwrap_or_default)
114 }
115
116 async fn set_session_mode(
117 &self,
118 args: SetSessionModeRequest,
119 ) -> Result<SetSessionModeResponse> {
120 self.conn
121 .request(
122 AGENT_METHOD_NAMES.session_set_mode,
123 Some(ClientRequest::SetSessionModeRequest(args)),
124 )
125 .await
126 }
127
128 async fn prompt(&self, args: PromptRequest) -> Result<PromptResponse> {
129 self.conn
130 .request(
131 AGENT_METHOD_NAMES.session_prompt,
132 Some(ClientRequest::PromptRequest(args)),
133 )
134 .await
135 }
136
137 async fn cancel(&self, args: CancelNotification) -> Result<()> {
138 self.conn.notify(
139 AGENT_METHOD_NAMES.session_cancel,
140 Some(ClientNotification::CancelNotification(args)),
141 )
142 }
143
144 #[cfg(feature = "unstable")]
145 async fn set_session_model(
146 &self,
147 args: SetSessionModelRequest,
148 ) -> Result<SetSessionModelResponse> {
149 self.conn
150 .request(
151 AGENT_METHOD_NAMES.session_set_model,
152 Some(ClientRequest::SetSessionModelRequest(args)),
153 )
154 .await
155 }
156
157 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
158 self.conn
159 .request(
160 format!("_{}", args.method),
161 Some(ClientRequest::ExtMethodRequest(args)),
162 )
163 .await
164 }
165
166 async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
167 self.conn.notify(
168 format!("_{}", args.method),
169 Some(ClientNotification::ExtNotification(args)),
170 )
171 }
172}
173
174#[derive(Clone)]
181pub struct ClientSide;
182
183impl Side for ClientSide {
184 type InNotification = AgentNotification;
185 type InRequest = AgentRequest;
186 type OutResponse = ClientResponse;
187
188 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
189 let params = params.ok_or_else(Error::invalid_params)?;
190
191 match method {
192 m if m == CLIENT_METHOD_NAMES.session_request_permission => {
193 serde_json::from_str(params.get())
194 .map(AgentRequest::RequestPermissionRequest)
195 .map_err(Into::into)
196 }
197 m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
198 .map(AgentRequest::WriteTextFileRequest)
199 .map_err(Into::into),
200 m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
201 .map(AgentRequest::ReadTextFileRequest)
202 .map_err(Into::into),
203 m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
204 .map(AgentRequest::CreateTerminalRequest)
205 .map_err(Into::into),
206 m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
207 .map(AgentRequest::TerminalOutputRequest)
208 .map_err(Into::into),
209 m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
210 .map(AgentRequest::KillTerminalCommandRequest)
211 .map_err(Into::into),
212 m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
213 .map(AgentRequest::ReleaseTerminalRequest)
214 .map_err(Into::into),
215 m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
216 serde_json::from_str(params.get())
217 .map(AgentRequest::WaitForTerminalExitRequest)
218 .map_err(Into::into)
219 }
220 _ => {
221 if let Some(custom_method) = method.strip_prefix('_') {
222 Ok(AgentRequest::ExtMethodRequest(ExtRequest {
223 method: custom_method.into(),
224 params: params.to_owned().into(),
225 }))
226 } else {
227 Err(Error::method_not_found())
228 }
229 }
230 }
231 }
232
233 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
234 let params = params.ok_or_else(Error::invalid_params)?;
235
236 match method {
237 m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
238 .map(AgentNotification::SessionNotification)
239 .map_err(Into::into),
240 _ => {
241 if let Some(custom_method) = method.strip_prefix('_') {
242 Ok(AgentNotification::ExtNotification(ExtNotification {
243 method: custom_method.into(),
244 params: RawValue::from_string(params.get().to_string())?.into(),
245 }))
246 } else {
247 Err(Error::method_not_found())
248 }
249 }
250 }
251 }
252}
253
254impl<T: Client> MessageHandler<ClientSide> for T {
255 async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse> {
256 match request {
257 AgentRequest::RequestPermissionRequest(args) => {
258 let response = self.request_permission(args).await?;
259 Ok(ClientResponse::RequestPermissionResponse(response))
260 }
261 AgentRequest::WriteTextFileRequest(args) => {
262 let response = self.write_text_file(args).await?;
263 Ok(ClientResponse::WriteTextFileResponse(response))
264 }
265 AgentRequest::ReadTextFileRequest(args) => {
266 let response = self.read_text_file(args).await?;
267 Ok(ClientResponse::ReadTextFileResponse(response))
268 }
269 AgentRequest::CreateTerminalRequest(args) => {
270 let response = self.create_terminal(args).await?;
271 Ok(ClientResponse::CreateTerminalResponse(response))
272 }
273 AgentRequest::TerminalOutputRequest(args) => {
274 let response = self.terminal_output(args).await?;
275 Ok(ClientResponse::TerminalOutputResponse(response))
276 }
277 AgentRequest::ReleaseTerminalRequest(args) => {
278 let response = self.release_terminal(args).await?;
279 Ok(ClientResponse::ReleaseTerminalResponse(response))
280 }
281 AgentRequest::WaitForTerminalExitRequest(args) => {
282 let response = self.wait_for_terminal_exit(args).await?;
283 Ok(ClientResponse::WaitForTerminalExitResponse(response))
284 }
285 AgentRequest::KillTerminalCommandRequest(args) => {
286 let response = self.kill_terminal_command(args).await?;
287 Ok(ClientResponse::KillTerminalResponse(response))
288 }
289 AgentRequest::ExtMethodRequest(args) => {
290 let response = self.ext_method(args).await?;
291 Ok(ClientResponse::ExtMethodResponse(response))
292 }
293 }
294 }
295
296 async fn handle_notification(&self, notification: AgentNotification) -> Result<()> {
297 match notification {
298 AgentNotification::SessionNotification(args) => {
299 self.session_notification(args).await?;
300 }
301 AgentNotification::ExtNotification(args) => {
302 self.ext_notification(args).await?;
303 }
304 }
305 Ok(())
306 }
307}
308
309pub struct AgentSideConnection {
320 conn: RpcConnection<AgentSide, ClientSide>,
321}
322
323impl AgentSideConnection {
324 pub fn new(
344 agent: impl MessageHandler<AgentSide> + 'static,
345 outgoing_bytes: impl Unpin + AsyncWrite,
346 incoming_bytes: impl Unpin + AsyncRead,
347 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
348 ) -> (Self, impl Future<Output = Result<()>>) {
349 let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
350 (Self { conn }, io_task)
351 }
352
353 pub fn subscribe(&self) -> StreamReceiver {
362 self.conn.subscribe()
363 }
364}
365
366#[async_trait::async_trait(?Send)]
367impl Client for AgentSideConnection {
368 async fn request_permission(
369 &self,
370 args: RequestPermissionRequest,
371 ) -> Result<RequestPermissionResponse> {
372 self.conn
373 .request(
374 CLIENT_METHOD_NAMES.session_request_permission,
375 Some(AgentRequest::RequestPermissionRequest(args)),
376 )
377 .await
378 }
379
380 async fn write_text_file(&self, args: WriteTextFileRequest) -> Result<WriteTextFileResponse> {
381 self.conn
382 .request::<Option<_>>(
383 CLIENT_METHOD_NAMES.fs_write_text_file,
384 Some(AgentRequest::WriteTextFileRequest(args)),
385 )
386 .await
387 .map(Option::unwrap_or_default)
388 }
389
390 async fn read_text_file(&self, args: ReadTextFileRequest) -> Result<ReadTextFileResponse> {
391 self.conn
392 .request(
393 CLIENT_METHOD_NAMES.fs_read_text_file,
394 Some(AgentRequest::ReadTextFileRequest(args)),
395 )
396 .await
397 }
398
399 async fn create_terminal(&self, args: CreateTerminalRequest) -> Result<CreateTerminalResponse> {
400 self.conn
401 .request(
402 CLIENT_METHOD_NAMES.terminal_create,
403 Some(AgentRequest::CreateTerminalRequest(args)),
404 )
405 .await
406 }
407
408 async fn terminal_output(&self, args: TerminalOutputRequest) -> Result<TerminalOutputResponse> {
409 self.conn
410 .request(
411 CLIENT_METHOD_NAMES.terminal_output,
412 Some(AgentRequest::TerminalOutputRequest(args)),
413 )
414 .await
415 }
416
417 async fn release_terminal(
418 &self,
419 args: ReleaseTerminalRequest,
420 ) -> Result<ReleaseTerminalResponse> {
421 self.conn
422 .request::<Option<_>>(
423 CLIENT_METHOD_NAMES.terminal_release,
424 Some(AgentRequest::ReleaseTerminalRequest(args)),
425 )
426 .await
427 .map(Option::unwrap_or_default)
428 }
429
430 async fn wait_for_terminal_exit(
431 &self,
432 args: WaitForTerminalExitRequest,
433 ) -> Result<WaitForTerminalExitResponse> {
434 self.conn
435 .request(
436 CLIENT_METHOD_NAMES.terminal_wait_for_exit,
437 Some(AgentRequest::WaitForTerminalExitRequest(args)),
438 )
439 .await
440 }
441
442 async fn kill_terminal_command(
443 &self,
444 args: KillTerminalCommandRequest,
445 ) -> Result<KillTerminalCommandResponse> {
446 self.conn
447 .request::<Option<_>>(
448 CLIENT_METHOD_NAMES.terminal_kill,
449 Some(AgentRequest::KillTerminalCommandRequest(args)),
450 )
451 .await
452 .map(Option::unwrap_or_default)
453 }
454
455 async fn session_notification(&self, args: SessionNotification) -> Result<()> {
456 self.conn.notify(
457 CLIENT_METHOD_NAMES.session_update,
458 Some(AgentNotification::SessionNotification(args)),
459 )
460 }
461
462 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
463 self.conn
464 .request(
465 format!("_{}", args.method),
466 Some(AgentRequest::ExtMethodRequest(args)),
467 )
468 .await
469 }
470
471 async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
472 self.conn.notify(
473 format!("_{}", args.method),
474 Some(AgentNotification::ExtNotification(args)),
475 )
476 }
477}
478
479#[derive(Clone)]
486pub struct AgentSide;
487
488impl Side for AgentSide {
489 type InRequest = ClientRequest;
490 type InNotification = ClientNotification;
491 type OutResponse = AgentResponse;
492
493 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
494 let params = params.ok_or_else(Error::invalid_params)?;
495
496 match method {
497 m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
498 .map(ClientRequest::InitializeRequest)
499 .map_err(Into::into),
500 m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
501 .map(ClientRequest::AuthenticateRequest)
502 .map_err(Into::into),
503 m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
504 .map(ClientRequest::NewSessionRequest)
505 .map_err(Into::into),
506 m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
507 .map(ClientRequest::LoadSessionRequest)
508 .map_err(Into::into),
509 m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
510 .map(ClientRequest::SetSessionModeRequest)
511 .map_err(Into::into),
512 #[cfg(feature = "unstable")]
513 m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
514 .map(ClientRequest::SetSessionModelRequest)
515 .map_err(Into::into),
516 m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
517 .map(ClientRequest::PromptRequest)
518 .map_err(Into::into),
519 _ => {
520 if let Some(custom_method) = method.strip_prefix('_') {
521 Ok(ClientRequest::ExtMethodRequest(ExtRequest {
522 method: custom_method.into(),
523 params: params.to_owned().into(),
524 }))
525 } else {
526 Err(Error::method_not_found())
527 }
528 }
529 }
530 }
531
532 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
533 let params = params.ok_or_else(Error::invalid_params)?;
534
535 match method {
536 m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
537 .map(ClientNotification::CancelNotification)
538 .map_err(Into::into),
539 _ => {
540 if let Some(custom_method) = method.strip_prefix('_') {
541 Ok(ClientNotification::ExtNotification(ExtNotification {
542 method: custom_method.into(),
543 params: RawValue::from_string(params.get().to_string())?.into(),
544 }))
545 } else {
546 Err(Error::method_not_found())
547 }
548 }
549 }
550 }
551}
552
553impl<T: Agent> MessageHandler<AgentSide> for T {
554 async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse> {
555 match request {
556 ClientRequest::InitializeRequest(args) => {
557 let response = self.initialize(args).await?;
558 Ok(AgentResponse::InitializeResponse(response))
559 }
560 ClientRequest::AuthenticateRequest(args) => {
561 let response = self.authenticate(args).await?;
562 Ok(AgentResponse::AuthenticateResponse(response))
563 }
564 ClientRequest::NewSessionRequest(args) => {
565 let response = self.new_session(args).await?;
566 Ok(AgentResponse::NewSessionResponse(response))
567 }
568 ClientRequest::LoadSessionRequest(args) => {
569 let response = self.load_session(args).await?;
570 Ok(AgentResponse::LoadSessionResponse(response))
571 }
572 ClientRequest::PromptRequest(args) => {
573 let response = self.prompt(args).await?;
574 Ok(AgentResponse::PromptResponse(response))
575 }
576 ClientRequest::SetSessionModeRequest(args) => {
577 let response = self.set_session_mode(args).await?;
578 Ok(AgentResponse::SetSessionModeResponse(response))
579 }
580 #[cfg(feature = "unstable")]
581 ClientRequest::SetSessionModelRequest(args) => {
582 let response = self.set_session_model(args).await?;
583 Ok(AgentResponse::SetSessionModelResponse(response))
584 }
585 ClientRequest::ExtMethodRequest(args) => {
586 let response = self.ext_method(args).await?;
587 Ok(AgentResponse::ExtMethodResponse(response))
588 }
589 }
590 }
591
592 async fn handle_notification(&self, notification: ClientNotification) -> Result<()> {
593 match notification {
594 ClientNotification::CancelNotification(args) => {
595 self.cancel(args).await?;
596 }
597 ClientNotification::ExtNotification(args) => {
598 self.ext_notification(args).await?;
599 }
600 }
601 Ok(())
602 }
603}