1mod agent;
2mod client;
3mod content;
4mod error;
5mod plan;
6mod rpc;
7#[cfg(test)]
8mod rpc_tests;
9mod tool_call;
10mod version;
11
12pub use agent::*;
13pub use client::*;
14pub use content::*;
15pub use error::*;
16pub use plan::*;
17pub use tool_call::*;
18pub use version::*;
19
20use anyhow::Result;
21use futures::{AsyncRead, AsyncWrite, Future, future::LocalBoxFuture};
22use schemars::JsonSchema;
23use serde::{Deserialize, Serialize};
24use serde_json::value::RawValue;
25use std::{fmt, sync::Arc};
26
27use crate::rpc::{MessageHandler, RpcConnection, Side};
28
29#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
30#[serde(transparent)]
31pub struct SessionId(pub Arc<str>);
32
33impl fmt::Display for SessionId {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 write!(f, "{}", self.0)
36 }
37}
38
39pub struct ClientSideConnection {
42 conn: RpcConnection<ClientSide, AgentSide>,
43}
44
45impl ClientSideConnection {
46 pub fn new(
47 client: impl MessageHandler<ClientSide> + 'static,
48 outgoing_bytes: impl Unpin + AsyncWrite,
49 incoming_bytes: impl Unpin + AsyncRead,
50 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
51 ) -> (Self, impl Future<Output = Result<()>>) {
52 let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
53 (Self { conn }, io_task)
54 }
55}
56
57impl Agent for ClientSideConnection {
58 async fn initialize(&self, arguments: InitializeRequest) -> Result<InitializeResponse, Error> {
59 self.conn
60 .request(
61 INITIALIZE_METHOD_NAME,
62 Some(ClientRequest::InitializeRequest(arguments)),
63 )
64 .await
65 }
66
67 async fn authenticate(&self, arguments: AuthenticateRequest) -> Result<(), Error> {
68 self.conn
69 .request(
70 AUTHENTICATE_METHOD_NAME,
71 Some(ClientRequest::AuthenticateRequest(arguments)),
72 )
73 .await
74 }
75
76 async fn new_session(&self, arguments: NewSessionRequest) -> Result<NewSessionResponse, Error> {
77 self.conn
78 .request(
79 SESSION_NEW_METHOD_NAME,
80 Some(ClientRequest::NewSessionRequest(arguments)),
81 )
82 .await
83 }
84
85 async fn load_session(
86 &self,
87 arguments: LoadSessionRequest,
88 ) -> Result<LoadSessionResponse, Error> {
89 self.conn
90 .request(
91 SESSION_LOAD_METHOD_NAME,
92 Some(ClientRequest::LoadSessionRequest(arguments)),
93 )
94 .await
95 }
96
97 async fn prompt(&self, arguments: PromptRequest) -> Result<(), Error> {
98 self.conn
99 .request(
100 SESSION_PROMPT_METHOD_NAME,
101 Some(ClientRequest::PromptRequest(arguments)),
102 )
103 .await
104 }
105
106 async fn cancelled(&self, notification: CancelledNotification) -> Result<(), Error> {
107 self.conn.notify(
108 SESSION_CANCELLED_METHOD_NAME,
109 Some(ClientNotification::CancelledNotification(notification)),
110 )
111 }
112}
113
114pub struct ClientSide;
115
116impl Side for ClientSide {
117 type InNotification = AgentNotification;
118 type InRequest = AgentRequest;
119 type OutResponse = ClientResponse;
120
121 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest, Error> {
122 let params = params.ok_or_else(Error::invalid_params)?;
123
124 match method {
125 SESSION_REQUEST_PERMISSION_METHOD_NAME => serde_json::from_str(params.get())
126 .map(AgentRequest::RequestPermissionRequest)
127 .map_err(Into::into),
128 FS_WRITE_TEXT_FILE_METHOD_NAME => serde_json::from_str(params.get())
129 .map(AgentRequest::WriteTextFileRequest)
130 .map_err(Into::into),
131 FS_READ_TEXT_FILE_METHOD_NAME => serde_json::from_str(params.get())
132 .map(AgentRequest::ReadTextFileRequest)
133 .map_err(Into::into),
134 _ => Err(Error::method_not_found()),
135 }
136 }
137
138 fn decode_notification(
139 method: &str,
140 params: Option<&RawValue>,
141 ) -> Result<AgentNotification, Error> {
142 let params = params.ok_or_else(Error::invalid_params)?;
143
144 match method {
145 SESSION_UPDATE_NOTIFICATION => serde_json::from_str(params.get())
146 .map(AgentNotification::SessionNotification)
147 .map_err(Into::into),
148 _ => Err(Error::method_not_found()),
149 }
150 }
151}
152
153impl<T: Client> MessageHandler<ClientSide> for T {
154 async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse, Error> {
155 match request {
156 AgentRequest::RequestPermissionRequest(args) => {
157 let response = self.request_permission(args).await?;
158 Ok(ClientResponse::RequestPermissionResponse(response))
159 }
160 AgentRequest::WriteTextFileRequest(args) => {
161 self.write_text_file(args).await?;
162 Ok(ClientResponse::WriteTextFileResponse)
163 }
164 AgentRequest::ReadTextFileRequest(args) => {
165 let response = self.read_text_file(args).await?;
166 Ok(ClientResponse::ReadTextFileResponse(response))
167 }
168 }
169 }
170
171 async fn handle_notification(&self, notification: AgentNotification) -> Result<(), Error> {
172 match notification {
173 AgentNotification::SessionNotification(notification) => {
174 self.session_notification(notification).await?;
175 }
176 }
177 Ok(())
178 }
179}
180
181pub struct AgentSideConnection {
184 conn: RpcConnection<AgentSide, ClientSide>,
185}
186
187impl AgentSideConnection {
188 pub fn new(
189 agent: impl MessageHandler<AgentSide> + 'static,
190 outgoing_bytes: impl Unpin + AsyncWrite,
191 incoming_bytes: impl Unpin + AsyncRead,
192 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
193 ) -> (Self, impl Future<Output = Result<()>>) {
194 let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
195 (Self { conn }, io_task)
196 }
197}
198
199impl Client for AgentSideConnection {
200 async fn request_permission(
201 &self,
202 arguments: RequestPermissionRequest,
203 ) -> Result<RequestPermissionResponse, Error> {
204 self.conn
205 .request(
206 SESSION_REQUEST_PERMISSION_METHOD_NAME,
207 Some(AgentRequest::RequestPermissionRequest(arguments)),
208 )
209 .await
210 }
211
212 async fn write_text_file(&self, arguments: WriteTextFileRequest) -> Result<(), Error> {
213 self.conn
214 .request(
215 FS_WRITE_TEXT_FILE_METHOD_NAME,
216 Some(AgentRequest::WriteTextFileRequest(arguments)),
217 )
218 .await
219 }
220
221 async fn read_text_file(
222 &self,
223 arguments: ReadTextFileRequest,
224 ) -> Result<ReadTextFileResponse, Error> {
225 self.conn
226 .request(
227 FS_READ_TEXT_FILE_METHOD_NAME,
228 Some(AgentRequest::ReadTextFileRequest(arguments)),
229 )
230 .await
231 }
232
233 async fn session_notification(&self, notification: SessionNotification) -> Result<(), Error> {
234 self.conn.notify(
235 SESSION_UPDATE_NOTIFICATION,
236 Some(AgentNotification::SessionNotification(notification)),
237 )
238 }
239}
240
241pub struct AgentSide;
242
243impl Side for AgentSide {
244 type InRequest = ClientRequest;
245 type InNotification = ClientNotification;
246 type OutResponse = AgentResponse;
247
248 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest, Error> {
249 let params = params.ok_or_else(Error::invalid_params)?;
250
251 match method {
252 INITIALIZE_METHOD_NAME => serde_json::from_str(params.get())
253 .map(ClientRequest::InitializeRequest)
254 .map_err(Into::into),
255 AUTHENTICATE_METHOD_NAME => serde_json::from_str(params.get())
256 .map(ClientRequest::AuthenticateRequest)
257 .map_err(Into::into),
258 SESSION_NEW_METHOD_NAME => serde_json::from_str(params.get())
259 .map(ClientRequest::NewSessionRequest)
260 .map_err(Into::into),
261 SESSION_LOAD_METHOD_NAME => serde_json::from_str(params.get())
262 .map(ClientRequest::LoadSessionRequest)
263 .map_err(Into::into),
264 SESSION_PROMPT_METHOD_NAME => serde_json::from_str(params.get())
265 .map(ClientRequest::PromptRequest)
266 .map_err(Into::into),
267 _ => Err(Error::method_not_found()),
268 }
269 }
270
271 fn decode_notification(
272 method: &str,
273 params: Option<&RawValue>,
274 ) -> Result<ClientNotification, Error> {
275 let params = params.ok_or_else(Error::invalid_params)?;
276
277 match method {
278 SESSION_CANCELLED_METHOD_NAME => serde_json::from_str(params.get())
279 .map(ClientNotification::CancelledNotification)
280 .map_err(Into::into),
281 _ => Err(Error::method_not_found()),
282 }
283 }
284}
285
286impl<T: Agent> MessageHandler<AgentSide> for T {
287 async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse, Error> {
288 match request {
289 ClientRequest::InitializeRequest(args) => {
290 let response = self.initialize(args).await?;
291 Ok(AgentResponse::InitializeResponse(response))
292 }
293 ClientRequest::AuthenticateRequest(args) => {
294 self.authenticate(args).await?;
295 Ok(AgentResponse::AuthenticateResponse)
296 }
297 ClientRequest::NewSessionRequest(args) => {
298 let response = self.new_session(args).await?;
299 Ok(AgentResponse::NewSessionResponse(response))
300 }
301 ClientRequest::LoadSessionRequest(args) => {
302 let response = self.load_session(args).await?;
303 Ok(AgentResponse::LoadSessionResponse(response))
304 }
305 ClientRequest::PromptRequest(args) => {
306 self.prompt(args).await?;
307 Ok(AgentResponse::PromptResponse)
308 }
309 }
310 }
311
312 async fn handle_notification(&self, notification: ClientNotification) -> Result<(), Error> {
313 match notification {
314 ClientNotification::CancelledNotification(notification) => {
315 self.cancelled(notification).await?;
316 }
317 }
318 Ok(())
319 }
320}