use serde::{Deserialize, Serialize};
use serde_json::value::RawValue;
use std::fmt;
pub const JSONRPC_VERSION: &str = "2.0";
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(untagged)]
pub enum RequestId {
String(String),
Number(i64),
}
impl fmt::Display for RequestId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::String(s) => write!(f, "{}", s),
Self::Number(n) => write!(f, "{}", n),
}
}
}
impl From<String> for RequestId {
fn from(s: String) -> Self {
Self::String(s)
}
}
impl From<&str> for RequestId {
fn from(s: &str) -> Self {
Self::String(s.to_string())
}
}
impl From<i64> for RequestId {
fn from(n: i64) -> Self {
Self::Number(n)
}
}
impl From<u64> for RequestId {
fn from(n: u64) -> Self {
let num = i64::try_from(n).unwrap_or(i64::MAX);
Self::Number(num)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JSONRPCRequest<P = serde_json::Value> {
pub jsonrpc: String,
pub id: RequestId,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<P>,
}
impl<P> JSONRPCRequest<P> {
pub fn new(id: impl Into<RequestId>, method: impl Into<String>, params: Option<P>) -> Self {
contract_pre_jsonrpc_framing!();
Self {
jsonrpc: JSONRPC_VERSION.to_string(),
id: id.into(),
method: method.into(),
params,
}
}
pub fn validate(&self) -> Result<(), crate::Error> {
contract_pre_jsonrpc_framing!();
if self.jsonrpc != JSONRPC_VERSION {
return Err(crate::Error::validation(format!(
"Invalid JSON-RPC version: expected {}, got {}",
JSONRPC_VERSION, self.jsonrpc
)));
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JSONRPCResponse<R = serde_json::Value, E = JSONRPCError> {
pub jsonrpc: String,
pub id: RequestId,
#[serde(flatten)]
pub payload: ResponsePayload<R, E>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ResponsePayload<R, E> {
Result(R),
Error(E),
}
impl<R, E> JSONRPCResponse<R, E> {
pub fn success(id: RequestId, result: R) -> Self {
Self {
jsonrpc: JSONRPC_VERSION.to_string(),
id,
payload: ResponsePayload::Result(result),
}
}
pub fn error(id: RequestId, error: E) -> Self {
Self {
jsonrpc: JSONRPC_VERSION.to_string(),
id,
payload: ResponsePayload::Error(error),
}
}
pub fn is_success(&self) -> bool {
matches!(self.payload, ResponsePayload::Result(_))
}
pub fn is_error(&self) -> bool {
matches!(self.payload, ResponsePayload::Error(_))
}
pub fn result(&self) -> Option<&R> {
match &self.payload {
ResponsePayload::Result(r) => Some(r),
ResponsePayload::Error(_) => None,
}
}
pub fn get_error(&self) -> Option<&E> {
match &self.payload {
ResponsePayload::Error(e) => Some(e),
ResponsePayload::Result(_) => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JSONRPCNotification<P = serde_json::Value> {
pub jsonrpc: String,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<P>,
}
impl<P> JSONRPCNotification<P> {
pub fn new(method: impl Into<String>, params: Option<P>) -> Self {
Self {
jsonrpc: JSONRPC_VERSION.to_string(),
method: method.into(),
params,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct JSONRPCError {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
}
impl JSONRPCError {
pub fn new(code: i32, message: impl Into<String>) -> Self {
Self {
code,
message: message.into(),
data: None,
}
}
pub fn with_data(code: i32, message: impl Into<String>, data: serde_json::Value) -> Self {
Self {
code,
message: message.into(),
data: Some(data),
}
}
}
impl From<crate::Error> for JSONRPCError {
fn from(err: crate::Error) -> Self {
match &err {
crate::Error::Protocol {
code,
message,
data,
} => Self {
code: code.as_i32(),
message: message.clone(),
data: data.clone(),
},
_ => Self::new(-32603, err.to_string()),
}
}
}
impl std::fmt::Display for JSONRPCError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "JSON-RPC error {}: {}", self.code, self.message)
}
}
#[derive(Debug, Deserialize)]
pub struct RawMessage {
pub jsonrpc: String,
#[serde(default)]
pub id: Option<RequestId>,
#[serde(default)]
pub method: Option<String>,
#[serde(default)]
pub params: Option<Box<RawValue>>,
#[serde(default)]
pub result: Option<Box<RawValue>>,
#[serde(default)]
pub error: Option<JSONRPCError>,
}
impl RawMessage {
pub fn message_type(&self) -> MessageType {
match (&self.id, &self.method, &self.result, &self.error) {
(Some(_), Some(_), None, None) => MessageType::Request,
(None, Some(_), None, None) => MessageType::Notification,
(Some(_), None, Some(_), None) => MessageType::Response,
(Some(_), None, None, Some(_)) => MessageType::ErrorResponse,
_ => MessageType::Invalid,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MessageType {
Request,
Notification,
Response,
ErrorResponse,
Invalid,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn request_id_conversion() {
assert_eq!(
RequestId::from("test"),
RequestId::String("test".to_string())
);
assert_eq!(RequestId::from(42i64), RequestId::Number(42));
assert_eq!(RequestId::from(42u64), RequestId::Number(42));
}
#[test]
fn request_serialization() {
let request = JSONRPCRequest::new(1i64, "test/method", Some(json!({"key": "value"})));
let json = serde_json::to_value(&request).unwrap();
assert_eq!(json["jsonrpc"], "2.0");
assert_eq!(json["id"], 1);
assert_eq!(json["method"], "test/method");
assert_eq!(json["params"]["key"], "value");
}
#[test]
fn response_success() {
let response: JSONRPCResponse<serde_json::Value, JSONRPCError> =
JSONRPCResponse::success(RequestId::from(1i64), json!({"result": true}));
assert!(response.is_success());
assert!(!response.is_error());
assert_eq!(response.result(), Some(&json!({"result": true})));
}
#[test]
fn response_error() {
let error = JSONRPCError::new(-32600, "Invalid request");
let response: JSONRPCResponse<serde_json::Value, JSONRPCError> =
JSONRPCResponse::error(RequestId::from(1i64), error);
assert!(!response.is_success());
assert!(response.is_error());
assert_eq!(response.get_error().unwrap().code, -32600);
}
#[test]
fn notification_serialization() {
let notification = JSONRPCNotification::new("test/notify", None::<serde_json::Value>);
let json = serde_json::to_value(¬ification).unwrap();
assert_eq!(json["jsonrpc"], "2.0");
assert_eq!(json["method"], "test/notify");
assert_eq!(json.get("params"), None);
}
#[test]
fn test_request_id_display() {
let string_id = RequestId::String("req-123".to_string());
let number_id = RequestId::Number(42);
assert_eq!(format!("{}", string_id), "req-123");
assert_eq!(format!("{}", number_id), "42");
}
#[test]
fn test_request_id_from_u64_overflow() {
let large_u64 = u64::MAX;
let id = RequestId::from(large_u64);
match id {
RequestId::Number(n) => assert_eq!(n, i64::MAX),
RequestId::String(_) => panic!("Expected Number variant"),
}
}
#[test]
fn test_request_validation() {
let valid_request = JSONRPCRequest::new(1i64, "test", None::<serde_json::Value>);
assert!(valid_request.validate().is_ok());
let invalid_request: JSONRPCRequest<serde_json::Value> = JSONRPCRequest {
jsonrpc: "1.0".to_string(),
id: RequestId::Number(1),
method: "test".to_string(),
params: None,
};
let err = invalid_request.validate().unwrap_err();
assert!(err.to_string().contains("Invalid JSON-RPC version"));
}
#[test]
fn test_notification_with_params() {
let params = json!({"key": "value", "number": 42});
let notification = JSONRPCNotification::new("test/notify", Some(params.clone()));
let json = serde_json::to_value(¬ification).unwrap();
assert_eq!(json["params"], params);
}
#[test]
fn test_jsonrpc_error_constructors() {
let error =
JSONRPCError::with_data(-32000, "Custom error", json!({"details": "more info"}));
assert_eq!(error.code, -32000);
assert_eq!(error.message, "Custom error");
assert_eq!(error.data, Some(json!({"details": "more info"})));
let mcp_err = crate::error::Error::validation("Bad input");
let jsonrpc_err = JSONRPCError::from(mcp_err);
assert_eq!(jsonrpc_err.code, -32603); assert!(jsonrpc_err.message.contains("Bad input"));
}
#[test]
fn test_raw_message_type_detection() {
let request_json = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "test",
"params": null
});
let request: RawMessage = serde_json::from_value(request_json).unwrap();
assert_eq!(request.message_type(), MessageType::Request);
let notification_json = json!({
"jsonrpc": "2.0",
"method": "notify",
"params": null
});
let notification: RawMessage = serde_json::from_value(notification_json).unwrap();
assert_eq!(notification.message_type(), MessageType::Notification);
let response_json = json!({
"jsonrpc": "2.0",
"id": 1,
"result": "success"
});
let response: RawMessage = serde_json::from_value(response_json).unwrap();
assert_eq!(response.message_type(), MessageType::Response);
let error_json = json!({
"jsonrpc": "2.0",
"id": 1,
"error": {
"code": -32600,
"message": "Invalid request"
}
});
let error_response: RawMessage = serde_json::from_value(error_json).unwrap();
assert_eq!(error_response.message_type(), MessageType::ErrorResponse);
let invalid_json = json!({
"jsonrpc": "2.0"
});
let invalid: RawMessage = serde_json::from_value(invalid_json).unwrap();
assert_eq!(invalid.message_type(), MessageType::Invalid);
}
#[test]
fn test_response_payload_serialization() {
let result_payload: ResponsePayload<String, JSONRPCError> =
ResponsePayload::Result("success".to_string());
let json = serde_json::to_value(&result_payload).unwrap();
assert_eq!(json["result"], "success");
let error_payload: ResponsePayload<String, JSONRPCError> =
ResponsePayload::Error(JSONRPCError::new(-32601, "Method not found"));
let json = serde_json::to_value(&error_payload).unwrap();
assert_eq!(json["error"]["code"], -32601);
}
#[test]
fn test_jsonrpc_response_methods() {
type TestResponse = JSONRPCResponse<String, JSONRPCError>;
let success_resp =
TestResponse::success(RequestId::from("req-1"), "result data".to_string());
assert!(success_resp.is_success());
assert!(!success_resp.is_error());
assert_eq!(success_resp.result(), Some(&"result data".to_string()));
assert_eq!(success_resp.get_error(), None);
let error_resp = TestResponse::error(
RequestId::from("req-2"),
JSONRPCError::new(-32700, "Parse error"),
);
assert!(!error_resp.is_success());
assert!(error_resp.is_error());
assert_eq!(error_resp.result(), None);
assert_eq!(error_resp.get_error().unwrap().code, -32700);
}
#[test]
fn test_jsonrpc_error_display() {
let error = JSONRPCError::new(-32600, "Invalid request");
let display = format!("{}", error);
assert!(display.contains("Invalid request"));
assert!(display.contains("-32600"));
let error_with_data =
JSONRPCError::with_data(-32000, "Server error", json!({"code": "ERR001"}));
let display = format!("{}", error_with_data);
assert!(display.contains("Server error"));
}
}