1mod agent;
54mod client;
55mod content;
56mod error;
57mod plan;
58mod rpc;
59#[cfg(test)]
60mod rpc_tests;
61mod stream_broadcast;
62mod tool_call;
63mod version;
64
65pub use agent::*;
66pub use client::*;
67pub use content::*;
68pub use error::*;
69pub use plan::*;
70pub use stream_broadcast::{
71 StreamMessage, StreamMessageContent, StreamMessageDirection, StreamReceiver,
72};
73pub use tool_call::*;
74pub use version::*;
75
76use anyhow::Result;
77use futures::{AsyncRead, AsyncWrite, Future, future::LocalBoxFuture};
78use schemars::JsonSchema;
79use serde::{Deserialize, Serialize};
80use serde_json::value::RawValue;
81use std::{fmt, sync::Arc};
82
83use crate::rpc::{MessageHandler, RpcConnection, Side};
84
85#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
101#[serde(transparent)]
102pub struct SessionId(pub Arc<str>);
103
104impl fmt::Display for SessionId {
105 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106 write!(f, "{}", self.0)
107 }
108}
109
110pub struct ClientSideConnection {
121 conn: RpcConnection<ClientSide, AgentSide>,
122}
123
124impl ClientSideConnection {
125 pub fn new(
145 client: impl MessageHandler<ClientSide> + 'static,
146 outgoing_bytes: impl Unpin + AsyncWrite,
147 incoming_bytes: impl Unpin + AsyncRead,
148 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
149 ) -> (Self, impl Future<Output = Result<()>>) {
150 let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
151 (Self { conn }, io_task)
152 }
153
154 pub fn subscribe(&self) -> StreamReceiver {
163 self.conn.subscribe()
164 }
165}
166
167impl Agent for ClientSideConnection {
168 async fn initialize(&self, arguments: InitializeRequest) -> Result<InitializeResponse, Error> {
169 self.conn
170 .request(
171 INITIALIZE_METHOD_NAME,
172 Some(ClientRequest::InitializeRequest(arguments)),
173 )
174 .await
175 }
176
177 async fn authenticate(&self, arguments: AuthenticateRequest) -> Result<(), Error> {
178 self.conn
179 .request(
180 AUTHENTICATE_METHOD_NAME,
181 Some(ClientRequest::AuthenticateRequest(arguments)),
182 )
183 .await
184 }
185
186 async fn new_session(&self, arguments: NewSessionRequest) -> Result<NewSessionResponse, Error> {
187 self.conn
188 .request(
189 SESSION_NEW_METHOD_NAME,
190 Some(ClientRequest::NewSessionRequest(arguments)),
191 )
192 .await
193 }
194
195 async fn load_session(&self, arguments: LoadSessionRequest) -> Result<(), Error> {
196 self.conn
197 .request(
198 SESSION_LOAD_METHOD_NAME,
199 Some(ClientRequest::LoadSessionRequest(arguments)),
200 )
201 .await
202 }
203
204 async fn prompt(&self, arguments: PromptRequest) -> Result<PromptResponse, Error> {
205 self.conn
206 .request(
207 SESSION_PROMPT_METHOD_NAME,
208 Some(ClientRequest::PromptRequest(arguments)),
209 )
210 .await
211 }
212
213 async fn cancel(&self, notification: CancelNotification) -> Result<(), Error> {
214 self.conn.notify(
215 SESSION_CANCEL_METHOD_NAME,
216 Some(ClientNotification::CancelNotification(notification)),
217 )
218 }
219}
220
221#[derive(Clone)]
228pub struct ClientSide;
229
230impl Side for ClientSide {
231 type InNotification = AgentNotification;
232 type InRequest = AgentRequest;
233 type OutResponse = ClientResponse;
234
235 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest, Error> {
236 let params = params.ok_or_else(Error::invalid_params)?;
237
238 match method {
239 SESSION_REQUEST_PERMISSION_METHOD_NAME => serde_json::from_str(params.get())
240 .map(AgentRequest::RequestPermissionRequest)
241 .map_err(Into::into),
242 FS_WRITE_TEXT_FILE_METHOD_NAME => serde_json::from_str(params.get())
243 .map(AgentRequest::WriteTextFileRequest)
244 .map_err(Into::into),
245 FS_READ_TEXT_FILE_METHOD_NAME => serde_json::from_str(params.get())
246 .map(AgentRequest::ReadTextFileRequest)
247 .map_err(Into::into),
248 _ => Err(Error::method_not_found()),
249 }
250 }
251
252 fn decode_notification(
253 method: &str,
254 params: Option<&RawValue>,
255 ) -> Result<AgentNotification, Error> {
256 let params = params.ok_or_else(Error::invalid_params)?;
257
258 match method {
259 SESSION_UPDATE_NOTIFICATION => serde_json::from_str(params.get())
260 .map(AgentNotification::SessionNotification)
261 .map_err(Into::into),
262 _ => Err(Error::method_not_found()),
263 }
264 }
265}
266
267impl<T: Client> MessageHandler<ClientSide> for T {
268 async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse, Error> {
269 match request {
270 AgentRequest::RequestPermissionRequest(args) => {
271 let response = self.request_permission(args).await?;
272 Ok(ClientResponse::RequestPermissionResponse(response))
273 }
274 AgentRequest::WriteTextFileRequest(args) => {
275 self.write_text_file(args).await?;
276 Ok(ClientResponse::WriteTextFileResponse)
277 }
278 AgentRequest::ReadTextFileRequest(args) => {
279 let response = self.read_text_file(args).await?;
280 Ok(ClientResponse::ReadTextFileResponse(response))
281 }
282 }
283 }
284
285 async fn handle_notification(&self, notification: AgentNotification) -> Result<(), Error> {
286 match notification {
287 AgentNotification::SessionNotification(notification) => {
288 self.session_notification(notification).await?;
289 }
290 }
291 Ok(())
292 }
293}
294
295pub struct AgentSideConnection {
306 conn: RpcConnection<AgentSide, ClientSide>,
307}
308
309impl AgentSideConnection {
310 pub fn new(
330 agent: impl MessageHandler<AgentSide> + 'static,
331 outgoing_bytes: impl Unpin + AsyncWrite,
332 incoming_bytes: impl Unpin + AsyncRead,
333 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
334 ) -> (Self, impl Future<Output = Result<()>>) {
335 let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
336 (Self { conn }, io_task)
337 }
338
339 pub fn subscribe(&self) -> StreamReceiver {
348 self.conn.subscribe()
349 }
350}
351
352impl Client for AgentSideConnection {
353 async fn request_permission(
354 &self,
355 arguments: RequestPermissionRequest,
356 ) -> Result<RequestPermissionResponse, Error> {
357 self.conn
358 .request(
359 SESSION_REQUEST_PERMISSION_METHOD_NAME,
360 Some(AgentRequest::RequestPermissionRequest(arguments)),
361 )
362 .await
363 }
364
365 async fn write_text_file(&self, arguments: WriteTextFileRequest) -> Result<(), Error> {
366 self.conn
367 .request(
368 FS_WRITE_TEXT_FILE_METHOD_NAME,
369 Some(AgentRequest::WriteTextFileRequest(arguments)),
370 )
371 .await
372 }
373
374 async fn read_text_file(
375 &self,
376 arguments: ReadTextFileRequest,
377 ) -> Result<ReadTextFileResponse, Error> {
378 self.conn
379 .request(
380 FS_READ_TEXT_FILE_METHOD_NAME,
381 Some(AgentRequest::ReadTextFileRequest(arguments)),
382 )
383 .await
384 }
385
386 async fn session_notification(&self, notification: SessionNotification) -> Result<(), Error> {
387 self.conn.notify(
388 SESSION_UPDATE_NOTIFICATION,
389 Some(AgentNotification::SessionNotification(notification)),
390 )
391 }
392}
393
394#[derive(Clone)]
401pub struct AgentSide;
402
403impl Side for AgentSide {
404 type InRequest = ClientRequest;
405 type InNotification = ClientNotification;
406 type OutResponse = AgentResponse;
407
408 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest, Error> {
409 let params = params.ok_or_else(Error::invalid_params)?;
410
411 match method {
412 INITIALIZE_METHOD_NAME => serde_json::from_str(params.get())
413 .map(ClientRequest::InitializeRequest)
414 .map_err(Into::into),
415 AUTHENTICATE_METHOD_NAME => serde_json::from_str(params.get())
416 .map(ClientRequest::AuthenticateRequest)
417 .map_err(Into::into),
418 SESSION_NEW_METHOD_NAME => serde_json::from_str(params.get())
419 .map(ClientRequest::NewSessionRequest)
420 .map_err(Into::into),
421 SESSION_LOAD_METHOD_NAME => serde_json::from_str(params.get())
422 .map(ClientRequest::LoadSessionRequest)
423 .map_err(Into::into),
424 SESSION_PROMPT_METHOD_NAME => serde_json::from_str(params.get())
425 .map(ClientRequest::PromptRequest)
426 .map_err(Into::into),
427 _ => Err(Error::method_not_found()),
428 }
429 }
430
431 fn decode_notification(
432 method: &str,
433 params: Option<&RawValue>,
434 ) -> Result<ClientNotification, Error> {
435 let params = params.ok_or_else(Error::invalid_params)?;
436
437 match method {
438 SESSION_CANCEL_METHOD_NAME => serde_json::from_str(params.get())
439 .map(ClientNotification::CancelNotification)
440 .map_err(Into::into),
441 _ => Err(Error::method_not_found()),
442 }
443 }
444}
445
446impl<T: Agent> MessageHandler<AgentSide> for T {
447 async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse, Error> {
448 match request {
449 ClientRequest::InitializeRequest(args) => {
450 let response = self.initialize(args).await?;
451 Ok(AgentResponse::InitializeResponse(response))
452 }
453 ClientRequest::AuthenticateRequest(args) => {
454 self.authenticate(args).await?;
455 Ok(AgentResponse::AuthenticateResponse)
456 }
457 ClientRequest::NewSessionRequest(args) => {
458 let response = self.new_session(args).await?;
459 Ok(AgentResponse::NewSessionResponse(response))
460 }
461 ClientRequest::LoadSessionRequest(args) => {
462 self.load_session(args).await?;
463 Ok(AgentResponse::LoadSessionResponse)
464 }
465 ClientRequest::PromptRequest(args) => {
466 let response = self.prompt(args).await?;
467 Ok(AgentResponse::PromptResponse(response))
468 }
469 }
470 }
471
472 async fn handle_notification(&self, notification: ClientNotification) -> Result<(), Error> {
473 match notification {
474 ClientNotification::CancelNotification(notification) => {
475 self.cancel(notification).await?;
476 }
477 }
478 Ok(())
479 }
480}