use std::str::FromStr;
use bytes::{Buf, BufMut, BytesMut};
use tokio_util::codec::{Decoder, Encoder};
use crate::protocol::{WireEvent, WireRequest, WireResponse};
pub const KIND_REQUEST: u8 = 1;
pub const KIND_RESPONSE: u8 = 2;
pub const KIND_EVENT: u8 = 3;
pub const MAX_FRAME_PAYLOAD: usize = 64 * 1024 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Encoding {
Msgpack,
Json,
}
impl Encoding {
pub fn wire_name(self) -> &'static str {
match self {
Encoding::Msgpack => "msgpack",
Encoding::Json => "json",
}
}
}
impl FromStr for Encoding {
type Err = std::convert::Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(if s.eq_ignore_ascii_case("json") {
Encoding::Json
} else {
Encoding::Msgpack
})
}
}
#[derive(Debug)]
pub enum Frame {
Request(WireRequest),
Response(WireResponse),
Event(WireEvent),
}
pub struct FrameCodec {
encoding: Encoding,
}
impl FrameCodec {
pub fn new(encoding: Encoding) -> Self { Self { encoding } }
pub fn msgpack() -> Self { Self::new(Encoding::Msgpack) }
pub fn json() -> Self { Self::new(Encoding::Json) }
}
impl Decoder for FrameCodec {
type Item = Frame;
type Error = anyhow::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.len() < 5 {
return Ok(None);
}
let length = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
let kind = src[4];
if length > MAX_FRAME_PAYLOAD {
return Err(anyhow::anyhow!("frame payload too large: {} bytes (max {})", length, MAX_FRAME_PAYLOAD));
}
if src.len() < 5 + length {
src.reserve(5 + length - src.len());
return Ok(None);
}
src.advance(5);
let payload = src.split_to(length);
let frame = match kind {
KIND_REQUEST => Frame::Request(self.unmarshal(&payload)?),
KIND_RESPONSE => Frame::Response(self.unmarshal(&payload)?),
KIND_EVENT => Frame::Event(self.unmarshal(&payload)?),
k => return Err(anyhow::anyhow!("unknown frame kind: {}", k)),
};
Ok(Some(frame))
}
}
impl Encoder<Frame> for FrameCodec {
type Error = anyhow::Error;
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
let (kind, payload) = match &frame {
Frame::Request(r) => (KIND_REQUEST, self.marshal(r)?),
Frame::Response(r) => (KIND_RESPONSE, self.marshal(r)?),
Frame::Event(e) => (KIND_EVENT, self.marshal(e)?),
};
dst.reserve(5 + payload.len());
dst.put_u32(payload.len() as u32); dst.put_u8(kind); dst.put_slice(&payload);
Ok(())
}
}
impl FrameCodec {
fn marshal<T: serde::Serialize>(&self, v: &T) -> anyhow::Result<Vec<u8>> {
match self.encoding {
Encoding::Msgpack => rmp_serde::to_vec_named(v).map_err(Into::into),
Encoding::Json => serde_json::to_vec(v).map_err(Into::into),
}
}
fn unmarshal<T: serde::de::DeserializeOwned>(&self, data: &[u8]) -> anyhow::Result<T> {
match self.encoding {
Encoding::Msgpack => rmp_serde::from_slice(data).map_err(Into::into),
Encoding::Json => serde_json::from_slice(data).map_err(Into::into),
}
}
}