use crate::Error;
use anyhow::anyhow;
use bytes::{BufMut, BytesMut};
use serde_json::value::Value;
use std::str::FromStr;
use std::{io, str};
use tokio_util::codec::{Decoder, Encoder};
use crate::messages::JsonRpc;
use crate::messages::{Notification, Request};
#[derive(Default)]
pub struct MultiLineCodec {
search_pos: usize,
}
fn utf8(buf: &[u8]) -> Result<&str, io::Error> {
str::from_utf8(buf)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Unable to decode input as UTF8"))
}
impl Decoder for MultiLineCodec {
type Item = String;
type Error = Error;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Error> {
let bytes = &buf[..];
let mut i = self.search_pos;
while i + 1 < bytes.len() {
if bytes[i] == b'\n' && bytes[i + 1] == b'\n' {
let line = buf.split_to(i + 2);
let line = &line[..line.len() - 2];
self.search_pos = 0;
return Ok(Some(utf8(line)?.to_owned()));
}
i += 1;
}
self.search_pos = bytes.len().saturating_sub(1);
Ok(None)
}
}
impl<T> Encoder<T> for MultiLineCodec
where
T: AsRef<str>,
{
type Error = Error;
fn encode(&mut self, line: T, buf: &mut BytesMut) -> Result<(), Self::Error> {
let line = line.as_ref();
buf.reserve(line.len() + 2);
buf.put(line.as_bytes());
buf.put_u8(b'\n');
buf.put_u8(b'\n');
Ok(())
}
}
#[derive(Default)]
pub struct JsonCodec {
inner: MultiLineCodec,
}
impl<T> Encoder<T> for JsonCodec
where
T: Into<Value>,
{
type Error = Error;
fn encode(&mut self, msg: T, buf: &mut BytesMut) -> Result<(), Self::Error> {
let s = msg.into().to_string();
self.inner.encode(s, buf)
}
}
impl Decoder for JsonCodec {
type Item = Value;
type Error = Error;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Error> {
match self.inner.decode(buf) {
Ok(None) => Ok(None),
Err(e) => Err(e),
Ok(Some(s)) => {
if let Ok(v) = Value::from_str(&s) {
Ok(Some(v))
} else {
Err(anyhow!("failed to parse JSON"))
}
}
}
}
}
#[derive(Default)]
pub(crate) struct JsonRpcCodec {
inner: JsonCodec,
}
impl Decoder for JsonRpcCodec {
type Item = JsonRpc<Notification, Request>;
type Error = Error;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Error> {
match self.inner.decode(buf) {
Ok(None) => Ok(None),
Err(e) => Err(e),
Ok(Some(s)) => {
let req: Self::Item = serde_json::from_value(s)?;
Ok(Some(req))
}
}
}
}
#[cfg(test)]
mod test {
use super::{JsonCodec, MultiLineCodec};
use bytes::{BufMut, BytesMut};
use serde_json::json;
use tokio_util::codec::{Decoder, Encoder};
#[test]
fn test_ml_decoder() {
struct Test(String, Option<String>, String);
let tests = vec![
Test("".to_string(), None, "".to_string()),
Test(
"{\"hello\":\"world\"}\n\nremainder".to_string(),
Some("{\"hello\":\"world\"}".to_string()),
"remainder".to_string(),
),
Test(
"{\"hello\":\"world\"}\n\n{}\n\nremainder".to_string(),
Some("{\"hello\":\"world\"}".to_string()),
"{}\n\nremainder".to_string(),
),
];
for t in tests.iter() {
let mut buf = BytesMut::new();
buf.put_slice(t.0.as_bytes());
let mut codec = MultiLineCodec::default();
let mut remainder = BytesMut::new();
remainder.put_slice(t.2.as_bytes());
assert_eq!(codec.decode(&mut buf).unwrap(), t.1);
assert_eq!(buf, remainder);
}
}
#[test]
fn test_ml_encoder() {
let tests = vec!["test"];
for t in tests.iter() {
let mut buf = BytesMut::new();
let mut codec = MultiLineCodec::default();
let mut expected = BytesMut::new();
expected.put_slice(t.as_bytes());
expected.put_u8(b'\n');
expected.put_u8(b'\n');
codec.encode(t, &mut buf).unwrap();
assert_eq!(buf, expected);
}
}
#[test]
fn test_json_codec() {
let tests = vec![json!({"hello": "world"})];
for t in tests.iter() {
let mut codec = JsonCodec::default();
let mut buf = BytesMut::new();
codec.encode(t.clone(), &mut buf).unwrap();
let decoded = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(&decoded, t);
}
}
}