use crate::uds_req_res::{UdsRequest, UdsResponse};
use async_trait::async_trait;
use futures::StreamExt;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub trait JsonRpcServerTransport:
futures::Stream<
Item = (
JsonRpcRequest,
futures::channel::oneshot::Sender<JsonRpcResponse>,
),
> + Unpin
+ Send
+ Sync
{
}
#[async_trait]
pub trait JsonRpcServerHandler: Send + Sync {
async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcResponseData {
let mut responses = self.handle_batch_request(vec![request]).await;
if responses.len() != 1 {
return JsonRpcResponseData::Error {
error: JsonRpcError {
code: JsonRpcErrorCode::InternalError,
message: format!(
"Internal error: Batch handler returned {} responses instead of 1",
responses.len()
),
data: None,
},
};
}
responses.pop().unwrap()
}
async fn handle_batch_request(&self, requests: Vec<JsonRpcRequest>)
-> Vec<JsonRpcResponseData>;
}
pub struct JsonRpcServer {
task_handle: tokio::task::JoinHandle<()>,
}
impl JsonRpcServer {
pub fn new(
mut transport: Box<dyn JsonRpcServerTransport>,
handler: Box<dyn JsonRpcServerHandler>,
) -> Self {
let task_handle = tokio::spawn(async move {
while let Some((request, response_sender)) = transport.next().await {
let request_id = request.id().clone();
let response =
JsonRpcResponse::new(handler.handle_request(request).await, request_id);
response_sender.send(response).unwrap();
}
});
Self { task_handle }
}
pub fn stop(self) {
drop(self);
}
}
impl std::ops::Drop for JsonRpcServer {
fn drop(&mut self) {
self.task_handle.abort();
}
}
#[async_trait]
pub trait JsonRpcClientTransport<E> {
async fn send_request(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse, E>;
async fn send_batch_request(
&self,
requests: Vec<JsonRpcRequest>,
) -> Result<Vec<JsonRpcResponse>, E>;
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
enum JsonRpcVersion {
#[serde(rename = "2.0")]
V2,
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
pub struct JsonRpcRequest {
jsonrpc: JsonRpcVersion,
method: String,
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<JsonRpcStructuredValue>,
id: JsonRpcId,
}
impl UdsRequest for JsonRpcRequest {}
impl JsonRpcRequest {
pub fn new(method: String, params: Option<JsonRpcStructuredValue>, id: JsonRpcId) -> Self {
Self {
jsonrpc: JsonRpcVersion::V2,
method,
params,
id,
}
}
pub fn method(&self) -> &str {
&self.method
}
pub fn params(&self) -> Option<&JsonRpcStructuredValue> {
self.params.as_ref()
}
pub fn id(&self) -> &JsonRpcId {
&self.id
}
}
#[derive(PartialEq, Debug, Clone)]
pub enum JsonRpcId {
Number(i32),
String(String),
Null,
}
impl JsonRpcId {
fn to_json_value(&self) -> serde_json::Value {
match self {
JsonRpcId::Number(n) => serde_json::Value::Number((*n).into()),
JsonRpcId::String(s) => serde_json::Value::String(s.clone()),
JsonRpcId::Null => serde_json::Value::Null,
}
}
}
impl serde::Serialize for JsonRpcId {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.to_json_value().serialize(serializer)
}
}
impl<'de> Deserialize<'de> for JsonRpcId {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
serde_json::Value::deserialize(deserializer).and_then(|value| {
if value.is_i64() {
Ok(JsonRpcId::Number(value.as_i64().unwrap() as i32))
} else if value.is_string() {
Ok(JsonRpcId::String(value.as_str().unwrap().to_string()))
} else if value.is_null() {
Ok(JsonRpcId::Null)
} else {
Err(serde::de::Error::custom("Invalid JSON-RPC ID"))
}
})
}
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
#[serde(untagged)]
pub enum JsonRpcStructuredValue {
Object(serde_json::Map<String, serde_json::Value>),
Array(Vec<serde_json::Value>),
}
impl JsonRpcStructuredValue {
pub fn into_value(self) -> serde_json::Value {
match self {
JsonRpcStructuredValue::Object(object) => serde_json::Value::Object(object),
JsonRpcStructuredValue::Array(array) => serde_json::Value::Array(array),
}
}
}
#[derive(Serialize, Deserialize, PartialEq, Debug)]
pub struct JsonRpcResponse {
jsonrpc: JsonRpcVersion,
#[serde(flatten)]
data: JsonRpcResponseData,
id: JsonRpcId,
}
impl UdsResponse for JsonRpcResponse {
fn request_parse_error_response() -> Self {
Self::new(
JsonRpcResponseData::Error {
error: JsonRpcError {
code: JsonRpcErrorCode::ParseError,
message: "Failed to parse request".to_string(),
data: None,
},
},
JsonRpcId::Null,
)
}
fn internal_error_response(msg: String) -> Self {
Self::new(
JsonRpcResponseData::Error {
error: JsonRpcError {
code: JsonRpcErrorCode::InternalError,
message: msg,
data: None,
},
},
JsonRpcId::Null,
)
}
}
impl JsonRpcResponse {
pub fn new(data: JsonRpcResponseData, id: JsonRpcId) -> Self {
Self {
jsonrpc: JsonRpcVersion::V2,
data,
id,
}
}
pub fn data(&self) -> &JsonRpcResponseData {
&self.data
}
}
#[derive(Serialize, Deserialize, PartialEq, Debug)]
#[serde(untagged)]
pub enum JsonRpcResponseData {
Success { result: serde_json::Value },
Error { error: JsonRpcError },
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
pub struct JsonRpcError {
code: JsonRpcErrorCode,
message: String,
#[serde(skip_serializing_if = "Option::is_none")]
data: Option<serde_json::Value>,
}
impl JsonRpcError {
pub fn new(code: JsonRpcErrorCode, message: String, data: Option<serde_json::Value>) -> Self {
Self {
code,
message,
data,
}
}
pub fn code(&self) -> JsonRpcErrorCode {
self.code
}
}
#[derive(PartialEq, Debug, Copy, Clone)]
pub enum JsonRpcErrorCode {
ParseError,
InvalidRequest,
MethodNotFound,
InvalidParams,
InternalError,
Custom(i32), }
impl Serialize for JsonRpcErrorCode {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let code = match *self {
JsonRpcErrorCode::ParseError => -32700,
JsonRpcErrorCode::InvalidRequest => -32600,
JsonRpcErrorCode::MethodNotFound => -32601,
JsonRpcErrorCode::InvalidParams => -32602,
JsonRpcErrorCode::InternalError => -32603,
JsonRpcErrorCode::Custom(c) => c,
};
serializer.serialize_i32(code)
}
}
impl<'de> Deserialize<'de> for JsonRpcErrorCode {
fn deserialize<D>(deserializer: D) -> Result<JsonRpcErrorCode, D::Error>
where
D: serde::Deserializer<'de>,
{
let code = i32::deserialize(deserializer)?;
match code {
-32700 => Ok(JsonRpcErrorCode::ParseError),
-32600 => Ok(JsonRpcErrorCode::InvalidRequest),
-32601 => Ok(JsonRpcErrorCode::MethodNotFound),
-32602 => Ok(JsonRpcErrorCode::InvalidParams),
-32603 => Ok(JsonRpcErrorCode::InternalError),
_ => Ok(JsonRpcErrorCode::Custom(code)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_json_serialization<
'a,
T: Serialize + Deserialize<'a> + PartialEq + std::fmt::Debug,
>(
value: T,
json_string: &'a str,
) {
assert_eq!(serde_json::from_str::<T>(json_string).unwrap(), value);
assert_eq!(serde_json::to_string(&value).unwrap(), json_string);
}
#[test]
fn serialize_and_deserialize_json_rpc_request() {
assert_json_serialization(
JsonRpcRequest::new("get_public_key".to_string(), None, JsonRpcId::Null),
"{\"jsonrpc\":\"2.0\",\"method\":\"get_public_key\",\"id\":null}",
);
assert_json_serialization(
JsonRpcRequest::new(
"get_public_key".to_string(),
Some(JsonRpcStructuredValue::Object(serde_json::from_str("{\"key_type\":\"rsa\"}").unwrap())),
JsonRpcId::Null),
"{\"jsonrpc\":\"2.0\",\"method\":\"get_public_key\",\"params\":{\"key_type\":\"rsa\"},\"id\":null}"
);
assert_json_serialization(
JsonRpcRequest::new(
"fetch_values".to_string(),
Some(JsonRpcStructuredValue::Array(vec![
serde_json::from_str("1").unwrap(),
serde_json::from_str("\"2\"").unwrap(),
serde_json::from_str("{\"3\":true}").unwrap(),
])),
JsonRpcId::Null,
),
"{\"jsonrpc\":\"2.0\",\"method\":\"fetch_values\",\"params\":[1,\"2\",{\"3\":true}],\"id\":null}",
);
assert_json_serialization(
JsonRpcRequest::new("get_public_key".to_string(), None, JsonRpcId::Number(1234)),
"{\"jsonrpc\":\"2.0\",\"method\":\"get_public_key\",\"id\":1234}",
);
assert_json_serialization(
JsonRpcRequest::new(
"get_foo_string".to_string(),
None,
JsonRpcId::String("foo".to_string()),
),
"{\"jsonrpc\":\"2.0\",\"method\":\"get_foo_string\",\"id\":\"foo\"}",
);
}
#[test]
fn serialize_and_deserialize_json_rpc_response() {
assert_json_serialization(
JsonRpcResponse::new(
JsonRpcResponseData::Success {
result: serde_json::from_str("\"foo\"").unwrap(),
},
JsonRpcId::Null,
),
"{\"jsonrpc\":\"2.0\",\"result\":\"foo\",\"id\":null}",
);
assert_json_serialization(
JsonRpcResponse::new(
JsonRpcResponseData::Error {
error: JsonRpcError {
code: JsonRpcErrorCode::InternalError,
message: "foo".to_string(),
data: None,
},
},
JsonRpcId::Null,
),
"{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32603,\"message\":\"foo\"},\"id\":null}",
);
assert_json_serialization(
JsonRpcResponse::new(
JsonRpcResponseData::Error {
error: JsonRpcError {
code: JsonRpcErrorCode::InternalError,
message: "foo".to_string(),
data: Some(serde_json::from_str("\"bar\"").unwrap()),
},
},
JsonRpcId::Null,
),
"{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32603,\"message\":\"foo\",\"data\":\"bar\"},\"id\":null}",
);
}
#[test]
fn serialize_and_deserialize_id() {
assert_json_serialization(JsonRpcId::Number(1234), "1234");
assert_json_serialization(JsonRpcId::String("foo".to_string()), "\"foo\"");
assert_json_serialization(JsonRpcId::Null, "null");
}
#[test]
fn serialize_and_deserialize_error_code() {
assert_json_serialization(JsonRpcErrorCode::ParseError, "-32700");
assert_json_serialization(JsonRpcErrorCode::InvalidRequest, "-32600");
assert_json_serialization(JsonRpcErrorCode::MethodNotFound, "-32601");
assert_json_serialization(JsonRpcErrorCode::InvalidParams, "-32602");
assert_json_serialization(JsonRpcErrorCode::InternalError, "-32603");
assert_json_serialization(JsonRpcErrorCode::Custom(1234), "1234");
}
}