1mod agent;
54mod client;
55mod content;
56mod error;
57mod ext;
58mod plan;
59mod rpc;
60#[cfg(test)]
61mod rpc_tests;
62mod stream_broadcast;
63mod tool_call;
64mod version;
65
66pub use agent::*;
67pub use client::*;
68pub use content::*;
69pub use error::*;
70pub use ext::*;
71pub use plan::*;
72pub use serde_json::value::RawValue;
73pub use stream_broadcast::{
74 StreamMessage, StreamMessageContent, StreamMessageDirection, StreamReceiver,
75};
76pub use tool_call::*;
77pub use version::*;
78
79use anyhow::Result;
80use futures::{AsyncRead, AsyncWrite, Future, future::LocalBoxFuture};
81use schemars::JsonSchema;
82use serde::{Deserialize, Serialize};
83use std::{fmt, sync::Arc};
84
85use crate::rpc::{MessageHandler, RpcConnection, Side};
86
87#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
103#[serde(transparent)]
104pub struct SessionId(pub Arc<str>);
105
106impl fmt::Display for SessionId {
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108 write!(f, "{}", self.0)
109 }
110}
111
112pub struct ClientSideConnection {
123 conn: RpcConnection<ClientSide, AgentSide>,
124}
125
126impl ClientSideConnection {
127 pub fn new(
147 client: impl MessageHandler<ClientSide> + 'static,
148 outgoing_bytes: impl Unpin + AsyncWrite,
149 incoming_bytes: impl Unpin + AsyncRead,
150 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
151 ) -> (Self, impl Future<Output = Result<()>>) {
152 let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
153 (Self { conn }, io_task)
154 }
155
156 pub fn subscribe(&self) -> StreamReceiver {
165 self.conn.subscribe()
166 }
167}
168
169#[async_trait::async_trait(?Send)]
170impl Agent for ClientSideConnection {
171 async fn initialize(&self, args: InitializeRequest) -> Result<InitializeResponse, Error> {
172 self.conn
173 .request(
174 INITIALIZE_METHOD_NAME,
175 Some(ClientRequest::InitializeRequest(args)),
176 )
177 .await
178 }
179
180 async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse, Error> {
181 self.conn
182 .request::<Option<_>>(
183 AUTHENTICATE_METHOD_NAME,
184 Some(ClientRequest::AuthenticateRequest(args)),
185 )
186 .await
187 .map(|value| value.unwrap_or_default())
188 }
189
190 async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse, Error> {
191 self.conn
192 .request(
193 SESSION_NEW_METHOD_NAME,
194 Some(ClientRequest::NewSessionRequest(args)),
195 )
196 .await
197 }
198
199 async fn load_session(&self, args: LoadSessionRequest) -> Result<LoadSessionResponse, Error> {
200 self.conn
201 .request::<Option<_>>(
202 SESSION_LOAD_METHOD_NAME,
203 Some(ClientRequest::LoadSessionRequest(args)),
204 )
205 .await
206 .map(|value| value.unwrap_or_default())
207 }
208
209 async fn set_session_mode(
210 &self,
211 args: SetSessionModeRequest,
212 ) -> Result<SetSessionModeResponse, Error> {
213 self.conn
214 .request(
215 SESSION_SET_MODE_METHOD_NAME,
216 Some(ClientRequest::SetSessionModeRequest(args)),
217 )
218 .await
219 }
220
221 async fn prompt(&self, args: PromptRequest) -> Result<PromptResponse, Error> {
222 self.conn
223 .request(
224 SESSION_PROMPT_METHOD_NAME,
225 Some(ClientRequest::PromptRequest(args)),
226 )
227 .await
228 }
229
230 async fn cancel(&self, args: CancelNotification) -> Result<(), Error> {
231 self.conn.notify(
232 SESSION_CANCEL_METHOD_NAME,
233 Some(ClientNotification::CancelNotification(args)),
234 )
235 }
236
237 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse, Error> {
238 self.conn
239 .request(
240 format!("_{}", args.method),
241 Some(ClientRequest::ExtMethodRequest(args)),
242 )
243 .await
244 }
245
246 async fn ext_notification(&self, args: ExtNotification) -> Result<(), Error> {
247 self.conn.notify(
248 format!("_{}", args.method),
249 Some(ClientNotification::ExtNotification(args)),
250 )
251 }
252}
253
254#[derive(Clone)]
261pub struct ClientSide;
262
263impl Side for ClientSide {
264 type InNotification = AgentNotification;
265 type InRequest = AgentRequest;
266 type OutResponse = ClientResponse;
267
268 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest, Error> {
269 let params = params.ok_or_else(Error::invalid_params)?;
270
271 match method {
272 SESSION_REQUEST_PERMISSION_METHOD_NAME => serde_json::from_str(params.get())
273 .map(AgentRequest::RequestPermissionRequest)
274 .map_err(Into::into),
275 FS_WRITE_TEXT_FILE_METHOD_NAME => serde_json::from_str(params.get())
276 .map(AgentRequest::WriteTextFileRequest)
277 .map_err(Into::into),
278 FS_READ_TEXT_FILE_METHOD_NAME => serde_json::from_str(params.get())
279 .map(AgentRequest::ReadTextFileRequest)
280 .map_err(Into::into),
281 TERMINAL_CREATE_METHOD_NAME => serde_json::from_str(params.get())
282 .map(AgentRequest::CreateTerminalRequest)
283 .map_err(Into::into),
284 TERMINAL_OUTPUT_METHOD_NAME => serde_json::from_str(params.get())
285 .map(AgentRequest::TerminalOutputRequest)
286 .map_err(Into::into),
287 TERMINAL_KILL_METHOD_NAME => serde_json::from_str(params.get())
288 .map(AgentRequest::KillTerminalCommandRequest)
289 .map_err(Into::into),
290 TERMINAL_RELEASE_METHOD_NAME => serde_json::from_str(params.get())
291 .map(AgentRequest::ReleaseTerminalRequest)
292 .map_err(Into::into),
293 TERMINAL_WAIT_FOR_EXIT_METHOD_NAME => serde_json::from_str(params.get())
294 .map(AgentRequest::WaitForTerminalExitRequest)
295 .map_err(Into::into),
296 _ => {
297 if let Some(custom_method) = method.strip_prefix('_') {
298 Ok(AgentRequest::ExtMethodRequest(ExtRequest {
299 method: custom_method.into(),
300 params: RawValue::from_string(params.get().to_string())?.into(),
301 }))
302 } else {
303 Err(Error::method_not_found())
304 }
305 }
306 }
307 }
308
309 fn decode_notification(
310 method: &str,
311 params: Option<&RawValue>,
312 ) -> Result<AgentNotification, Error> {
313 let params = params.ok_or_else(Error::invalid_params)?;
314
315 match method {
316 SESSION_UPDATE_NOTIFICATION => serde_json::from_str(params.get())
317 .map(AgentNotification::SessionNotification)
318 .map_err(Into::into),
319 _ => {
320 if let Some(custom_method) = method.strip_prefix('_') {
321 Ok(AgentNotification::ExtNotification(ExtNotification {
322 method: custom_method.into(),
323 params: RawValue::from_string(params.get().to_string())?.into(),
324 }))
325 } else {
326 Err(Error::method_not_found())
327 }
328 }
329 }
330 }
331}
332
333impl<T: Client> MessageHandler<ClientSide> for T {
334 async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse, Error> {
335 match request {
336 AgentRequest::RequestPermissionRequest(args) => {
337 let response = self.request_permission(args).await?;
338 Ok(ClientResponse::RequestPermissionResponse(response))
339 }
340 AgentRequest::WriteTextFileRequest(args) => {
341 let response = self.write_text_file(args).await?;
342 Ok(ClientResponse::WriteTextFileResponse(response))
343 }
344 AgentRequest::ReadTextFileRequest(args) => {
345 let response = self.read_text_file(args).await?;
346 Ok(ClientResponse::ReadTextFileResponse(response))
347 }
348 AgentRequest::CreateTerminalRequest(args) => {
349 let response = self.create_terminal(args).await?;
350 Ok(ClientResponse::CreateTerminalResponse(response))
351 }
352 AgentRequest::TerminalOutputRequest(args) => {
353 let response = self.terminal_output(args).await?;
354 Ok(ClientResponse::TerminalOutputResponse(response))
355 }
356 AgentRequest::ReleaseTerminalRequest(args) => {
357 let response = self.release_terminal(args).await?;
358 Ok(ClientResponse::ReleaseTerminalResponse(response))
359 }
360 AgentRequest::WaitForTerminalExitRequest(args) => {
361 let response = self.wait_for_terminal_exit(args).await?;
362 Ok(ClientResponse::WaitForTerminalExitResponse(response))
363 }
364 AgentRequest::KillTerminalCommandRequest(args) => {
365 let response = self.kill_terminal_command(args).await?;
366 Ok(ClientResponse::KillTerminalResponse(response))
367 }
368 AgentRequest::ExtMethodRequest(args) => {
369 let response = self.ext_method(args).await?;
370 Ok(ClientResponse::ExtMethodResponse(response))
371 }
372 }
373 }
374
375 async fn handle_notification(&self, notification: AgentNotification) -> Result<(), Error> {
376 match notification {
377 AgentNotification::SessionNotification(args) => {
378 self.session_notification(args).await?;
379 }
380 AgentNotification::ExtNotification(args) => {
381 self.ext_notification(args).await?;
382 }
383 }
384 Ok(())
385 }
386}
387
388pub struct AgentSideConnection {
399 conn: RpcConnection<AgentSide, ClientSide>,
400}
401
402impl AgentSideConnection {
403 pub fn new(
423 agent: impl MessageHandler<AgentSide> + 'static,
424 outgoing_bytes: impl Unpin + AsyncWrite,
425 incoming_bytes: impl Unpin + AsyncRead,
426 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
427 ) -> (Self, impl Future<Output = Result<()>>) {
428 let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
429 (Self { conn }, io_task)
430 }
431
432 pub fn subscribe(&self) -> StreamReceiver {
441 self.conn.subscribe()
442 }
443}
444
445#[async_trait::async_trait(?Send)]
446impl Client for AgentSideConnection {
447 async fn request_permission(
448 &self,
449 args: RequestPermissionRequest,
450 ) -> Result<RequestPermissionResponse, Error> {
451 self.conn
452 .request(
453 SESSION_REQUEST_PERMISSION_METHOD_NAME,
454 Some(AgentRequest::RequestPermissionRequest(args)),
455 )
456 .await
457 }
458
459 async fn write_text_file(
460 &self,
461 args: WriteTextFileRequest,
462 ) -> Result<WriteTextFileResponse, Error> {
463 self.conn
464 .request::<Option<_>>(
465 FS_WRITE_TEXT_FILE_METHOD_NAME,
466 Some(AgentRequest::WriteTextFileRequest(args)),
467 )
468 .await
469 .map(|value| value.unwrap_or_default())
470 }
471
472 async fn read_text_file(
473 &self,
474 args: ReadTextFileRequest,
475 ) -> Result<ReadTextFileResponse, Error> {
476 self.conn
477 .request(
478 FS_READ_TEXT_FILE_METHOD_NAME,
479 Some(AgentRequest::ReadTextFileRequest(args)),
480 )
481 .await
482 }
483
484 async fn create_terminal(
485 &self,
486 args: CreateTerminalRequest,
487 ) -> Result<CreateTerminalResponse, Error> {
488 self.conn
489 .request(
490 TERMINAL_CREATE_METHOD_NAME,
491 Some(AgentRequest::CreateTerminalRequest(args)),
492 )
493 .await
494 }
495
496 async fn terminal_output(
497 &self,
498 args: TerminalOutputRequest,
499 ) -> Result<TerminalOutputResponse, Error> {
500 self.conn
501 .request(
502 TERMINAL_OUTPUT_METHOD_NAME,
503 Some(AgentRequest::TerminalOutputRequest(args)),
504 )
505 .await
506 }
507
508 async fn release_terminal(
509 &self,
510 args: ReleaseTerminalRequest,
511 ) -> Result<ReleaseTerminalResponse, Error> {
512 self.conn
513 .request::<Option<_>>(
514 TERMINAL_RELEASE_METHOD_NAME,
515 Some(AgentRequest::ReleaseTerminalRequest(args)),
516 )
517 .await
518 .map(|value| value.unwrap_or_default())
519 }
520
521 async fn wait_for_terminal_exit(
522 &self,
523 args: WaitForTerminalExitRequest,
524 ) -> Result<WaitForTerminalExitResponse, Error> {
525 self.conn
526 .request(
527 TERMINAL_WAIT_FOR_EXIT_METHOD_NAME,
528 Some(AgentRequest::WaitForTerminalExitRequest(args)),
529 )
530 .await
531 }
532
533 async fn kill_terminal_command(
534 &self,
535 args: KillTerminalCommandRequest,
536 ) -> Result<KillTerminalCommandResponse, Error> {
537 self.conn
538 .request::<Option<_>>(
539 TERMINAL_KILL_METHOD_NAME,
540 Some(AgentRequest::KillTerminalCommandRequest(args)),
541 )
542 .await
543 .map(|value| value.unwrap_or_default())
544 }
545
546 async fn session_notification(&self, args: SessionNotification) -> Result<(), Error> {
547 self.conn.notify(
548 SESSION_UPDATE_NOTIFICATION,
549 Some(AgentNotification::SessionNotification(args)),
550 )
551 }
552
553 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse, Error> {
554 self.conn
555 .request(
556 format!("_{}", args.method),
557 Some(AgentRequest::ExtMethodRequest(args)),
558 )
559 .await
560 }
561
562 async fn ext_notification(&self, args: ExtNotification) -> Result<(), Error> {
563 self.conn.notify(
564 format!("_{}", args.method),
565 Some(AgentNotification::ExtNotification(args)),
566 )
567 }
568}
569
570#[derive(Clone)]
577pub struct AgentSide;
578
579impl Side for AgentSide {
580 type InRequest = ClientRequest;
581 type InNotification = ClientNotification;
582 type OutResponse = AgentResponse;
583
584 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest, Error> {
585 let params = params.ok_or_else(Error::invalid_params)?;
586
587 match method {
588 INITIALIZE_METHOD_NAME => serde_json::from_str(params.get())
589 .map(ClientRequest::InitializeRequest)
590 .map_err(Into::into),
591 AUTHENTICATE_METHOD_NAME => serde_json::from_str(params.get())
592 .map(ClientRequest::AuthenticateRequest)
593 .map_err(Into::into),
594 SESSION_NEW_METHOD_NAME => serde_json::from_str(params.get())
595 .map(ClientRequest::NewSessionRequest)
596 .map_err(Into::into),
597 SESSION_LOAD_METHOD_NAME => serde_json::from_str(params.get())
598 .map(ClientRequest::LoadSessionRequest)
599 .map_err(Into::into),
600 SESSION_SET_MODE_METHOD_NAME => serde_json::from_str(params.get())
601 .map(ClientRequest::SetSessionModeRequest)
602 .map_err(Into::into),
603 SESSION_PROMPT_METHOD_NAME => serde_json::from_str(params.get())
604 .map(ClientRequest::PromptRequest)
605 .map_err(Into::into),
606 _ => {
607 if let Some(custom_method) = method.strip_prefix('_') {
608 Ok(ClientRequest::ExtMethodRequest(ExtRequest {
609 method: custom_method.into(),
610 params: RawValue::from_string(params.get().to_string())?.into(),
611 }))
612 } else {
613 Err(Error::method_not_found())
614 }
615 }
616 }
617 }
618
619 fn decode_notification(
620 method: &str,
621 params: Option<&RawValue>,
622 ) -> Result<ClientNotification, Error> {
623 let params = params.ok_or_else(Error::invalid_params)?;
624
625 match method {
626 SESSION_CANCEL_METHOD_NAME => serde_json::from_str(params.get())
627 .map(ClientNotification::CancelNotification)
628 .map_err(Into::into),
629 _ => {
630 if let Some(custom_method) = method.strip_prefix('_') {
631 Ok(ClientNotification::ExtNotification(ExtNotification {
632 method: custom_method.into(),
633 params: RawValue::from_string(params.get().to_string())?.into(),
634 }))
635 } else {
636 Err(Error::method_not_found())
637 }
638 }
639 }
640 }
641}
642
643impl<T: Agent> MessageHandler<AgentSide> for T {
644 async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse, Error> {
645 match request {
646 ClientRequest::InitializeRequest(args) => {
647 let response = self.initialize(args).await?;
648 Ok(AgentResponse::InitializeResponse(response))
649 }
650 ClientRequest::AuthenticateRequest(args) => {
651 let response = self.authenticate(args).await?;
652 Ok(AgentResponse::AuthenticateResponse(response))
653 }
654 ClientRequest::NewSessionRequest(args) => {
655 let response = self.new_session(args).await?;
656 Ok(AgentResponse::NewSessionResponse(response))
657 }
658 ClientRequest::LoadSessionRequest(args) => {
659 let response = self.load_session(args).await?;
660 Ok(AgentResponse::LoadSessionResponse(response))
661 }
662 ClientRequest::PromptRequest(args) => {
663 let response = self.prompt(args).await?;
664 Ok(AgentResponse::PromptResponse(response))
665 }
666 ClientRequest::SetSessionModeRequest(args) => {
667 let response = self.set_session_mode(args).await?;
668 Ok(AgentResponse::SetSessionModeResponse(response))
669 }
670 ClientRequest::ExtMethodRequest(args) => {
671 let response = self.ext_method(args).await?;
672 Ok(AgentResponse::ExtMethodResponse(response))
673 }
674 }
675 }
676
677 async fn handle_notification(&self, notification: ClientNotification) -> Result<(), Error> {
678 match notification {
679 ClientNotification::CancelNotification(args) => {
680 self.cancel(args).await?;
681 }
682 ClientNotification::ExtNotification(args) => {
683 self.ext_notification(args).await?;
684 }
685 }
686 Ok(())
687 }
688}