1use serde::{Deserialize, Serialize};
6use serde_json::value::RawValue;
7use std::fmt;
8
9pub const JSONRPC_VERSION: &str = "2.0";
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
29#[serde(untagged)]
30pub enum RequestId {
31 String(String),
33 Number(i64),
35}
36
37impl fmt::Display for RequestId {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 match self {
40 Self::String(s) => write!(f, "{}", s),
41 Self::Number(n) => write!(f, "{}", n),
42 }
43 }
44}
45
46impl From<String> for RequestId {
47 fn from(s: String) -> Self {
48 Self::String(s)
49 }
50}
51
52impl From<&str> for RequestId {
53 fn from(s: &str) -> Self {
54 Self::String(s.to_string())
55 }
56}
57
58impl From<i64> for RequestId {
59 fn from(n: i64) -> Self {
60 Self::Number(n)
61 }
62}
63
64impl From<u64> for RequestId {
65 fn from(n: u64) -> Self {
66 let num = i64::try_from(n).unwrap_or(i64::MAX);
68 Self::Number(num)
69 }
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct JSONRPCRequest<P = serde_json::Value> {
89 pub jsonrpc: String,
91 pub id: RequestId,
93 pub method: String,
95 #[serde(skip_serializing_if = "Option::is_none")]
97 pub params: Option<P>,
98}
99
100impl<P> JSONRPCRequest<P> {
101 pub fn new(id: impl Into<RequestId>, method: impl Into<String>, params: Option<P>) -> Self {
103 Self {
104 jsonrpc: JSONRPC_VERSION.to_string(),
105 id: id.into(),
106 method: method.into(),
107 params,
108 }
109 }
110
111 pub fn validate(&self) -> Result<(), crate::Error> {
113 if self.jsonrpc != JSONRPC_VERSION {
114 return Err(crate::Error::validation(format!(
115 "Invalid JSON-RPC version: expected {}, got {}",
116 JSONRPC_VERSION, self.jsonrpc
117 )));
118 }
119 Ok(())
120 }
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct JSONRPCResponse<R = serde_json::Value, E = JSONRPCError> {
126 pub jsonrpc: String,
128 pub id: RequestId,
130 #[serde(flatten)]
132 pub payload: ResponsePayload<R, E>,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
137#[serde(rename_all = "lowercase")]
138pub enum ResponsePayload<R, E> {
139 Result(R),
141 Error(E),
143}
144
145impl<R, E> JSONRPCResponse<R, E> {
146 pub fn success(id: RequestId, result: R) -> Self {
148 Self {
149 jsonrpc: JSONRPC_VERSION.to_string(),
150 id,
151 payload: ResponsePayload::Result(result),
152 }
153 }
154
155 pub fn error(id: RequestId, error: E) -> Self {
157 Self {
158 jsonrpc: JSONRPC_VERSION.to_string(),
159 id,
160 payload: ResponsePayload::Error(error),
161 }
162 }
163
164 pub fn is_success(&self) -> bool {
166 matches!(self.payload, ResponsePayload::Result(_))
167 }
168
169 pub fn is_error(&self) -> bool {
171 matches!(self.payload, ResponsePayload::Error(_))
172 }
173
174 pub fn result(&self) -> Option<&R> {
176 match &self.payload {
177 ResponsePayload::Result(r) => Some(r),
178 ResponsePayload::Error(_) => None,
179 }
180 }
181
182 pub fn get_error(&self) -> Option<&E> {
184 match &self.payload {
185 ResponsePayload::Error(e) => Some(e),
186 ResponsePayload::Result(_) => None,
187 }
188 }
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct JSONRPCNotification<P = serde_json::Value> {
194 pub jsonrpc: String,
196 pub method: String,
198 #[serde(skip_serializing_if = "Option::is_none")]
200 pub params: Option<P>,
201}
202
203impl<P> JSONRPCNotification<P> {
204 pub fn new(method: impl Into<String>, params: Option<P>) -> Self {
206 Self {
207 jsonrpc: JSONRPC_VERSION.to_string(),
208 method: method.into(),
209 params,
210 }
211 }
212}
213
214#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
216pub struct JSONRPCError {
217 pub code: i32,
219 pub message: String,
221 #[serde(skip_serializing_if = "Option::is_none")]
223 pub data: Option<serde_json::Value>,
224}
225
226impl JSONRPCError {
227 pub fn new(code: i32, message: impl Into<String>) -> Self {
229 Self {
230 code,
231 message: message.into(),
232 data: None,
233 }
234 }
235
236 pub fn with_data(code: i32, message: impl Into<String>, data: serde_json::Value) -> Self {
238 Self {
239 code,
240 message: message.into(),
241 data: Some(data),
242 }
243 }
244}
245
246impl From<crate::Error> for JSONRPCError {
247 fn from(err: crate::Error) -> Self {
248 match &err {
249 crate::Error::Protocol {
250 code,
251 message,
252 data,
253 } => Self {
254 code: code.as_i32(),
255 message: message.clone(),
256 data: data.clone(),
257 },
258 _ => Self::new(-32603, err.to_string()),
259 }
260 }
261}
262
263impl std::fmt::Display for JSONRPCError {
264 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
265 write!(f, "JSON-RPC error {}: {}", self.code, self.message)
266 }
267}
268
269#[derive(Debug, Deserialize)]
271pub struct RawMessage {
272 pub jsonrpc: String,
274 #[serde(default)]
276 pub id: Option<RequestId>,
277 #[serde(default)]
279 pub method: Option<String>,
280 #[serde(default)]
282 pub params: Option<Box<RawValue>>,
283 #[serde(default)]
285 pub result: Option<Box<RawValue>>,
286 #[serde(default)]
288 pub error: Option<JSONRPCError>,
289}
290
291impl RawMessage {
292 pub fn message_type(&self) -> MessageType {
294 match (&self.id, &self.method, &self.result, &self.error) {
295 (Some(_), Some(_), None, None) => MessageType::Request,
296 (None, Some(_), None, None) => MessageType::Notification,
297 (Some(_), None, Some(_), None) => MessageType::Response,
298 (Some(_), None, None, Some(_)) => MessageType::ErrorResponse,
299 _ => MessageType::Invalid,
300 }
301 }
302}
303
304#[derive(Debug, Clone, Copy, PartialEq, Eq)]
306pub enum MessageType {
307 Request,
309 Notification,
311 Response,
313 ErrorResponse,
315 Invalid,
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use serde_json::json;
323
324 #[test]
325 fn request_id_conversion() {
326 assert_eq!(
327 RequestId::from("test"),
328 RequestId::String("test".to_string())
329 );
330 assert_eq!(RequestId::from(42i64), RequestId::Number(42));
331 assert_eq!(RequestId::from(42u64), RequestId::Number(42));
332 }
333
334 #[test]
335 fn request_serialization() {
336 let request = JSONRPCRequest::new(1i64, "test/method", Some(json!({"key": "value"})));
337 let json = serde_json::to_value(&request).unwrap();
338
339 assert_eq!(json["jsonrpc"], "2.0");
340 assert_eq!(json["id"], 1);
341 assert_eq!(json["method"], "test/method");
342 assert_eq!(json["params"]["key"], "value");
343 }
344
345 #[test]
346 fn response_success() {
347 let response: JSONRPCResponse<serde_json::Value, JSONRPCError> =
348 JSONRPCResponse::success(RequestId::from(1i64), json!({"result": true}));
349 assert!(response.is_success());
350 assert!(!response.is_error());
351 assert_eq!(response.result(), Some(&json!({"result": true})));
352 }
353
354 #[test]
355 fn response_error() {
356 let error = JSONRPCError::new(-32600, "Invalid request");
357 let response: JSONRPCResponse<serde_json::Value, JSONRPCError> =
358 JSONRPCResponse::error(RequestId::from(1i64), error);
359 assert!(!response.is_success());
360 assert!(response.is_error());
361 assert_eq!(response.get_error().unwrap().code, -32600);
362 }
363
364 #[test]
365 fn notification_serialization() {
366 let notification = JSONRPCNotification::new("test/notify", None::<serde_json::Value>);
367 let json = serde_json::to_value(¬ification).unwrap();
368
369 assert_eq!(json["jsonrpc"], "2.0");
370 assert_eq!(json["method"], "test/notify");
371 assert_eq!(json.get("params"), None);
372 }
373
374 #[test]
375 fn test_request_id_display() {
376 let string_id = RequestId::String("req-123".to_string());
377 let number_id = RequestId::Number(42);
378
379 assert_eq!(format!("{}", string_id), "req-123");
380 assert_eq!(format!("{}", number_id), "42");
381 }
382
383 #[test]
384 fn test_request_id_from_u64_overflow() {
385 let large_u64 = u64::MAX;
387 let id = RequestId::from(large_u64);
388 match id {
389 RequestId::Number(n) => assert_eq!(n, i64::MAX),
390 RequestId::String(_) => panic!("Expected Number variant"),
391 }
392 }
393
394 #[test]
395 fn test_request_validation() {
396 let valid_request = JSONRPCRequest::new(1i64, "test", None::<serde_json::Value>);
397 assert!(valid_request.validate().is_ok());
398
399 let invalid_request: JSONRPCRequest<serde_json::Value> = JSONRPCRequest {
400 jsonrpc: "1.0".to_string(),
401 id: RequestId::Number(1),
402 method: "test".to_string(),
403 params: None,
404 };
405 let err = invalid_request.validate().unwrap_err();
406 assert!(err.to_string().contains("Invalid JSON-RPC version"));
407 }
408
409 #[test]
410 fn test_notification_with_params() {
411 let params = json!({"key": "value", "number": 42});
412 let notification = JSONRPCNotification::new("test/notify", Some(params.clone()));
413 let json = serde_json::to_value(¬ification).unwrap();
414
415 assert_eq!(json["params"], params);
416 }
417
418 #[test]
419 fn test_jsonrpc_error_constructors() {
420 let error =
422 JSONRPCError::with_data(-32000, "Custom error", json!({"details": "more info"}));
423 assert_eq!(error.code, -32000);
424 assert_eq!(error.message, "Custom error");
425 assert_eq!(error.data, Some(json!({"details": "more info"})));
426
427 let mcp_err = crate::error::Error::validation("Bad input");
429 let jsonrpc_err = JSONRPCError::from(mcp_err);
430 assert_eq!(jsonrpc_err.code, -32603); assert!(jsonrpc_err.message.contains("Bad input"));
432 }
433
434 #[test]
435 fn test_raw_message_type_detection() {
436 let request_json = json!({
438 "jsonrpc": "2.0",
439 "id": 1,
440 "method": "test",
441 "params": null
442 });
443 let request: RawMessage = serde_json::from_value(request_json).unwrap();
444 assert_eq!(request.message_type(), MessageType::Request);
445
446 let notification_json = json!({
448 "jsonrpc": "2.0",
449 "method": "notify",
450 "params": null
451 });
452 let notification: RawMessage = serde_json::from_value(notification_json).unwrap();
453 assert_eq!(notification.message_type(), MessageType::Notification);
454
455 let response_json = json!({
457 "jsonrpc": "2.0",
458 "id": 1,
459 "result": "success"
460 });
461 let response: RawMessage = serde_json::from_value(response_json).unwrap();
462 assert_eq!(response.message_type(), MessageType::Response);
463
464 let error_json = json!({
466 "jsonrpc": "2.0",
467 "id": 1,
468 "error": {
469 "code": -32600,
470 "message": "Invalid request"
471 }
472 });
473 let error_response: RawMessage = serde_json::from_value(error_json).unwrap();
474 assert_eq!(error_response.message_type(), MessageType::ErrorResponse);
475
476 let invalid_json = json!({
478 "jsonrpc": "2.0"
479 });
480 let invalid: RawMessage = serde_json::from_value(invalid_json).unwrap();
481 assert_eq!(invalid.message_type(), MessageType::Invalid);
482 }
483
484 #[test]
485 fn test_response_payload_serialization() {
486 let result_payload: ResponsePayload<String, JSONRPCError> =
488 ResponsePayload::Result("success".to_string());
489 let json = serde_json::to_value(&result_payload).unwrap();
490 assert_eq!(json["result"], "success");
491
492 let error_payload: ResponsePayload<String, JSONRPCError> =
494 ResponsePayload::Error(JSONRPCError::new(-32601, "Method not found"));
495 let json = serde_json::to_value(&error_payload).unwrap();
496 assert_eq!(json["error"]["code"], -32601);
497 }
498
499 #[test]
500 fn test_jsonrpc_response_methods() {
501 type TestResponse = JSONRPCResponse<String, JSONRPCError>;
503
504 let success_resp =
505 TestResponse::success(RequestId::from("req-1"), "result data".to_string());
506 assert!(success_resp.is_success());
507 assert!(!success_resp.is_error());
508 assert_eq!(success_resp.result(), Some(&"result data".to_string()));
509 assert_eq!(success_resp.get_error(), None);
510
511 let error_resp = TestResponse::error(
512 RequestId::from("req-2"),
513 JSONRPCError::new(-32700, "Parse error"),
514 );
515 assert!(!error_resp.is_success());
516 assert!(error_resp.is_error());
517 assert_eq!(error_resp.result(), None);
518 assert_eq!(error_resp.get_error().unwrap().code, -32700);
519 }
520
521 #[test]
522 fn test_jsonrpc_error_display() {
523 let error = JSONRPCError::new(-32600, "Invalid request");
524 let display = format!("{}", error);
525 assert!(display.contains("Invalid request"));
526 assert!(display.contains("-32600"));
527
528 let error_with_data =
529 JSONRPCError::with_data(-32000, "Server error", json!({"code": "ERR001"}));
530 let display = format!("{}", error_with_data);
531 assert!(display.contains("Server error"));
532 }
533}