use crate::wire_format::{MAX_VARINT_BYTES, VARINT_CONTINUATION_BIT, VARINT_PAYLOAD_MASK};
use crate::{ProtobufError, Result};
use ::std::io::Read;
use ::std::iter::Iterator;
use super::Varint;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub struct DecodeState {
decoded_value: u64,
shift: u32,
bytes_consumed: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DecodeOutcome {
Complete(Varint),
Incomplete(DecodeState),
Empty,
}
impl DecodeState {
pub(crate) fn new() -> Self {
Self::default()
}
#[allow(dead_code)] pub(crate) fn bytes_consumed(&self) -> usize {
self.bytes_consumed
}
pub(crate) fn feed<I, E>(mut self, iter: I) -> Result<DecodeOutcome>
where
I: Iterator<Item = ::std::result::Result<u8, E>>,
E: Into<ProtobufError>,
{
for byte_result in iter.take(MAX_VARINT_BYTES - self.bytes_consumed) {
let byte = byte_result.map_err(Into::into)?;
self.bytes_consumed += 1;
let value = (byte & VARINT_PAYLOAD_MASK) as u64;
self.decoded_value |= value << self.shift;
if byte & VARINT_CONTINUATION_BIT == 0 {
let result_bytes = self.decoded_value.to_le_bytes();
return Ok(DecodeOutcome::Complete(Varint::new(result_bytes)));
}
self.shift += 7;
}
if self.bytes_consumed == 0 {
return Ok(DecodeOutcome::Empty);
}
if self.bytes_consumed >= MAX_VARINT_BYTES {
return Err(ProtobufError::VarintTooLong);
}
Ok(DecodeOutcome::Incomplete(self))
}
}
pub struct VarintIterator<I: Iterator> {
bytes: I,
}
pub struct ToResultIterator<I> {
inner: I,
}
impl<I> Iterator for ToResultIterator<I>
where
I: Iterator<Item = u8>,
{
type Item = ::std::result::Result<u8, ::std::convert::Infallible>;
fn next(&mut self) -> Option<Self::Item> {
self.inner.next().map(Ok)
}
}
impl<I, E> VarintIterator<I>
where
I: Iterator<Item = ::std::result::Result<u8, E>>,
E: Into<ProtobufError>,
{
fn new(bytes: I) -> Self {
Self { bytes }
}
}
impl<I, E> Iterator for VarintIterator<I>
where
I: Iterator<Item = ::std::result::Result<u8, E>>,
E: Into<ProtobufError>,
{
type Item = Result<Varint>;
fn next(&mut self) -> Option<Self::Item> {
match DecodeState::new().feed(&mut self.bytes) {
Ok(DecodeOutcome::Complete(v)) => Some(Ok(v)),
Ok(DecodeOutcome::Empty) => None,
Ok(DecodeOutcome::Incomplete(_)) => Some(Err(ProtobufError::UnexpectedEof)),
Err(e) => Some(Err(e)),
}
}
}
pub trait IteratorExtVarint {
fn read_varint(self) -> Result<Option<Varint>>;
fn read_varint_partial(self) -> Result<DecodeOutcome>;
fn read_varint_resume(self, state: DecodeState) -> Result<DecodeOutcome>;
fn read_varints(self) -> VarintIterator<ToResultIterator<Self>>
where
Self: Sized + Iterator<Item = u8>;
}
impl<I> IteratorExtVarint for I
where
I: Iterator<Item = u8>,
{
fn read_varint(self) -> Result<Option<Varint>> {
match DecodeState::new().feed(self.map(Ok::<u8, ::std::convert::Infallible>)) {
Ok(DecodeOutcome::Complete(v)) => Ok(Some(v)),
Ok(DecodeOutcome::Empty) => Ok(None),
Ok(DecodeOutcome::Incomplete(_)) => Err(crate::ProtobufError::UnexpectedEof),
Err(e) => Err(e),
}
}
fn read_varint_partial(self) -> Result<DecodeOutcome> {
DecodeState::new().feed(self.map(Ok::<u8, ::std::convert::Infallible>))
}
fn read_varint_resume(self, state: DecodeState) -> Result<DecodeOutcome> {
state.feed(self.map(Ok::<u8, ::std::convert::Infallible>))
}
fn read_varints(self) -> VarintIterator<ToResultIterator<Self>>
where
Self: Sized,
{
VarintIterator::new(ToResultIterator { inner: self })
}
}
pub trait TryIteratorExtVarint {
fn read_varint(self) -> Result<Option<Varint>>;
fn read_varint_partial(self) -> Result<DecodeOutcome>;
fn read_varint_resume(self, state: DecodeState) -> Result<DecodeOutcome>;
fn read_varints(self) -> VarintIterator<Self>
where
Self: Sized + Iterator;
}
impl<I, E> TryIteratorExtVarint for I
where
I: Iterator<Item = ::std::result::Result<u8, E>>,
E: Into<ProtobufError>,
{
fn read_varint(self) -> Result<Option<Varint>> {
match DecodeState::new().feed(self) {
Ok(DecodeOutcome::Complete(v)) => Ok(Some(v)),
Ok(DecodeOutcome::Empty) => Ok(None),
Ok(DecodeOutcome::Incomplete(_)) => Err(crate::ProtobufError::UnexpectedEof),
Err(e) => Err(e),
}
}
fn read_varint_partial(self) -> Result<DecodeOutcome> {
DecodeState::new().feed(self)
}
fn read_varint_resume(self, state: DecodeState) -> Result<DecodeOutcome> {
state.feed(self)
}
fn read_varints(self) -> VarintIterator<Self>
where
Self: Sized,
{
VarintIterator::new(self)
}
}
pub trait ReadExtVarint {
fn read_varint(&mut self) -> Result<Option<Varint>> {
match self.read_varint_partial()? {
DecodeOutcome::Complete(v) => Ok(Some(v)),
DecodeOutcome::Empty => Ok(None),
DecodeOutcome::Incomplete(_) => Err(crate::ProtobufError::UnexpectedEof),
}
}
fn read_varint_partial(&mut self) -> Result<DecodeOutcome>;
fn read_varint_resume(&mut self, state: DecodeState) -> Result<DecodeOutcome>;
fn read_varints(&mut self) -> VarintIterator<::std::io::Bytes<&mut Self>>
where
Self: ::std::io::Read;
}
impl<R> ReadExtVarint for R
where
R: Read,
{
#[allow(clippy::unbuffered_bytes)] fn read_varint_partial(&mut self) -> Result<DecodeOutcome> {
DecodeState::new().feed(self.bytes())
}
#[allow(clippy::unbuffered_bytes)]
fn read_varint_resume(&mut self, state: DecodeState) -> Result<DecodeOutcome> {
state.feed(self.bytes())
}
#[allow(clippy::unbuffered_bytes)] fn read_varints(&mut self) -> VarintIterator<::std::io::Bytes<&mut Self>> {
VarintIterator::new(self.bytes())
}
}
#[cfg(test)]
mod tests {
use super::{DecodeOutcome, DecodeState, IteratorExtVarint, ReadExtVarint, TryIteratorExtVarint};
use crate::ProtobufError;
use crate::varint::Varint;
#[test]
fn test_read_varint_from_iterator() {
let input = [0x96, 0x01];
let iter = input.iter().copied();
let outcome = iter.read_varint_partial().unwrap();
let varint = match outcome {
DecodeOutcome::Complete(v) => v,
_ => panic!("expected Complete"),
};
assert_eq!(varint.to_uint64(), 150);
}
#[test]
fn test_iterator_ext_varint_trait() {
let bytes = vec![0x96, 0x01];
let iter = bytes.into_iter();
let outcome = iter.read_varint_partial().unwrap();
let varint = match outcome {
DecodeOutcome::Complete(v) => v,
_ => panic!("expected Complete"),
};
assert_eq!(varint.to_uint64(), 150);
}
#[test]
fn test_iterator_ext_varint_empty() {
let outcome = IteratorExtVarint::read_varint_partial(::std::iter::empty()).unwrap();
assert_eq!(outcome, DecodeOutcome::Empty);
}
#[test]
fn test_iterator_ext_varint_read_varints() {
let bytes = vec![0x96, 0x01, 0x7F, 0x01];
let iter = bytes.into_iter();
let varints: Vec<Varint> = iter.read_varints().collect::<Result<Vec<_>, _>>().unwrap();
assert_eq!(varints.len(), 3);
assert_eq!(varints[0].to_uint64(), 150);
assert_eq!(varints[1].to_uint64(), 127);
assert_eq!(varints[2].to_uint64(), 1);
}
#[test]
fn test_iterator_ext_varint_incomplete() {
let bytes = vec![0x80u8];
let result = IteratorExtVarint::read_varint_partial(bytes.into_iter());
assert!(result.is_ok());
assert!(matches!(
result,
Ok(DecodeOutcome::Incomplete(state)) if state.bytes_consumed() == 1
));
}
#[test]
fn test_read_ext_varint_from_slice() {
let mut slice = &[0x96u8, 0x01][..];
let outcome = slice.read_varint_partial().unwrap();
let varint = match outcome {
DecodeOutcome::Complete(v) => v,
_ => panic!("expected Complete"),
};
assert_eq!(varint.to_uint64(), 150);
}
#[test]
fn test_read_ext_varint_from_empty_slice() {
let mut slice: &[u8] = &[];
let outcome = slice.read_varint_partial().unwrap();
assert_eq!(outcome, DecodeOutcome::Empty);
}
#[test]
fn test_read_ext_varint_from_slice_incomplete() {
let mut slice = &[0x80u8][..];
let result = slice.read_varint_partial();
assert!(result.is_ok());
assert!(matches!(result, Ok(DecodeOutcome::Incomplete(_))));
}
#[test]
fn test_decode_state_feed_incomplete() {
use ::std::convert::Infallible;
let slice = &[0x80u8][..];
let result = DecodeState::new().feed(slice.iter().copied().map(Ok::<u8, Infallible>));
assert!(result.is_ok());
assert!(matches!(result, Ok(DecodeOutcome::Incomplete(_))));
}
#[test]
fn test_decode_state_feed_resume_complete() {
use ::std::convert::Infallible;
let buf1 = &[0x80u8][..];
let Ok(DecodeOutcome::Incomplete(state)) =
DecodeState::new().feed(buf1.iter().copied().map(Ok::<u8, Infallible>))
else {
panic!("Expected Incomplete");
};
let buf2 = &[0x01u8][..];
let Ok(DecodeOutcome::Complete(varint)) =
state.feed(buf2.iter().copied().map(Ok::<u8, Infallible>))
else {
panic!("Expected Complete");
};
assert_eq!(varint.to_uint64(), 128);
}
#[test]
fn test_decode_state_feed_resume_incomplete() {
use ::std::convert::Infallible;
let buf1 = &[0x80u8][..];
let Ok(DecodeOutcome::Incomplete(state)) =
DecodeState::new().feed(buf1.iter().copied().map(Ok::<u8, Infallible>))
else {
panic!("Expected Incomplete");
};
let buf2 = &[0x80u8][..];
let result = state.feed(buf2.iter().copied().map(Ok::<u8, Infallible>));
assert!(result.is_ok());
assert!(matches!(result, Ok(DecodeOutcome::Incomplete(_))));
}
#[test]
fn test_try_iterator_ext_varint() {
use ::std::io::{Cursor, Read};
let data = vec![0x96, 0x01];
let reader = Cursor::new(data);
let iter = reader.bytes();
let outcome = TryIteratorExtVarint::read_varint_partial(iter).unwrap();
let varint = match outcome {
DecodeOutcome::Complete(v) => v,
_ => panic!("expected Complete"),
};
assert_eq!(varint.to_uint64(), 150);
}
#[test]
fn test_try_iterator_ext_varint_empty() {
use ::std::io::{Cursor, Read};
let data = vec![];
let reader = Cursor::new(data);
let iter = reader.bytes();
let outcome = TryIteratorExtVarint::read_varint_partial(iter).unwrap();
assert_eq!(outcome, DecodeOutcome::Empty);
}
#[test]
fn test_try_iterator_ext_varint_error() {
use ::std::io::ErrorKind;
let error = ::std::io::Error::new(ErrorKind::UnexpectedEof, "test error");
let iter = ::std::iter::once(Err(error));
let result = TryIteratorExtVarint::read_varint_partial(iter);
assert!(result.is_err());
if let Err(ProtobufError::IoError(io_err)) = result {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof);
} else {
panic!("Expected IoError");
}
}
#[test]
fn test_try_iterator_ext_varint_read_varints() {
use ::std::io::{Cursor, Read};
let data = vec![0x96, 0x01, 0x7F, 0x01];
let reader = Cursor::new(data);
let iter = reader.bytes();
let varints: Vec<Varint> = iter.read_varints().collect::<Result<Vec<_>, _>>().unwrap();
assert_eq!(varints.len(), 3);
assert_eq!(varints[0].to_uint64(), 150);
assert_eq!(varints[1].to_uint64(), 127);
assert_eq!(varints[2].to_uint64(), 1);
}
#[test]
fn test_read_ext_varint_trait() {
use ::std::io::Cursor;
let input = [0x96, 0x01];
let mut reader = Cursor::new(input);
let outcome = reader.read_varint_partial().unwrap();
let varint = match outcome {
DecodeOutcome::Complete(v) => v,
_ => panic!("expected Complete"),
};
assert_eq!(varint.to_uint64(), 150);
}
#[test]
fn test_read_ext_varint_read_varints() {
use ::std::io::Cursor;
let data = vec![0x96, 0x01, 0x7F, 0x01];
let mut reader = Cursor::new(data);
let varints: Vec<Varint> = reader
.read_varints()
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(varints.len(), 3);
assert_eq!(varints[0].to_uint64(), 150);
assert_eq!(varints[1].to_uint64(), 127);
assert_eq!(varints[2].to_uint64(), 1);
}
#[test]
fn test_write_varint_roundtrip() {
use crate::varint::WriteExtVarint;
let test_values = vec![0, 1, 127, 128, 150, 255, 256, 65535, 0x7FFFFFFF];
for &value in &test_values {
let varint = Varint::from_uint64(value);
let mut buffer = Vec::new();
buffer.write_varint(&varint).unwrap();
let iter = buffer.iter().copied();
let outcome = iter.read_varint_partial().unwrap();
let decoded_varint = match outcome {
DecodeOutcome::Complete(v) => v,
_ => panic!("expected Complete"),
};
let decoded_value = decoded_varint.to_uint64();
assert_eq!(decoded_value, value, "Roundtrip failed for value {}", value);
}
}
}