use bytes::{BufMut, BytesMut};
use serde::{Deserialize, Serialize};
use std::io;
use tokio_util::codec::{Decoder, Encoder};
const MAX_LINE_SIZE: usize = 32 * 1024 * 1024;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub(crate) enum RequestId {
Number(i64),
String(String),
}
#[allow(clippy::disallowed_types)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct JsonRpcRequest {
pub jsonrpc: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<RequestId>,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<serde_json::Value>,
}
#[allow(clippy::disallowed_types)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct JsonRpcResponse {
pub jsonrpc: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<RequestId>,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>,
}
#[allow(clippy::disallowed_types)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct JsonRpcError {
pub code: i64,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub(crate) enum JsonRpcMessage {
Request(JsonRpcRequest),
Response(JsonRpcResponse),
}
impl JsonRpcResponse {
#[allow(clippy::disallowed_types)]
pub fn success(id: Option<RequestId>, result: serde_json::Value) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id,
result: Some(result),
error: None,
}
}
pub fn error(id: Option<RequestId>, error: JsonRpcError) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id,
result: None,
error: Some(error),
}
}
}
pub(crate) struct NdJsonCodec;
impl NdJsonCodec {
pub fn new() -> Self {
Self
}
}
impl Decoder for NdJsonCodec {
type Item = JsonRpcMessage;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
loop {
let newline_pos = src.iter().position(|b| *b == b'\n');
match newline_pos {
Some(pos) => {
let line = src.split_to(pos + 1);
let end = if pos > 0 && line[pos - 1] == b'\r' {
pos - 1
} else {
pos
};
let trimmed = &line[..end];
if trimmed.is_empty() {
continue;
}
if trimmed.len() > MAX_LINE_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"line of {} bytes exceeds {} byte limit",
trimmed.len(),
MAX_LINE_SIZE
),
));
}
match serde_json::from_slice::<JsonRpcMessage>(trimmed) {
Ok(msg) => return Ok(Some(msg)),
Err(_) => continue,
}
}
None => {
if src.len() > MAX_LINE_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"accumulated {} bytes without newline, exceeds {} byte limit",
src.len(),
MAX_LINE_SIZE
),
));
}
return Ok(None);
}
}
}
}
}
impl Encoder<JsonRpcResponse> for NdJsonCodec {
type Error = io::Error;
fn encode(&mut self, item: JsonRpcResponse, dst: &mut BytesMut) -> Result<(), Self::Error> {
serde_json::to_writer((&mut *dst).writer(), &item)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
dst.extend_from_slice(b"\n");
Ok(())
}
}