1mod agent;
54mod client;
55mod content;
56mod error;
57mod ext;
58mod plan;
59mod rpc;
60#[cfg(test)]
61mod rpc_tests;
62mod stream_broadcast;
63mod tool_call;
64mod version;
65
66pub use agent::*;
67pub use client::*;
68pub use content::*;
69pub use error::*;
70pub use ext::*;
71pub use plan::*;
72pub use serde_json::value::RawValue;
73pub use stream_broadcast::{
74 StreamMessage, StreamMessageContent, StreamMessageDirection, StreamReceiver,
75};
76pub use tool_call::*;
77pub use version::*;
78
79use anyhow::Result;
80use futures::{AsyncRead, AsyncWrite, Future, future::LocalBoxFuture};
81use schemars::JsonSchema;
82use serde::{Deserialize, Serialize};
83use std::{fmt, sync::Arc};
84
85use crate::rpc::{MessageHandler, RpcConnection, Side};
86
87#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
103#[serde(transparent)]
104pub struct SessionId(pub Arc<str>);
105
106impl fmt::Display for SessionId {
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108 write!(f, "{}", self.0)
109 }
110}
111
112pub struct ClientSideConnection {
123 conn: RpcConnection<ClientSide, AgentSide>,
124}
125
126impl ClientSideConnection {
127 pub fn new(
147 client: impl MessageHandler<ClientSide> + 'static,
148 outgoing_bytes: impl Unpin + AsyncWrite,
149 incoming_bytes: impl Unpin + AsyncRead,
150 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
151 ) -> (Self, impl Future<Output = Result<()>>) {
152 let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
153 (Self { conn }, io_task)
154 }
155
156 pub fn subscribe(&self) -> StreamReceiver {
165 self.conn.subscribe()
166 }
167}
168
169#[async_trait::async_trait(?Send)]
170impl Agent for ClientSideConnection {
171 async fn initialize(&self, args: InitializeRequest) -> Result<InitializeResponse, Error> {
172 self.conn
173 .request(
174 INITIALIZE_METHOD_NAME,
175 Some(ClientRequest::InitializeRequest(args)),
176 )
177 .await
178 }
179
180 async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse, Error> {
181 self.conn
182 .request::<Option<_>>(
183 AUTHENTICATE_METHOD_NAME,
184 Some(ClientRequest::AuthenticateRequest(args)),
185 )
186 .await
187 .map(Option::unwrap_or_default)
188 }
189
190 async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse, Error> {
191 self.conn
192 .request(
193 SESSION_NEW_METHOD_NAME,
194 Some(ClientRequest::NewSessionRequest(args)),
195 )
196 .await
197 }
198
199 async fn load_session(&self, args: LoadSessionRequest) -> Result<LoadSessionResponse, Error> {
200 self.conn
201 .request::<Option<_>>(
202 SESSION_LOAD_METHOD_NAME,
203 Some(ClientRequest::LoadSessionRequest(args)),
204 )
205 .await
206 .map(Option::unwrap_or_default)
207 }
208
209 async fn set_session_mode(
210 &self,
211 args: SetSessionModeRequest,
212 ) -> Result<SetSessionModeResponse, Error> {
213 self.conn
214 .request(
215 SESSION_SET_MODE_METHOD_NAME,
216 Some(ClientRequest::SetSessionModeRequest(args)),
217 )
218 .await
219 }
220
221 async fn prompt(&self, args: PromptRequest) -> Result<PromptResponse, Error> {
222 self.conn
223 .request(
224 SESSION_PROMPT_METHOD_NAME,
225 Some(ClientRequest::PromptRequest(args)),
226 )
227 .await
228 }
229
230 async fn cancel(&self, args: CancelNotification) -> Result<(), Error> {
231 self.conn.notify(
232 SESSION_CANCEL_METHOD_NAME,
233 Some(ClientNotification::CancelNotification(args)),
234 )
235 }
236
237 #[cfg(feature = "unstable")]
238 async fn set_session_model(
239 &self,
240 args: SetSessionModelRequest,
241 ) -> Result<SetSessionModelResponse, Error> {
242 self.conn
243 .request(
244 SESSION_SET_MODEL_METHOD_NAME,
245 Some(ClientRequest::SetSessionModelRequest(args)),
246 )
247 .await
248 }
249
250 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse, Error> {
251 self.conn
252 .request(
253 format!("_{}", args.method),
254 Some(ClientRequest::ExtMethodRequest(args)),
255 )
256 .await
257 }
258
259 async fn ext_notification(&self, args: ExtNotification) -> Result<(), Error> {
260 self.conn.notify(
261 format!("_{}", args.method),
262 Some(ClientNotification::ExtNotification(args)),
263 )
264 }
265}
266
267#[derive(Clone)]
274pub struct ClientSide;
275
276impl Side for ClientSide {
277 type InNotification = AgentNotification;
278 type InRequest = AgentRequest;
279 type OutResponse = ClientResponse;
280
281 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest, Error> {
282 let params = params.ok_or_else(Error::invalid_params)?;
283
284 match method {
285 SESSION_REQUEST_PERMISSION_METHOD_NAME => serde_json::from_str(params.get())
286 .map(AgentRequest::RequestPermissionRequest)
287 .map_err(Into::into),
288 FS_WRITE_TEXT_FILE_METHOD_NAME => serde_json::from_str(params.get())
289 .map(AgentRequest::WriteTextFileRequest)
290 .map_err(Into::into),
291 FS_READ_TEXT_FILE_METHOD_NAME => serde_json::from_str(params.get())
292 .map(AgentRequest::ReadTextFileRequest)
293 .map_err(Into::into),
294 TERMINAL_CREATE_METHOD_NAME => serde_json::from_str(params.get())
295 .map(AgentRequest::CreateTerminalRequest)
296 .map_err(Into::into),
297 TERMINAL_OUTPUT_METHOD_NAME => serde_json::from_str(params.get())
298 .map(AgentRequest::TerminalOutputRequest)
299 .map_err(Into::into),
300 TERMINAL_KILL_METHOD_NAME => serde_json::from_str(params.get())
301 .map(AgentRequest::KillTerminalCommandRequest)
302 .map_err(Into::into),
303 TERMINAL_RELEASE_METHOD_NAME => serde_json::from_str(params.get())
304 .map(AgentRequest::ReleaseTerminalRequest)
305 .map_err(Into::into),
306 TERMINAL_WAIT_FOR_EXIT_METHOD_NAME => serde_json::from_str(params.get())
307 .map(AgentRequest::WaitForTerminalExitRequest)
308 .map_err(Into::into),
309 _ => {
310 if let Some(custom_method) = method.strip_prefix('_') {
311 Ok(AgentRequest::ExtMethodRequest(ExtRequest {
312 method: custom_method.into(),
313 params: RawValue::from_string(params.get().to_string())?.into(),
314 }))
315 } else {
316 Err(Error::method_not_found())
317 }
318 }
319 }
320 }
321
322 fn decode_notification(
323 method: &str,
324 params: Option<&RawValue>,
325 ) -> Result<AgentNotification, Error> {
326 let params = params.ok_or_else(Error::invalid_params)?;
327
328 match method {
329 SESSION_UPDATE_NOTIFICATION => serde_json::from_str(params.get())
330 .map(AgentNotification::SessionNotification)
331 .map_err(Into::into),
332 _ => {
333 if let Some(custom_method) = method.strip_prefix('_') {
334 Ok(AgentNotification::ExtNotification(ExtNotification {
335 method: custom_method.into(),
336 params: RawValue::from_string(params.get().to_string())?.into(),
337 }))
338 } else {
339 Err(Error::method_not_found())
340 }
341 }
342 }
343 }
344}
345
346impl<T: Client> MessageHandler<ClientSide> for T {
347 async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse, Error> {
348 match request {
349 AgentRequest::RequestPermissionRequest(args) => {
350 let response = self.request_permission(args).await?;
351 Ok(ClientResponse::RequestPermissionResponse(response))
352 }
353 AgentRequest::WriteTextFileRequest(args) => {
354 let response = self.write_text_file(args).await?;
355 Ok(ClientResponse::WriteTextFileResponse(response))
356 }
357 AgentRequest::ReadTextFileRequest(args) => {
358 let response = self.read_text_file(args).await?;
359 Ok(ClientResponse::ReadTextFileResponse(response))
360 }
361 AgentRequest::CreateTerminalRequest(args) => {
362 let response = self.create_terminal(args).await?;
363 Ok(ClientResponse::CreateTerminalResponse(response))
364 }
365 AgentRequest::TerminalOutputRequest(args) => {
366 let response = self.terminal_output(args).await?;
367 Ok(ClientResponse::TerminalOutputResponse(response))
368 }
369 AgentRequest::ReleaseTerminalRequest(args) => {
370 let response = self.release_terminal(args).await?;
371 Ok(ClientResponse::ReleaseTerminalResponse(response))
372 }
373 AgentRequest::WaitForTerminalExitRequest(args) => {
374 let response = self.wait_for_terminal_exit(args).await?;
375 Ok(ClientResponse::WaitForTerminalExitResponse(response))
376 }
377 AgentRequest::KillTerminalCommandRequest(args) => {
378 let response = self.kill_terminal_command(args).await?;
379 Ok(ClientResponse::KillTerminalResponse(response))
380 }
381 AgentRequest::ExtMethodRequest(args) => {
382 let response = self.ext_method(args).await?;
383 Ok(ClientResponse::ExtMethodResponse(response))
384 }
385 }
386 }
387
388 async fn handle_notification(&self, notification: AgentNotification) -> Result<(), Error> {
389 match notification {
390 AgentNotification::SessionNotification(args) => {
391 self.session_notification(args).await?;
392 }
393 AgentNotification::ExtNotification(args) => {
394 self.ext_notification(args).await?;
395 }
396 }
397 Ok(())
398 }
399}
400
401pub struct AgentSideConnection {
412 conn: RpcConnection<AgentSide, ClientSide>,
413}
414
415impl AgentSideConnection {
416 pub fn new(
436 agent: impl MessageHandler<AgentSide> + 'static,
437 outgoing_bytes: impl Unpin + AsyncWrite,
438 incoming_bytes: impl Unpin + AsyncRead,
439 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
440 ) -> (Self, impl Future<Output = Result<()>>) {
441 let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
442 (Self { conn }, io_task)
443 }
444
445 pub fn subscribe(&self) -> StreamReceiver {
454 self.conn.subscribe()
455 }
456}
457
458#[async_trait::async_trait(?Send)]
459impl Client for AgentSideConnection {
460 async fn request_permission(
461 &self,
462 args: RequestPermissionRequest,
463 ) -> Result<RequestPermissionResponse, Error> {
464 self.conn
465 .request(
466 SESSION_REQUEST_PERMISSION_METHOD_NAME,
467 Some(AgentRequest::RequestPermissionRequest(args)),
468 )
469 .await
470 }
471
472 async fn write_text_file(
473 &self,
474 args: WriteTextFileRequest,
475 ) -> Result<WriteTextFileResponse, Error> {
476 self.conn
477 .request::<Option<_>>(
478 FS_WRITE_TEXT_FILE_METHOD_NAME,
479 Some(AgentRequest::WriteTextFileRequest(args)),
480 )
481 .await
482 .map(Option::unwrap_or_default)
483 }
484
485 async fn read_text_file(
486 &self,
487 args: ReadTextFileRequest,
488 ) -> Result<ReadTextFileResponse, Error> {
489 self.conn
490 .request(
491 FS_READ_TEXT_FILE_METHOD_NAME,
492 Some(AgentRequest::ReadTextFileRequest(args)),
493 )
494 .await
495 }
496
497 async fn create_terminal(
498 &self,
499 args: CreateTerminalRequest,
500 ) -> Result<CreateTerminalResponse, Error> {
501 self.conn
502 .request(
503 TERMINAL_CREATE_METHOD_NAME,
504 Some(AgentRequest::CreateTerminalRequest(args)),
505 )
506 .await
507 }
508
509 async fn terminal_output(
510 &self,
511 args: TerminalOutputRequest,
512 ) -> Result<TerminalOutputResponse, Error> {
513 self.conn
514 .request(
515 TERMINAL_OUTPUT_METHOD_NAME,
516 Some(AgentRequest::TerminalOutputRequest(args)),
517 )
518 .await
519 }
520
521 async fn release_terminal(
522 &self,
523 args: ReleaseTerminalRequest,
524 ) -> Result<ReleaseTerminalResponse, Error> {
525 self.conn
526 .request::<Option<_>>(
527 TERMINAL_RELEASE_METHOD_NAME,
528 Some(AgentRequest::ReleaseTerminalRequest(args)),
529 )
530 .await
531 .map(Option::unwrap_or_default)
532 }
533
534 async fn wait_for_terminal_exit(
535 &self,
536 args: WaitForTerminalExitRequest,
537 ) -> Result<WaitForTerminalExitResponse, Error> {
538 self.conn
539 .request(
540 TERMINAL_WAIT_FOR_EXIT_METHOD_NAME,
541 Some(AgentRequest::WaitForTerminalExitRequest(args)),
542 )
543 .await
544 }
545
546 async fn kill_terminal_command(
547 &self,
548 args: KillTerminalCommandRequest,
549 ) -> Result<KillTerminalCommandResponse, Error> {
550 self.conn
551 .request::<Option<_>>(
552 TERMINAL_KILL_METHOD_NAME,
553 Some(AgentRequest::KillTerminalCommandRequest(args)),
554 )
555 .await
556 .map(Option::unwrap_or_default)
557 }
558
559 async fn session_notification(&self, args: SessionNotification) -> Result<(), Error> {
560 self.conn.notify(
561 SESSION_UPDATE_NOTIFICATION,
562 Some(AgentNotification::SessionNotification(args)),
563 )
564 }
565
566 async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse, Error> {
567 self.conn
568 .request(
569 format!("_{}", args.method),
570 Some(AgentRequest::ExtMethodRequest(args)),
571 )
572 .await
573 }
574
575 async fn ext_notification(&self, args: ExtNotification) -> Result<(), Error> {
576 self.conn.notify(
577 format!("_{}", args.method),
578 Some(AgentNotification::ExtNotification(args)),
579 )
580 }
581}
582
583#[derive(Clone)]
590pub struct AgentSide;
591
592impl Side for AgentSide {
593 type InRequest = ClientRequest;
594 type InNotification = ClientNotification;
595 type OutResponse = AgentResponse;
596
597 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest, Error> {
598 let params = params.ok_or_else(Error::invalid_params)?;
599
600 match method {
601 INITIALIZE_METHOD_NAME => serde_json::from_str(params.get())
602 .map(ClientRequest::InitializeRequest)
603 .map_err(Into::into),
604 AUTHENTICATE_METHOD_NAME => serde_json::from_str(params.get())
605 .map(ClientRequest::AuthenticateRequest)
606 .map_err(Into::into),
607 SESSION_NEW_METHOD_NAME => serde_json::from_str(params.get())
608 .map(ClientRequest::NewSessionRequest)
609 .map_err(Into::into),
610 SESSION_LOAD_METHOD_NAME => serde_json::from_str(params.get())
611 .map(ClientRequest::LoadSessionRequest)
612 .map_err(Into::into),
613 SESSION_SET_MODE_METHOD_NAME => serde_json::from_str(params.get())
614 .map(ClientRequest::SetSessionModeRequest)
615 .map_err(Into::into),
616 #[cfg(feature = "unstable")]
617 SESSION_SET_MODEL_METHOD_NAME => serde_json::from_str(params.get())
618 .map(ClientRequest::SetSessionModelRequest)
619 .map_err(Into::into),
620 SESSION_PROMPT_METHOD_NAME => serde_json::from_str(params.get())
621 .map(ClientRequest::PromptRequest)
622 .map_err(Into::into),
623 _ => {
624 if let Some(custom_method) = method.strip_prefix('_') {
625 Ok(ClientRequest::ExtMethodRequest(ExtRequest {
626 method: custom_method.into(),
627 params: RawValue::from_string(params.get().to_string())?.into(),
628 }))
629 } else {
630 Err(Error::method_not_found())
631 }
632 }
633 }
634 }
635
636 fn decode_notification(
637 method: &str,
638 params: Option<&RawValue>,
639 ) -> Result<ClientNotification, Error> {
640 let params = params.ok_or_else(Error::invalid_params)?;
641
642 match method {
643 SESSION_CANCEL_METHOD_NAME => serde_json::from_str(params.get())
644 .map(ClientNotification::CancelNotification)
645 .map_err(Into::into),
646 _ => {
647 if let Some(custom_method) = method.strip_prefix('_') {
648 Ok(ClientNotification::ExtNotification(ExtNotification {
649 method: custom_method.into(),
650 params: RawValue::from_string(params.get().to_string())?.into(),
651 }))
652 } else {
653 Err(Error::method_not_found())
654 }
655 }
656 }
657 }
658}
659
660impl<T: Agent> MessageHandler<AgentSide> for T {
661 async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse, Error> {
662 match request {
663 ClientRequest::InitializeRequest(args) => {
664 let response = self.initialize(args).await?;
665 Ok(AgentResponse::InitializeResponse(response))
666 }
667 ClientRequest::AuthenticateRequest(args) => {
668 let response = self.authenticate(args).await?;
669 Ok(AgentResponse::AuthenticateResponse(response))
670 }
671 ClientRequest::NewSessionRequest(args) => {
672 let response = self.new_session(args).await?;
673 Ok(AgentResponse::NewSessionResponse(response))
674 }
675 ClientRequest::LoadSessionRequest(args) => {
676 let response = self.load_session(args).await?;
677 Ok(AgentResponse::LoadSessionResponse(response))
678 }
679 ClientRequest::PromptRequest(args) => {
680 let response = self.prompt(args).await?;
681 Ok(AgentResponse::PromptResponse(response))
682 }
683 ClientRequest::SetSessionModeRequest(args) => {
684 let response = self.set_session_mode(args).await?;
685 Ok(AgentResponse::SetSessionModeResponse(response))
686 }
687 #[cfg(feature = "unstable")]
688 ClientRequest::SetSessionModelRequest(args) => {
689 let response = self.set_session_model(args).await?;
690 Ok(AgentResponse::SetSessionModelResponse(response))
691 }
692 ClientRequest::ExtMethodRequest(args) => {
693 let response = self.ext_method(args).await?;
694 Ok(AgentResponse::ExtMethodResponse(response))
695 }
696 }
697 }
698
699 async fn handle_notification(&self, notification: ClientNotification) -> Result<(), Error> {
700 match notification {
701 ClientNotification::CancelNotification(args) => {
702 self.cancel(args).await?;
703 }
704 ClientNotification::ExtNotification(args) => {
705 self.ext_notification(args).await?;
706 }
707 }
708 Ok(())
709 }
710}