use serde::{Deserialize, Deserializer, de::DeserializeOwned};
use serde_json::value::RawValue;
use serde_json::{Map, Value};
#[derive(Debug, Clone)]
pub struct JsonRpcEnvelope {
pub id: Option<JsonRpcId>,
pub method: Option<String>,
pub params: Option<Box<RawValue>>,
pub result: Option<Box<RawValue>>,
pub error: Option<JsonRpcError>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum JsonRpcId {
Number(i64),
String(String),
Null,
}
#[derive(Debug, Clone)]
pub struct JsonRpcError {
pub code: i32,
pub message: String,
pub data: Option<Box<RawValue>>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ParseError {
NotJson,
NotJsonRpc20,
InvalidShape,
}
#[derive(Deserialize)]
struct Raw {
#[serde(default)]
jsonrpc: Option<String>,
#[serde(default, deserialize_with = "some_value")]
id: Option<serde_json::Value>,
#[serde(default)]
method: Option<String>,
#[serde(default)]
params: Option<Box<RawValue>>,
#[serde(default)]
result: Option<Box<RawValue>>,
#[serde(default)]
error: Option<RawError>,
}
fn some_value<'de, D>(d: D) -> Result<Option<serde_json::Value>, D::Error>
where
D: Deserializer<'de>,
{
serde_json::Value::deserialize(d).map(Some)
}
#[derive(Deserialize)]
struct RawError {
code: i32,
message: String,
#[serde(default)]
data: Option<Box<RawValue>>,
}
impl JsonRpcEnvelope {
pub fn parse(bytes: &[u8]) -> Result<Self, ParseError> {
if first_non_ws(bytes) == Some(b'[') {
return Err(ParseError::InvalidShape);
}
let raw: Raw = serde_json::from_slice(bytes).map_err(|_| ParseError::NotJson)?;
if raw.jsonrpc.as_deref() != Some("2.0") {
return Err(ParseError::NotJsonRpc20);
}
let id = match raw.id {
None => None,
Some(serde_json::Value::Null) => Some(JsonRpcId::Null),
Some(serde_json::Value::Number(n)) => Some(JsonRpcId::Number(
n.as_i64().ok_or(ParseError::InvalidShape)?,
)),
Some(serde_json::Value::String(s)) => Some(JsonRpcId::String(s)),
Some(_) => return Err(ParseError::InvalidShape),
};
let error = raw.error.map(|e| JsonRpcError {
code: e.code,
message: e.message,
data: e.data,
});
let shape = (
raw.method.is_some(),
id.is_some(),
raw.result.is_some(),
error.is_some(),
);
let valid = matches!(
shape,
(true, true, false, false) | (true, false, false, false) | (false, true, true, false) | (false, true, false, true) );
if !valid {
return Err(ParseError::InvalidShape);
}
Ok(JsonRpcEnvelope {
id,
method: raw.method,
params: raw.params,
result: raw.result,
error,
})
}
pub fn params_as<T: DeserializeOwned>(&self) -> Option<T> {
let raw = self.params.as_ref()?;
serde_json::from_str(raw.get()).ok()
}
pub fn result_as<T: DeserializeOwned>(&self) -> Option<T> {
let raw = self.result.as_ref()?;
serde_json::from_str(raw.get()).ok()
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut map = Map::with_capacity(5);
map.insert("jsonrpc".into(), Value::String("2.0".into()));
if let Some(id) = &self.id {
map.insert("id".into(), id_to_value(id));
}
if let Some(method) = &self.method {
map.insert("method".into(), Value::String(method.clone()));
}
if let Some(params) = &self.params {
map.insert(
"params".into(),
serde_json::from_str(params.get()).unwrap_or(Value::Null),
);
}
if let Some(result) = &self.result {
map.insert(
"result".into(),
serde_json::from_str(result.get()).unwrap_or(Value::Null),
);
}
if let Some(error) = &self.error {
let mut err = Map::with_capacity(3);
err.insert("code".into(), Value::Number((error.code as i64).into()));
err.insert("message".into(), Value::String(error.message.clone()));
if let Some(data) = &error.data {
err.insert(
"data".into(),
serde_json::from_str(data.get()).unwrap_or(Value::Null),
);
}
map.insert("error".into(), Value::Object(err));
}
serde_json::to_vec(&Value::Object(map)).unwrap_or_default()
}
}
fn id_to_value(id: &JsonRpcId) -> Value {
match id {
JsonRpcId::Number(n) => Value::Number((*n).into()),
JsonRpcId::String(s) => Value::String(s.clone()),
JsonRpcId::Null => Value::Null,
}
}
fn first_non_ws(bytes: &[u8]) -> Option<u8> {
bytes.iter().copied().find(|b| !b.is_ascii_whitespace())
}
#[cfg(test)]
#[allow(non_snake_case)]
mod tests {
use super::*;
use serde::Deserialize;
#[derive(Debug, Deserialize, PartialEq)]
struct Greet {
name: String,
}
#[test]
fn parse__request_shape() {
let env = JsonRpcEnvelope::parse(
br#"{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{"x":1}}"#,
)
.unwrap();
assert_eq!(env.id, Some(JsonRpcId::Number(1)));
assert_eq!(env.method.as_deref(), Some("tools/list"));
assert!(env.params.is_some());
assert!(env.result.is_none());
assert!(env.error.is_none());
}
#[test]
fn parse__notification_shape() {
let env = JsonRpcEnvelope::parse(
br#"{"jsonrpc":"2.0","method":"notifications/progress","params":{"p":0.5}}"#,
)
.unwrap();
assert!(env.id.is_none());
assert_eq!(env.method.as_deref(), Some("notifications/progress"));
}
#[test]
fn parse__result_shape() {
let env =
JsonRpcEnvelope::parse(br#"{"jsonrpc":"2.0","id":"r1","result":{"ok":true}}"#).unwrap();
assert_eq!(env.id, Some(JsonRpcId::String("r1".into())));
assert!(env.result.is_some());
}
#[test]
fn parse__error_shape() {
let env = JsonRpcEnvelope::parse(
br#"{"jsonrpc":"2.0","id":7,"error":{"code":-32600,"message":"invalid"}}"#,
)
.unwrap();
let err = env.error.unwrap();
assert_eq!(err.code, -32600);
assert_eq!(err.message, "invalid");
}
#[test]
fn parse__null_id_accepted() {
let env = JsonRpcEnvelope::parse(
br#"{"jsonrpc":"2.0","id":null,"error":{"code":-32700,"message":"parse error"}}"#,
)
.unwrap();
assert_eq!(env.id, Some(JsonRpcId::Null));
}
#[test]
fn parse__id_fractional_number_rejected() {
let err =
JsonRpcEnvelope::parse(br#"{"jsonrpc":"2.0","id":1.5,"method":"x"}"#).unwrap_err();
assert_eq!(err, ParseError::InvalidShape);
}
#[test]
fn parse__empty_body_returns_not_json() {
assert_eq!(
JsonRpcEnvelope::parse(b"").unwrap_err(),
ParseError::NotJson
);
}
#[test]
fn parse__garbage_bytes_return_not_json() {
assert_eq!(
JsonRpcEnvelope::parse(b"not json at all").unwrap_err(),
ParseError::NotJson,
);
}
#[test]
fn parse__missing_jsonrpc_field() {
let err = JsonRpcEnvelope::parse(br#"{"id":1,"method":"foo"}"#).unwrap_err();
assert_eq!(err, ParseError::NotJsonRpc20);
}
#[test]
fn parse__wrong_jsonrpc_version() {
let err =
JsonRpcEnvelope::parse(br#"{"jsonrpc":"1.0","id":1,"method":"foo"}"#).unwrap_err();
assert_eq!(err, ParseError::NotJsonRpc20);
}
#[test]
fn parse__bare_jsonrpc_is_invalid_shape() {
let err = JsonRpcEnvelope::parse(br#"{"jsonrpc":"2.0"}"#).unwrap_err();
assert_eq!(err, ParseError::InvalidShape);
}
#[test]
fn parse__top_level_array_rejected() {
let err = JsonRpcEnvelope::parse(br#"[{"jsonrpc":"2.0","method":"x"}]"#).unwrap_err();
assert_eq!(err, ParseError::InvalidShape);
}
#[test]
fn parse__top_level_array_with_leading_ws_rejected() {
let err = JsonRpcEnvelope::parse(b" [ ]").unwrap_err();
assert_eq!(err, ParseError::InvalidShape);
}
#[test]
fn parse__both_result_and_error_rejected() {
let err = JsonRpcEnvelope::parse(
br#"{"jsonrpc":"2.0","id":1,"result":{},"error":{"code":-1,"message":"x"}}"#,
)
.unwrap_err();
assert_eq!(err, ParseError::InvalidShape);
}
#[test]
fn parse__response_without_id_rejected() {
let err = JsonRpcEnvelope::parse(br#"{"jsonrpc":"2.0","result":{}}"#).unwrap_err();
assert_eq!(err, ParseError::InvalidShape);
}
#[test]
fn params_as__deserializes_on_match() {
let env = JsonRpcEnvelope::parse(
br#"{"jsonrpc":"2.0","id":1,"method":"greet","params":{"name":"rod"}}"#,
)
.unwrap();
assert_eq!(env.params_as::<Greet>(), Some(Greet { name: "rod".into() }));
}
#[test]
fn params_as__none_on_mismatch() {
let env = JsonRpcEnvelope::parse(
br#"{"jsonrpc":"2.0","id":1,"method":"greet","params":{"wrong":1}}"#,
)
.unwrap();
assert!(env.params_as::<Greet>().is_none());
}
#[test]
fn params_as__none_when_absent() {
let env = JsonRpcEnvelope::parse(br#"{"jsonrpc":"2.0","id":1,"method":"greet"}"#).unwrap();
assert!(env.params_as::<Greet>().is_none());
}
#[test]
fn result_as__deserializes_on_match() {
let env =
JsonRpcEnvelope::parse(br#"{"jsonrpc":"2.0","id":1,"result":{"name":"rod"}}"#).unwrap();
assert_eq!(env.result_as::<Greet>(), Some(Greet { name: "rod".into() }));
}
}