1mod agent;
48mod client;
49mod content;
50mod error;
51mod plan;
52mod rpc;
53#[cfg(test)]
54mod rpc_tests;
55mod stream_broadcast;
56mod tool_call;
57mod version;
58
59pub use agent::*;
60pub use client::*;
61pub use content::*;
62pub use error::*;
63pub use plan::*;
64pub use stream_broadcast::{
65 StreamMessage, StreamMessageContent, StreamMessageDirection, StreamReceiver,
66};
67pub use tool_call::*;
68pub use version::*;
69
70use anyhow::Result;
71use futures::{AsyncRead, AsyncWrite, Future, future::LocalBoxFuture};
72use schemars::JsonSchema;
73use serde::{Deserialize, Serialize};
74use serde_json::value::RawValue;
75use std::{fmt, sync::Arc};
76
77use crate::rpc::{MessageHandler, RpcConnection, Side};
78
79#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
95#[serde(transparent)]
96pub struct SessionId(pub Arc<str>);
97
98impl fmt::Display for SessionId {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 write!(f, "{}", self.0)
101 }
102}
103
104pub struct ClientSideConnection {
115 conn: RpcConnection<ClientSide, AgentSide>,
116}
117
118impl ClientSideConnection {
119 pub fn new(
139 client: impl MessageHandler<ClientSide> + 'static,
140 outgoing_bytes: impl Unpin + AsyncWrite,
141 incoming_bytes: impl Unpin + AsyncRead,
142 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
143 ) -> (Self, impl Future<Output = Result<()>>) {
144 let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
145 (Self { conn }, io_task)
146 }
147
148 pub fn subscribe(&self) -> StreamReceiver {
157 self.conn.subscribe()
158 }
159}
160
161impl Agent for ClientSideConnection {
162 async fn initialize(&self, arguments: InitializeRequest) -> Result<InitializeResponse, Error> {
163 self.conn
164 .request(
165 INITIALIZE_METHOD_NAME,
166 Some(ClientRequest::InitializeRequest(arguments)),
167 )
168 .await
169 }
170
171 async fn authenticate(&self, arguments: AuthenticateRequest) -> Result<(), Error> {
172 self.conn
173 .request(
174 AUTHENTICATE_METHOD_NAME,
175 Some(ClientRequest::AuthenticateRequest(arguments)),
176 )
177 .await
178 }
179
180 async fn new_session(&self, arguments: NewSessionRequest) -> Result<NewSessionResponse, Error> {
181 self.conn
182 .request(
183 SESSION_NEW_METHOD_NAME,
184 Some(ClientRequest::NewSessionRequest(arguments)),
185 )
186 .await
187 }
188
189 async fn load_session(&self, arguments: LoadSessionRequest) -> Result<(), Error> {
190 self.conn
191 .request(
192 SESSION_LOAD_METHOD_NAME,
193 Some(ClientRequest::LoadSessionRequest(arguments)),
194 )
195 .await
196 }
197
198 async fn prompt(&self, arguments: PromptRequest) -> Result<PromptResponse, Error> {
199 self.conn
200 .request(
201 SESSION_PROMPT_METHOD_NAME,
202 Some(ClientRequest::PromptRequest(arguments)),
203 )
204 .await
205 }
206
207 async fn cancel(&self, notification: CancelNotification) -> Result<(), Error> {
208 self.conn.notify(
209 SESSION_CANCEL_METHOD_NAME,
210 Some(ClientNotification::CancelNotification(notification)),
211 )
212 }
213}
214
215#[derive(Clone)]
222pub struct ClientSide;
223
224impl Side for ClientSide {
225 type InNotification = AgentNotification;
226 type InRequest = AgentRequest;
227 type OutResponse = ClientResponse;
228
229 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest, Error> {
230 let params = params.ok_or_else(Error::invalid_params)?;
231
232 match method {
233 SESSION_REQUEST_PERMISSION_METHOD_NAME => serde_json::from_str(params.get())
234 .map(AgentRequest::RequestPermissionRequest)
235 .map_err(Into::into),
236 FS_WRITE_TEXT_FILE_METHOD_NAME => serde_json::from_str(params.get())
237 .map(AgentRequest::WriteTextFileRequest)
238 .map_err(Into::into),
239 FS_READ_TEXT_FILE_METHOD_NAME => serde_json::from_str(params.get())
240 .map(AgentRequest::ReadTextFileRequest)
241 .map_err(Into::into),
242 _ => Err(Error::method_not_found()),
243 }
244 }
245
246 fn decode_notification(
247 method: &str,
248 params: Option<&RawValue>,
249 ) -> Result<AgentNotification, Error> {
250 let params = params.ok_or_else(Error::invalid_params)?;
251
252 match method {
253 SESSION_UPDATE_NOTIFICATION => serde_json::from_str(params.get())
254 .map(AgentNotification::SessionNotification)
255 .map_err(Into::into),
256 _ => Err(Error::method_not_found()),
257 }
258 }
259}
260
261impl<T: Client> MessageHandler<ClientSide> for T {
262 async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse, Error> {
263 match request {
264 AgentRequest::RequestPermissionRequest(args) => {
265 let response = self.request_permission(args).await?;
266 Ok(ClientResponse::RequestPermissionResponse(response))
267 }
268 AgentRequest::WriteTextFileRequest(args) => {
269 self.write_text_file(args).await?;
270 Ok(ClientResponse::WriteTextFileResponse)
271 }
272 AgentRequest::ReadTextFileRequest(args) => {
273 let response = self.read_text_file(args).await?;
274 Ok(ClientResponse::ReadTextFileResponse(response))
275 }
276 }
277 }
278
279 async fn handle_notification(&self, notification: AgentNotification) -> Result<(), Error> {
280 match notification {
281 AgentNotification::SessionNotification(notification) => {
282 self.session_notification(notification).await?;
283 }
284 }
285 Ok(())
286 }
287}
288
289pub struct AgentSideConnection {
300 conn: RpcConnection<AgentSide, ClientSide>,
301}
302
303impl AgentSideConnection {
304 pub fn new(
324 agent: impl MessageHandler<AgentSide> + 'static,
325 outgoing_bytes: impl Unpin + AsyncWrite,
326 incoming_bytes: impl Unpin + AsyncRead,
327 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
328 ) -> (Self, impl Future<Output = Result<()>>) {
329 let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
330 (Self { conn }, io_task)
331 }
332
333 pub fn subscribe(&self) -> StreamReceiver {
342 self.conn.subscribe()
343 }
344}
345
346impl Client for AgentSideConnection {
347 async fn request_permission(
348 &self,
349 arguments: RequestPermissionRequest,
350 ) -> Result<RequestPermissionResponse, Error> {
351 self.conn
352 .request(
353 SESSION_REQUEST_PERMISSION_METHOD_NAME,
354 Some(AgentRequest::RequestPermissionRequest(arguments)),
355 )
356 .await
357 }
358
359 async fn write_text_file(&self, arguments: WriteTextFileRequest) -> Result<(), Error> {
360 self.conn
361 .request(
362 FS_WRITE_TEXT_FILE_METHOD_NAME,
363 Some(AgentRequest::WriteTextFileRequest(arguments)),
364 )
365 .await
366 }
367
368 async fn read_text_file(
369 &self,
370 arguments: ReadTextFileRequest,
371 ) -> Result<ReadTextFileResponse, Error> {
372 self.conn
373 .request(
374 FS_READ_TEXT_FILE_METHOD_NAME,
375 Some(AgentRequest::ReadTextFileRequest(arguments)),
376 )
377 .await
378 }
379
380 async fn session_notification(&self, notification: SessionNotification) -> Result<(), Error> {
381 self.conn.notify(
382 SESSION_UPDATE_NOTIFICATION,
383 Some(AgentNotification::SessionNotification(notification)),
384 )
385 }
386}
387
388#[derive(Clone)]
395pub struct AgentSide;
396
397impl Side for AgentSide {
398 type InRequest = ClientRequest;
399 type InNotification = ClientNotification;
400 type OutResponse = AgentResponse;
401
402 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest, Error> {
403 let params = params.ok_or_else(Error::invalid_params)?;
404
405 match method {
406 INITIALIZE_METHOD_NAME => serde_json::from_str(params.get())
407 .map(ClientRequest::InitializeRequest)
408 .map_err(Into::into),
409 AUTHENTICATE_METHOD_NAME => serde_json::from_str(params.get())
410 .map(ClientRequest::AuthenticateRequest)
411 .map_err(Into::into),
412 SESSION_NEW_METHOD_NAME => serde_json::from_str(params.get())
413 .map(ClientRequest::NewSessionRequest)
414 .map_err(Into::into),
415 SESSION_LOAD_METHOD_NAME => serde_json::from_str(params.get())
416 .map(ClientRequest::LoadSessionRequest)
417 .map_err(Into::into),
418 SESSION_PROMPT_METHOD_NAME => serde_json::from_str(params.get())
419 .map(ClientRequest::PromptRequest)
420 .map_err(Into::into),
421 _ => Err(Error::method_not_found()),
422 }
423 }
424
425 fn decode_notification(
426 method: &str,
427 params: Option<&RawValue>,
428 ) -> Result<ClientNotification, Error> {
429 let params = params.ok_or_else(Error::invalid_params)?;
430
431 match method {
432 SESSION_CANCEL_METHOD_NAME => serde_json::from_str(params.get())
433 .map(ClientNotification::CancelNotification)
434 .map_err(Into::into),
435 _ => Err(Error::method_not_found()),
436 }
437 }
438}
439
440impl<T: Agent> MessageHandler<AgentSide> for T {
441 async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse, Error> {
442 match request {
443 ClientRequest::InitializeRequest(args) => {
444 let response = self.initialize(args).await?;
445 Ok(AgentResponse::InitializeResponse(response))
446 }
447 ClientRequest::AuthenticateRequest(args) => {
448 self.authenticate(args).await?;
449 Ok(AgentResponse::AuthenticateResponse)
450 }
451 ClientRequest::NewSessionRequest(args) => {
452 let response = self.new_session(args).await?;
453 Ok(AgentResponse::NewSessionResponse(response))
454 }
455 ClientRequest::LoadSessionRequest(args) => {
456 self.load_session(args).await?;
457 Ok(AgentResponse::LoadSessionResponse)
458 }
459 ClientRequest::PromptRequest(args) => {
460 let response = self.prompt(args).await?;
461 Ok(AgentResponse::PromptResponse(response))
462 }
463 }
464 }
465
466 async fn handle_notification(&self, notification: ClientNotification) -> Result<(), Error> {
467 match notification {
468 ClientNotification::CancelNotification(notification) => {
469 self.cancel(notification).await?;
470 }
471 }
472 Ok(())
473 }
474}