use crate::core::vlq::{decode_slice, VlqError, CONTINUE};
use log::trace;
use snafu::{ensure, OptionExt, ResultExt, Snafu};
use std::fs::File;
use std::io::{BufReader, Bytes, ErrorKind, Read};
use std::path::{Path, PathBuf};
use std::str::{from_utf8, Utf8Error};
pub(crate) struct ByteIter<R: Read> {
iter: Bytes<R>,
position: Option<u64>,
current: Option<u8>,
peek1: Option<u8>,
peek2: Option<u8>,
peek3: Option<u8>,
position_limit: Option<u64>,
latest_message_byte: Option<u8>,
running_status_detected: bool,
}
#[derive(Debug, Snafu)]
pub(crate) enum ByteError {
#[snafu(display("io error around byte {}: {}", position, source))]
Io {
position: u64,
source: std::io::Error,
},
#[snafu(display("unexpended end reached around byte {}", position))]
End { position: u64 },
#[snafu(display(
"expected string but found non-utf8 encoded bytes around {}: {}",
position,
source
))]
Str { position: u64, source: Utf8Error },
#[snafu(display(
"expected tag '{}' but found '{}' near position {}",
expected,
found,
position
))]
Tag {
expected: String,
found: String,
position: u64,
},
#[snafu(display("too many bytes while reading vlq around {}", position))]
VlqTooBig { position: u64 },
#[snafu(display("problem decoding vlq around {}: {}", position, source))]
VlqDecode { position: u64, source: VlqError },
#[snafu(display(
"incorrect byte value around {}: expected '{:#X}', found '{:#X}'",
position,
expected,
found,
))]
ReadExpect {
expected: u8,
found: u8,
position: u64,
},
#[snafu(display("unable to open '{}': {}", path.display(), source,))]
FileOpen {
path: PathBuf,
source: std::io::Error,
},
}
pub(crate) type ByteResult<T> = std::result::Result<T, ByteError>;
const BYTE_SIZE: usize = 8;
const KB: usize = BYTE_SIZE * 1024;
const MB: usize = KB * 1024;
impl ByteIter<BufReader<File>> {
pub(crate) fn new_file<P: AsRef<Path>>(path: P) -> ByteResult<Self> {
let path = path.as_ref();
let f = File::open(path).context(FileOpenSnafu { path })?;
let buf = BufReader::with_capacity(MB, f);
Self::new(buf.bytes())
}
}
impl<R: Read> ByteIter<R> {
pub(crate) fn new(mut iter: Bytes<R>) -> ByteResult<Self> {
let peek1 = Self::next_impl(&mut iter, 0)?;
let peek2 = Self::next_impl(&mut iter, 0)?;
let peek3 = Self::next_impl(&mut iter, 0)?;
Ok(Self {
iter,
position: None,
current: None,
peek1,
peek2,
peek3,
position_limit: None,
latest_message_byte: None,
running_status_detected: false,
})
}
fn next_impl(iter: &mut Bytes<R>, position: u64) -> ByteResult<Option<u8>> {
match iter.next() {
None => Ok(None),
Some(result) => match result {
Ok(val) => Ok(Some(val)),
Err(ref e) if e.kind() == ErrorKind::UnexpectedEof => Ok(None),
Err(e) => Err(e).context(IoSnafu { position }),
},
}
}
pub(crate) fn read(&mut self) -> ByteResult<Option<u8>> {
if let Some(position_limit) = self.position_limit {
if let Some(position) = self.position {
if position >= position_limit {
return Ok(None);
}
}
}
if self.current.is_none() {
self.position = Some(0);
} else if self.current.is_some() {
self.position = Some(self.position.unwrap_or(0) + 1);
}
let return_val = self.peek1;
self.current = self.peek1;
self.peek1 = self.peek2;
self.peek2 = self.peek3;
let next_opt = self.iter.next();
let next_result = match next_opt {
None => {
self.peek3 = None;
trace!(
"read {:#x} at position {}",
return_val.unwrap_or(0),
self.position.unwrap_or(0)
);
return Ok(return_val);
}
Some(r) => r,
};
let e = match next_result {
Ok(ok) => {
self.peek3 = Some(ok);
trace!(
"read {:#x} at position {}",
return_val.unwrap_or(0),
self.position.unwrap_or(0)
);
return Ok(return_val);
}
Err(e) => {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
self.peek3 = None;
trace!(
"read {:#x} at position {}",
return_val.unwrap_or(0),
self.position.unwrap_or(0)
);
return Ok(return_val);
}
e
}
};
Err(e).context(IoSnafu {
position: self.position.unwrap_or(0),
})
}
pub(crate) fn read_or_die(&mut self) -> ByteResult<u8> {
self.read()?.context(EndSnafu {
position: self.position.unwrap_or(0),
})
}
pub(crate) fn read2(&mut self) -> ByteResult<[u8; 2]> {
let mut retval = [0u8; 2];
retval[0] = self.read()?.context(EndSnafu {
position: self.position.unwrap_or(0),
})?;
retval[1] = self.read()?.context(EndSnafu {
position: self.position.unwrap_or(0),
})?;
Ok(retval)
}
pub(crate) fn read4(&mut self) -> ByteResult<[u8; 4]> {
let mut retval = [0u8; 4];
retval[0] = self.read()?.context(EndSnafu {
position: self.position.unwrap_or(0),
})?;
retval[1] = self.read()?.context(EndSnafu {
position: self.position.unwrap_or(0),
})?;
retval[2] = self.read()?.context(EndSnafu {
position: self.position.unwrap_or(0),
})?;
retval[3] = self.read()?.context(EndSnafu {
position: self.position.unwrap_or(0),
})?;
Ok(retval)
}
pub(crate) fn read_u16(&mut self) -> ByteResult<u16> {
let bytes: [u8; 2] = self.read2()?;
Ok(u16::from_be_bytes(bytes))
}
pub(crate) fn read_u32(&mut self) -> ByteResult<u32> {
let bytes = self.read4()?;
Ok(u32::from_be_bytes(bytes))
}
pub(crate) fn read_vlq_bytes(&mut self) -> ByteResult<Vec<u8>> {
let mut retval = Vec::new();
let mut current_byte = CONTINUE;
let mut byte_count = 0u8;
while current_byte & CONTINUE == CONTINUE {
ensure!(
byte_count <= 4,
VlqTooBigSnafu {
position: self.position.unwrap_or(0)
}
);
current_byte = self.read_or_die()?;
retval.push(current_byte);
byte_count += 1;
}
Ok(retval)
}
pub(crate) fn read_vlq_u32(&mut self) -> ByteResult<u32> {
let bytes = self.read_vlq_bytes()?;
let decoded = decode_slice(&bytes).context(VlqDecodeSnafu {
position: self.position.unwrap_or(0),
})?;
trace!("decoded vlq value {} from {} bytes", decoded, bytes.len());
Ok(decoded)
}
pub(crate) fn current(&self) -> Option<u8> {
self.current
}
pub(crate) fn peek_or_die(&self) -> ByteResult<u8> {
self.peek1.context(EndSnafu {
position: self.position.unwrap_or(0),
})
}
pub(crate) fn is_end(&self) -> bool {
if let Some(limit) = self.position_limit {
debug_assert!(self.position.is_some());
debug_assert!(self.position.unwrap_or(0) <= limit);
if self.position.unwrap_or(0) >= limit {
return true;
}
}
self.current.is_none()
}
pub(crate) fn expect_tag(&mut self, expected_tag: &str) -> ByteResult<()> {
let tag_bytes = self.read4()?;
let actual_tag = from_utf8(&tag_bytes).context(StrSnafu {
position: self.position.unwrap_or(0),
})?;
ensure!(
expected_tag == actual_tag,
TagSnafu {
expected: expected_tag,
found: actual_tag,
position: self.position.unwrap_or(0)
}
);
Ok(())
}
pub(crate) fn set_size_limit(&mut self, size: u64) {
self.position_limit = Some(self.position.unwrap_or(0) + size)
}
pub(crate) fn clear_size_limit(&mut self) {
self.position_limit = None
}
pub(crate) fn read_expect(&mut self, expected: u8) -> ByteResult<()> {
let found = self.read_or_die()?;
ensure!(
expected == found,
ReadExpectSnafu {
expected,
found,
position: self.position.unwrap_or(0)
}
);
Ok(())
}
pub(crate) fn read_n(&mut self, num_bytes: usize) -> ByteResult<Vec<u8>> {
let mut bytes = Vec::with_capacity(num_bytes);
for _ in 0..num_bytes {
bytes.push(self.read_or_die()?)
}
debug_assert_eq!(num_bytes, bytes.len());
Ok(bytes)
}
pub(crate) fn set_latest_message_byte(&mut self, value: Option<u8>) {
self.latest_message_byte = value;
}
pub(crate) fn latest_message_byte(&self) -> Option<u8> {
self.latest_message_byte
}
pub(crate) fn set_running_status_detected(&mut self) {
self.running_status_detected = true;
}
pub(crate) fn is_running_status_detected(&self) -> bool {
self.running_status_detected
}
}
#[test]
fn byte_iter_test() {
use std::io::Cursor;
let bytes = [0x00u8, 0x01, 0x02, 0x03, 0x04, 0x10, 0x20, 0x30, 0x40];
let cursor = Cursor::new(bytes);
let mut iter = ByteIter::new(cursor.bytes()).unwrap();
assert!(iter.current.is_none());
assert_eq!(0x00, iter.read().unwrap().unwrap());
assert_eq!(0x00, iter.current.unwrap());
assert_eq!(0x01, iter.peek1.unwrap());
assert_eq!(0x02, iter.peek2.unwrap());
assert_eq!(0x03, iter.peek3.unwrap());
assert_eq!([0x01, 0x02], iter.read2().unwrap());
assert_eq!(2, iter.position.unwrap());
iter.set_size_limit(2);
assert!(!iter.is_end());
assert_eq!(0x03, iter.read().unwrap().unwrap());
assert_eq!(0x04, iter.read().unwrap().unwrap());
assert_eq!(0x04, iter.current().unwrap());
assert!(iter.read().unwrap().is_none());
iter.clear_size_limit();
assert_eq!(0x10, iter.read().unwrap().unwrap());
}