1mod agent;
2mod client;
3mod content;
4mod error;
5mod plan;
6mod rpc;
7#[cfg(test)]
8mod rpc_tests;
9mod tool_call;
10mod version;
11
12pub use agent::*;
13pub use client::*;
14pub use content::*;
15pub use error::*;
16pub use plan::*;
17pub use tool_call::*;
18pub use version::*;
19
20use anyhow::Result;
21use futures::{AsyncRead, AsyncWrite, Future, future::LocalBoxFuture};
22use schemars::JsonSchema;
23use serde::{Deserialize, Serialize};
24use serde_json::value::RawValue;
25use std::{fmt, sync::Arc};
26
27use crate::rpc::{MessageHandler, RpcConnection, Side};
28
29#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
30#[serde(transparent)]
31pub struct SessionId(pub Arc<str>);
32
33impl fmt::Display for SessionId {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 write!(f, "{}", self.0)
36 }
37}
38
39pub struct ClientSideConnection {
42 conn: RpcConnection<ClientSide, AgentSide>,
43}
44
45impl ClientSideConnection {
46 pub fn new(
47 client: impl MessageHandler<ClientSide> + 'static,
48 outgoing_bytes: impl Unpin + AsyncWrite,
49 incoming_bytes: impl Unpin + AsyncRead,
50 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
51 ) -> (Self, impl Future<Output = Result<()>>) {
52 let (conn, io_task) = RpcConnection::new(client, outgoing_bytes, incoming_bytes, spawn);
53 (Self { conn }, io_task)
54 }
55}
56
57impl Agent for ClientSideConnection {
58 async fn initialize(&self, arguments: InitializeRequest) -> Result<InitializeResponse, Error> {
59 self.conn
60 .request(
61 INITIALIZE_METHOD_NAME,
62 Some(ClientRequest::InitializeRequest(arguments)),
63 )
64 .await
65 }
66
67 async fn authenticate(&self, arguments: AuthenticateRequest) -> Result<(), Error> {
68 self.conn
69 .request(
70 AUTHENTICATE_METHOD_NAME,
71 Some(ClientRequest::AuthenticateRequest(arguments)),
72 )
73 .await
74 }
75
76 async fn new_session(&self, arguments: NewSessionRequest) -> Result<NewSessionResponse, Error> {
77 self.conn
78 .request(
79 SESSION_NEW_METHOD_NAME,
80 Some(ClientRequest::NewSessionRequest(arguments)),
81 )
82 .await
83 }
84
85 async fn load_session(&self, arguments: LoadSessionRequest) -> Result<(), Error> {
86 self.conn
87 .request(
88 SESSION_LOAD_METHOD_NAME,
89 Some(ClientRequest::LoadSessionRequest(arguments)),
90 )
91 .await
92 }
93
94 async fn prompt(&self, arguments: PromptRequest) -> Result<PromptResponse, Error> {
95 self.conn
96 .request(
97 SESSION_PROMPT_METHOD_NAME,
98 Some(ClientRequest::PromptRequest(arguments)),
99 )
100 .await
101 }
102
103 async fn cancel(&self, notification: CancelNotification) -> Result<(), Error> {
104 self.conn.notify(
105 SESSION_CANCEL_METHOD_NAME,
106 Some(ClientNotification::CancelNotification(notification)),
107 )
108 }
109}
110
111pub struct ClientSide;
112
113impl Side for ClientSide {
114 type InNotification = AgentNotification;
115 type InRequest = AgentRequest;
116 type OutResponse = ClientResponse;
117
118 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest, Error> {
119 let params = params.ok_or_else(Error::invalid_params)?;
120
121 match method {
122 SESSION_REQUEST_PERMISSION_METHOD_NAME => serde_json::from_str(params.get())
123 .map(AgentRequest::RequestPermissionRequest)
124 .map_err(Into::into),
125 FS_WRITE_TEXT_FILE_METHOD_NAME => serde_json::from_str(params.get())
126 .map(AgentRequest::WriteTextFileRequest)
127 .map_err(Into::into),
128 FS_READ_TEXT_FILE_METHOD_NAME => serde_json::from_str(params.get())
129 .map(AgentRequest::ReadTextFileRequest)
130 .map_err(Into::into),
131 _ => Err(Error::method_not_found()),
132 }
133 }
134
135 fn decode_notification(
136 method: &str,
137 params: Option<&RawValue>,
138 ) -> Result<AgentNotification, Error> {
139 let params = params.ok_or_else(Error::invalid_params)?;
140
141 match method {
142 SESSION_UPDATE_NOTIFICATION => serde_json::from_str(params.get())
143 .map(AgentNotification::SessionNotification)
144 .map_err(Into::into),
145 _ => Err(Error::method_not_found()),
146 }
147 }
148}
149
150impl<T: Client> MessageHandler<ClientSide> for T {
151 async fn handle_request(&self, request: AgentRequest) -> Result<ClientResponse, Error> {
152 match request {
153 AgentRequest::RequestPermissionRequest(args) => {
154 let response = self.request_permission(args).await?;
155 Ok(ClientResponse::RequestPermissionResponse(response))
156 }
157 AgentRequest::WriteTextFileRequest(args) => {
158 self.write_text_file(args).await?;
159 Ok(ClientResponse::WriteTextFileResponse)
160 }
161 AgentRequest::ReadTextFileRequest(args) => {
162 let response = self.read_text_file(args).await?;
163 Ok(ClientResponse::ReadTextFileResponse(response))
164 }
165 }
166 }
167
168 async fn handle_notification(&self, notification: AgentNotification) -> Result<(), Error> {
169 match notification {
170 AgentNotification::SessionNotification(notification) => {
171 self.session_notification(notification).await?;
172 }
173 }
174 Ok(())
175 }
176}
177
178pub struct AgentSideConnection {
181 conn: RpcConnection<AgentSide, ClientSide>,
182}
183
184impl AgentSideConnection {
185 pub fn new(
186 agent: impl MessageHandler<AgentSide> + 'static,
187 outgoing_bytes: impl Unpin + AsyncWrite,
188 incoming_bytes: impl Unpin + AsyncRead,
189 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
190 ) -> (Self, impl Future<Output = Result<()>>) {
191 let (conn, io_task) = RpcConnection::new(agent, outgoing_bytes, incoming_bytes, spawn);
192 (Self { conn }, io_task)
193 }
194}
195
196impl Client for AgentSideConnection {
197 async fn request_permission(
198 &self,
199 arguments: RequestPermissionRequest,
200 ) -> Result<RequestPermissionResponse, Error> {
201 self.conn
202 .request(
203 SESSION_REQUEST_PERMISSION_METHOD_NAME,
204 Some(AgentRequest::RequestPermissionRequest(arguments)),
205 )
206 .await
207 }
208
209 async fn write_text_file(&self, arguments: WriteTextFileRequest) -> Result<(), Error> {
210 self.conn
211 .request(
212 FS_WRITE_TEXT_FILE_METHOD_NAME,
213 Some(AgentRequest::WriteTextFileRequest(arguments)),
214 )
215 .await
216 }
217
218 async fn read_text_file(
219 &self,
220 arguments: ReadTextFileRequest,
221 ) -> Result<ReadTextFileResponse, Error> {
222 self.conn
223 .request(
224 FS_READ_TEXT_FILE_METHOD_NAME,
225 Some(AgentRequest::ReadTextFileRequest(arguments)),
226 )
227 .await
228 }
229
230 async fn session_notification(&self, notification: SessionNotification) -> Result<(), Error> {
231 self.conn.notify(
232 SESSION_UPDATE_NOTIFICATION,
233 Some(AgentNotification::SessionNotification(notification)),
234 )
235 }
236}
237
238pub struct AgentSide;
239
240impl Side for AgentSide {
241 type InRequest = ClientRequest;
242 type InNotification = ClientNotification;
243 type OutResponse = AgentResponse;
244
245 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest, Error> {
246 let params = params.ok_or_else(Error::invalid_params)?;
247
248 match method {
249 INITIALIZE_METHOD_NAME => serde_json::from_str(params.get())
250 .map(ClientRequest::InitializeRequest)
251 .map_err(Into::into),
252 AUTHENTICATE_METHOD_NAME => serde_json::from_str(params.get())
253 .map(ClientRequest::AuthenticateRequest)
254 .map_err(Into::into),
255 SESSION_NEW_METHOD_NAME => serde_json::from_str(params.get())
256 .map(ClientRequest::NewSessionRequest)
257 .map_err(Into::into),
258 SESSION_LOAD_METHOD_NAME => serde_json::from_str(params.get())
259 .map(ClientRequest::LoadSessionRequest)
260 .map_err(Into::into),
261 SESSION_PROMPT_METHOD_NAME => serde_json::from_str(params.get())
262 .map(ClientRequest::PromptRequest)
263 .map_err(Into::into),
264 _ => Err(Error::method_not_found()),
265 }
266 }
267
268 fn decode_notification(
269 method: &str,
270 params: Option<&RawValue>,
271 ) -> Result<ClientNotification, Error> {
272 let params = params.ok_or_else(Error::invalid_params)?;
273
274 match method {
275 SESSION_CANCEL_METHOD_NAME => serde_json::from_str(params.get())
276 .map(ClientNotification::CancelNotification)
277 .map_err(Into::into),
278 _ => Err(Error::method_not_found()),
279 }
280 }
281}
282
283impl<T: Agent> MessageHandler<AgentSide> for T {
284 async fn handle_request(&self, request: ClientRequest) -> Result<AgentResponse, Error> {
285 match request {
286 ClientRequest::InitializeRequest(args) => {
287 let response = self.initialize(args).await?;
288 Ok(AgentResponse::InitializeResponse(response))
289 }
290 ClientRequest::AuthenticateRequest(args) => {
291 self.authenticate(args).await?;
292 Ok(AgentResponse::AuthenticateResponse)
293 }
294 ClientRequest::NewSessionRequest(args) => {
295 let response = self.new_session(args).await?;
296 Ok(AgentResponse::NewSessionResponse(response))
297 }
298 ClientRequest::LoadSessionRequest(args) => {
299 self.load_session(args).await?;
300 Ok(AgentResponse::LoadSessionResponse)
301 }
302 ClientRequest::PromptRequest(args) => {
303 let response = self.prompt(args).await?;
304 Ok(AgentResponse::PromptResponse(response))
305 }
306 }
307 }
308
309 async fn handle_notification(&self, notification: ClientNotification) -> Result<(), Error> {
310 match notification {
311 ClientNotification::CancelNotification(notification) => {
312 self.cancel(notification).await?;
313 }
314 }
315 Ok(())
316 }
317}