use crate::{errors::*, message::Message};
use bytes::{buf::BufMut, BytesMut};
use error_chain::bail;
use tokio_util::codec::{Decoder, Encoder};
use std::{convert::TryInto, mem, result::Result as StdResult, u16, u32, u8};
macro_rules! read_int_frame {
($src:expr, $assign_to:expr, $type:ty) => {
if $assign_to.is_none() {
let len = mem::size_of::<$type>();
if $src.len() < len {
return Ok(None);
}
$assign_to = Some(<$type>::from_be_bytes(
(*$src.split_to(len)).try_into().unwrap(),
));
}
};
}
macro_rules! read_str_frame {
($src:expr, $assign_to:expr, $len:expr) => {
if $assign_to.is_none() {
if $src.len() < $len {
return Ok(None);
}
$assign_to = Some(String::from_utf8_lossy(&$src.split_to($len)).into_owned());
}
};
}
#[derive(Debug, Default)]
pub struct Codec {
discriminant: Option<u8>,
ns_length: Option<u16>,
namespace: Option<String>,
data_length: Option<u32>,
data: Option<String>,
}
impl Encoder<Message> for Codec {
type Error = Error;
fn encode(&mut self, message: Message, dst: &mut BytesMut) -> StdResult<(), Self::Error> {
if message.namespace().len() > u16::MAX as usize {
bail!(ErrorKind::OversizedNamespace);
}
dst.reserve(3 + message.namespace().len());
dst.put_u8(message.poor_mans_discriminant());
dst.put_u16(message.namespace().len() as u16);
dst.extend_from_slice(message.namespace().as_bytes());
if let Message::Event(_, data) = message {
if data.len() > u32::MAX as usize {
bail!(ErrorKind::OversizedData);
}
dst.reserve(4 + data.len());
dst.put_u32(data.len() as u32);
dst.extend_from_slice(data.as_bytes());
}
Ok(())
}
}
impl Decoder for Codec {
type Item = Message;
type Error = Error;
fn decode(&mut self, src: &mut BytesMut) -> StdResult<Option<Self::Item>, Self::Error> {
read_int_frame!(src, self.discriminant, u8);
self.discriminant = self
.discriminant
.filter(Message::test_poor_mans_discriminant);
if self.discriminant.is_none() {
bail!("Unknown Message discriminant");
}
read_int_frame!(src, self.ns_length, u16);
read_str_frame!(
src,
self.namespace,
*self.ns_length.as_ref().unwrap() as usize
);
if *self.discriminant.as_ref().unwrap() == 4 {
read_int_frame!(src, self.data_length, u32);
read_str_frame!(src, self.data, *self.data_length.as_ref().unwrap() as usize);
}
self.ns_length = None;
self.data_length = None;
Ok(Some(Message::from_poor_mans_discriminant(
self.discriminant.take().unwrap(),
self.namespace.take().unwrap().into(),
self.data.take(),
)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
#[test]
fn test_encode_nodata_ok() {
let msg = Message::Provide("/my/namespace".into());
let mut bytes = BytesMut::new();
let mut encoder = Codec::default();
encoder
.encode(msg, &mut bytes)
.expect("Failed to encode message");
assert_eq!(bytes, Bytes::from("\0\0\r/my/namespace"));
}
#[test]
fn test_encode_event_ok() {
let msg = Message::Event("/my/namespace".into(), "abc, easy as 123".into());
let mut bytes = BytesMut::new();
let mut encoder = Codec::default();
encoder
.encode(msg, &mut bytes)
.expect("Failed to encode message");
assert_eq!(
bytes,
Bytes::from("\x04\0\r/my/namespace\0\0\0\x10abc, easy as 123")
);
}
#[test]
fn test_encode_oversized_namespace() {
#[allow(clippy::cast_lossless)]
let long_str = String::from_utf8(vec![0; (u16::MAX as u32 + 1) as usize]).unwrap();
let msg = Message::Unsubscribe(long_str.into());
let mut bytes = BytesMut::new();
let mut encoder = Codec::default();
match encoder
.encode(msg, &mut bytes)
.err()
.expect("Test passed unexpectedly")
.kind()
{
ErrorKind::OversizedNamespace => (),
_ => panic!("Test passed unexpectedly"),
}
}
#[test]
#[ignore]
fn test_encode_oversized_data() {
#[allow(clippy::cast_lossless)]
let long_str = String::from_utf8(vec![0; (u32::MAX as u64 + 1) as usize]).unwrap();
let msg = Message::Event("/".into(), long_str);
let mut bytes = BytesMut::new();
let mut encoder = Codec::default();
match encoder
.encode(msg, &mut bytes)
.err()
.expect("Test passed unexpectedly")
.kind()
{
ErrorKind::OversizedData => (),
_ => panic!("Test passed unexpectedly"),
}
}
#[test]
fn test_decode_ok() {
let mut bytes = BytesMut::from("\x01\0\r/my/namespace");
let mut decoder = Codec::default();
let msg = decoder
.decode(&mut bytes)
.expect("Failed to decode message");
assert_eq!(msg, Some(Message::Revoke("/my/namespace".into())));
}
#[test]
fn test_decode_invalid_discriminant() {
let mut bytes = BytesMut::from("\x09");
let mut decoder = Codec::default();
match decoder.decode(&mut bytes) {
Ok(_) => panic!("Failed to detect invalid Message discriminant"),
Err(e) => assert_eq!(e.description(), "Unknown Message discriminant"),
}
}
#[test]
fn test_decode_partial() {
let mut bytes = BytesMut::new();
let mut decoder = Codec::default();
let response = decoder
.decode(&mut bytes)
.expect("Failed to decode message");
assert!(response.is_none());
bytes.put_u8(Message::Event("/".into(), String::new()).poor_mans_discriminant());
let response = decoder
.decode(&mut bytes)
.expect("Failed to decode message");
assert!(response.is_none());
bytes.put_u16(13);
bytes.extend_from_slice(b"/my/name");
let response = decoder
.decode(&mut bytes)
.expect("Failed to decode message");
assert!(response.is_none());
bytes.extend_from_slice(b"space");
let response = decoder
.decode(&mut bytes)
.expect("Failed to decode message");
assert!(response.is_none());
bytes.put_u32(5);
bytes.extend_from_slice(b"a");
let response = decoder
.decode(&mut bytes)
.expect("Failed to decode message");
assert!(response.is_none());
bytes.extend_from_slice(b"bcde");
let msg = decoder
.decode(&mut bytes)
.expect("Failed to decode message");
assert_eq!(
msg,
Some(Message::Event("/my/namespace".into(), "abcde".into()))
);
}
#[test]
fn test_decode_multiple() {
let mut decoder = Codec::default();
let mut bytes = BytesMut::from("\x01\0\r/my/namespace");
let msg = decoder
.decode(&mut bytes)
.expect("Failed to decode message");
assert_eq!(msg, Some(Message::Revoke("/my/namespace".into())));
bytes.put_u8(4);
bytes.put_u16(4);
bytes.extend_from_slice(b"/moo");
bytes.put_u32(3);
bytes.extend_from_slice(b"cow");
let msg = decoder
.decode(&mut bytes)
.expect("Failed to decode message");
assert_eq!(msg, Some(Message::Event("/moo".into(), "cow".into())));
}
}