1mod agent;
54mod client;
55mod content;
56mod error;
57mod ext;
58mod plan;
59mod rpc;
60#[cfg(test)]
61mod rpc_tests;
62mod stream_broadcast;
63mod tool_call;
64mod version;
65
66pub use agent::*;
67pub use client::*;
68pub use content::*;
69pub use error::*;
70pub use ext::*;
71pub use plan::*;
72pub use stream_broadcast::{
73 StreamMessage, StreamMessageContent, StreamMessageDirection, StreamReceiver,
74};
75pub use tool_call::*;
76pub use version::*;
77
78use anyhow::Result;
79use futures::{AsyncRead, AsyncWrite, Future, future::LocalBoxFuture};
80use schemars::JsonSchema;
81use serde::{Deserialize, Serialize};
82use serde_json::value::RawValue;
83use std::{fmt, sync::Arc};
84
85use crate::rpc::{MessageHandler, RpcConnection, Side};
86
87#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
103#[serde(transparent)]
104pub struct SessionId(pub Arc<str>);
105
106impl fmt::Display for SessionId {
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108 write!(f, "{}", self.0)
109 }
110}
111
112pub struct ClientSideConnection {
123 conn: RpcConnection<ClientSide, AgentSide>,
124}
125
126impl ClientSideConnection {
127 pub fn new(
147 client: impl MessageHandler<ClientSide> + 'static,
148 outgoing_bytes: impl Unpin + AsyncWrite,
149 incoming_bytes: impl Unpin + AsyncRead,
150 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
151 ) -> (Self, impl Future<Output = Result<()>>) {
152 let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
153 (Self { conn }, io_task)
154 }
155
156 pub fn subscribe(&self) -> StreamReceiver {
165 self.conn.subscribe()
166 }
167}
168
169impl Agent for ClientSideConnection {
170 async fn initialize(&self, arguments: InitializeRequest) -> Result<InitializeResponse, Error> {
171 self.conn
172 .request(
173 INITIALIZE_METHOD_NAME,
174 Some(ClientRequest::InitializeRequest(arguments)),
175 )
176 .await
177 }
178
179 async fn authenticate(
180 &self,
181 arguments: AuthenticateRequest,
182 ) -> Result<AuthenticateResponse, Error> {
183 self.conn
184 .request::<Option<_>>(
185 AUTHENTICATE_METHOD_NAME,
186 Some(ClientRequest::AuthenticateRequest(arguments)),
187 )
188 .await
189 .map(|value| value.unwrap_or_default())
190 }
191
192 async fn new_session(&self, arguments: NewSessionRequest) -> Result<NewSessionResponse, Error> {
193 self.conn
194 .request(
195 SESSION_NEW_METHOD_NAME,
196 Some(ClientRequest::NewSessionRequest(arguments)),
197 )
198 .await
199 }
200
201 async fn load_session(
202 &self,
203 arguments: LoadSessionRequest,
204 ) -> Result<LoadSessionResponse, Error> {
205 self.conn
206 .request::<Option<_>>(
207 SESSION_LOAD_METHOD_NAME,
208 Some(ClientRequest::LoadSessionRequest(arguments)),
209 )
210 .await
211 .map(|value| value.unwrap_or_default())
212 }
213
214 #[cfg(feature = "unstable")]
215 async fn set_session_mode(
216 &self,
217 arguments: SetSessionModeRequest,
218 ) -> Result<SetSessionModeResponse, Error> {
219 self.conn
220 .request(
221 SESSION_SET_MODE_METHOD_NAME,
222 Some(ClientRequest::SetSessionModeRequest(arguments)),
223 )
224 .await
225 }
226
227 async fn prompt(&self, arguments: PromptRequest) -> Result<PromptResponse, Error> {
228 self.conn
229 .request(
230 SESSION_PROMPT_METHOD_NAME,
231 Some(ClientRequest::PromptRequest(arguments)),
232 )
233 .await
234 }
235
236 async fn cancel(&self, notification: CancelNotification) -> Result<(), Error> {
237 self.conn.notify(
238 SESSION_CANCEL_METHOD_NAME,
239 Some(ClientNotification::CancelNotification(notification)),
240 )
241 }
242
243 async fn ext_method(
244 &self,
245 method: Arc<str>,
246 params: Arc<RawValue>,
247 ) -> Result<Arc<RawValue>, Error> {
248 self.conn
249 .request(
250 format!("_{method}"),
251 Some(ClientRequest::ExtMethodRequest(ExtMethod {
252 method,
253 params,
254 })),
255 )
256 .await
257 }
258
259 async fn ext_notification(&self, method: Arc<str>, params: Arc<RawValue>) -> Result<(), Error> {
260 self.conn.notify(
261 format!("_{method}"),
262 Some(ClientNotification::ExtNotification(ExtMethod {
263 method,
264 params,
265 })),
266 )
267 }
268}
269
270#[derive(Clone)]
277pub struct ClientSide;
278
279impl Side for ClientSide {
280 type InNotification = AgentNotification;
281 type InRequest = AgentRequest;
282 type OutResponse = ClientResponse;
283
284 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest, Error> {
285 let params = params.ok_or_else(Error::invalid_params)?;
286
287 match method {
288 SESSION_REQUEST_PERMISSION_METHOD_NAME => serde_json::from_str(params.get())
289 .map(AgentRequest::RequestPermissionRequest)
290 .map_err(Into::into),
291 FS_WRITE_TEXT_FILE_METHOD_NAME => serde_json::from_str(params.get())
292 .map(AgentRequest::WriteTextFileRequest)
293 .map_err(Into::into),
294 FS_READ_TEXT_FILE_METHOD_NAME => serde_json::from_str(params.get())
295 .map(AgentRequest::ReadTextFileRequest)
296 .map_err(Into::into),
297 TERMINAL_CREATE_METHOD_NAME => serde_json::from_str(params.get())
298 .map(AgentRequest::CreateTerminalRequest)
299 .map_err(Into::into),
300 TERMINAL_OUTPUT_METHOD_NAME => serde_json::from_str(params.get())
301 .map(AgentRequest::TerminalOutputRequest)
302 .map_err(Into::into),
303 TERMINAL_KILL_METHOD_NAME => serde_json::from_str(params.get())
304 .map(AgentRequest::KillTerminalCommandRequest)
305 .map_err(Into::into),
306 TERMINAL_RELEASE_METHOD_NAME => serde_json::from_str(params.get())
307 .map(AgentRequest::ReleaseTerminalRequest)
308 .map_err(Into::into),
309 TERMINAL_WAIT_FOR_EXIT_METHOD_NAME => serde_json::from_str(params.get())
310 .map(AgentRequest::WaitForTerminalExitRequest)
311 .map_err(Into::into),
312 _ => {
313 if let Some(custom_method) = method.strip_prefix('_') {
314 Ok(AgentRequest::ExtMethodRequest(ExtMethod {
315 method: custom_method.into(),
316 params: RawValue::from_string(params.get().to_string())?.into(),
317 }))
318 } else {
319 Err(Error::method_not_found())
320 }
321 }
322 }
323 }
324
325 fn decode_notification(
326 method: &str,
327 params: Option<&RawValue>,
328 ) -> Result<AgentNotification, Error> {
329 let params = params.ok_or_else(Error::invalid_params)?;
330
331 match method {
332 SESSION_UPDATE_NOTIFICATION => serde_json::from_str(params.get())
333 .map(AgentNotification::SessionNotification)
334 .map_err(Into::into),
335 _ => {
336 if let Some(custom_method) = method.strip_prefix('_') {
337 Ok(AgentNotification::ExtNotification(ExtMethod {
338 method: custom_method.into(),
339 params: RawValue::from_string(params.get().to_string())?.into(),
340 }))
341 } else {
342 Err(Error::method_not_found())
343 }
344 }
345 }
346 }
347}
348
349impl<T: Client> MessageHandler<ClientSide> for T {
350 async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse, Error> {
351 match request {
352 AgentRequest::RequestPermissionRequest(args) => {
353 let response = self.request_permission(args).await?;
354 Ok(ClientResponse::RequestPermissionResponse(response))
355 }
356 AgentRequest::WriteTextFileRequest(args) => {
357 let response = self.write_text_file(args).await?;
358 Ok(ClientResponse::WriteTextFileResponse(response))
359 }
360 AgentRequest::ReadTextFileRequest(args) => {
361 let response = self.read_text_file(args).await?;
362 Ok(ClientResponse::ReadTextFileResponse(response))
363 }
364 AgentRequest::CreateTerminalRequest(args) => {
365 let response = self.create_terminal(args).await?;
366 Ok(ClientResponse::CreateTerminalResponse(response))
367 }
368 AgentRequest::TerminalOutputRequest(args) => {
369 let response = self.terminal_output(args).await?;
370 Ok(ClientResponse::TerminalOutputResponse(response))
371 }
372 AgentRequest::ReleaseTerminalRequest(args) => {
373 let response = self.release_terminal(args).await?;
374 Ok(ClientResponse::ReleaseTerminalResponse(response))
375 }
376 AgentRequest::WaitForTerminalExitRequest(args) => {
377 let response = self.wait_for_terminal_exit(args).await?;
378 Ok(ClientResponse::WaitForTerminalExitResponse(response))
379 }
380 AgentRequest::KillTerminalCommandRequest(args) => {
381 let response = self.kill_terminal_command(args).await?;
382 Ok(ClientResponse::KillTerminalResponse(response))
383 }
384 AgentRequest::ExtMethodRequest(args) => {
385 let response = self.ext_method(args.method, args.params).await?;
386 Ok(ClientResponse::ExtMethodResponse(response))
387 }
388 }
389 }
390
391 async fn handle_notification(&self, notification: AgentNotification) -> Result<(), Error> {
392 match notification {
393 AgentNotification::SessionNotification(notification) => {
394 self.session_notification(notification).await?;
395 }
396 AgentNotification::ExtNotification(args) => {
397 self.ext_notification(args.method, args.params).await?;
398 }
399 }
400 Ok(())
401 }
402}
403
404pub struct AgentSideConnection {
415 conn: RpcConnection<AgentSide, ClientSide>,
416}
417
418impl AgentSideConnection {
419 pub fn new(
439 agent: impl MessageHandler<AgentSide> + 'static,
440 outgoing_bytes: impl Unpin + AsyncWrite,
441 incoming_bytes: impl Unpin + AsyncRead,
442 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
443 ) -> (Self, impl Future<Output = Result<()>>) {
444 let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
445 (Self { conn }, io_task)
446 }
447
448 pub fn subscribe(&self) -> StreamReceiver {
457 self.conn.subscribe()
458 }
459}
460
461impl Client for AgentSideConnection {
462 async fn request_permission(
463 &self,
464 arguments: RequestPermissionRequest,
465 ) -> Result<RequestPermissionResponse, Error> {
466 self.conn
467 .request(
468 SESSION_REQUEST_PERMISSION_METHOD_NAME,
469 Some(AgentRequest::RequestPermissionRequest(arguments)),
470 )
471 .await
472 }
473
474 async fn write_text_file(
475 &self,
476 arguments: WriteTextFileRequest,
477 ) -> Result<WriteTextFileResponse, Error> {
478 self.conn
479 .request::<Option<_>>(
480 FS_WRITE_TEXT_FILE_METHOD_NAME,
481 Some(AgentRequest::WriteTextFileRequest(arguments)),
482 )
483 .await
484 .map(|value| value.unwrap_or_default())
485 }
486
487 async fn read_text_file(
488 &self,
489 arguments: ReadTextFileRequest,
490 ) -> Result<ReadTextFileResponse, Error> {
491 self.conn
492 .request(
493 FS_READ_TEXT_FILE_METHOD_NAME,
494 Some(AgentRequest::ReadTextFileRequest(arguments)),
495 )
496 .await
497 }
498
499 async fn create_terminal(
500 &self,
501 arguments: CreateTerminalRequest,
502 ) -> Result<CreateTerminalResponse, Error> {
503 self.conn
504 .request(
505 TERMINAL_CREATE_METHOD_NAME,
506 Some(AgentRequest::CreateTerminalRequest(arguments)),
507 )
508 .await
509 }
510
511 async fn terminal_output(
512 &self,
513 arguments: TerminalOutputRequest,
514 ) -> Result<TerminalOutputResponse, Error> {
515 self.conn
516 .request(
517 TERMINAL_OUTPUT_METHOD_NAME,
518 Some(AgentRequest::TerminalOutputRequest(arguments)),
519 )
520 .await
521 }
522
523 async fn release_terminal(
524 &self,
525 arguments: ReleaseTerminalRequest,
526 ) -> Result<ReleaseTerminalResponse, Error> {
527 self.conn
528 .request::<Option<_>>(
529 TERMINAL_RELEASE_METHOD_NAME,
530 Some(AgentRequest::ReleaseTerminalRequest(arguments)),
531 )
532 .await
533 .map(|value| value.unwrap_or_default())
534 }
535
536 async fn wait_for_terminal_exit(
537 &self,
538 arguments: WaitForTerminalExitRequest,
539 ) -> Result<WaitForTerminalExitResponse, Error> {
540 self.conn
541 .request(
542 TERMINAL_WAIT_FOR_EXIT_METHOD_NAME,
543 Some(AgentRequest::WaitForTerminalExitRequest(arguments)),
544 )
545 .await
546 }
547
548 async fn kill_terminal_command(
549 &self,
550 arguments: KillTerminalCommandRequest,
551 ) -> Result<KillTerminalCommandResponse, Error> {
552 self.conn
553 .request::<Option<_>>(
554 TERMINAL_KILL_METHOD_NAME,
555 Some(AgentRequest::KillTerminalCommandRequest(arguments)),
556 )
557 .await
558 .map(|value| value.unwrap_or_default())
559 }
560
561 async fn session_notification(&self, notification: SessionNotification) -> Result<(), Error> {
562 self.conn.notify(
563 SESSION_UPDATE_NOTIFICATION,
564 Some(AgentNotification::SessionNotification(notification)),
565 )
566 }
567
568 async fn ext_method(
569 &self,
570 method: Arc<str>,
571 params: Arc<RawValue>,
572 ) -> Result<Arc<RawValue>, Error> {
573 self.conn
574 .request(
575 format!("_{method}"),
576 Some(AgentRequest::ExtMethodRequest(ExtMethod { method, params })),
577 )
578 .await
579 }
580
581 async fn ext_notification(&self, method: Arc<str>, params: Arc<RawValue>) -> Result<(), Error> {
582 self.conn.notify(
583 format!("_{method}"),
584 Some(AgentNotification::ExtNotification(ExtMethod {
585 method,
586 params,
587 })),
588 )
589 }
590}
591
592#[derive(Clone)]
599pub struct AgentSide;
600
601impl Side for AgentSide {
602 type InRequest = ClientRequest;
603 type InNotification = ClientNotification;
604 type OutResponse = AgentResponse;
605
606 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest, Error> {
607 let params = params.ok_or_else(Error::invalid_params)?;
608
609 match method {
610 INITIALIZE_METHOD_NAME => serde_json::from_str(params.get())
611 .map(ClientRequest::InitializeRequest)
612 .map_err(Into::into),
613 AUTHENTICATE_METHOD_NAME => serde_json::from_str(params.get())
614 .map(ClientRequest::AuthenticateRequest)
615 .map_err(Into::into),
616 SESSION_NEW_METHOD_NAME => serde_json::from_str(params.get())
617 .map(ClientRequest::NewSessionRequest)
618 .map_err(Into::into),
619 SESSION_LOAD_METHOD_NAME => serde_json::from_str(params.get())
620 .map(ClientRequest::LoadSessionRequest)
621 .map_err(Into::into),
622 #[cfg(feature = "unstable")]
623 SESSION_SET_MODE_METHOD_NAME => serde_json::from_str(params.get())
624 .map(ClientRequest::SetSessionModeRequest)
625 .map_err(Into::into),
626 SESSION_PROMPT_METHOD_NAME => serde_json::from_str(params.get())
627 .map(ClientRequest::PromptRequest)
628 .map_err(Into::into),
629 _ => {
630 if let Some(custom_method) = method.strip_prefix('_') {
631 Ok(ClientRequest::ExtMethodRequest(ExtMethod {
632 method: custom_method.into(),
633 params: RawValue::from_string(params.get().to_string())?.into(),
634 }))
635 } else {
636 Err(Error::method_not_found())
637 }
638 }
639 }
640 }
641
642 fn decode_notification(
643 method: &str,
644 params: Option<&RawValue>,
645 ) -> Result<ClientNotification, Error> {
646 let params = params.ok_or_else(Error::invalid_params)?;
647
648 match method {
649 SESSION_CANCEL_METHOD_NAME => serde_json::from_str(params.get())
650 .map(ClientNotification::CancelNotification)
651 .map_err(Into::into),
652 _ => {
653 if let Some(custom_method) = method.strip_prefix('_') {
654 Ok(ClientNotification::ExtNotification(ExtMethod {
655 method: custom_method.into(),
656 params: RawValue::from_string(params.get().to_string())?.into(),
657 }))
658 } else {
659 Err(Error::method_not_found())
660 }
661 }
662 }
663 }
664}
665
666impl<T: Agent> MessageHandler<AgentSide> for T {
667 async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse, Error> {
668 match request {
669 ClientRequest::InitializeRequest(args) => {
670 let response = self.initialize(args).await?;
671 Ok(AgentResponse::InitializeResponse(response))
672 }
673 ClientRequest::AuthenticateRequest(args) => {
674 let response = self.authenticate(args).await?;
675 Ok(AgentResponse::AuthenticateResponse(response))
676 }
677 ClientRequest::NewSessionRequest(args) => {
678 let response = self.new_session(args).await?;
679 Ok(AgentResponse::NewSessionResponse(response))
680 }
681 ClientRequest::LoadSessionRequest(args) => {
682 let response = self.load_session(args).await?;
683 Ok(AgentResponse::LoadSessionResponse(response))
684 }
685 ClientRequest::PromptRequest(args) => {
686 let response = self.prompt(args).await?;
687 Ok(AgentResponse::PromptResponse(response))
688 }
689 #[cfg(feature = "unstable")]
690 ClientRequest::SetSessionModeRequest(args) => {
691 let response = self.set_session_mode(args).await?;
692 Ok(AgentResponse::SetSessionModeResponse(response))
693 }
694 ClientRequest::ExtMethodRequest(args) => {
695 let response = self.ext_method(args.method, args.params).await?;
696 Ok(AgentResponse::ExtMethodResponse(response))
697 }
698 }
699 }
700
701 async fn handle_notification(&self, notification: ClientNotification) -> Result<(), Error> {
702 match notification {
703 ClientNotification::CancelNotification(notification) => {
704 self.cancel(notification).await?;
705 }
706 ClientNotification::ExtNotification(args) => {
707 self.ext_notification(args.method, args.params).await?;
708 }
709 }
710 Ok(())
711 }
712}