1use std::sync::Arc;
2
3use derive_more::{Display, From};
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize, de::DeserializeOwned};
6use serde_json::value::RawValue;
7
8use crate::{
9 AGENT_METHOD_NAMES, AgentNotification, AgentRequest, AgentResponse, CLIENT_METHOD_NAMES,
10 ClientNotification, ClientRequest, ClientResponse, Error, ExtNotification, ExtRequest, Result,
11};
12
13#[derive(
23 Debug,
24 PartialEq,
25 Clone,
26 Hash,
27 Eq,
28 Deserialize,
29 Serialize,
30 PartialOrd,
31 Ord,
32 Display,
33 JsonSchema,
34 From,
35)]
36#[serde(untagged)]
37#[allow(
38 clippy::exhaustive_enums,
39 reason = "This comes from the JSON-RPC specification itself"
40)]
41#[from(String, i64)]
42pub enum RequestId {
43 #[display("null")]
44 Null,
45 Number(i64),
46 Str(String),
47}
48
49#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
50#[allow(
51 clippy::exhaustive_structs,
52 reason = "This comes from the JSON-RPC specification itself"
53)]
54#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
55pub struct Request<Params> {
56 pub id: RequestId,
57 pub method: Arc<str>,
58 #[serde(skip_serializing_if = "Option::is_none")]
59 pub params: Option<Params>,
60}
61
62#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
63#[allow(
64 clippy::exhaustive_enums,
65 reason = "This comes from the JSON-RPC specification itself"
66)]
67#[serde(untagged)]
68#[schemars(rename = "{Result}", extend("x-docs-ignore" = true))]
69pub enum Response<Result> {
70 Result { id: RequestId, result: Result },
71 Error { id: RequestId, error: Error },
72}
73
74impl<R> Response<R> {
75 #[must_use]
76 pub fn new(id: impl Into<RequestId>, result: Result<R>) -> Self {
77 match result {
78 Ok(result) => Self::Result {
79 id: id.into(),
80 result,
81 },
82 Err(error) => Self::Error {
83 id: id.into(),
84 error,
85 },
86 }
87 }
88}
89
90#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
91#[allow(
92 clippy::exhaustive_structs,
93 reason = "This comes from the JSON-RPC specification itself"
94)]
95#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
96pub struct Notification<Params> {
97 pub method: Arc<str>,
98 #[serde(skip_serializing_if = "Option::is_none")]
99 pub params: Option<Params>,
100}
101
102#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
103#[serde(untagged)]
104#[schemars(inline)]
105#[allow(
106 clippy::exhaustive_enums,
107 reason = "This comes from the JSON-RPC specification itself"
108)]
109pub enum OutgoingMessage<Local: Side, Remote: Side> {
110 Request(Request<Remote::InRequest>),
111 Response(Response<Local::OutResponse>),
112 Notification(Notification<Remote::InNotification>),
113}
114
115#[derive(Debug, Serialize, Deserialize, JsonSchema)]
116#[schemars(inline)]
117enum JsonRpcVersion {
118 #[serde(rename = "2.0")]
119 V2,
120}
121
122#[derive(Debug, Serialize, Deserialize, JsonSchema)]
127#[schemars(inline)]
128pub struct JsonRpcMessage<M> {
129 jsonrpc: JsonRpcVersion,
130 #[serde(flatten)]
131 message: M,
132}
133
134impl<M> JsonRpcMessage<M> {
135 #[must_use]
138 pub fn wrap(message: M) -> Self {
139 Self {
140 jsonrpc: JsonRpcVersion::V2,
141 message,
142 }
143 }
144}
145
146pub trait Side: Clone {
147 type InRequest: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
148 type InNotification: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
149 type OutResponse: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
150
151 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest>;
159
160 fn decode_notification(method: &str, params: Option<&RawValue>)
168 -> Result<Self::InNotification>;
169}
170
171#[derive(Clone, Default, Debug, JsonSchema)]
178#[non_exhaustive]
179pub struct ClientSide;
180
181impl Side for ClientSide {
182 type InRequest = AgentRequest;
183 type InNotification = AgentNotification;
184 type OutResponse = ClientResponse;
185
186 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
187 let params = params.ok_or_else(Error::invalid_params)?;
188
189 match method {
190 m if m == CLIENT_METHOD_NAMES.session_request_permission => {
191 serde_json::from_str(params.get())
192 .map(AgentRequest::RequestPermissionRequest)
193 .map_err(Into::into)
194 }
195 m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
196 .map(AgentRequest::WriteTextFileRequest)
197 .map_err(Into::into),
198 m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
199 .map(AgentRequest::ReadTextFileRequest)
200 .map_err(Into::into),
201 m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
202 .map(AgentRequest::CreateTerminalRequest)
203 .map_err(Into::into),
204 m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
205 .map(AgentRequest::TerminalOutputRequest)
206 .map_err(Into::into),
207 m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
208 .map(AgentRequest::KillTerminalRequest)
209 .map_err(Into::into),
210 m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
211 .map(AgentRequest::ReleaseTerminalRequest)
212 .map_err(Into::into),
213 m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
214 serde_json::from_str(params.get())
215 .map(AgentRequest::WaitForTerminalExitRequest)
216 .map_err(Into::into)
217 }
218 _ => {
219 if let Some(custom_method) = method.strip_prefix('_') {
220 Ok(AgentRequest::ExtMethodRequest(ExtRequest {
221 method: custom_method.into(),
222 params: params.to_owned().into(),
223 }))
224 } else {
225 Err(Error::method_not_found())
226 }
227 }
228 }
229 }
230
231 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
232 let params = params.ok_or_else(Error::invalid_params)?;
233
234 match method {
235 m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
236 .map(AgentNotification::SessionNotification)
237 .map_err(Into::into),
238 _ => {
239 if let Some(custom_method) = method.strip_prefix('_') {
240 Ok(AgentNotification::ExtNotification(ExtNotification {
241 method: custom_method.into(),
242 params: params.to_owned().into(),
243 }))
244 } else {
245 Err(Error::method_not_found())
246 }
247 }
248 }
249 }
250}
251
252#[derive(Clone, Default, Debug, JsonSchema)]
259#[non_exhaustive]
260pub struct AgentSide;
261
262impl Side for AgentSide {
263 type InRequest = ClientRequest;
264 type InNotification = ClientNotification;
265 type OutResponse = AgentResponse;
266
267 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
268 let params = params.ok_or_else(Error::invalid_params)?;
269
270 match method {
271 m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
272 .map(ClientRequest::InitializeRequest)
273 .map_err(Into::into),
274 m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
275 .map(ClientRequest::AuthenticateRequest)
276 .map_err(Into::into),
277 m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
278 .map(ClientRequest::NewSessionRequest)
279 .map_err(Into::into),
280 m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
281 .map(ClientRequest::LoadSessionRequest)
282 .map_err(Into::into),
283 #[cfg(feature = "unstable_session_list")]
284 m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
285 .map(ClientRequest::ListSessionsRequest)
286 .map_err(Into::into),
287 #[cfg(feature = "unstable_session_fork")]
288 m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get())
289 .map(ClientRequest::ForkSessionRequest)
290 .map_err(Into::into),
291 #[cfg(feature = "unstable_session_resume")]
292 m if m == AGENT_METHOD_NAMES.session_resume => serde_json::from_str(params.get())
293 .map(ClientRequest::ResumeSessionRequest)
294 .map_err(Into::into),
295 #[cfg(feature = "unstable_session_stop")]
296 m if m == AGENT_METHOD_NAMES.session_stop => serde_json::from_str(params.get())
297 .map(ClientRequest::StopSessionRequest)
298 .map_err(Into::into),
299 m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
300 .map(ClientRequest::SetSessionModeRequest)
301 .map_err(Into::into),
302 m if m == AGENT_METHOD_NAMES.session_set_config_option => {
303 serde_json::from_str(params.get())
304 .map(ClientRequest::SetSessionConfigOptionRequest)
305 .map_err(Into::into)
306 }
307 #[cfg(feature = "unstable_session_model")]
308 m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
309 .map(ClientRequest::SetSessionModelRequest)
310 .map_err(Into::into),
311 m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
312 .map(ClientRequest::PromptRequest)
313 .map_err(Into::into),
314 _ => {
315 if let Some(custom_method) = method.strip_prefix('_') {
316 Ok(ClientRequest::ExtMethodRequest(ExtRequest {
317 method: custom_method.into(),
318 params: params.to_owned().into(),
319 }))
320 } else {
321 Err(Error::method_not_found())
322 }
323 }
324 }
325 }
326
327 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
328 let params = params.ok_or_else(Error::invalid_params)?;
329
330 match method {
331 m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
332 .map(ClientNotification::CancelNotification)
333 .map_err(Into::into),
334 _ => {
335 if let Some(custom_method) = method.strip_prefix('_') {
336 Ok(ClientNotification::ExtNotification(ExtNotification {
337 method: custom_method.into(),
338 params: params.to_owned().into(),
339 }))
340 } else {
341 Err(Error::method_not_found())
342 }
343 }
344 }
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 use serde_json::{Number, Value};
353
354 #[test]
355 fn id_deserialization() {
356 let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
357 assert_eq!(id, RequestId::Null);
358
359 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
360 .unwrap();
361 assert_eq!(id, RequestId::Number(1));
362
363 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
364 .unwrap();
365 assert_eq!(id, RequestId::Number(-1));
366
367 let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
368 assert_eq!(id, RequestId::Str("id".to_owned()));
369 }
370
371 #[test]
372 fn id_serialization() {
373 let id = serde_json::to_value(RequestId::Null).unwrap();
374 assert_eq!(id, Value::Null);
375
376 let id = serde_json::to_value(RequestId::Number(1)).unwrap();
377 assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
378
379 let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
380 assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
381
382 let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
383 assert_eq!(id, Value::String("id".to_owned()));
384 }
385
386 #[test]
387 fn id_display() {
388 let id = RequestId::Null;
389 assert_eq!(id.to_string(), "null");
390
391 let id = RequestId::Number(1);
392 assert_eq!(id.to_string(), "1");
393
394 let id = RequestId::Number(-1);
395 assert_eq!(id.to_string(), "-1");
396
397 let id = RequestId::Str("id".to_owned());
398 assert_eq!(id.to_string(), "id");
399 }
400}
401
402#[test]
403fn test_notification_wire_format() {
404 use super::*;
405
406 use serde_json::{Value, json};
407
408 let outgoing_msg = JsonRpcMessage::wrap(
410 OutgoingMessage::<ClientSide, AgentSide>::Notification(Notification {
411 method: "cancel".into(),
412 params: Some(ClientNotification::CancelNotification(CancelNotification {
413 session_id: SessionId("test-123".into()),
414 meta: None,
415 })),
416 }),
417 );
418
419 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
420 assert_eq!(
421 serialized,
422 json!({
423 "jsonrpc": "2.0",
424 "method": "cancel",
425 "params": {
426 "sessionId": "test-123"
427 },
428 })
429 );
430
431 let outgoing_msg = JsonRpcMessage::wrap(
433 OutgoingMessage::<AgentSide, ClientSide>::Notification(Notification {
434 method: "sessionUpdate".into(),
435 params: Some(AgentNotification::SessionNotification(
436 SessionNotification {
437 session_id: SessionId("test-456".into()),
438 update: SessionUpdate::AgentMessageChunk(ContentChunk {
439 content: ContentBlock::Text(TextContent {
440 annotations: None,
441 text: "Hello".to_string(),
442 meta: None,
443 }),
444 #[cfg(feature = "unstable_message_id")]
445 message_id: None,
446 meta: None,
447 }),
448 meta: None,
449 },
450 )),
451 }),
452 );
453
454 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
455 assert_eq!(
456 serialized,
457 json!({
458 "jsonrpc": "2.0",
459 "method": "sessionUpdate",
460 "params": {
461 "sessionId": "test-456",
462 "update": {
463 "sessionUpdate": "agent_message_chunk",
464 "content": {
465 "type": "text",
466 "text": "Hello"
467 }
468 }
469 }
470 })
471 );
472}