1mod agent;
54mod client;
55mod content;
56mod error;
57mod ext;
58mod plan;
59mod rpc;
60#[cfg(test)]
61mod rpc_tests;
62mod stream_broadcast;
63mod tool_call;
64mod version;
65
66pub use agent::*;
67pub use client::*;
68pub use content::*;
69pub use error::*;
70pub use ext::*;
71pub use plan::*;
72pub use stream_broadcast::{
73 StreamMessage, StreamMessageContent, StreamMessageDirection, StreamReceiver,
74};
75pub use tool_call::*;
76pub use version::*;
77
78use anyhow::Result;
79use futures::{AsyncRead, AsyncWrite, Future, future::LocalBoxFuture};
80use schemars::JsonSchema;
81use serde::{Deserialize, Serialize};
82use serde_json::value::RawValue;
83use std::{fmt, sync::Arc};
84
85use crate::rpc::{MessageHandler, RpcConnection, Side};
86
87#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
103#[serde(transparent)]
104pub struct SessionId(pub Arc<str>);
105
106impl fmt::Display for SessionId {
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108 write!(f, "{}", self.0)
109 }
110}
111
112pub struct ClientSideConnection {
123 conn: RpcConnection<ClientSide, AgentSide>,
124}
125
126impl ClientSideConnection {
127 pub fn new(
147 client: impl MessageHandler<ClientSide> + 'static,
148 outgoing_bytes: impl Unpin + AsyncWrite,
149 incoming_bytes: impl Unpin + AsyncRead,
150 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
151 ) -> (Self, impl Future<Output = Result<()>>) {
152 let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
153 (Self { conn }, io_task)
154 }
155
156 pub fn subscribe(&self) -> StreamReceiver {
165 self.conn.subscribe()
166 }
167}
168
169impl Agent for ClientSideConnection {
170 async fn initialize(&self, arguments: InitializeRequest) -> Result<InitializeResponse, Error> {
171 self.conn
172 .request(
173 INITIALIZE_METHOD_NAME,
174 Some(ClientRequest::InitializeRequest(arguments)),
175 )
176 .await
177 }
178
179 async fn authenticate(
180 &self,
181 arguments: AuthenticateRequest,
182 ) -> Result<AuthenticateResponse, Error> {
183 self.conn
184 .request::<Option<_>>(
185 AUTHENTICATE_METHOD_NAME,
186 Some(ClientRequest::AuthenticateRequest(arguments)),
187 )
188 .await
189 .map(|value| value.unwrap_or_default())
190 }
191
192 async fn new_session(&self, arguments: NewSessionRequest) -> Result<NewSessionResponse, Error> {
193 self.conn
194 .request(
195 SESSION_NEW_METHOD_NAME,
196 Some(ClientRequest::NewSessionRequest(arguments)),
197 )
198 .await
199 }
200
201 async fn load_session(
202 &self,
203 arguments: LoadSessionRequest,
204 ) -> Result<LoadSessionResponse, Error> {
205 self.conn
206 .request::<Option<_>>(
207 SESSION_LOAD_METHOD_NAME,
208 Some(ClientRequest::LoadSessionRequest(arguments)),
209 )
210 .await
211 .map(|value| value.unwrap_or_default())
212 }
213
214 async fn set_session_mode(
215 &self,
216 arguments: SetSessionModeRequest,
217 ) -> Result<SetSessionModeResponse, Error> {
218 self.conn
219 .request(
220 SESSION_SET_MODE_METHOD_NAME,
221 Some(ClientRequest::SetSessionModeRequest(arguments)),
222 )
223 .await
224 }
225
226 async fn prompt(&self, arguments: PromptRequest) -> Result<PromptResponse, Error> {
227 self.conn
228 .request(
229 SESSION_PROMPT_METHOD_NAME,
230 Some(ClientRequest::PromptRequest(arguments)),
231 )
232 .await
233 }
234
235 async fn cancel(&self, notification: CancelNotification) -> Result<(), Error> {
236 self.conn.notify(
237 SESSION_CANCEL_METHOD_NAME,
238 Some(ClientNotification::CancelNotification(notification)),
239 )
240 }
241
242 async fn ext_method(
243 &self,
244 method: Arc<str>,
245 params: Arc<RawValue>,
246 ) -> Result<Arc<RawValue>, Error> {
247 self.conn
248 .request(
249 format!("_{method}"),
250 Some(ClientRequest::ExtMethodRequest(ExtMethod {
251 method,
252 params,
253 })),
254 )
255 .await
256 }
257
258 async fn ext_notification(&self, method: Arc<str>, params: Arc<RawValue>) -> Result<(), Error> {
259 self.conn.notify(
260 format!("_{method}"),
261 Some(ClientNotification::ExtNotification(ExtMethod {
262 method,
263 params,
264 })),
265 )
266 }
267}
268
269#[derive(Clone)]
276pub struct ClientSide;
277
278impl Side for ClientSide {
279 type InNotification = AgentNotification;
280 type InRequest = AgentRequest;
281 type OutResponse = ClientResponse;
282
283 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest, Error> {
284 let params = params.ok_or_else(Error::invalid_params)?;
285
286 match method {
287 SESSION_REQUEST_PERMISSION_METHOD_NAME => serde_json::from_str(params.get())
288 .map(AgentRequest::RequestPermissionRequest)
289 .map_err(Into::into),
290 FS_WRITE_TEXT_FILE_METHOD_NAME => serde_json::from_str(params.get())
291 .map(AgentRequest::WriteTextFileRequest)
292 .map_err(Into::into),
293 FS_READ_TEXT_FILE_METHOD_NAME => serde_json::from_str(params.get())
294 .map(AgentRequest::ReadTextFileRequest)
295 .map_err(Into::into),
296 TERMINAL_CREATE_METHOD_NAME => serde_json::from_str(params.get())
297 .map(AgentRequest::CreateTerminalRequest)
298 .map_err(Into::into),
299 TERMINAL_OUTPUT_METHOD_NAME => serde_json::from_str(params.get())
300 .map(AgentRequest::TerminalOutputRequest)
301 .map_err(Into::into),
302 TERMINAL_KILL_METHOD_NAME => serde_json::from_str(params.get())
303 .map(AgentRequest::KillTerminalCommandRequest)
304 .map_err(Into::into),
305 TERMINAL_RELEASE_METHOD_NAME => serde_json::from_str(params.get())
306 .map(AgentRequest::ReleaseTerminalRequest)
307 .map_err(Into::into),
308 TERMINAL_WAIT_FOR_EXIT_METHOD_NAME => serde_json::from_str(params.get())
309 .map(AgentRequest::WaitForTerminalExitRequest)
310 .map_err(Into::into),
311 _ => {
312 if let Some(custom_method) = method.strip_prefix('_') {
313 Ok(AgentRequest::ExtMethodRequest(ExtMethod {
314 method: custom_method.into(),
315 params: RawValue::from_string(params.get().to_string())?.into(),
316 }))
317 } else {
318 Err(Error::method_not_found())
319 }
320 }
321 }
322 }
323
324 fn decode_notification(
325 method: &str,
326 params: Option<&RawValue>,
327 ) -> Result<AgentNotification, Error> {
328 let params = params.ok_or_else(Error::invalid_params)?;
329
330 match method {
331 SESSION_UPDATE_NOTIFICATION => serde_json::from_str(params.get())
332 .map(AgentNotification::SessionNotification)
333 .map_err(Into::into),
334 _ => {
335 if let Some(custom_method) = method.strip_prefix('_') {
336 Ok(AgentNotification::ExtNotification(ExtMethod {
337 method: custom_method.into(),
338 params: RawValue::from_string(params.get().to_string())?.into(),
339 }))
340 } else {
341 Err(Error::method_not_found())
342 }
343 }
344 }
345 }
346}
347
348impl<T: Client> MessageHandler<ClientSide> for T {
349 async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse, Error> {
350 match request {
351 AgentRequest::RequestPermissionRequest(args) => {
352 let response = self.request_permission(args).await?;
353 Ok(ClientResponse::RequestPermissionResponse(response))
354 }
355 AgentRequest::WriteTextFileRequest(args) => {
356 let response = self.write_text_file(args).await?;
357 Ok(ClientResponse::WriteTextFileResponse(response))
358 }
359 AgentRequest::ReadTextFileRequest(args) => {
360 let response = self.read_text_file(args).await?;
361 Ok(ClientResponse::ReadTextFileResponse(response))
362 }
363 AgentRequest::CreateTerminalRequest(args) => {
364 let response = self.create_terminal(args).await?;
365 Ok(ClientResponse::CreateTerminalResponse(response))
366 }
367 AgentRequest::TerminalOutputRequest(args) => {
368 let response = self.terminal_output(args).await?;
369 Ok(ClientResponse::TerminalOutputResponse(response))
370 }
371 AgentRequest::ReleaseTerminalRequest(args) => {
372 let response = self.release_terminal(args).await?;
373 Ok(ClientResponse::ReleaseTerminalResponse(response))
374 }
375 AgentRequest::WaitForTerminalExitRequest(args) => {
376 let response = self.wait_for_terminal_exit(args).await?;
377 Ok(ClientResponse::WaitForTerminalExitResponse(response))
378 }
379 AgentRequest::KillTerminalCommandRequest(args) => {
380 let response = self.kill_terminal_command(args).await?;
381 Ok(ClientResponse::KillTerminalResponse(response))
382 }
383 AgentRequest::ExtMethodRequest(args) => {
384 let response = self.ext_method(args.method, args.params).await?;
385 Ok(ClientResponse::ExtMethodResponse(response))
386 }
387 }
388 }
389
390 async fn handle_notification(&self, notification: AgentNotification) -> Result<(), Error> {
391 match notification {
392 AgentNotification::SessionNotification(notification) => {
393 self.session_notification(notification).await?;
394 }
395 AgentNotification::ExtNotification(args) => {
396 self.ext_notification(args.method, args.params).await?;
397 }
398 }
399 Ok(())
400 }
401}
402
403pub struct AgentSideConnection {
414 conn: RpcConnection<AgentSide, ClientSide>,
415}
416
417impl AgentSideConnection {
418 pub fn new(
438 agent: impl MessageHandler<AgentSide> + 'static,
439 outgoing_bytes: impl Unpin + AsyncWrite,
440 incoming_bytes: impl Unpin + AsyncRead,
441 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
442 ) -> (Self, impl Future<Output = Result<()>>) {
443 let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
444 (Self { conn }, io_task)
445 }
446
447 pub fn subscribe(&self) -> StreamReceiver {
456 self.conn.subscribe()
457 }
458}
459
460impl Client for AgentSideConnection {
461 async fn request_permission(
462 &self,
463 arguments: RequestPermissionRequest,
464 ) -> Result<RequestPermissionResponse, Error> {
465 self.conn
466 .request(
467 SESSION_REQUEST_PERMISSION_METHOD_NAME,
468 Some(AgentRequest::RequestPermissionRequest(arguments)),
469 )
470 .await
471 }
472
473 async fn write_text_file(
474 &self,
475 arguments: WriteTextFileRequest,
476 ) -> Result<WriteTextFileResponse, Error> {
477 self.conn
478 .request::<Option<_>>(
479 FS_WRITE_TEXT_FILE_METHOD_NAME,
480 Some(AgentRequest::WriteTextFileRequest(arguments)),
481 )
482 .await
483 .map(|value| value.unwrap_or_default())
484 }
485
486 async fn read_text_file(
487 &self,
488 arguments: ReadTextFileRequest,
489 ) -> Result<ReadTextFileResponse, Error> {
490 self.conn
491 .request(
492 FS_READ_TEXT_FILE_METHOD_NAME,
493 Some(AgentRequest::ReadTextFileRequest(arguments)),
494 )
495 .await
496 }
497
498 async fn create_terminal(
499 &self,
500 arguments: CreateTerminalRequest,
501 ) -> Result<CreateTerminalResponse, Error> {
502 self.conn
503 .request(
504 TERMINAL_CREATE_METHOD_NAME,
505 Some(AgentRequest::CreateTerminalRequest(arguments)),
506 )
507 .await
508 }
509
510 async fn terminal_output(
511 &self,
512 arguments: TerminalOutputRequest,
513 ) -> Result<TerminalOutputResponse, Error> {
514 self.conn
515 .request(
516 TERMINAL_OUTPUT_METHOD_NAME,
517 Some(AgentRequest::TerminalOutputRequest(arguments)),
518 )
519 .await
520 }
521
522 async fn release_terminal(
523 &self,
524 arguments: ReleaseTerminalRequest,
525 ) -> Result<ReleaseTerminalResponse, Error> {
526 self.conn
527 .request::<Option<_>>(
528 TERMINAL_RELEASE_METHOD_NAME,
529 Some(AgentRequest::ReleaseTerminalRequest(arguments)),
530 )
531 .await
532 .map(|value| value.unwrap_or_default())
533 }
534
535 async fn wait_for_terminal_exit(
536 &self,
537 arguments: WaitForTerminalExitRequest,
538 ) -> Result<WaitForTerminalExitResponse, Error> {
539 self.conn
540 .request(
541 TERMINAL_WAIT_FOR_EXIT_METHOD_NAME,
542 Some(AgentRequest::WaitForTerminalExitRequest(arguments)),
543 )
544 .await
545 }
546
547 async fn kill_terminal_command(
548 &self,
549 arguments: KillTerminalCommandRequest,
550 ) -> Result<KillTerminalCommandResponse, Error> {
551 self.conn
552 .request::<Option<_>>(
553 TERMINAL_KILL_METHOD_NAME,
554 Some(AgentRequest::KillTerminalCommandRequest(arguments)),
555 )
556 .await
557 .map(|value| value.unwrap_or_default())
558 }
559
560 async fn session_notification(&self, notification: SessionNotification) -> Result<(), Error> {
561 self.conn.notify(
562 SESSION_UPDATE_NOTIFICATION,
563 Some(AgentNotification::SessionNotification(notification)),
564 )
565 }
566
567 async fn ext_method(
568 &self,
569 method: Arc<str>,
570 params: Arc<RawValue>,
571 ) -> Result<Arc<RawValue>, Error> {
572 self.conn
573 .request(
574 format!("_{method}"),
575 Some(AgentRequest::ExtMethodRequest(ExtMethod { method, params })),
576 )
577 .await
578 }
579
580 async fn ext_notification(&self, method: Arc<str>, params: Arc<RawValue>) -> Result<(), Error> {
581 self.conn.notify(
582 format!("_{method}"),
583 Some(AgentNotification::ExtNotification(ExtMethod {
584 method,
585 params,
586 })),
587 )
588 }
589}
590
591#[derive(Clone)]
598pub struct AgentSide;
599
600impl Side for AgentSide {
601 type InRequest = ClientRequest;
602 type InNotification = ClientNotification;
603 type OutResponse = AgentResponse;
604
605 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest, Error> {
606 let params = params.ok_or_else(Error::invalid_params)?;
607
608 match method {
609 INITIALIZE_METHOD_NAME => serde_json::from_str(params.get())
610 .map(ClientRequest::InitializeRequest)
611 .map_err(Into::into),
612 AUTHENTICATE_METHOD_NAME => serde_json::from_str(params.get())
613 .map(ClientRequest::AuthenticateRequest)
614 .map_err(Into::into),
615 SESSION_NEW_METHOD_NAME => serde_json::from_str(params.get())
616 .map(ClientRequest::NewSessionRequest)
617 .map_err(Into::into),
618 SESSION_LOAD_METHOD_NAME => serde_json::from_str(params.get())
619 .map(ClientRequest::LoadSessionRequest)
620 .map_err(Into::into),
621 SESSION_SET_MODE_METHOD_NAME => serde_json::from_str(params.get())
622 .map(ClientRequest::SetSessionModeRequest)
623 .map_err(Into::into),
624 SESSION_PROMPT_METHOD_NAME => serde_json::from_str(params.get())
625 .map(ClientRequest::PromptRequest)
626 .map_err(Into::into),
627 _ => {
628 if let Some(custom_method) = method.strip_prefix('_') {
629 Ok(ClientRequest::ExtMethodRequest(ExtMethod {
630 method: custom_method.into(),
631 params: RawValue::from_string(params.get().to_string())?.into(),
632 }))
633 } else {
634 Err(Error::method_not_found())
635 }
636 }
637 }
638 }
639
640 fn decode_notification(
641 method: &str,
642 params: Option<&RawValue>,
643 ) -> Result<ClientNotification, Error> {
644 let params = params.ok_or_else(Error::invalid_params)?;
645
646 match method {
647 SESSION_CANCEL_METHOD_NAME => serde_json::from_str(params.get())
648 .map(ClientNotification::CancelNotification)
649 .map_err(Into::into),
650 _ => {
651 if let Some(custom_method) = method.strip_prefix('_') {
652 Ok(ClientNotification::ExtNotification(ExtMethod {
653 method: custom_method.into(),
654 params: RawValue::from_string(params.get().to_string())?.into(),
655 }))
656 } else {
657 Err(Error::method_not_found())
658 }
659 }
660 }
661 }
662}
663
664impl<T: Agent> MessageHandler<AgentSide> for T {
665 async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse, Error> {
666 match request {
667 ClientRequest::InitializeRequest(args) => {
668 let response = self.initialize(args).await?;
669 Ok(AgentResponse::InitializeResponse(response))
670 }
671 ClientRequest::AuthenticateRequest(args) => {
672 let response = self.authenticate(args).await?;
673 Ok(AgentResponse::AuthenticateResponse(response))
674 }
675 ClientRequest::NewSessionRequest(args) => {
676 let response = self.new_session(args).await?;
677 Ok(AgentResponse::NewSessionResponse(response))
678 }
679 ClientRequest::LoadSessionRequest(args) => {
680 let response = self.load_session(args).await?;
681 Ok(AgentResponse::LoadSessionResponse(response))
682 }
683 ClientRequest::PromptRequest(args) => {
684 let response = self.prompt(args).await?;
685 Ok(AgentResponse::PromptResponse(response))
686 }
687 ClientRequest::SetSessionModeRequest(args) => {
688 let response = self.set_session_mode(args).await?;
689 Ok(AgentResponse::SetSessionModeResponse(response))
690 }
691 ClientRequest::ExtMethodRequest(args) => {
692 let response = self.ext_method(args.method, args.params).await?;
693 Ok(AgentResponse::ExtMethodResponse(response))
694 }
695 }
696 }
697
698 async fn handle_notification(&self, notification: ClientNotification) -> Result<(), Error> {
699 match notification {
700 ClientNotification::CancelNotification(notification) => {
701 self.cancel(notification).await?;
702 }
703 ClientNotification::ExtNotification(args) => {
704 self.ext_notification(args.method, args.params).await?;
705 }
706 }
707 Ok(())
708 }
709}