use crate::codec::decoder::Decoder;
use crate::codec::encoder::Encoder;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use std::{cmp, fmt, io, str, usize};
const DEFAULT_SEEK_DELIMITERS: &[u8] = b",;\n\r";
const DEFAULT_SEQUENCE_WRITER: &[u8] = b",";
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct AnyDelimiterCodec {
next_index: usize,
max_length: usize,
is_discarding: bool,
seek_delimiters: Vec<u8>,
sequence_writer: Vec<u8>,
}
impl AnyDelimiterCodec {
pub fn new(seek_delimiters: Vec<u8>, sequence_writer: Vec<u8>) -> AnyDelimiterCodec {
AnyDelimiterCodec {
next_index: 0,
max_length: usize::MAX,
is_discarding: false,
seek_delimiters,
sequence_writer,
}
}
pub fn new_with_max_length(
seek_delimiters: Vec<u8>,
sequence_writer: Vec<u8>,
max_length: usize,
) -> Self {
AnyDelimiterCodec {
max_length,
..AnyDelimiterCodec::new(seek_delimiters, sequence_writer)
}
}
pub fn max_length(&self) -> usize {
self.max_length
}
}
impl Decoder for AnyDelimiterCodec {
type Item = Bytes;
type Error = AnyDelimiterCodecError;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Bytes>, AnyDelimiterCodecError> {
loop {
let read_to = cmp::min(self.max_length.saturating_add(1), buf.len());
let new_chunk_offset = buf[self.next_index..read_to].iter().position(|b| {
self.seek_delimiters
.iter()
.any(|delimiter| *b == *delimiter)
});
match (self.is_discarding, new_chunk_offset) {
(true, Some(offset)) => {
buf.advance(offset + self.next_index + 1);
self.is_discarding = false;
self.next_index = 0;
}
(true, None) => {
buf.advance(read_to);
self.next_index = 0;
if buf.is_empty() {
return Ok(None);
}
}
(false, Some(offset)) => {
let new_chunk_index = offset + self.next_index;
self.next_index = 0;
let mut chunk = buf.split_to(new_chunk_index + 1);
chunk.truncate(chunk.len() - 1);
let chunk = chunk.freeze();
return Ok(Some(chunk));
}
(false, None) if buf.len() > self.max_length => {
self.is_discarding = true;
return Err(AnyDelimiterCodecError::MaxChunkLengthExceeded);
}
(false, None) => {
self.next_index = read_to;
return Ok(None);
}
}
}
}
fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Bytes>, AnyDelimiterCodecError> {
Ok(match self.decode(buf)? {
Some(frame) => Some(frame),
None => {
if buf.is_empty() {
None
} else {
let chunk = buf.split_to(buf.len());
self.next_index = 0;
Some(chunk.freeze())
}
}
})
}
}
impl<T> Encoder<T> for AnyDelimiterCodec
where
T: AsRef<str>,
{
type Error = AnyDelimiterCodecError;
fn encode(&mut self, chunk: T, buf: &mut BytesMut) -> Result<(), AnyDelimiterCodecError> {
let chunk = chunk.as_ref();
buf.reserve(chunk.len() + 1);
buf.put(chunk.as_bytes());
buf.put(self.sequence_writer.as_ref());
Ok(())
}
}
impl Default for AnyDelimiterCodec {
fn default() -> Self {
Self::new(
DEFAULT_SEEK_DELIMITERS.to_vec(),
DEFAULT_SEQUENCE_WRITER.to_vec(),
)
}
}
#[derive(Debug)]
pub enum AnyDelimiterCodecError {
MaxChunkLengthExceeded,
Io(io::Error),
}
impl fmt::Display for AnyDelimiterCodecError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AnyDelimiterCodecError::MaxChunkLengthExceeded => {
write!(f, "max chunk length exceeded")
}
AnyDelimiterCodecError::Io(e) => write!(f, "{}", e),
}
}
}
impl From<io::Error> for AnyDelimiterCodecError {
fn from(e: io::Error) -> AnyDelimiterCodecError {
AnyDelimiterCodecError::Io(e)
}
}
impl std::error::Error for AnyDelimiterCodecError {}