agent_client_protocol_schema/
rpc.rs1use std::sync::Arc;
2
3use derive_more::Display;
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize, de::DeserializeOwned};
6use serde_json::value::RawValue;
7
8use crate::{
9 AGENT_METHOD_NAMES, AgentNotification, AgentRequest, AgentResponse, CLIENT_METHOD_NAMES,
10 ClientNotification, ClientRequest, ClientResponse, Error, ExtNotification, ExtRequest, Result,
11};
12
13#[derive(
23 Debug, PartialEq, Clone, Hash, Eq, Deserialize, Serialize, PartialOrd, Ord, Display, JsonSchema,
24)]
25#[serde(untagged)]
26#[schemars(inline)]
27#[allow(
28 clippy::exhaustive_enums,
29 reason = "This comes from the JSON-RPC specification itself"
30)]
31pub enum RequestId {
32 #[display("null")]
33 Null,
34 Number(i64),
35 Str(String),
36}
37
38#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
39#[serde(untagged)]
40#[schemars(inline)]
41#[non_exhaustive]
42pub enum OutgoingMessage<Local: Side, Remote: Side> {
43 Request {
44 id: RequestId,
45 method: Arc<str>,
46 #[serde(skip_serializing_if = "Option::is_none")]
47 params: Option<Remote::InRequest>,
48 },
49 Response {
50 id: RequestId,
51 #[serde(flatten)]
52 result: ResponseResult<Local::OutResponse>,
53 },
54 Notification {
55 method: Arc<str>,
56 #[serde(skip_serializing_if = "Option::is_none")]
57 params: Option<Remote::InNotification>,
58 },
59}
60
61#[derive(Debug, Serialize, Deserialize, JsonSchema)]
62#[schemars(inline)]
63enum JsonRpcVersion {
64 #[serde(rename = "2.0")]
65 V2,
66}
67
68#[derive(Debug, Serialize, Deserialize, JsonSchema)]
73#[schemars(inline)]
74pub struct JsonRpcMessage<M> {
75 jsonrpc: JsonRpcVersion,
76 #[serde(flatten)]
77 message: M,
78}
79
80impl<M> JsonRpcMessage<M> {
81 #[must_use]
84 pub fn wrap(message: M) -> Self {
85 Self {
86 jsonrpc: JsonRpcVersion::V2,
87 message,
88 }
89 }
90}
91
92#[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)]
93#[serde(rename_all = "snake_case")]
94#[non_exhaustive]
95pub enum ResponseResult<Res> {
96 Result(Res),
97 Error(Error),
98}
99
100impl<T> From<Result<T>> for ResponseResult<T> {
101 fn from(result: Result<T>) -> Self {
102 match result {
103 Ok(value) => ResponseResult::Result(value),
104 Err(error) => ResponseResult::Error(error),
105 }
106 }
107}
108
109pub trait Side: Clone {
110 type InRequest: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
111 type InNotification: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
112 type OutResponse: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
113
114 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest>;
122
123 fn decode_notification(method: &str, params: Option<&RawValue>)
131 -> Result<Self::InNotification>;
132}
133
134#[derive(Clone, Default, Debug, JsonSchema)]
141#[non_exhaustive]
142pub struct ClientSide;
143
144impl Side for ClientSide {
145 type InRequest = AgentRequest;
146 type InNotification = AgentNotification;
147 type OutResponse = ClientResponse;
148
149 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
150 let params = params.ok_or_else(Error::invalid_params)?;
151
152 match method {
153 m if m == CLIENT_METHOD_NAMES.session_request_permission => {
154 serde_json::from_str(params.get())
155 .map(AgentRequest::RequestPermissionRequest)
156 .map_err(Into::into)
157 }
158 m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
159 .map(AgentRequest::WriteTextFileRequest)
160 .map_err(Into::into),
161 m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
162 .map(AgentRequest::ReadTextFileRequest)
163 .map_err(Into::into),
164 m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
165 .map(AgentRequest::CreateTerminalRequest)
166 .map_err(Into::into),
167 m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
168 .map(AgentRequest::TerminalOutputRequest)
169 .map_err(Into::into),
170 m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
171 .map(AgentRequest::KillTerminalCommandRequest)
172 .map_err(Into::into),
173 m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
174 .map(AgentRequest::ReleaseTerminalRequest)
175 .map_err(Into::into),
176 m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
177 serde_json::from_str(params.get())
178 .map(AgentRequest::WaitForTerminalExitRequest)
179 .map_err(Into::into)
180 }
181 _ => {
182 if let Some(custom_method) = method.strip_prefix('_') {
183 Ok(AgentRequest::ExtMethodRequest(ExtRequest {
184 method: custom_method.into(),
185 params: params.to_owned().into(),
186 }))
187 } else {
188 Err(Error::method_not_found())
189 }
190 }
191 }
192 }
193
194 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
195 let params = params.ok_or_else(Error::invalid_params)?;
196
197 match method {
198 m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
199 .map(AgentNotification::SessionNotification)
200 .map_err(Into::into),
201 _ => {
202 if let Some(custom_method) = method.strip_prefix('_') {
203 Ok(AgentNotification::ExtNotification(ExtNotification {
204 method: custom_method.into(),
205 params: RawValue::from_string(params.get().to_string())?.into(),
206 }))
207 } else {
208 Err(Error::method_not_found())
209 }
210 }
211 }
212 }
213}
214
215#[derive(Clone, Default, Debug, JsonSchema)]
222#[non_exhaustive]
223pub struct AgentSide;
224
225impl Side for AgentSide {
226 type InRequest = ClientRequest;
227 type InNotification = ClientNotification;
228 type OutResponse = AgentResponse;
229
230 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
231 let params = params.ok_or_else(Error::invalid_params)?;
232
233 match method {
234 m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
235 .map(ClientRequest::InitializeRequest)
236 .map_err(Into::into),
237 m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
238 .map(ClientRequest::AuthenticateRequest)
239 .map_err(Into::into),
240 m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
241 .map(ClientRequest::NewSessionRequest)
242 .map_err(Into::into),
243 m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
244 .map(ClientRequest::LoadSessionRequest)
245 .map_err(Into::into),
246 #[cfg(feature = "unstable_session_list")]
247 m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
248 .map(ClientRequest::ListSessionsRequest)
249 .map_err(Into::into),
250 m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
251 .map(ClientRequest::SetSessionModeRequest)
252 .map_err(Into::into),
253 #[cfg(feature = "unstable_session_model")]
254 m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
255 .map(ClientRequest::SetSessionModelRequest)
256 .map_err(Into::into),
257 m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
258 .map(ClientRequest::PromptRequest)
259 .map_err(Into::into),
260 _ => {
261 if let Some(custom_method) = method.strip_prefix('_') {
262 Ok(ClientRequest::ExtMethodRequest(ExtRequest {
263 method: custom_method.into(),
264 params: params.to_owned().into(),
265 }))
266 } else {
267 Err(Error::method_not_found())
268 }
269 }
270 }
271 }
272
273 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
274 let params = params.ok_or_else(Error::invalid_params)?;
275
276 match method {
277 m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
278 .map(ClientNotification::CancelNotification)
279 .map_err(Into::into),
280 _ => {
281 if let Some(custom_method) = method.strip_prefix('_') {
282 Ok(ClientNotification::ExtNotification(ExtNotification {
283 method: custom_method.into(),
284 params: RawValue::from_string(params.get().to_string())?.into(),
285 }))
286 } else {
287 Err(Error::method_not_found())
288 }
289 }
290 }
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 use serde_json::{Number, Value};
299
300 #[test]
301 fn id_deserialization() {
302 let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
303 assert_eq!(id, RequestId::Null);
304
305 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
306 .unwrap();
307 assert_eq!(id, RequestId::Number(1));
308
309 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
310 .unwrap();
311 assert_eq!(id, RequestId::Number(-1));
312
313 let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
314 assert_eq!(id, RequestId::Str("id".to_owned()));
315 }
316
317 #[test]
318 fn id_serialization() {
319 let id = serde_json::to_value(RequestId::Null).unwrap();
320 assert_eq!(id, Value::Null);
321
322 let id = serde_json::to_value(RequestId::Number(1)).unwrap();
323 assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
324
325 let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
326 assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
327
328 let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
329 assert_eq!(id, Value::String("id".to_owned()));
330 }
331
332 #[test]
333 fn id_display() {
334 let id = RequestId::Null;
335 assert_eq!(id.to_string(), "null");
336
337 let id = RequestId::Number(1);
338 assert_eq!(id.to_string(), "1");
339
340 let id = RequestId::Number(-1);
341 assert_eq!(id.to_string(), "-1");
342
343 let id = RequestId::Str("id".to_owned());
344 assert_eq!(id.to_string(), "id");
345 }
346}
347
348#[test]
349fn test_notification_wire_format() {
350 use super::*;
351
352 use serde_json::{Value, json};
353
354 let outgoing_msg =
356 JsonRpcMessage::wrap(OutgoingMessage::<ClientSide, AgentSide>::Notification {
357 method: "cancel".into(),
358 params: Some(ClientNotification::CancelNotification(CancelNotification {
359 session_id: SessionId("test-123".into()),
360 meta: None,
361 })),
362 });
363
364 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
365 assert_eq!(
366 serialized,
367 json!({
368 "jsonrpc": "2.0",
369 "method": "cancel",
370 "params": {
371 "sessionId": "test-123"
372 },
373 })
374 );
375
376 let outgoing_msg =
378 JsonRpcMessage::wrap(OutgoingMessage::<AgentSide, ClientSide>::Notification {
379 method: "sessionUpdate".into(),
380 params: Some(AgentNotification::SessionNotification(
381 SessionNotification {
382 session_id: SessionId("test-456".into()),
383 update: SessionUpdate::AgentMessageChunk(ContentChunk {
384 content: ContentBlock::Text(TextContent {
385 annotations: None,
386 text: "Hello".to_string(),
387 meta: None,
388 }),
389 meta: None,
390 }),
391 meta: None,
392 },
393 )),
394 });
395
396 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
397 assert_eq!(
398 serialized,
399 json!({
400 "jsonrpc": "2.0",
401 "method": "sessionUpdate",
402 "params": {
403 "sessionId": "test-456",
404 "update": {
405 "sessionUpdate": "agent_message_chunk",
406 "content": {
407 "type": "text",
408 "text": "Hello"
409 }
410 }
411 }
412 })
413 );
414}