1use std::sync::Arc;
2
3use derive_more::{Display, From};
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize, de::DeserializeOwned};
6use serde_json::value::RawValue;
7
8use crate::{
9 AGENT_METHOD_NAMES, AgentNotification, AgentRequest, AgentResponse, CLIENT_METHOD_NAMES,
10 ClientNotification, ClientRequest, ClientResponse, Error, ExtNotification, ExtRequest, Result,
11};
12
13#[derive(
23 Debug,
24 PartialEq,
25 Clone,
26 Hash,
27 Eq,
28 Deserialize,
29 Serialize,
30 PartialOrd,
31 Ord,
32 Display,
33 JsonSchema,
34 From,
35)]
36#[serde(untagged)]
37#[allow(
38 clippy::exhaustive_enums,
39 reason = "This comes from the JSON-RPC specification itself"
40)]
41#[from(String, i64)]
42pub enum RequestId {
43 #[display("null")]
44 Null,
45 Number(i64),
46 Str(String),
47}
48
49#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
50#[allow(
51 clippy::exhaustive_structs,
52 reason = "This comes from the JSON-RPC specification itself"
53)]
54#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
55pub struct Request<Params> {
56 pub id: RequestId,
57 pub method: Arc<str>,
58 #[serde(skip_serializing_if = "Option::is_none")]
59 pub params: Option<Params>,
60}
61
62#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
63#[allow(
64 clippy::exhaustive_enums,
65 reason = "This comes from the JSON-RPC specification itself"
66)]
67#[serde(untagged)]
68#[schemars(rename = "{Result}", extend("x-docs-ignore" = true))]
69pub enum Response<Result> {
70 Result { id: RequestId, result: Result },
71 Error { id: RequestId, error: Error },
72}
73
74impl<R> Response<R> {
75 pub fn new(id: impl Into<RequestId>, result: Result<R>) -> Self {
76 match result {
77 Ok(result) => Self::Result {
78 id: id.into(),
79 result,
80 },
81 Err(error) => Self::Error {
82 id: id.into(),
83 error,
84 },
85 }
86 }
87}
88
89#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
90#[allow(
91 clippy::exhaustive_structs,
92 reason = "This comes from the JSON-RPC specification itself"
93)]
94#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
95pub struct Notification<Params> {
96 pub method: Arc<str>,
97 #[serde(skip_serializing_if = "Option::is_none")]
98 pub params: Option<Params>,
99}
100
101#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
102#[serde(untagged)]
103#[schemars(inline)]
104#[allow(
105 clippy::exhaustive_enums,
106 reason = "This comes from the JSON-RPC specification itself"
107)]
108pub enum OutgoingMessage<Local: Side, Remote: Side> {
109 Request(Request<Remote::InRequest>),
110 Response(Response<Local::OutResponse>),
111 Notification(Notification<Remote::InNotification>),
112}
113
114#[derive(Debug, Serialize, Deserialize, JsonSchema)]
115#[schemars(inline)]
116enum JsonRpcVersion {
117 #[serde(rename = "2.0")]
118 V2,
119}
120
121#[derive(Debug, Serialize, Deserialize, JsonSchema)]
126#[schemars(inline)]
127pub struct JsonRpcMessage<M> {
128 jsonrpc: JsonRpcVersion,
129 #[serde(flatten)]
130 message: M,
131}
132
133impl<M> JsonRpcMessage<M> {
134 #[must_use]
137 pub fn wrap(message: M) -> Self {
138 Self {
139 jsonrpc: JsonRpcVersion::V2,
140 message,
141 }
142 }
143}
144
145pub trait Side: Clone {
146 type InRequest: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
147 type InNotification: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
148 type OutResponse: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
149
150 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest>;
158
159 fn decode_notification(method: &str, params: Option<&RawValue>)
167 -> Result<Self::InNotification>;
168}
169
170#[derive(Clone, Default, Debug, JsonSchema)]
177#[non_exhaustive]
178pub struct ClientSide;
179
180impl Side for ClientSide {
181 type InRequest = AgentRequest;
182 type InNotification = AgentNotification;
183 type OutResponse = ClientResponse;
184
185 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
186 let params = params.ok_or_else(Error::invalid_params)?;
187
188 match method {
189 m if m == CLIENT_METHOD_NAMES.session_request_permission => {
190 serde_json::from_str(params.get())
191 .map(AgentRequest::RequestPermissionRequest)
192 .map_err(Into::into)
193 }
194 m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
195 .map(AgentRequest::WriteTextFileRequest)
196 .map_err(Into::into),
197 m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
198 .map(AgentRequest::ReadTextFileRequest)
199 .map_err(Into::into),
200 m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
201 .map(AgentRequest::CreateTerminalRequest)
202 .map_err(Into::into),
203 m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
204 .map(AgentRequest::TerminalOutputRequest)
205 .map_err(Into::into),
206 m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
207 .map(AgentRequest::KillTerminalCommandRequest)
208 .map_err(Into::into),
209 m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
210 .map(AgentRequest::ReleaseTerminalRequest)
211 .map_err(Into::into),
212 m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
213 serde_json::from_str(params.get())
214 .map(AgentRequest::WaitForTerminalExitRequest)
215 .map_err(Into::into)
216 }
217 _ => {
218 if let Some(custom_method) = method.strip_prefix('_') {
219 Ok(AgentRequest::ExtMethodRequest(ExtRequest {
220 method: custom_method.into(),
221 params: params.to_owned().into(),
222 }))
223 } else {
224 Err(Error::method_not_found())
225 }
226 }
227 }
228 }
229
230 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
231 let params = params.ok_or_else(Error::invalid_params)?;
232
233 match method {
234 m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
235 .map(AgentNotification::SessionNotification)
236 .map_err(Into::into),
237 _ => {
238 if let Some(custom_method) = method.strip_prefix('_') {
239 Ok(AgentNotification::ExtNotification(ExtNotification {
240 method: custom_method.into(),
241 params: RawValue::from_string(params.get().to_string())?.into(),
242 }))
243 } else {
244 Err(Error::method_not_found())
245 }
246 }
247 }
248 }
249}
250
251#[derive(Clone, Default, Debug, JsonSchema)]
258#[non_exhaustive]
259pub struct AgentSide;
260
261impl Side for AgentSide {
262 type InRequest = ClientRequest;
263 type InNotification = ClientNotification;
264 type OutResponse = AgentResponse;
265
266 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
267 let params = params.ok_or_else(Error::invalid_params)?;
268
269 match method {
270 m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
271 .map(ClientRequest::InitializeRequest)
272 .map_err(Into::into),
273 m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
274 .map(ClientRequest::AuthenticateRequest)
275 .map_err(Into::into),
276 m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
277 .map(ClientRequest::NewSessionRequest)
278 .map_err(Into::into),
279 m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
280 .map(ClientRequest::LoadSessionRequest)
281 .map_err(Into::into),
282 #[cfg(feature = "unstable_session_list")]
283 m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
284 .map(ClientRequest::ListSessionsRequest)
285 .map_err(Into::into),
286 #[cfg(feature = "unstable_session_fork")]
287 m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get())
288 .map(ClientRequest::ForkSessionRequest)
289 .map_err(Into::into),
290 #[cfg(feature = "unstable_session_resume")]
291 m if m == AGENT_METHOD_NAMES.session_resume => serde_json::from_str(params.get())
292 .map(ClientRequest::ResumeSessionRequest)
293 .map_err(Into::into),
294 m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
295 .map(ClientRequest::SetSessionModeRequest)
296 .map_err(Into::into),
297 #[cfg(feature = "unstable_session_config_options")]
298 m if m == AGENT_METHOD_NAMES.session_set_config_option => {
299 serde_json::from_str(params.get())
300 .map(ClientRequest::SetSessionConfigOptionRequest)
301 .map_err(Into::into)
302 }
303 #[cfg(feature = "unstable_session_model")]
304 m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
305 .map(ClientRequest::SetSessionModelRequest)
306 .map_err(Into::into),
307 m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
308 .map(ClientRequest::PromptRequest)
309 .map_err(Into::into),
310 _ => {
311 if let Some(custom_method) = method.strip_prefix('_') {
312 Ok(ClientRequest::ExtMethodRequest(ExtRequest {
313 method: custom_method.into(),
314 params: params.to_owned().into(),
315 }))
316 } else {
317 Err(Error::method_not_found())
318 }
319 }
320 }
321 }
322
323 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
324 let params = params.ok_or_else(Error::invalid_params)?;
325
326 match method {
327 m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
328 .map(ClientNotification::CancelNotification)
329 .map_err(Into::into),
330 _ => {
331 if let Some(custom_method) = method.strip_prefix('_') {
332 Ok(ClientNotification::ExtNotification(ExtNotification {
333 method: custom_method.into(),
334 params: RawValue::from_string(params.get().to_string())?.into(),
335 }))
336 } else {
337 Err(Error::method_not_found())
338 }
339 }
340 }
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 use serde_json::{Number, Value};
349
350 #[test]
351 fn id_deserialization() {
352 let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
353 assert_eq!(id, RequestId::Null);
354
355 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
356 .unwrap();
357 assert_eq!(id, RequestId::Number(1));
358
359 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
360 .unwrap();
361 assert_eq!(id, RequestId::Number(-1));
362
363 let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
364 assert_eq!(id, RequestId::Str("id".to_owned()));
365 }
366
367 #[test]
368 fn id_serialization() {
369 let id = serde_json::to_value(RequestId::Null).unwrap();
370 assert_eq!(id, Value::Null);
371
372 let id = serde_json::to_value(RequestId::Number(1)).unwrap();
373 assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
374
375 let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
376 assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
377
378 let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
379 assert_eq!(id, Value::String("id".to_owned()));
380 }
381
382 #[test]
383 fn id_display() {
384 let id = RequestId::Null;
385 assert_eq!(id.to_string(), "null");
386
387 let id = RequestId::Number(1);
388 assert_eq!(id.to_string(), "1");
389
390 let id = RequestId::Number(-1);
391 assert_eq!(id.to_string(), "-1");
392
393 let id = RequestId::Str("id".to_owned());
394 assert_eq!(id.to_string(), "id");
395 }
396}
397
398#[test]
399fn test_notification_wire_format() {
400 use super::*;
401
402 use serde_json::{Value, json};
403
404 let outgoing_msg = JsonRpcMessage::wrap(
406 OutgoingMessage::<ClientSide, AgentSide>::Notification(Notification {
407 method: "cancel".into(),
408 params: Some(ClientNotification::CancelNotification(CancelNotification {
409 session_id: SessionId("test-123".into()),
410 meta: None,
411 })),
412 }),
413 );
414
415 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
416 assert_eq!(
417 serialized,
418 json!({
419 "jsonrpc": "2.0",
420 "method": "cancel",
421 "params": {
422 "sessionId": "test-123"
423 },
424 })
425 );
426
427 let outgoing_msg = JsonRpcMessage::wrap(
429 OutgoingMessage::<AgentSide, ClientSide>::Notification(Notification {
430 method: "sessionUpdate".into(),
431 params: Some(AgentNotification::SessionNotification(
432 SessionNotification {
433 session_id: SessionId("test-456".into()),
434 update: SessionUpdate::AgentMessageChunk(ContentChunk {
435 content: ContentBlock::Text(TextContent {
436 annotations: None,
437 text: "Hello".to_string(),
438 meta: None,
439 }),
440 meta: None,
441 }),
442 meta: None,
443 },
444 )),
445 }),
446 );
447
448 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
449 assert_eq!(
450 serialized,
451 json!({
452 "jsonrpc": "2.0",
453 "method": "sessionUpdate",
454 "params": {
455 "sessionId": "test-456",
456 "update": {
457 "sessionUpdate": "agent_message_chunk",
458 "content": {
459 "type": "text",
460 "text": "Hello"
461 }
462 }
463 }
464 })
465 );
466}