1use std::sync::Arc;
2
3use derive_more::{Display, From};
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize, de::DeserializeOwned};
6use serde_json::value::RawValue;
7
8use crate::{
9 AGENT_METHOD_NAMES, AgentNotification, AgentRequest, AgentResponse, CLIENT_METHOD_NAMES,
10 ClientNotification, ClientRequest, ClientResponse, Error, ExtNotification, ExtRequest, Result,
11};
12
13#[derive(
23 Debug,
24 PartialEq,
25 Clone,
26 Hash,
27 Eq,
28 Deserialize,
29 Serialize,
30 PartialOrd,
31 Ord,
32 Display,
33 JsonSchema,
34 From,
35)]
36#[serde(untagged)]
37#[allow(
38 clippy::exhaustive_enums,
39 reason = "This comes from the JSON-RPC specification itself"
40)]
41#[from(String, i64)]
42pub enum RequestId {
43 #[display("null")]
44 Null,
45 Number(i64),
46 Str(String),
47}
48
49#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
50#[allow(
51 clippy::exhaustive_structs,
52 reason = "This comes from the JSON-RPC specification itself"
53)]
54#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
55pub struct Request<Params> {
56 pub id: RequestId,
57 pub method: Arc<str>,
58 #[serde(skip_serializing_if = "Option::is_none")]
59 pub params: Option<Params>,
60}
61
62#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
63#[allow(
64 clippy::exhaustive_enums,
65 reason = "This comes from the JSON-RPC specification itself"
66)]
67#[serde(untagged)]
68#[schemars(rename = "{Result}", extend("x-docs-ignore" = true))]
69pub enum Response<Result> {
70 Result { id: RequestId, result: Result },
71 Error { id: RequestId, error: Error },
72}
73
74impl<R> Response<R> {
75 #[must_use]
76 pub fn new(id: impl Into<RequestId>, result: Result<R>) -> Self {
77 match result {
78 Ok(result) => Self::Result {
79 id: id.into(),
80 result,
81 },
82 Err(error) => Self::Error {
83 id: id.into(),
84 error,
85 },
86 }
87 }
88}
89
90#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
91#[allow(
92 clippy::exhaustive_structs,
93 reason = "This comes from the JSON-RPC specification itself"
94)]
95#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
96pub struct Notification<Params> {
97 pub method: Arc<str>,
98 #[serde(skip_serializing_if = "Option::is_none")]
99 pub params: Option<Params>,
100}
101
102#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
103#[serde(untagged)]
104#[schemars(inline)]
105#[allow(
106 clippy::exhaustive_enums,
107 reason = "This comes from the JSON-RPC specification itself"
108)]
109pub enum OutgoingMessage<Local: Side, Remote: Side> {
110 Request(Request<Remote::InRequest>),
111 Response(Response<Local::OutResponse>),
112 Notification(Notification<Remote::InNotification>),
113}
114
115#[derive(Debug, Serialize, Deserialize, JsonSchema)]
116#[schemars(inline)]
117enum JsonRpcVersion {
118 #[serde(rename = "2.0")]
119 V2,
120}
121
122#[derive(Debug, Serialize, Deserialize, JsonSchema)]
127#[schemars(inline)]
128pub struct JsonRpcMessage<M> {
129 jsonrpc: JsonRpcVersion,
130 #[serde(flatten)]
131 message: M,
132}
133
134impl<M> JsonRpcMessage<M> {
135 #[must_use]
138 pub fn wrap(message: M) -> Self {
139 Self {
140 jsonrpc: JsonRpcVersion::V2,
141 message,
142 }
143 }
144}
145
146pub trait Side: Clone {
147 type InRequest: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
148 type InNotification: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
149 type OutResponse: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
150
151 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest>;
159
160 fn decode_notification(method: &str, params: Option<&RawValue>)
168 -> Result<Self::InNotification>;
169}
170
171#[derive(Clone, Default, Debug, JsonSchema)]
178#[non_exhaustive]
179pub struct ClientSide;
180
181impl Side for ClientSide {
182 type InRequest = AgentRequest;
183 type InNotification = AgentNotification;
184 type OutResponse = ClientResponse;
185
186 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
187 let params = params.ok_or_else(Error::invalid_params)?;
188
189 match method {
190 m if m == CLIENT_METHOD_NAMES.session_request_permission => {
191 serde_json::from_str(params.get())
192 .map(AgentRequest::RequestPermissionRequest)
193 .map_err(Into::into)
194 }
195 m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
196 .map(AgentRequest::WriteTextFileRequest)
197 .map_err(Into::into),
198 m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
199 .map(AgentRequest::ReadTextFileRequest)
200 .map_err(Into::into),
201 m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
202 .map(AgentRequest::CreateTerminalRequest)
203 .map_err(Into::into),
204 m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
205 .map(AgentRequest::TerminalOutputRequest)
206 .map_err(Into::into),
207 m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
208 .map(AgentRequest::KillTerminalRequest)
209 .map_err(Into::into),
210 m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
211 .map(AgentRequest::ReleaseTerminalRequest)
212 .map_err(Into::into),
213 m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
214 serde_json::from_str(params.get())
215 .map(AgentRequest::WaitForTerminalExitRequest)
216 .map_err(Into::into)
217 }
218 #[cfg(feature = "unstable_elicitation")]
219 m if m == CLIENT_METHOD_NAMES.session_elicitation => serde_json::from_str(params.get())
220 .map(AgentRequest::ElicitationRequest)
221 .map_err(Into::into),
222 _ => {
223 if let Some(custom_method) = method.strip_prefix('_') {
224 Ok(AgentRequest::ExtMethodRequest(ExtRequest {
225 method: custom_method.into(),
226 params: params.to_owned().into(),
227 }))
228 } else {
229 Err(Error::method_not_found())
230 }
231 }
232 }
233 }
234
235 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
236 let params = params.ok_or_else(Error::invalid_params)?;
237
238 match method {
239 m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
240 .map(AgentNotification::SessionNotification)
241 .map_err(Into::into),
242 #[cfg(feature = "unstable_elicitation")]
243 m if m == CLIENT_METHOD_NAMES.session_elicitation_complete => {
244 serde_json::from_str(params.get())
245 .map(AgentNotification::ElicitationCompleteNotification)
246 .map_err(Into::into)
247 }
248 _ => {
249 if let Some(custom_method) = method.strip_prefix('_') {
250 Ok(AgentNotification::ExtNotification(ExtNotification {
251 method: custom_method.into(),
252 params: params.to_owned().into(),
253 }))
254 } else {
255 Err(Error::method_not_found())
256 }
257 }
258 }
259 }
260}
261
262#[derive(Clone, Default, Debug, JsonSchema)]
269#[non_exhaustive]
270pub struct AgentSide;
271
272impl Side for AgentSide {
273 type InRequest = ClientRequest;
274 type InNotification = ClientNotification;
275 type OutResponse = AgentResponse;
276
277 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
278 let params = params.ok_or_else(Error::invalid_params)?;
279
280 match method {
281 m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
282 .map(ClientRequest::InitializeRequest)
283 .map_err(Into::into),
284 m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
285 .map(ClientRequest::AuthenticateRequest)
286 .map_err(Into::into),
287 #[cfg(feature = "unstable_logout")]
288 m if m == AGENT_METHOD_NAMES.logout => serde_json::from_str(params.get())
289 .map(ClientRequest::LogoutRequest)
290 .map_err(Into::into),
291 m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
292 .map(ClientRequest::NewSessionRequest)
293 .map_err(Into::into),
294 m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
295 .map(ClientRequest::LoadSessionRequest)
296 .map_err(Into::into),
297 m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
298 .map(ClientRequest::ListSessionsRequest)
299 .map_err(Into::into),
300 #[cfg(feature = "unstable_session_fork")]
301 m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get())
302 .map(ClientRequest::ForkSessionRequest)
303 .map_err(Into::into),
304 #[cfg(feature = "unstable_session_resume")]
305 m if m == AGENT_METHOD_NAMES.session_resume => serde_json::from_str(params.get())
306 .map(ClientRequest::ResumeSessionRequest)
307 .map_err(Into::into),
308 #[cfg(feature = "unstable_session_close")]
309 m if m == AGENT_METHOD_NAMES.session_close => serde_json::from_str(params.get())
310 .map(ClientRequest::CloseSessionRequest)
311 .map_err(Into::into),
312 m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
313 .map(ClientRequest::SetSessionModeRequest)
314 .map_err(Into::into),
315 m if m == AGENT_METHOD_NAMES.session_set_config_option => {
316 serde_json::from_str(params.get())
317 .map(ClientRequest::SetSessionConfigOptionRequest)
318 .map_err(Into::into)
319 }
320 #[cfg(feature = "unstable_session_model")]
321 m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
322 .map(ClientRequest::SetSessionModelRequest)
323 .map_err(Into::into),
324 m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
325 .map(ClientRequest::PromptRequest)
326 .map_err(Into::into),
327 _ => {
328 if let Some(custom_method) = method.strip_prefix('_') {
329 Ok(ClientRequest::ExtMethodRequest(ExtRequest {
330 method: custom_method.into(),
331 params: params.to_owned().into(),
332 }))
333 } else {
334 Err(Error::method_not_found())
335 }
336 }
337 }
338 }
339
340 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
341 let params = params.ok_or_else(Error::invalid_params)?;
342
343 match method {
344 m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
345 .map(ClientNotification::CancelNotification)
346 .map_err(Into::into),
347 _ => {
348 if let Some(custom_method) = method.strip_prefix('_') {
349 Ok(ClientNotification::ExtNotification(ExtNotification {
350 method: custom_method.into(),
351 params: params.to_owned().into(),
352 }))
353 } else {
354 Err(Error::method_not_found())
355 }
356 }
357 }
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364
365 use serde_json::{Number, Value};
366
367 #[test]
368 fn id_deserialization() {
369 let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
370 assert_eq!(id, RequestId::Null);
371
372 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
373 .unwrap();
374 assert_eq!(id, RequestId::Number(1));
375
376 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
377 .unwrap();
378 assert_eq!(id, RequestId::Number(-1));
379
380 let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
381 assert_eq!(id, RequestId::Str("id".to_owned()));
382 }
383
384 #[test]
385 fn id_serialization() {
386 let id = serde_json::to_value(RequestId::Null).unwrap();
387 assert_eq!(id, Value::Null);
388
389 let id = serde_json::to_value(RequestId::Number(1)).unwrap();
390 assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
391
392 let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
393 assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
394
395 let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
396 assert_eq!(id, Value::String("id".to_owned()));
397 }
398
399 #[test]
400 fn id_display() {
401 let id = RequestId::Null;
402 assert_eq!(id.to_string(), "null");
403
404 let id = RequestId::Number(1);
405 assert_eq!(id.to_string(), "1");
406
407 let id = RequestId::Number(-1);
408 assert_eq!(id.to_string(), "-1");
409
410 let id = RequestId::Str("id".to_owned());
411 assert_eq!(id.to_string(), "id");
412 }
413}
414
415#[test]
416fn test_notification_wire_format() {
417 use super::*;
418
419 use serde_json::{Value, json};
420
421 let outgoing_msg = JsonRpcMessage::wrap(
423 OutgoingMessage::<ClientSide, AgentSide>::Notification(Notification {
424 method: "cancel".into(),
425 params: Some(ClientNotification::CancelNotification(CancelNotification {
426 session_id: SessionId("test-123".into()),
427 meta: None,
428 })),
429 }),
430 );
431
432 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
433 assert_eq!(
434 serialized,
435 json!({
436 "jsonrpc": "2.0",
437 "method": "cancel",
438 "params": {
439 "sessionId": "test-123"
440 },
441 })
442 );
443
444 let outgoing_msg = JsonRpcMessage::wrap(
446 OutgoingMessage::<AgentSide, ClientSide>::Notification(Notification {
447 method: "sessionUpdate".into(),
448 params: Some(AgentNotification::SessionNotification(
449 SessionNotification {
450 session_id: SessionId("test-456".into()),
451 update: SessionUpdate::AgentMessageChunk(ContentChunk {
452 content: ContentBlock::Text(TextContent {
453 annotations: None,
454 text: "Hello".to_string(),
455 meta: None,
456 }),
457 #[cfg(feature = "unstable_message_id")]
458 message_id: None,
459 meta: None,
460 }),
461 meta: None,
462 },
463 )),
464 }),
465 );
466
467 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
468 assert_eq!(
469 serialized,
470 json!({
471 "jsonrpc": "2.0",
472 "method": "sessionUpdate",
473 "params": {
474 "sessionId": "test-456",
475 "update": {
476 "sessionUpdate": "agent_message_chunk",
477 "content": {
478 "type": "text",
479 "text": "Hello"
480 }
481 }
482 }
483 })
484 );
485}