use crate::util::{CharEncoding, latin1ify};
use bytes::{BufMut, BytesMut};
use std::{cmp, io};
use tokio_util::codec::{Decoder, Encoder};
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub(crate) struct ConfabCodec {
next_index: usize,
max_length: usize,
encoding: CharEncoding,
crlf: bool,
}
impl ConfabCodec {
pub(crate) fn new() -> ConfabCodec {
ConfabCodec {
next_index: 0,
max_length: usize::MAX,
encoding: CharEncoding::Utf8,
crlf: false,
}
}
pub(crate) fn new_with_max_length(max_length: usize) -> Self {
ConfabCodec {
max_length,
..ConfabCodec::new()
}
}
pub(crate) fn encoding(self, encoding: CharEncoding) -> ConfabCodec {
ConfabCodec { encoding, ..self }
}
pub(crate) fn crlf(self, crlf: bool) -> ConfabCodec {
ConfabCodec { crlf, ..self }
}
pub(crate) fn prepare_line(&self, mut line: String) -> String {
if self.encoding == CharEncoding::Latin1 {
line = latin1ify(line);
}
if self.crlf {
line.push_str("\r\n");
} else {
line.push('\n');
}
line
}
}
impl Decoder for ConfabCodec {
type Item = String;
type Error = io::Error;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<String>, io::Error> {
let read_to = cmp::min(self.max_length, buf.len());
let newline_offset = buf[self.next_index..read_to]
.iter()
.position(|b| *b == b'\n');
match newline_offset {
Some(offset) => {
let newline_index = offset + self.next_index;
self.next_index = 0;
let line = buf.split_to(newline_index + 1);
let line = self.encoding.decode(line.into());
Ok(Some(line))
}
None if buf.len() >= self.max_length => {
self.next_index = 0;
let i = if self.encoding.is_utf8() {
find_final_char_boundary(&buf[..self.max_length])
} else {
self.max_length
};
let line = buf.split_to(i);
let line = self.encoding.decode(line.into());
Ok(Some(line))
}
None => {
self.next_index = read_to;
Ok(None)
}
}
}
fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<String>, io::Error> {
Ok(match self.decode(buf)? {
Some(frame) => Some(frame),
None => {
if buf.is_empty() {
None
} else {
let line = buf.split_to(buf.len());
let line = self.encoding.decode(line.into());
self.next_index = 0;
Some(line)
}
}
})
}
}
impl<T> Encoder<T> for ConfabCodec
where
T: AsRef<str>,
{
type Error = io::Error;
fn encode(&mut self, line: T, buf: &mut BytesMut) -> Result<(), io::Error> {
let line = self.encoding.encode(line.as_ref());
buf.reserve(line.len());
buf.put(&*line);
Ok(())
}
}
impl Default for ConfabCodec {
fn default() -> Self {
Self::new()
}
}
fn find_final_char_boundary(buf: &[u8]) -> usize {
for (i, b) in buf.iter().enumerate().rev() {
let seen = buf.len() - i - 1;
if (0x80..0xC0).contains(b) && seen < 3 {
continue;
} else if (0xC0..0xE0).contains(b) && seen < 1
|| (0xE0..0xF0).contains(b) && seen < 2
|| (0xF0..0xF8).contains(b) && seen < 3
{
return i;
} else {
return buf.len();
}
}
buf.len()
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[rstest]
#[case(b"", 0)]
#[case(b"foo", 3)]
#[case(b"foo\xE2\x98\x83", 6)]
#[case(b"foo\xE2\x98", 3)]
#[case(b"foo\xE2", 3)]
#[case(b"foo\x98\x83", 5)]
#[case(b"\x98\x83", 2)]
#[case(b"\x80\x98\x83", 3)]
#[case(b"\x80\x80\x98\x83", 4)]
#[case(b"foo\xC0\x80", 5)]
#[case(b"foo\xC0\x80\x80", 6)]
#[case(b"foo\xC0", 3)]
#[case(b"foo\xF0\x80\x80", 3)]
#[case(b"foo\x80\x80\x80", 6)]
#[case(b"foo\x80\x80\x80\x80", 7)]
#[case(b"foo\xFF", 4)]
#[case(b"foo\xFC", 4)]
#[case(b"foo\xFC\x80\x80\x80", 7)]
#[case(b"foo\xFC\x80\x80\x80\x80\x80", 9)]
fn test_find_final_char_boundary(#[case] buf: &[u8], #[case] i: usize) {
assert_eq!(find_final_char_boundary(buf), i);
}
#[test]
fn test_decode_end_before_limit() {
let mut codec = ConfabCodec::new_with_max_length(32);
let mut buf = BytesMut::from("This is test text.\nAnd so is this.\n");
assert_eq!(
codec.decode(&mut buf).unwrap().unwrap(),
"This is test text.\n"
);
assert_eq!(buf, "And so is this.\n");
}
#[test]
fn test_decode_end_at_limit() {
let mut codec = ConfabCodec::new_with_max_length(32);
let mut buf = BytesMut::from("123456789.abcdefghi.123456789.a\nbcdef");
assert_eq!(
codec.decode(&mut buf).unwrap().unwrap(),
"123456789.abcdefghi.123456789.a\n"
);
assert_eq!(buf, "bcdef");
}
#[test]
fn test_decode_end_right_after_limit() {
let mut codec = ConfabCodec::new_with_max_length(32);
let mut buf = BytesMut::from("123456789.abcdefghi.123456789.ab\ncdef");
assert_eq!(
codec.decode(&mut buf).unwrap().unwrap(),
"123456789.abcdefghi.123456789.ab"
);
assert_eq!(buf, "\ncdef");
}
#[test]
fn test_decode_end_after_limit() {
let mut codec = ConfabCodec::new_with_max_length(32);
let mut buf = BytesMut::from("123456789.abcdefghi.123456789.abcdef\n");
assert_eq!(
codec.decode(&mut buf).unwrap().unwrap(),
"123456789.abcdefghi.123456789.ab"
);
assert_eq!(buf, "cdef\n");
}
#[test]
fn test_decode_max_length_no_end() {
let mut codec = ConfabCodec::new_with_max_length(32);
let mut buf = BytesMut::from("123456789.abcdefghi.123456789.ab");
assert_eq!(
codec.decode(&mut buf).unwrap().unwrap(),
"123456789.abcdefghi.123456789.ab"
);
assert_eq!(buf, "");
}
#[test]
fn test_decode_max_length_plus_1_no_end() {
let mut codec = ConfabCodec::new_with_max_length(32);
let mut buf = BytesMut::from("123456789.abcdefghi.123456789.abc");
assert_eq!(
codec.decode(&mut buf).unwrap().unwrap(),
"123456789.abcdefghi.123456789.ab"
);
assert_eq!(buf, "c");
}
#[test]
fn test_decode_max_length_minus_1_no_end() {
let mut codec = ConfabCodec::new_with_max_length(32);
let mut buf = BytesMut::from("123456789.abcdefghi.123456789.a");
assert_eq!(codec.decode(&mut buf).unwrap(), None);
assert_eq!(buf, "123456789.abcdefghi.123456789.a");
assert_eq!(codec.next_index, 31);
}
#[test]
fn test_decode_over_max_length_straddling_utf8() {
let mut codec = ConfabCodec::new_with_max_length(32);
let mut buf = BytesMut::from(&b"123456789.abcdefghi.123456789.\xE2\x98\x83"[..]);
assert_eq!(
codec.decode(&mut buf).unwrap().unwrap(),
"123456789.abcdefghi.123456789."
);
assert_eq!(buf, &b"\xE2\x98\x83"[..]);
}
#[test]
fn test_decode_over_max_length_straddling_utf8_in_latin1() {
let mut codec = ConfabCodec::new_with_max_length(32).encoding(CharEncoding::Latin1);
let mut buf = BytesMut::from(&b"123456789.abcdefghi.123456789.\xE2\x98\x83"[..]);
assert_eq!(
codec.decode(&mut buf).unwrap().unwrap(),
"123456789.abcdefghi.123456789.\u{e2}\u{98}"
);
assert_eq!(buf, &b"\x83"[..]);
}
}