1use std::sync::Arc;
2
3use derive_more::{Display, From};
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize, de::DeserializeOwned};
6use serde_json::value::RawValue;
7
8use crate::{
9 AGENT_METHOD_NAMES, AgentNotification, AgentRequest, AgentResponse, CLIENT_METHOD_NAMES,
10 ClientNotification, ClientRequest, ClientResponse, Error, ExtNotification, ExtRequest, Result,
11};
12
13#[derive(
23 Debug,
24 PartialEq,
25 Clone,
26 Hash,
27 Eq,
28 Deserialize,
29 Serialize,
30 PartialOrd,
31 Ord,
32 Display,
33 JsonSchema,
34 From,
35)]
36#[serde(untagged)]
37#[allow(
38 clippy::exhaustive_enums,
39 reason = "This comes from the JSON-RPC specification itself"
40)]
41#[from(String, i64)]
42pub enum RequestId {
43 #[display("null")]
44 Null,
45 Number(i64),
46 Str(String),
47}
48
49#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
50#[allow(
51 clippy::exhaustive_structs,
52 reason = "This comes from the JSON-RPC specification itself"
53)]
54#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
55pub struct Request<Params> {
56 pub id: RequestId,
57 pub method: Arc<str>,
58 #[serde(skip_serializing_if = "Option::is_none")]
59 pub params: Option<Params>,
60}
61
62#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
63#[allow(
64 clippy::exhaustive_enums,
65 reason = "This comes from the JSON-RPC specification itself"
66)]
67#[serde(untagged)]
68#[schemars(rename = "{Result}", extend("x-docs-ignore" = true))]
69pub enum Response<Result> {
70 Result { id: RequestId, result: Result },
71 Error { id: RequestId, error: Error },
72}
73
74impl<R> Response<R> {
75 pub fn new(id: impl Into<RequestId>, result: Result<R>) -> Self {
76 match result {
77 Ok(result) => Self::Result {
78 id: id.into(),
79 result,
80 },
81 Err(error) => Self::Error {
82 id: id.into(),
83 error,
84 },
85 }
86 }
87}
88
89#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
90#[allow(
91 clippy::exhaustive_structs,
92 reason = "This comes from the JSON-RPC specification itself"
93)]
94#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
95pub struct Notification<Params> {
96 pub method: Arc<str>,
97 #[serde(skip_serializing_if = "Option::is_none")]
98 pub params: Option<Params>,
99}
100
101#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
102#[serde(untagged)]
103#[schemars(inline)]
104#[allow(
105 clippy::exhaustive_enums,
106 reason = "This comes from the JSON-RPC specification itself"
107)]
108pub enum OutgoingMessage<Local: Side, Remote: Side> {
109 Request(Request<Remote::InRequest>),
110 Response(Response<Local::OutResponse>),
111 Notification(Notification<Remote::InNotification>),
112}
113
114#[derive(Debug, Serialize, Deserialize, JsonSchema)]
115#[schemars(inline)]
116enum JsonRpcVersion {
117 #[serde(rename = "2.0")]
118 V2,
119}
120
121#[derive(Debug, Serialize, Deserialize, JsonSchema)]
126#[schemars(inline)]
127pub struct JsonRpcMessage<M> {
128 jsonrpc: JsonRpcVersion,
129 #[serde(flatten)]
130 message: M,
131}
132
133impl<M> JsonRpcMessage<M> {
134 #[must_use]
137 pub fn wrap(message: M) -> Self {
138 Self {
139 jsonrpc: JsonRpcVersion::V2,
140 message,
141 }
142 }
143}
144
145pub trait Side: Clone {
146 type InRequest: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
147 type InNotification: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
148 type OutResponse: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
149
150 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest>;
158
159 fn decode_notification(method: &str, params: Option<&RawValue>)
167 -> Result<Self::InNotification>;
168}
169
170#[derive(Clone, Default, Debug, JsonSchema)]
177#[non_exhaustive]
178pub struct ClientSide;
179
180impl Side for ClientSide {
181 type InRequest = AgentRequest;
182 type InNotification = AgentNotification;
183 type OutResponse = ClientResponse;
184
185 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
186 let params = params.ok_or_else(Error::invalid_params)?;
187
188 match method {
189 m if m == CLIENT_METHOD_NAMES.session_request_permission => {
190 serde_json::from_str(params.get())
191 .map(AgentRequest::RequestPermissionRequest)
192 .map_err(Into::into)
193 }
194 m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
195 .map(AgentRequest::WriteTextFileRequest)
196 .map_err(Into::into),
197 m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
198 .map(AgentRequest::ReadTextFileRequest)
199 .map_err(Into::into),
200 m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
201 .map(AgentRequest::CreateTerminalRequest)
202 .map_err(Into::into),
203 m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
204 .map(AgentRequest::TerminalOutputRequest)
205 .map_err(Into::into),
206 m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
207 .map(AgentRequest::KillTerminalCommandRequest)
208 .map_err(Into::into),
209 m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
210 .map(AgentRequest::ReleaseTerminalRequest)
211 .map_err(Into::into),
212 m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
213 serde_json::from_str(params.get())
214 .map(AgentRequest::WaitForTerminalExitRequest)
215 .map_err(Into::into)
216 }
217 _ => {
218 if let Some(custom_method) = method.strip_prefix('_') {
219 Ok(AgentRequest::ExtMethodRequest(ExtRequest {
220 method: custom_method.into(),
221 params: params.to_owned().into(),
222 }))
223 } else {
224 Err(Error::method_not_found())
225 }
226 }
227 }
228 }
229
230 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
231 let params = params.ok_or_else(Error::invalid_params)?;
232
233 match method {
234 m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
235 .map(AgentNotification::SessionNotification)
236 .map_err(Into::into),
237 _ => {
238 if let Some(custom_method) = method.strip_prefix('_') {
239 Ok(AgentNotification::ExtNotification(ExtNotification {
240 method: custom_method.into(),
241 params: RawValue::from_string(params.get().to_string())?.into(),
242 }))
243 } else {
244 Err(Error::method_not_found())
245 }
246 }
247 }
248 }
249}
250
251#[derive(Clone, Default, Debug, JsonSchema)]
258#[non_exhaustive]
259pub struct AgentSide;
260
261impl Side for AgentSide {
262 type InRequest = ClientRequest;
263 type InNotification = ClientNotification;
264 type OutResponse = AgentResponse;
265
266 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
267 let params = params.ok_or_else(Error::invalid_params)?;
268
269 match method {
270 m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
271 .map(ClientRequest::InitializeRequest)
272 .map_err(Into::into),
273 m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
274 .map(ClientRequest::AuthenticateRequest)
275 .map_err(Into::into),
276 m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
277 .map(ClientRequest::NewSessionRequest)
278 .map_err(Into::into),
279 m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
280 .map(ClientRequest::LoadSessionRequest)
281 .map_err(Into::into),
282 #[cfg(feature = "unstable_session_list")]
283 m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
284 .map(ClientRequest::ListSessionsRequest)
285 .map_err(Into::into),
286 #[cfg(feature = "unstable_session_fork")]
287 m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get())
288 .map(ClientRequest::ForkSessionRequest)
289 .map_err(Into::into),
290 #[cfg(feature = "unstable_session_resume")]
291 m if m == AGENT_METHOD_NAMES.session_resume => serde_json::from_str(params.get())
292 .map(ClientRequest::ResumeSessionRequest)
293 .map_err(Into::into),
294 m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
295 .map(ClientRequest::SetSessionModeRequest)
296 .map_err(Into::into),
297 #[cfg(feature = "unstable_session_model")]
298 m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
299 .map(ClientRequest::SetSessionModelRequest)
300 .map_err(Into::into),
301 m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
302 .map(ClientRequest::PromptRequest)
303 .map_err(Into::into),
304 _ => {
305 if let Some(custom_method) = method.strip_prefix('_') {
306 Ok(ClientRequest::ExtMethodRequest(ExtRequest {
307 method: custom_method.into(),
308 params: params.to_owned().into(),
309 }))
310 } else {
311 Err(Error::method_not_found())
312 }
313 }
314 }
315 }
316
317 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
318 let params = params.ok_or_else(Error::invalid_params)?;
319
320 match method {
321 m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
322 .map(ClientNotification::CancelNotification)
323 .map_err(Into::into),
324 _ => {
325 if let Some(custom_method) = method.strip_prefix('_') {
326 Ok(ClientNotification::ExtNotification(ExtNotification {
327 method: custom_method.into(),
328 params: RawValue::from_string(params.get().to_string())?.into(),
329 }))
330 } else {
331 Err(Error::method_not_found())
332 }
333 }
334 }
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 use serde_json::{Number, Value};
343
344 #[test]
345 fn id_deserialization() {
346 let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
347 assert_eq!(id, RequestId::Null);
348
349 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
350 .unwrap();
351 assert_eq!(id, RequestId::Number(1));
352
353 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
354 .unwrap();
355 assert_eq!(id, RequestId::Number(-1));
356
357 let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
358 assert_eq!(id, RequestId::Str("id".to_owned()));
359 }
360
361 #[test]
362 fn id_serialization() {
363 let id = serde_json::to_value(RequestId::Null).unwrap();
364 assert_eq!(id, Value::Null);
365
366 let id = serde_json::to_value(RequestId::Number(1)).unwrap();
367 assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
368
369 let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
370 assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
371
372 let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
373 assert_eq!(id, Value::String("id".to_owned()));
374 }
375
376 #[test]
377 fn id_display() {
378 let id = RequestId::Null;
379 assert_eq!(id.to_string(), "null");
380
381 let id = RequestId::Number(1);
382 assert_eq!(id.to_string(), "1");
383
384 let id = RequestId::Number(-1);
385 assert_eq!(id.to_string(), "-1");
386
387 let id = RequestId::Str("id".to_owned());
388 assert_eq!(id.to_string(), "id");
389 }
390}
391
392#[test]
393fn test_notification_wire_format() {
394 use super::*;
395
396 use serde_json::{Value, json};
397
398 let outgoing_msg = JsonRpcMessage::wrap(
400 OutgoingMessage::<ClientSide, AgentSide>::Notification(Notification {
401 method: "cancel".into(),
402 params: Some(ClientNotification::CancelNotification(CancelNotification {
403 session_id: SessionId("test-123".into()),
404 meta: None,
405 })),
406 }),
407 );
408
409 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
410 assert_eq!(
411 serialized,
412 json!({
413 "jsonrpc": "2.0",
414 "method": "cancel",
415 "params": {
416 "sessionId": "test-123"
417 },
418 })
419 );
420
421 let outgoing_msg = JsonRpcMessage::wrap(
423 OutgoingMessage::<AgentSide, ClientSide>::Notification(Notification {
424 method: "sessionUpdate".into(),
425 params: Some(AgentNotification::SessionNotification(
426 SessionNotification {
427 session_id: SessionId("test-456".into()),
428 update: SessionUpdate::AgentMessageChunk(ContentChunk {
429 content: ContentBlock::Text(TextContent {
430 annotations: None,
431 text: "Hello".to_string(),
432 meta: None,
433 }),
434 meta: None,
435 }),
436 meta: None,
437 },
438 )),
439 }),
440 );
441
442 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
443 assert_eq!(
444 serialized,
445 json!({
446 "jsonrpc": "2.0",
447 "method": "sessionUpdate",
448 "params": {
449 "sessionId": "test-456",
450 "update": {
451 "sessionUpdate": "agent_message_chunk",
452 "content": {
453 "type": "text",
454 "text": "Hello"
455 }
456 }
457 }
458 })
459 );
460}