use std::{
fmt,
{cmp, collections::HashMap, mem}
};
use bytes::{BufMut, Bytes, BytesMut};
use {
tokio::{io, sync::mpsc::UnboundedSender},
tokio_util::codec::{Decoder, Encoder}
};
use crate::{
err::Error,
{KVLines, Params, Telegram}
};
#[derive(Clone, Debug, PartialEq, Eq)]
enum CodecState {
Telegram,
Params,
KVLines,
Chunks,
Bytes,
BytesCh,
Skip
}
pub enum Input {
Telegram(Telegram),
KVLines(KVLines),
Params(Params),
Chunk(Bytes, u64),
Bytes(Bytes),
BytesChDone,
SkipDone
}
pub struct Codec {
next_line_index: usize,
max_line_length: usize,
tg: Telegram,
params: Params,
kvlines: KVLines,
state: CodecState,
remain: u64,
bytes_tx: Option<UnboundedSender<Bytes>>
}
#[allow(clippy::missing_fields_in_debug)]
impl fmt::Debug for Codec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Codec").field("state", &self.state).finish()
}
}
impl Default for Codec {
fn default() -> Self {
Self::new()
}
}
impl Codec {
#[must_use]
pub fn new() -> Self {
Self {
next_line_index: 0,
max_line_length: usize::MAX,
tg: Telegram::new_uninit(),
params: Params::new(),
kvlines: KVLines::new(),
state: CodecState::Telegram,
remain: 0,
bytes_tx: None
}
}
#[must_use]
pub fn new_with_max_length(max_line_length: usize) -> Self {
Self {
max_line_length,
..Self::new()
}
}
#[must_use]
pub const fn max_line_length(&self) -> usize {
self.max_line_length
}
fn find_newline(&self, buf: &BytesMut) -> (usize, Option<usize>) {
let read_to = cmp::min(self.max_line_length.saturating_add(1), buf.len());
let newline_offset = buf[self.next_line_index..read_to]
.iter()
.position(|b| *b == b'\n');
(read_to, newline_offset)
}
fn decode_telegram_line(&mut self, line: &str) -> Result<(), Error> {
if self.tg.get_topic().is_empty() {
self
.tg
.set_topic(line)
.map_err(|e| Error::Protocol(e.to_string()))?;
} else {
let idx = line.find(' ');
if let Some(idx) = idx {
let (k, v) = line.split_at(idx);
let v = &v[1..v.len()];
self.tg.add_param(k, v)?;
}
}
Ok(())
}
fn get_eol_idx(&mut self, buf: &BytesMut) -> Result<Option<usize>, Error> {
let (read_to, newline_offset) = self.find_newline(buf);
match newline_offset {
Some(offset) => {
let newline_index = offset + self.next_line_index;
self.next_line_index = 0;
Ok(Some(newline_index + 1))
}
None if buf.len() > self.max_line_length => {
Err(Error::Protocol("Exceeded maximum line length.".to_string()))
}
None => {
self.next_line_index = read_to;
Ok(None)
}
}
}
fn decode_telegram_lines(
&mut self,
buf: &mut BytesMut
) -> Result<Option<Telegram>, Error> {
loop {
if let Some(idx) = self.get_eol_idx(buf)? {
let line = buf.split_to(idx);
let line = &line[..line.len() - 1];
let line = utf8(without_carriage_return(line))?;
if line.is_empty() {
let newtg = Telegram::new_uninit();
return Ok(Some(mem::replace(&mut self.tg, newtg)));
}
self.decode_telegram_line(line)?;
} else {
return Ok(None);
}
}
}
fn decode_params_lines(
&mut self,
buf: &mut BytesMut
) -> Result<Option<Params>, Error> {
loop {
if let Some(idx) = self.get_eol_idx(buf)? {
let line = buf.split_to(idx);
let line = &line[..line.len() - 1];
let line = utf8(without_carriage_return(line))?;
if line.is_empty() {
self.state = CodecState::Telegram;
return Ok(Some(mem::take(&mut self.params)));
}
let idx = line.find(' ');
if let Some(idx) = idx {
let (k, v) = line.split_at(idx);
let v = &v[1..v.len()];
self.params.add_param(k, v)?;
}
} else {
return Ok(None);
}
}
}
fn decode_kvlines(
&mut self,
buf: &mut BytesMut
) -> Result<Option<KVLines>, Error> {
loop {
if let Some(idx) = self.get_eol_idx(buf)? {
let line = buf.split_to(idx);
let line = &line[..line.len() - 1];
let line = utf8(without_carriage_return(line))?;
if line.is_empty() {
self.state = CodecState::Telegram;
return Ok(Some(mem::take(&mut self.kvlines)));
}
let idx = line.find(' ');
if let Some(idx) = idx {
let (k, v) = line.split_at(idx);
let v = &v[1..v.len()];
self.kvlines.append(k, v)?;
}
} else {
return Ok(None);
}
}
}
pub fn expect_chunks(&mut self, size: u64) -> Result<(), Error> {
if size == 0 {
return Err(Error::InvalidSize("zero size".to_string()));
}
self.state = CodecState::Chunks;
self.remain = size;
Ok(())
}
#[allow(clippy::missing_panics_doc)]
pub fn expect_bytes(&mut self, size: usize) -> Result<(), Error> {
if size == 0 {
return Err(Error::InvalidSize("zero size".to_string()));
}
self.state = CodecState::Bytes;
self.remain = size.try_into().unwrap();
Ok(())
}
pub fn expect_bytes_channel(
&mut self,
size: u64,
tx: UnboundedSender<Bytes>
) -> Result<(), Error> {
if size == 0 {
return Err(Error::InvalidSize("must not be zero".to_string()));
}
self.state = CodecState::BytesCh;
self.bytes_tx = Some(tx);
self.remain = size;
Ok(())
}
pub const fn expect_params(&mut self) {
self.state = CodecState::Params;
}
pub const fn expect_kvlines(&mut self) {
self.state = CodecState::KVLines;
}
pub fn skip(&mut self, size: u64) -> Result<(), Error> {
if size == 0 {
return Err(Error::InvalidSize("zero size".to_string()));
}
self.state = CodecState::Skip;
self.remain = size;
Ok(())
}
}
fn utf8(buf: &[u8]) -> Result<&str, io::Error> {
std::str::from_utf8(buf).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"Unable to decode input as UTF8"
)
})
}
fn without_carriage_return(s: &[u8]) -> &[u8] {
if s.last() == Some(&b'\r') {
&s[..s.len() - 1]
} else {
s
}
}
impl Decoder for Codec {
type Item = Input;
type Error = crate::err::Error;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Input>, Error> {
match self.state {
CodecState::Telegram => {
let tg = self.decode_telegram_lines(buf)?;
if let Some(tg) = tg {
return Ok(Some(Input::Telegram(tg)));
}
Ok(None)
}
CodecState::Params => {
let params = self.decode_params_lines(buf)?;
if let Some(params) = params {
return Ok(Some(Input::Params(params)));
}
Ok(None)
}
CodecState::KVLines => {
let kvlines = self.decode_kvlines(buf)?;
if let Some(kvlines) = kvlines {
return Ok(Some(Input::KVLines(kvlines)));
}
Ok(None)
}
CodecState::Chunks => {
if buf.is_empty() {
return Ok(None);
}
let read_to = cmp::min(self.remain, buf.len() as u64);
self.remain -= read_to;
if self.remain == 0 {
self.state = CodecState::Telegram;
}
let len = usize::try_from(read_to).unwrap();
Ok(Some(Input::Chunk(buf.split_to(len).freeze(), self.remain)))
}
CodecState::Bytes => {
let remain: usize = self.remain.try_into().unwrap();
if buf.len() < remain {
Ok(None)
} else {
self.state = CodecState::Telegram;
Ok(Some(Input::Bytes(buf.split_to(remain).freeze())))
}
}
CodecState::BytesCh => {
let read_to = cmp::min(self.remain, buf.len() as u64);
self.remain -= read_to;
let len = usize::try_from(read_to).unwrap();
let buf = buf.split_to(len).freeze();
if let Some(ref tx) = self.bytes_tx {
let _ = tx.send(buf);
}
if self.remain == 0 {
self.state = CodecState::Telegram;
if let Some(tx) = self.bytes_tx.take() {
let _ = tx.send(Bytes::new());
}
Ok(Some(Input::BytesChDone))
} else {
Ok(None)
}
}
CodecState::Skip => {
if buf.is_empty() {
return Ok(None); }
let read_to = cmp::min(self.remain, buf.len() as u64);
let len = usize::try_from(read_to).unwrap();
let _ = buf.split_to(len);
self.remain -= read_to;
if self.remain != 0 {
return Ok(None); }
self.state = CodecState::Telegram;
Ok(Some(Input::SkipDone))
} } }
}
impl Encoder<&Telegram> for Codec {
type Error = crate::err::Error;
fn encode(
&mut self,
tg: &Telegram,
buf: &mut BytesMut
) -> Result<(), Error> {
tg.encoder_write(buf)?;
Ok(())
}
}
impl Encoder<&Params> for Codec {
type Error = crate::err::Error;
fn encode(
&mut self,
params: &Params,
buf: &mut BytesMut
) -> Result<(), Error> {
params.encoder_write(buf);
Ok(())
}
}
impl Encoder<&HashMap<String, String>> for Codec {
type Error = crate::err::Error;
fn encode(
&mut self,
data: &HashMap<String, String>,
buf: &mut BytesMut
) -> Result<(), Error> {
let mut sz = 0;
for (k, v) in data {
sz += k.len() + 1 + v.len() + 1;
}
sz += 1;
buf.reserve(sz);
for (k, v) in data {
buf.put(k.as_bytes());
buf.put_u8(b' ');
buf.put(v.as_bytes());
buf.put_u8(b'\n');
}
buf.put_u8(b'\n');
Ok(())
}
}
impl Encoder<&KVLines> for Codec {
type Error = crate::err::Error;
fn encode(
&mut self,
kvlines: &KVLines,
buf: &mut BytesMut
) -> Result<(), Error> {
kvlines.encoder_write(buf);
Ok(())
}
}
impl Encoder<Bytes> for Codec {
type Error = crate::err::Error;
fn encode(
&mut self,
data: Bytes,
buf: &mut BytesMut
) -> Result<(), crate::err::Error> {
buf.reserve(data.len());
buf.put(data);
Ok(())
}
}
impl Encoder<&[u8]> for Codec {
type Error = crate::err::Error;
fn encode(
&mut self,
data: &[u8],
buf: &mut BytesMut
) -> Result<(), crate::err::Error> {
buf.reserve(data.len());
buf.put(data);
Ok(())
}
}
#[cfg(test)]
mod tests {
use {
futures::sink::SinkExt, tokio::sync::mpsc::unbounded_channel,
tokio_stream::StreamExt, tokio_test::io::Builder,
tokio_util::codec::Framed
};
use bytes::BytesMut;
use super::{Bytes, Codec, Input, Telegram};
#[tokio::test]
async fn valid_no_params() {
let mut mock = Builder::new();
mock.read(b"hello\n\n");
let mut frm = Framed::new(mock.build(), Codec::new());
while let Some(o) = frm.next().await {
let o = o.unwrap();
if let Input::Telegram(tg) = o {
assert_eq!(tg.get_topic(), "hello");
let params = tg.into_params();
let map = params.into_inner();
assert_eq!(map.len(), 0);
} else {
panic!("Not a Telegram");
}
}
}
#[tokio::test]
async fn valid_with_params() {
let mut mock = Builder::new();
mock.read(b"hello\nmurky_waters off\nwrong_impression cows\n\n");
let mut frm = Framed::new(mock.build(), Codec::new());
while let Some(o) = frm.next().await {
let o = o.unwrap();
match o {
Input::Telegram(tg) => {
assert_eq!(tg.get_topic(), "hello");
let params = tg.into_params();
let map = params.into_inner();
assert_eq!(map.len(), 2);
assert_eq!(map.get("murky_waters").unwrap(), "off");
assert_eq!(map.get("wrong_impression").unwrap(), "cows");
}
_ => {
panic!("Not a Telegram");
}
}
}
}
#[tokio::test]
#[should_panic(
expected = "Protocol(\"Bad format; Invalid topic character\")"
)]
async fn bad_topic() {
let mut mock = Builder::new();
mock.read(b"hel lo\n\n");
let mut frm = Framed::new(mock.build(), Codec::new());
let e = frm.next().await.unwrap();
e.unwrap();
}
#[tokio::test]
async fn multiple() {
let mut mock = Builder::new();
mock.read(b"hello\nfoo bar\n\nworld\nholy cows\n\nfinal\nthe thing\n\n");
let mut frm = Framed::new(mock.build(), Codec::new());
let o = frm.next().await.unwrap().unwrap();
let Input::Telegram(tg) = o else {
panic!("Unexpectely not Input::Telegram");
};
assert_eq!(tg.get_topic(), "hello");
assert_eq!(tg.get_str("foo"), Some("bar"));
let o = frm.next().await.unwrap().unwrap();
let Input::Telegram(tg) = o else {
panic!("Unexpectely not Input::Telegram");
};
assert_eq!(tg.get_topic(), "world");
assert_eq!(tg.get_str("holy"), Some("cows"));
let o = frm.next().await.unwrap().unwrap();
let Input::Telegram(tg) = o else {
panic!("Unexpectely not Input::Telegram");
};
assert_eq!(tg.get_topic(), "final");
assert_eq!(tg.get_str("the"), Some("thing"));
}
#[tokio::test]
async fn tg_followed_by_buf() {
let mut mock = Builder::new();
mock.read(b"hello\nlen 4\n\n1234");
let mut frm = Framed::new(mock.build(), Codec::new());
let Some(o) = frm.next().await else {
panic!("No frame");
};
let o = o.unwrap();
if let Input::Telegram(tg) = o {
assert_eq!(tg.get_topic(), "hello");
assert_eq!(tg.get_fromstr::<usize, _>("len").unwrap().unwrap(), 4);
frm.codec_mut().expect_bytes(4).unwrap();
} else {
panic!("Not a Telegram");
}
while let Some(o) = frm.next().await {
let o = o.unwrap();
if let Input::Bytes(_bm) = o {
} else {
panic!("Not a Buf");
}
}
}
#[tokio::test]
async fn tg_buf_tg() {
let mut mock = Builder::new();
mock.read(b"hello\nlen 4\n\n1234world\nfoo bar\n\n");
let mut frm = Framed::new(mock.build(), Codec::new());
let o = frm.next().await.unwrap().unwrap();
let Input::Telegram(tg) = o else {
panic!("Unexpectedly not Input::Telegram(_)");
};
assert_eq!(tg.get_topic(), "hello");
let len = tg.get_fromstr::<usize, _>("len").unwrap().unwrap();
assert_eq!(len, 4);
frm.codec_mut().expect_bytes(len).unwrap();
let o = frm.next().await.unwrap().unwrap();
let Input::Bytes(buf) = o else {
panic!("Unexpectedly not Input::Bytes(_)");
};
assert_eq!(buf, "1234");
let o = frm.next().await.unwrap().unwrap();
let Input::Telegram(tg) = o else {
panic!("Unexpectedly not Input::Telegram(_)");
};
assert_eq!(tg.get_topic(), "world");
assert_eq!(tg.get_str("foo"), Some("bar"));
}
#[tokio::test]
async fn expect_bytes_ch() {
let (client, server) = tokio::io::duplex(64);
let mut frmin = Framed::new(server, Codec::new());
let mut frmout = Framed::new(client, Codec::new());
let jh = tokio::task::spawn(async move {
let o = frmin.next().await.unwrap().unwrap();
let Input::Telegram(tg) = o else {
panic!("Unexpectedly not Input::Telegram(_)");
};
assert_eq!(tg.as_ref(), "ReqToSend");
let len = tg.get_fromstr::<u64, _>("Len").unwrap().unwrap();
let (tx, mut rx) = unbounded_channel();
frmin.codec_mut().expect_bytes_channel(len, tx).unwrap();
let jh = tokio::task::spawn(async move {
let mut inbuf = BytesMut::new();
inbuf.reserve(16);
loop {
let buf = rx.recv().await.unwrap();
if buf.is_empty() {
break;
}
inbuf.extend_from_slice(&buf);
}
let buf = inbuf.freeze();
assert_eq!(buf, "0123456789abcdef");
});
let o = frmin.next().await.unwrap().unwrap();
let Input::BytesChDone = o else {
panic!("Unexpectedly not Input::BytesChDone");
};
jh.await.unwrap();
});
let len = 16;
let mut tg = Telegram::new("ReqToSend");
tg.add_param("Len", len).unwrap();
frmout.send(&tg).await.unwrap();
frmout.send(Bytes::from("0123")).await.unwrap();
frmout.send(Bytes::from("4567")).await.unwrap();
frmout.send(Bytes::from("89ab")).await.unwrap();
frmout.send(Bytes::from("cdef")).await.unwrap();
jh.await.unwrap();
}
}