#![allow(dead_code)]
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(untagged)]
pub(crate) enum JsonRpcId {
Number(u64),
String(String),
}
impl std::fmt::Display for JsonRpcId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
JsonRpcId::Number(n) => write!(f, "{n}"),
JsonRpcId::String(s) => write!(f, "{s}"),
}
}
}
impl From<u64> for JsonRpcId {
fn from(n: u64) -> Self {
JsonRpcId::Number(n)
}
}
impl From<String> for JsonRpcId {
fn from(s: String) -> Self {
JsonRpcId::String(s)
}
}
impl From<&str> for JsonRpcId {
fn from(s: &str) -> Self {
JsonRpcId::String(s.to_string())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct JsonRpcRequest {
pub jsonrpc: String,
pub id: JsonRpcId,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<Value>,
}
impl JsonRpcRequest {
pub fn new(
id: impl Into<JsonRpcId>,
method: impl Into<String>,
params: Option<Value>,
) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id: id.into(),
method: method.into(),
params,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct JsonRpcResponse {
pub jsonrpc: String,
pub id: JsonRpcId,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>,
}
impl JsonRpcResponse {
pub fn success(id: JsonRpcId, result: Value) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id,
result: Some(result),
error: None,
}
}
pub fn error(id: JsonRpcId, error: JsonRpcError) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id,
result: None,
error: Some(error),
}
}
pub fn into_result(self) -> std::result::Result<Value, JsonRpcError> {
if let Some(err) = self.error {
Err(err)
} else {
Ok(self.result.unwrap_or(Value::Null))
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct JsonRpcNotification {
pub jsonrpc: String,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<Value>,
}
impl JsonRpcNotification {
pub fn new(method: impl Into<String>, params: Option<Value>) -> Self {
Self {
jsonrpc: "2.0".to_string(),
method: method.into(),
params,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct JsonRpcError {
pub code: i64,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<Value>,
}
pub(crate) mod error_codes {
pub const PARSE_ERROR: i64 = -32700;
pub const INVALID_REQUEST: i64 = -32600;
pub const METHOD_NOT_FOUND: i64 = -32601;
pub const INVALID_PARAMS: i64 = -32602;
pub const INTERNAL_ERROR: i64 = -32603;
pub const AUTH_REQUIRED: i64 = 401;
}
#[derive(Debug, Clone)]
pub(crate) enum JsonRpcMessage {
Request(JsonRpcRequest),
Response(JsonRpcResponse),
Notification(JsonRpcNotification),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum MessageKind {
Request,
Response,
Notification,
}
impl JsonRpcMessage {
pub fn classify(value: &Value) -> crate::Result<MessageKind> {
let has_id = value.get("id").is_some();
let has_method = value.get("method").is_some();
match (has_id, has_method) {
(true, true) => Ok(MessageKind::Request),
(true, false) => Ok(MessageKind::Response),
(false, true) => Ok(MessageKind::Notification),
(false, false) => Err(crate::Error::ProtocolError(
"Invalid JSON-RPC message: missing both 'id' and 'method'".to_string(),
)),
}
}
pub fn parse(value: Value) -> crate::Result<Self> {
match Self::classify(&value)? {
MessageKind::Request => {
let req: JsonRpcRequest = serde_json::from_value(value)?;
Ok(JsonRpcMessage::Request(req))
}
MessageKind::Response => {
let resp: JsonRpcResponse = serde_json::from_value(value)?;
Ok(JsonRpcMessage::Response(resp))
}
MessageKind::Notification => {
let notif: JsonRpcNotification = serde_json::from_value(value)?;
Ok(JsonRpcMessage::Notification(notif))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_classify_request() {
let value = json!({"id": 1, "method": "foo", "jsonrpc": "2.0"});
let kind = JsonRpcMessage::classify(&value).expect("should classify");
assert_eq!(kind, MessageKind::Request);
}
#[test]
fn test_classify_response() {
let value = json!({"id": 1, "result": {}, "jsonrpc": "2.0"});
let kind = JsonRpcMessage::classify(&value).expect("should classify");
assert_eq!(kind, MessageKind::Response);
}
#[test]
fn test_classify_notification() {
let value = json!({"method": "foo", "jsonrpc": "2.0"});
let kind = JsonRpcMessage::classify(&value).expect("should classify");
assert_eq!(kind, MessageKind::Notification);
}
#[test]
fn test_classify_invalid() {
let value = json!({});
let result = JsonRpcMessage::classify(&value);
assert!(result.is_err(), "expected Err for message with no id and no method");
}
#[test]
fn test_request_roundtrip() {
let req = JsonRpcRequest::new(42u64, "myMethod", Some(json!({"key": "value"})));
let serialized = serde_json::to_string(&req).expect("serialize");
let deserialized: JsonRpcRequest = serde_json::from_str(&serialized).expect("deserialize");
assert_eq!(deserialized.jsonrpc, "2.0");
assert_eq!(deserialized.id, JsonRpcId::Number(42));
assert_eq!(deserialized.method, "myMethod");
assert_eq!(deserialized.params, Some(json!({"key": "value"})));
}
#[test]
fn test_response_into_result_success() {
let resp = JsonRpcResponse::success(JsonRpcId::Number(1), json!({"ok": true}));
let result = resp.into_result();
assert!(result.is_ok());
assert_eq!(result.unwrap(), json!({"ok": true}));
}
#[test]
fn test_response_into_result_error() {
let rpc_err = JsonRpcError {
code: error_codes::INTERNAL_ERROR,
message: "something went wrong".to_string(),
data: None,
};
let resp = JsonRpcResponse::error(JsonRpcId::Number(2), rpc_err);
let result = resp.into_result();
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.code, error_codes::INTERNAL_ERROR);
assert_eq!(err.message, "something went wrong");
}
#[test]
fn test_id_display() {
let num_id = JsonRpcId::Number(99);
assert_eq!(num_id.to_string(), "99");
let str_id = JsonRpcId::String("my-request-id".to_string());
assert_eq!(str_id.to_string(), "my-request-id");
}
}