use crate::error::StreamBodyKind;
use crate::StreamBodyError;
use bytes::{Buf, BytesMut};
use std::marker::PhantomData;
#[derive(Clone, Debug)]
pub struct ProtobufLenPrefixCodec<T> {
max_length: usize,
cursor: ProtobufCursor,
_ph: PhantomData<T>,
}
#[derive(Clone, Debug)]
struct ProtobufCursor {
current_obj_len: usize,
}
impl<T> ProtobufLenPrefixCodec<T> {
pub fn new_with_max_length(max_length: usize) -> Self {
let initial_cursor = ProtobufCursor { current_obj_len: 0 };
ProtobufLenPrefixCodec {
max_length,
cursor: initial_cursor,
_ph: PhantomData,
}
}
}
impl<T> tokio_util::codec::Decoder for ProtobufLenPrefixCodec<T>
where
T: prost::Message + Default,
{
type Item = T;
type Error = StreamBodyError;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<T>, StreamBodyError> {
let buf_len = buf.len();
if buf_len == 0 {
return Ok(None);
}
if self.cursor.current_obj_len == 0 {
let bytes = buf.chunk();
let byte = bytes[0];
if byte < 0x80 {
buf.advance(1);
self.cursor.current_obj_len = u64::from(byte) as usize;
Ok(None)
} else if buf_len > 10 || bytes[buf_len - 1] < 0x80 {
let (value, advance) = decode_varint_slice(bytes)?;
buf.advance(advance);
self.cursor.current_obj_len = value as usize;
Ok(None)
} else {
Ok(None) }
} else if self.cursor.current_obj_len > self.max_length {
Err(StreamBodyError::new(
StreamBodyKind::MaxLenReachedError,
None,
Some("Max object length reached".into()),
))
} else if buf_len >= self.cursor.current_obj_len {
let obj_bytes = buf.copy_to_bytes(self.cursor.current_obj_len);
let result: Result<Option<T>, StreamBodyError> = prost::Message::decode(obj_bytes)
.map(|res| Some(res))
.map_err(|err| {
StreamBodyError::new(StreamBodyKind::CodecError, Some(Box::new(err)), None)
});
self.cursor.current_obj_len = 0;
result
} else {
Ok(None)
}
}
fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<T>, StreamBodyError> {
self.decode(buf)
}
}
#[inline]
fn decode_varint_slice(bytes: &[u8]) -> Result<(u64, usize), StreamBodyError> {
assert!(!bytes.is_empty());
assert!(bytes.len() > 10 || bytes[bytes.len() - 1] < 0x80);
let mut b: u8 = bytes[0];
let mut part0: u32 = u32::from(b);
if b < 0x80 {
return Ok((u64::from(part0), 1));
};
part0 -= 0x80;
b = bytes[1];
part0 += u32::from(b) << 7;
if b < 0x80 {
return Ok((u64::from(part0), 2));
};
part0 -= 0x80 << 7;
b = bytes[2];
part0 += u32::from(b) << 14;
if b < 0x80 {
return Ok((u64::from(part0), 3));
};
part0 -= 0x80 << 14;
b = bytes[3];
part0 += u32::from(b) << 21;
if b < 0x80 {
return Ok((u64::from(part0), 4));
};
part0 -= 0x80 << 21;
let value = u64::from(part0);
b = bytes[4];
let mut part1: u32 = u32::from(b);
if b < 0x80 {
return Ok((value + (u64::from(part1) << 28), 5));
};
part1 -= 0x80;
b = bytes[5];
part1 += u32::from(b) << 7;
if b < 0x80 {
return Ok((value + (u64::from(part1) << 28), 6));
};
part1 -= 0x80 << 7;
b = bytes[6];
part1 += u32::from(b) << 14;
if b < 0x80 {
return Ok((value + (u64::from(part1) << 28), 7));
};
part1 -= 0x80 << 14;
b = bytes[7];
part1 += u32::from(b) << 21;
if b < 0x80 {
return Ok((value + (u64::from(part1) << 28), 8));
};
part1 -= 0x80 << 21;
let value = value + ((u64::from(part1)) << 28);
b = bytes[8];
let mut part2: u32 = u32::from(b);
if b < 0x80 {
return Ok((value + (u64::from(part2) << 56), 9));
};
part2 -= 0x80;
b = bytes[9];
part2 += u32::from(b) << 7;
if b < 0x02 {
return Ok((value + (u64::from(part2) << 56), 10));
};
Err(StreamBodyError::new(
StreamBodyKind::CodecError,
None,
Some("invalid varint".into()),
))
}