1use std::sync::Arc;
2
3use derive_more::{Display, From};
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize, de::DeserializeOwned};
6use serde_json::value::RawValue;
7
8use crate::{
9 AGENT_METHOD_NAMES, AgentNotification, AgentRequest, AgentResponse, CLIENT_METHOD_NAMES,
10 ClientNotification, ClientRequest, ClientResponse, Error, ExtNotification, ExtRequest, Result,
11};
12
13#[derive(
23 Debug,
24 PartialEq,
25 Clone,
26 Hash,
27 Eq,
28 Deserialize,
29 Serialize,
30 PartialOrd,
31 Ord,
32 Display,
33 JsonSchema,
34 From,
35)]
36#[serde(untagged)]
37#[allow(
38 clippy::exhaustive_enums,
39 reason = "This comes from the JSON-RPC specification itself"
40)]
41#[from(String, i64)]
42pub enum RequestId {
43 #[display("null")]
44 Null,
45 Number(i64),
46 Str(String),
47}
48
49#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
50#[allow(
51 clippy::exhaustive_structs,
52 reason = "This comes from the JSON-RPC specification itself"
53)]
54#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
55pub struct Request<Params> {
56 pub id: RequestId,
57 pub method: Arc<str>,
58 #[serde(skip_serializing_if = "Option::is_none")]
59 pub params: Option<Params>,
60}
61
62#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
63#[allow(
64 clippy::exhaustive_enums,
65 reason = "This comes from the JSON-RPC specification itself"
66)]
67#[serde(untagged)]
68#[schemars(rename = "{Result}", extend("x-docs-ignore" = true))]
69pub enum Response<Result> {
70 Result { id: RequestId, result: Result },
71 Error { id: RequestId, error: Error },
72}
73
74impl<R> Response<R> {
75 #[must_use]
76 pub fn new(id: impl Into<RequestId>, result: Result<R>) -> Self {
77 match result {
78 Ok(result) => Self::Result {
79 id: id.into(),
80 result,
81 },
82 Err(error) => Self::Error {
83 id: id.into(),
84 error,
85 },
86 }
87 }
88}
89
90#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
91#[allow(
92 clippy::exhaustive_structs,
93 reason = "This comes from the JSON-RPC specification itself"
94)]
95#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
96pub struct Notification<Params> {
97 pub method: Arc<str>,
98 #[serde(skip_serializing_if = "Option::is_none")]
99 pub params: Option<Params>,
100}
101
102#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
103#[serde(untagged)]
104#[schemars(inline)]
105#[allow(
106 clippy::exhaustive_enums,
107 reason = "This comes from the JSON-RPC specification itself"
108)]
109pub enum OutgoingMessage<Local: Side, Remote: Side> {
110 Request(Request<Remote::InRequest>),
111 Response(Response<Local::OutResponse>),
112 Notification(Notification<Remote::InNotification>),
113}
114
115#[derive(Debug, Serialize, Deserialize, JsonSchema)]
116#[schemars(inline)]
117enum JsonRpcVersion {
118 #[serde(rename = "2.0")]
119 V2,
120}
121
122#[derive(Debug, Serialize, Deserialize, JsonSchema)]
127#[schemars(inline)]
128pub struct JsonRpcMessage<M> {
129 jsonrpc: JsonRpcVersion,
130 #[serde(flatten)]
131 message: M,
132}
133
134impl<M> JsonRpcMessage<M> {
135 #[must_use]
138 pub fn wrap(message: M) -> Self {
139 Self {
140 jsonrpc: JsonRpcVersion::V2,
141 message,
142 }
143 }
144}
145
146pub trait Side: Clone {
147 type InRequest: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
148 type InNotification: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
149 type OutResponse: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
150
151 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest>;
159
160 fn decode_notification(method: &str, params: Option<&RawValue>)
168 -> Result<Self::InNotification>;
169}
170
171#[derive(Clone, Default, Debug, JsonSchema)]
178#[non_exhaustive]
179pub struct ClientSide;
180
181impl Side for ClientSide {
182 type InRequest = AgentRequest;
183 type InNotification = AgentNotification;
184 type OutResponse = ClientResponse;
185
186 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
187 let params = params.ok_or_else(Error::invalid_params)?;
188
189 match method {
190 m if m == CLIENT_METHOD_NAMES.session_request_permission => {
191 serde_json::from_str(params.get())
192 .map(AgentRequest::RequestPermissionRequest)
193 .map_err(Into::into)
194 }
195 m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
196 .map(AgentRequest::WriteTextFileRequest)
197 .map_err(Into::into),
198 m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
199 .map(AgentRequest::ReadTextFileRequest)
200 .map_err(Into::into),
201 m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
202 .map(AgentRequest::CreateTerminalRequest)
203 .map_err(Into::into),
204 m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
205 .map(AgentRequest::TerminalOutputRequest)
206 .map_err(Into::into),
207 m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
208 .map(AgentRequest::KillTerminalRequest)
209 .map_err(Into::into),
210 m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
211 .map(AgentRequest::ReleaseTerminalRequest)
212 .map_err(Into::into),
213 m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
214 serde_json::from_str(params.get())
215 .map(AgentRequest::WaitForTerminalExitRequest)
216 .map_err(Into::into)
217 }
218 _ => {
219 if let Some(custom_method) = method.strip_prefix('_') {
220 Ok(AgentRequest::ExtMethodRequest(ExtRequest {
221 method: custom_method.into(),
222 params: params.to_owned().into(),
223 }))
224 } else {
225 Err(Error::method_not_found())
226 }
227 }
228 }
229 }
230
231 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
232 let params = params.ok_or_else(Error::invalid_params)?;
233
234 match method {
235 m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
236 .map(AgentNotification::SessionNotification)
237 .map_err(Into::into),
238 _ => {
239 if let Some(custom_method) = method.strip_prefix('_') {
240 Ok(AgentNotification::ExtNotification(ExtNotification {
241 method: custom_method.into(),
242 params: params.to_owned().into(),
243 }))
244 } else {
245 Err(Error::method_not_found())
246 }
247 }
248 }
249 }
250}
251
252#[derive(Clone, Default, Debug, JsonSchema)]
259#[non_exhaustive]
260pub struct AgentSide;
261
262impl Side for AgentSide {
263 type InRequest = ClientRequest;
264 type InNotification = ClientNotification;
265 type OutResponse = AgentResponse;
266
267 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
268 let params = params.ok_or_else(Error::invalid_params)?;
269
270 match method {
271 m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
272 .map(ClientRequest::InitializeRequest)
273 .map_err(Into::into),
274 m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
275 .map(ClientRequest::AuthenticateRequest)
276 .map_err(Into::into),
277 m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
278 .map(ClientRequest::NewSessionRequest)
279 .map_err(Into::into),
280 m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
281 .map(ClientRequest::LoadSessionRequest)
282 .map_err(Into::into),
283 m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
284 .map(ClientRequest::ListSessionsRequest)
285 .map_err(Into::into),
286 #[cfg(feature = "unstable_session_fork")]
287 m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get())
288 .map(ClientRequest::ForkSessionRequest)
289 .map_err(Into::into),
290 #[cfg(feature = "unstable_session_resume")]
291 m if m == AGENT_METHOD_NAMES.session_resume => serde_json::from_str(params.get())
292 .map(ClientRequest::ResumeSessionRequest)
293 .map_err(Into::into),
294 #[cfg(feature = "unstable_session_close")]
295 m if m == AGENT_METHOD_NAMES.session_close => serde_json::from_str(params.get())
296 .map(ClientRequest::CloseSessionRequest)
297 .map_err(Into::into),
298 m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
299 .map(ClientRequest::SetSessionModeRequest)
300 .map_err(Into::into),
301 m if m == AGENT_METHOD_NAMES.session_set_config_option => {
302 serde_json::from_str(params.get())
303 .map(ClientRequest::SetSessionConfigOptionRequest)
304 .map_err(Into::into)
305 }
306 #[cfg(feature = "unstable_session_model")]
307 m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
308 .map(ClientRequest::SetSessionModelRequest)
309 .map_err(Into::into),
310 m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
311 .map(ClientRequest::PromptRequest)
312 .map_err(Into::into),
313 _ => {
314 if let Some(custom_method) = method.strip_prefix('_') {
315 Ok(ClientRequest::ExtMethodRequest(ExtRequest {
316 method: custom_method.into(),
317 params: params.to_owned().into(),
318 }))
319 } else {
320 Err(Error::method_not_found())
321 }
322 }
323 }
324 }
325
326 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
327 let params = params.ok_or_else(Error::invalid_params)?;
328
329 match method {
330 m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
331 .map(ClientNotification::CancelNotification)
332 .map_err(Into::into),
333 _ => {
334 if let Some(custom_method) = method.strip_prefix('_') {
335 Ok(ClientNotification::ExtNotification(ExtNotification {
336 method: custom_method.into(),
337 params: params.to_owned().into(),
338 }))
339 } else {
340 Err(Error::method_not_found())
341 }
342 }
343 }
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350
351 use serde_json::{Number, Value};
352
353 #[test]
354 fn id_deserialization() {
355 let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
356 assert_eq!(id, RequestId::Null);
357
358 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
359 .unwrap();
360 assert_eq!(id, RequestId::Number(1));
361
362 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
363 .unwrap();
364 assert_eq!(id, RequestId::Number(-1));
365
366 let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
367 assert_eq!(id, RequestId::Str("id".to_owned()));
368 }
369
370 #[test]
371 fn id_serialization() {
372 let id = serde_json::to_value(RequestId::Null).unwrap();
373 assert_eq!(id, Value::Null);
374
375 let id = serde_json::to_value(RequestId::Number(1)).unwrap();
376 assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
377
378 let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
379 assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
380
381 let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
382 assert_eq!(id, Value::String("id".to_owned()));
383 }
384
385 #[test]
386 fn id_display() {
387 let id = RequestId::Null;
388 assert_eq!(id.to_string(), "null");
389
390 let id = RequestId::Number(1);
391 assert_eq!(id.to_string(), "1");
392
393 let id = RequestId::Number(-1);
394 assert_eq!(id.to_string(), "-1");
395
396 let id = RequestId::Str("id".to_owned());
397 assert_eq!(id.to_string(), "id");
398 }
399}
400
401#[test]
402fn test_notification_wire_format() {
403 use super::*;
404
405 use serde_json::{Value, json};
406
407 let outgoing_msg = JsonRpcMessage::wrap(
409 OutgoingMessage::<ClientSide, AgentSide>::Notification(Notification {
410 method: "cancel".into(),
411 params: Some(ClientNotification::CancelNotification(CancelNotification {
412 session_id: SessionId("test-123".into()),
413 meta: None,
414 })),
415 }),
416 );
417
418 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
419 assert_eq!(
420 serialized,
421 json!({
422 "jsonrpc": "2.0",
423 "method": "cancel",
424 "params": {
425 "sessionId": "test-123"
426 },
427 })
428 );
429
430 let outgoing_msg = JsonRpcMessage::wrap(
432 OutgoingMessage::<AgentSide, ClientSide>::Notification(Notification {
433 method: "sessionUpdate".into(),
434 params: Some(AgentNotification::SessionNotification(
435 SessionNotification {
436 session_id: SessionId("test-456".into()),
437 update: SessionUpdate::AgentMessageChunk(ContentChunk {
438 content: ContentBlock::Text(TextContent {
439 annotations: None,
440 text: "Hello".to_string(),
441 meta: None,
442 }),
443 #[cfg(feature = "unstable_message_id")]
444 message_id: None,
445 meta: None,
446 }),
447 meta: None,
448 },
449 )),
450 }),
451 );
452
453 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
454 assert_eq!(
455 serialized,
456 json!({
457 "jsonrpc": "2.0",
458 "method": "sessionUpdate",
459 "params": {
460 "sessionId": "test-456",
461 "update": {
462 "sessionUpdate": "agent_message_chunk",
463 "content": {
464 "type": "text",
465 "text": "Hello"
466 }
467 }
468 }
469 })
470 );
471}