1use std::sync::Arc;
2
3use derive_more::{Display, From};
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize, de::DeserializeOwned};
6use serde_json::value::RawValue;
7
8use crate::{
9 AGENT_METHOD_NAMES, AgentNotification, AgentRequest, AgentResponse, CLIENT_METHOD_NAMES,
10 ClientNotification, ClientRequest, ClientResponse, Error, ExtNotification, ExtRequest, Result,
11};
12
13#[derive(
23 Debug,
24 PartialEq,
25 Clone,
26 Hash,
27 Eq,
28 Deserialize,
29 Serialize,
30 PartialOrd,
31 Ord,
32 Display,
33 JsonSchema,
34 From,
35)]
36#[serde(untagged)]
37#[allow(
38 clippy::exhaustive_enums,
39 reason = "This comes from the JSON-RPC specification itself"
40)]
41#[from(String, i64)]
42pub enum RequestId {
43 #[display("null")]
44 Null,
45 Number(i64),
46 Str(String),
47}
48
49#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
50#[allow(
51 clippy::exhaustive_structs,
52 reason = "This comes from the JSON-RPC specification itself"
53)]
54#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
55pub struct Request<Params> {
56 pub id: RequestId,
57 pub method: Arc<str>,
58 #[serde(skip_serializing_if = "Option::is_none")]
59 pub params: Option<Params>,
60}
61
62#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
63#[allow(
64 clippy::exhaustive_enums,
65 reason = "This comes from the JSON-RPC specification itself"
66)]
67#[serde(untagged)]
68#[schemars(rename = "{Result}", extend("x-docs-ignore" = true))]
69pub enum Response<Result> {
70 Result { id: RequestId, result: Result },
71 Error { id: RequestId, error: Error },
72}
73
74impl<R> Response<R> {
75 pub fn new(id: impl Into<RequestId>, result: Result<R>) -> Self {
76 match result {
77 Ok(result) => Self::Result {
78 id: id.into(),
79 result,
80 },
81 Err(error) => Self::Error {
82 id: id.into(),
83 error,
84 },
85 }
86 }
87}
88
89#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
90#[allow(
91 clippy::exhaustive_structs,
92 reason = "This comes from the JSON-RPC specification itself"
93)]
94#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
95pub struct Notification<Params> {
96 pub method: Arc<str>,
97 #[serde(skip_serializing_if = "Option::is_none")]
98 pub params: Option<Params>,
99}
100
101#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
102#[serde(untagged)]
103#[schemars(inline)]
104#[allow(
105 clippy::exhaustive_enums,
106 reason = "This comes from the JSON-RPC specification itself"
107)]
108pub enum OutgoingMessage<Local: Side, Remote: Side> {
109 Request(Request<Remote::InRequest>),
110 Response(Response<Local::OutResponse>),
111 Notification(Notification<Remote::InNotification>),
112}
113
114#[derive(Debug, Serialize, Deserialize, JsonSchema)]
115#[schemars(inline)]
116enum JsonRpcVersion {
117 #[serde(rename = "2.0")]
118 V2,
119}
120
121#[derive(Debug, Serialize, Deserialize, JsonSchema)]
126#[schemars(inline)]
127pub struct JsonRpcMessage<M> {
128 jsonrpc: JsonRpcVersion,
129 #[serde(flatten)]
130 message: M,
131}
132
133impl<M> JsonRpcMessage<M> {
134 #[must_use]
137 pub fn wrap(message: M) -> Self {
138 Self {
139 jsonrpc: JsonRpcVersion::V2,
140 message,
141 }
142 }
143}
144
145pub trait Side: Clone {
146 type InRequest: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
147 type InNotification: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
148 type OutResponse: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
149
150 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest>;
158
159 fn decode_notification(method: &str, params: Option<&RawValue>)
167 -> Result<Self::InNotification>;
168}
169
170#[derive(Clone, Default, Debug, JsonSchema)]
177#[non_exhaustive]
178pub struct ClientSide;
179
180impl Side for ClientSide {
181 type InRequest = AgentRequest;
182 type InNotification = AgentNotification;
183 type OutResponse = ClientResponse;
184
185 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
186 let params = params.ok_or_else(Error::invalid_params)?;
187
188 match method {
189 m if m == CLIENT_METHOD_NAMES.session_request_permission => {
190 serde_json::from_str(params.get())
191 .map(AgentRequest::RequestPermissionRequest)
192 .map_err(Into::into)
193 }
194 m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
195 .map(AgentRequest::WriteTextFileRequest)
196 .map_err(Into::into),
197 m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
198 .map(AgentRequest::ReadTextFileRequest)
199 .map_err(Into::into),
200 m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
201 .map(AgentRequest::CreateTerminalRequest)
202 .map_err(Into::into),
203 m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
204 .map(AgentRequest::TerminalOutputRequest)
205 .map_err(Into::into),
206 m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
207 .map(AgentRequest::KillTerminalCommandRequest)
208 .map_err(Into::into),
209 m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
210 .map(AgentRequest::ReleaseTerminalRequest)
211 .map_err(Into::into),
212 m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
213 serde_json::from_str(params.get())
214 .map(AgentRequest::WaitForTerminalExitRequest)
215 .map_err(Into::into)
216 }
217 _ => {
218 if let Some(custom_method) = method.strip_prefix('_') {
219 Ok(AgentRequest::ExtMethodRequest(ExtRequest {
220 method: custom_method.into(),
221 params: params.to_owned().into(),
222 }))
223 } else {
224 Err(Error::method_not_found())
225 }
226 }
227 }
228 }
229
230 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
231 let params = params.ok_or_else(Error::invalid_params)?;
232
233 match method {
234 m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
235 .map(AgentNotification::SessionNotification)
236 .map_err(Into::into),
237 _ => {
238 if let Some(custom_method) = method.strip_prefix('_') {
239 Ok(AgentNotification::ExtNotification(ExtNotification {
240 method: custom_method.into(),
241 params: RawValue::from_string(params.get().to_string())?.into(),
242 }))
243 } else {
244 Err(Error::method_not_found())
245 }
246 }
247 }
248 }
249}
250
251#[derive(Clone, Default, Debug, JsonSchema)]
258#[non_exhaustive]
259pub struct AgentSide;
260
261impl Side for AgentSide {
262 type InRequest = ClientRequest;
263 type InNotification = ClientNotification;
264 type OutResponse = AgentResponse;
265
266 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
267 let params = params.ok_or_else(Error::invalid_params)?;
268
269 match method {
270 m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
271 .map(ClientRequest::InitializeRequest)
272 .map_err(Into::into),
273 m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
274 .map(ClientRequest::AuthenticateRequest)
275 .map_err(Into::into),
276 m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
277 .map(ClientRequest::NewSessionRequest)
278 .map_err(Into::into),
279 m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
280 .map(ClientRequest::LoadSessionRequest)
281 .map_err(Into::into),
282 #[cfg(feature = "unstable_session_list")]
283 m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
284 .map(ClientRequest::ListSessionsRequest)
285 .map_err(Into::into),
286 #[cfg(feature = "unstable_session_fork")]
287 m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get())
288 .map(ClientRequest::ForkSessionRequest)
289 .map_err(Into::into),
290 m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
291 .map(ClientRequest::SetSessionModeRequest)
292 .map_err(Into::into),
293 #[cfg(feature = "unstable_session_model")]
294 m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
295 .map(ClientRequest::SetSessionModelRequest)
296 .map_err(Into::into),
297 m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
298 .map(ClientRequest::PromptRequest)
299 .map_err(Into::into),
300 _ => {
301 if let Some(custom_method) = method.strip_prefix('_') {
302 Ok(ClientRequest::ExtMethodRequest(ExtRequest {
303 method: custom_method.into(),
304 params: params.to_owned().into(),
305 }))
306 } else {
307 Err(Error::method_not_found())
308 }
309 }
310 }
311 }
312
313 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
314 let params = params.ok_or_else(Error::invalid_params)?;
315
316 match method {
317 m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
318 .map(ClientNotification::CancelNotification)
319 .map_err(Into::into),
320 _ => {
321 if let Some(custom_method) = method.strip_prefix('_') {
322 Ok(ClientNotification::ExtNotification(ExtNotification {
323 method: custom_method.into(),
324 params: RawValue::from_string(params.get().to_string())?.into(),
325 }))
326 } else {
327 Err(Error::method_not_found())
328 }
329 }
330 }
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337
338 use serde_json::{Number, Value};
339
340 #[test]
341 fn id_deserialization() {
342 let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
343 assert_eq!(id, RequestId::Null);
344
345 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
346 .unwrap();
347 assert_eq!(id, RequestId::Number(1));
348
349 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
350 .unwrap();
351 assert_eq!(id, RequestId::Number(-1));
352
353 let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
354 assert_eq!(id, RequestId::Str("id".to_owned()));
355 }
356
357 #[test]
358 fn id_serialization() {
359 let id = serde_json::to_value(RequestId::Null).unwrap();
360 assert_eq!(id, Value::Null);
361
362 let id = serde_json::to_value(RequestId::Number(1)).unwrap();
363 assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
364
365 let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
366 assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
367
368 let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
369 assert_eq!(id, Value::String("id".to_owned()));
370 }
371
372 #[test]
373 fn id_display() {
374 let id = RequestId::Null;
375 assert_eq!(id.to_string(), "null");
376
377 let id = RequestId::Number(1);
378 assert_eq!(id.to_string(), "1");
379
380 let id = RequestId::Number(-1);
381 assert_eq!(id.to_string(), "-1");
382
383 let id = RequestId::Str("id".to_owned());
384 assert_eq!(id.to_string(), "id");
385 }
386}
387
388#[test]
389fn test_notification_wire_format() {
390 use super::*;
391
392 use serde_json::{Value, json};
393
394 let outgoing_msg = JsonRpcMessage::wrap(
396 OutgoingMessage::<ClientSide, AgentSide>::Notification(Notification {
397 method: "cancel".into(),
398 params: Some(ClientNotification::CancelNotification(CancelNotification {
399 session_id: SessionId("test-123".into()),
400 meta: None,
401 })),
402 }),
403 );
404
405 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
406 assert_eq!(
407 serialized,
408 json!({
409 "jsonrpc": "2.0",
410 "method": "cancel",
411 "params": {
412 "sessionId": "test-123"
413 },
414 })
415 );
416
417 let outgoing_msg = JsonRpcMessage::wrap(
419 OutgoingMessage::<AgentSide, ClientSide>::Notification(Notification {
420 method: "sessionUpdate".into(),
421 params: Some(AgentNotification::SessionNotification(
422 SessionNotification {
423 session_id: SessionId("test-456".into()),
424 update: SessionUpdate::AgentMessageChunk(ContentChunk {
425 content: ContentBlock::Text(TextContent {
426 annotations: None,
427 text: "Hello".to_string(),
428 meta: None,
429 }),
430 meta: None,
431 }),
432 meta: None,
433 },
434 )),
435 }),
436 );
437
438 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
439 assert_eq!(
440 serialized,
441 json!({
442 "jsonrpc": "2.0",
443 "method": "sessionUpdate",
444 "params": {
445 "sessionId": "test-456",
446 "update": {
447 "sessionUpdate": "agent_message_chunk",
448 "content": {
449 "type": "text",
450 "text": "Hello"
451 }
452 }
453 }
454 })
455 );
456}