use crate::encode::NextRead;
use crate::{encode, ChunkGroupState};
use crate::{Finalization, Hash, GROUP_SIZE, HEADER_SIZE, MAX_DEPTH, PARENT_SIZE};
use arrayref::array_ref;
use arrayvec::ArrayVec;
use std::cmp;
use std::error;
use std::fmt;
use std::io;
use std::io::prelude::*;
use std::io::SeekFrom;
pub fn decode(encoded: impl AsRef<[u8]>, hash: &Hash) -> io::Result<Vec<u8>> {
let bytes = encoded.as_ref();
if bytes.len() < HEADER_SIZE {
return Err(Error::Truncated.into());
}
let content_len = crate::decode_len(array_ref!(bytes, 0, HEADER_SIZE));
if (bytes.len() as u128) < encode::encoded_size(content_len) {
return Err(Error::Truncated.into());
}
let mut vec = vec![0; content_len as usize];
let mut reader = Decoder::new(bytes, hash);
reader.read_exact(&mut vec)?;
let n = reader.read(&mut [0])?;
debug_assert_eq!(n, 0, "must be EOF");
Ok(vec)
}
#[derive(Clone)]
struct VerifyState {
stack: ArrayVec<Hash, MAX_DEPTH>,
parser: encode::ParseState,
root_hash: Hash,
}
impl VerifyState {
fn new(hash: &Hash) -> Self {
let mut stack = ArrayVec::new();
stack.push(*hash);
Self {
stack,
parser: encode::ParseState::new(),
root_hash: *hash,
}
}
fn content_position(&self) -> u64 {
self.parser.content_position()
}
fn read_next(&self) -> NextRead {
self.parser.read_next()
}
fn seek_next(&self, seek_to: u64) -> encode::SeekBookkeeping {
self.parser.seek_next(seek_to)
}
fn seek_bookkeeping_done(&mut self, bookkeeping: encode::SeekBookkeeping) -> encode::NextRead {
if bookkeeping.reset_to_root() {
self.stack.clear();
self.stack.push(self.root_hash);
}
debug_assert!(self.stack.len() >= bookkeeping.stack_depth());
while self.stack.len() > bookkeeping.stack_depth() {
self.stack.pop();
}
self.parser.seek_bookkeeping_done(bookkeeping)
}
fn len_next(&self) -> encode::LenNext {
self.parser.len_next()
}
fn feed_header(&mut self, header: &[u8; HEADER_SIZE]) {
self.parser.feed_header(header);
}
fn feed_parent(&mut self, parent: &crate::ParentNode) -> Result<(), Error> {
let finalization = self.parser.finalization();
let expected_hash: &Hash = self.stack.last().expect("unexpectedly empty stack");
let left_child: Hash = (*array_ref!(parent, 0, 32)).into();
let right_child: Hash = (*array_ref!(parent, 32, 32)).into();
let computed_hash: Hash =
blake3::guts::parent_cv(&left_child, &right_child, finalization.is_root());
if expected_hash != &computed_hash {
return Err(Error::HashMismatch);
}
self.stack.pop();
self.stack.push(right_child);
self.stack.push(left_child);
self.parser.advance_parent();
Ok(())
}
fn feed_chunk(&mut self, chunk_hash: &Hash) -> Result<(), Error> {
let expected_hash = self.stack.last().expect("unexpectedly empty stack");
if chunk_hash != expected_hash {
return Err(Error::HashMismatch);
}
self.stack.pop();
self.parser.advance_chunk();
Ok(())
}
}
impl fmt::Debug for VerifyState {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"VerifyState {{ stack_size: {}, parser: {:?} }}",
self.stack.len(), self.parser, )
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
HashMismatch,
Truncated,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Error::HashMismatch => write!(f, "hash mismatch"),
Error::Truncated => write!(f, "truncated encoding"),
}
}
}
impl error::Error for Error {}
impl From<Error> for io::Error {
fn from(e: Error) -> io::Error {
match e {
Error::HashMismatch => io::Error::new(io::ErrorKind::InvalidData, "hash mismatch"),
Error::Truncated => io::Error::new(io::ErrorKind::UnexpectedEof, "truncated encoding"),
}
}
}
#[derive(Clone)]
struct DecoderShared<T, O> {
input: T,
outboard: Option<O>,
state: VerifyState,
buf: [u8; GROUP_SIZE],
buf_start: usize,
buf_end: usize,
}
impl<T, O> DecoderShared<T, O> {
fn new(input: T, outboard: Option<O>, hash: &Hash) -> Self {
Self {
input,
outboard,
state: VerifyState::new(hash),
buf: [0; GROUP_SIZE],
buf_start: 0,
buf_end: 0,
}
}
fn adjusted_content_position(&self) -> u64 {
self.state.content_position() - self.buf_len() as u64
}
fn buf_len(&self) -> usize {
self.buf_end - self.buf_start
}
fn clear_buf(&mut self) {
self.buf_start = 0;
self.buf_end = 0;
}
}
impl<T: Read, O: Read> DecoderShared<T, O> {
fn take_buffered_bytes(&mut self, output: &mut [u8]) -> usize {
let take = cmp::min(self.buf_len(), output.len());
output[..take].copy_from_slice(&self.buf[self.buf_start..self.buf_start + take]);
self.buf_start += take;
take
}
fn get_and_feed_header(&mut self) -> io::Result<()> {
debug_assert_eq!(0, self.buf_len());
let mut header = [0; HEADER_SIZE];
if let Some(outboard) = &mut self.outboard {
outboard.read_exact(&mut header)?;
} else {
self.input.read_exact(&mut header)?;
}
self.state.feed_header(&header);
Ok(())
}
fn get_parent(&mut self) -> io::Result<crate::ParentNode> {
debug_assert_eq!(0, self.buf_len());
let mut parent = [0; PARENT_SIZE];
if let Some(outboard) = &mut self.outboard {
outboard.read_exact(&mut parent)?;
} else {
self.input.read_exact(&mut parent)?;
}
Ok(parent)
}
fn get_and_feed_parent(&mut self) -> io::Result<()> {
let parent = self.get_parent()?;
self.state.feed_parent(&parent)?;
Ok(())
}
fn buffer_verified_chunk(
&mut self,
size: usize,
finalization: Finalization,
skip: usize,
index: u64,
parents_to_read: usize,
) -> io::Result<()> {
debug_assert_eq!(0, self.buf_len());
self.buf_start = 0;
self.buf_end = 0;
for _ in 0..parents_to_read {
self.get_and_feed_parent()?;
}
let buf_slice = &mut self.buf[..size];
self.input.read_exact(buf_slice)?;
let hash = ChunkGroupState::new(index)
.update(buf_slice)
.finalize(finalization.is_root());
self.state.feed_chunk(&hash)?;
self.buf_start = skip;
self.buf_end = size;
Ok(())
}
fn read(&mut self, output: &mut [u8]) -> io::Result<usize> {
if output.is_empty() {
return Ok(0);
}
if self.buf_len() > 0 {
return Ok(self.take_buffered_bytes(output));
}
loop {
match self.state.read_next() {
NextRead::Done => {
return Ok(0);
}
NextRead::Header => {
self.get_and_feed_header()?;
}
NextRead::Parent => {
self.get_and_feed_parent()?;
}
NextRead::Chunk {
size,
finalization,
skip,
index,
} => {
debug_assert_eq!(self.buf_len(), 0);
let (read_buf, direct_output) = if output.len() >= size && skip == 0 {
(&mut output[..size], true)
} else {
(&mut self.buf[..size], false)
};
self.input.read_exact(read_buf)?;
let chunk_hash = ChunkGroupState::new(index)
.update(read_buf)
.finalize(finalization.is_root());
self.state.feed_chunk(&chunk_hash)?;
if direct_output {
return Ok(size);
} else {
self.buf_start = skip;
self.buf_end = size;
return Ok(self.take_buffered_bytes(output));
}
}
}
}
}
fn handle_seek_read(&mut self, next: NextRead) -> io::Result<bool> {
debug_assert_eq!(0, self.buf_len());
match next {
NextRead::Header => self.get_and_feed_header()?,
NextRead::Parent => self.get_and_feed_parent()?,
NextRead::Chunk {
size,
finalization,
skip,
index,
} => {
self.buffer_verified_chunk(
size,
finalization,
skip,
index,
0,
)?;
debug_assert_eq!(0, self.buf_len());
}
NextRead::Done => return Ok(true), }
Ok(false)
}
}
impl<T: Read + Seek, O: Read + Seek> DecoderShared<T, O> {
fn handle_seek_bookkeeping(
&mut self,
bookkeeping: encode::SeekBookkeeping,
) -> io::Result<NextRead> {
if let Some(outboard) = &mut self.outboard {
if let Some((content_pos, outboard_pos)) = bookkeeping.underlying_seek_outboard() {
self.input.seek(SeekFrom::Start(content_pos))?;
outboard.seek(SeekFrom::Start(outboard_pos))?;
}
} else if let Some(encoding_position) = bookkeeping.underlying_seek() {
let position_u64: u64 = encode::cast_offset(encoding_position)?;
self.input.seek(SeekFrom::Start(position_u64))?;
}
let next = self.state.seek_bookkeeping_done(bookkeeping);
Ok(next)
}
}
impl<T, O> fmt::Debug for DecoderShared<T, O> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"DecoderShared {{ is_outboard: {}, state: {:?}, buf_start: {}, buf_end: {} }}",
self.outboard.is_some(),
self.state,
self.buf_start,
self.buf_end,
)
}
}
#[cfg(feature = "tokio_io")]
mod tokio_io {
use crate::ChunkGroupState;
use super::{DecoderShared, Hash, NextRead};
use futures::{future, ready};
use std::{
cmp,
convert::TryInto,
io,
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, ReadBuf};
impl<T: AsyncRead + Unpin, O: AsyncRead + Unpin> DecoderShared<T, O> {
fn write_output(&mut self, buf: &mut ReadBuf<'_>) {
let n = cmp::min(buf.remaining(), self.buf_len());
buf.put_slice(&self.buf[self.buf_start..self.buf_start + n]);
self.buf_start += n;
}
fn poll_read_header(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
if let NextRead::Header = self.state.read_next() {
ready!(self.poll_fill_buffer_from_input_or_outboard(8, cx))?;
self.state.feed_header(self.buf[0..8].try_into().unwrap());
self.clear_buf();
}
Poll::Ready(Ok(()))
}
fn poll_input(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
loop {
match self.state.read_next() {
NextRead::Done => {
break Poll::Ready(Ok(()));
}
NextRead::Header => {
ready!(self.poll_fill_buffer_from_input_or_outboard(8, cx))?;
self.state.feed_header(self.buf[0..8].try_into().unwrap());
self.clear_buf();
}
NextRead::Parent => {
ready!(self.poll_fill_buffer_from_input_or_outboard(64, cx))?;
self.state
.feed_parent(&self.buf[0..64].try_into().unwrap())?;
self.clear_buf();
}
NextRead::Chunk {
size,
finalization,
skip,
index,
} => {
ready!(self.poll_fill_buffer_from_input(size, cx))?;
let read_buf = &self.buf[0..size];
let chunk_hash = ChunkGroupState::new(index)
.update(read_buf)
.finalize(finalization.is_root());
self.state.feed_chunk(&chunk_hash)?;
self.buf_start = skip;
debug_assert!(self.buf_len() > 0 || size == 0);
break Poll::Ready(Ok(()));
}
}
}
}
fn poll_handle_seek_read(
&mut self,
next: NextRead,
cx: &mut Context,
) -> Poll<io::Result<bool>> {
Poll::Ready(Ok(match next {
NextRead::Header => {
ready!(self.poll_fill_buffer_from_input_or_outboard(8, cx))?;
self.state.feed_header(self.buf[0..8].try_into().unwrap());
self.clear_buf();
false
}
NextRead::Parent => {
ready!(self.poll_fill_buffer_from_input_or_outboard(64, cx))?;
self.state
.feed_parent(&self.buf[0..64].try_into().unwrap())?;
self.clear_buf();
false
}
NextRead::Chunk {
size,
finalization,
skip: _,
index,
} => {
ready!(self.poll_fill_buffer_from_input(size, cx))?;
let read_buf = &self.buf[0..size];
let chunk_hash = ChunkGroupState::new(index)
.update(read_buf)
.finalize(finalization.is_root());
self.state.feed_chunk(&chunk_hash)?;
self.clear_buf();
false
}
NextRead::Done => true, }))
}
fn poll_fill_buffer_from_input(
&mut self,
size: usize,
cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
let mut buf = ReadBuf::new(&mut self.buf[..size]);
buf.advance(self.buf_end);
let src = &mut self.input;
while buf.remaining() > 0 {
ready!(AsyncRead::poll_read(Pin::new(src), cx, &mut buf))?;
self.buf_end = buf.filled().len();
}
Poll::Ready(Ok(()))
}
fn poll_fill_buffer_from_outboard(
&mut self,
size: usize,
cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
let mut buf = ReadBuf::new(&mut self.buf[..size]);
buf.advance(self.buf_end);
let src = self.outboard.as_mut().unwrap();
while buf.remaining() > 0 {
ready!(AsyncRead::poll_read(Pin::new(src), cx, &mut buf))?;
self.buf_end = buf.filled().len();
}
Poll::Ready(Ok(()))
}
fn poll_fill_buffer_from_input_or_outboard(
&mut self,
size: usize,
cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
if self.outboard.is_some() {
self.poll_fill_buffer_from_outboard(size, cx)
} else {
self.poll_fill_buffer_from_input(size, cx)
}
}
}
type BoxedDecoderShared<T, O> = Pin<Box<DecoderShared<T, O>>>;
#[derive(Debug)]
enum DecoderState<T: AsyncRead + Unpin, O: AsyncRead + Unpin> {
Reading(BoxedDecoderShared<T, O>),
Output(BoxedDecoderShared<T, O>),
Invalid,
}
impl<T: AsyncRead + Unpin, O: AsyncRead + Unpin> DecoderState<T, O> {
fn take(&mut self) -> Self {
std::mem::replace(self, DecoderState::Invalid)
}
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if buf.remaining() == 0 {
return Poll::Ready(Ok(()));
}
loop {
match self.take() {
Self::Reading(mut shared) => {
let res = shared.poll_input(cx);
if let Poll::Ready(Ok(())) = res {
*self = Self::Output(shared);
continue;
}
*self = Self::Reading(shared);
break res;
}
Self::Output(mut shared) => {
shared.write_output(buf);
*self = if shared.buf_len() == 0 {
shared.clear_buf();
Self::Reading(shared)
} else {
Self::Output(shared)
};
break Poll::Ready(Ok(()));
}
DecoderState::Invalid => {
break Poll::Ready(Ok(()));
}
}
}
}
}
#[derive(Debug)]
pub struct AsyncDecoder<T: AsyncRead + Unpin, O: AsyncRead + Unpin>(DecoderState<T, O>);
impl<T: AsyncRead + Unpin> AsyncDecoder<T, T> {
pub fn new(inner: T, hash: &Hash) -> Self {
let state = DecoderShared::new(inner, None, hash);
Self(DecoderState::Reading(Box::pin(state)))
}
}
impl<T: AsyncRead + Unpin, O: AsyncRead + Unpin> AsyncDecoder<T, O> {
pub fn new_outboard(inner: T, outboard: O, hash: &Hash) -> Self {
let state = DecoderShared::new(inner, Some(outboard), hash);
Self(DecoderState::Reading(Box::pin(state)))
}
}
impl<T: AsyncRead + Unpin, O: AsyncRead + Unpin> AsyncRead for AsyncDecoder<T, O> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}
#[derive(Clone, Debug)]
struct SliceDecoderInner<T: AsyncRead + Unpin> {
shared: DecoderShared<T, T>,
slice_start: u64,
slice_remaining: u64,
need_fake_read: bool,
}
impl<T: AsyncRead + Unpin> SliceDecoderInner<T> {
fn content_len(&self) -> Option<u64> {
self.shared.state.parser.content_len()
}
fn poll_input(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if self.shared.state.content_position() < self.slice_start {
loop {
let bookkeeping = self.shared.state.seek_next(self.slice_start);
let next = self.shared.state.seek_bookkeeping_done(bookkeeping);
let done = ready!(self.shared.poll_handle_seek_read(next, cx))?;
if done {
break;
}
}
debug_assert_eq!(0, self.shared.buf_len());
}
if self.need_fake_read {
ready!(self.shared.poll_input(cx))?;
self.shared.clear_buf();
self.need_fake_read = false;
} else if self.slice_remaining > 0 {
ready!(self.shared.poll_input(cx))?;
let len = self.shared.buf_len() as u64;
if len <= self.slice_remaining {
self.slice_remaining -= len;
} else {
self.shared.buf_end -= (len - self.slice_remaining) as usize;
self.slice_remaining = 0;
}
};
Poll::Ready(Ok(()))
}
}
#[derive(Clone, Debug)]
enum SliceDecoderState<T: AsyncRead + Unpin> {
Reading(Box<SliceDecoderInner<T>>),
Output(Box<SliceDecoderInner<T>>),
Taken,
}
impl<T: AsyncRead + Unpin> SliceDecoderState<T> {
fn take(&mut self) -> Self {
std::mem::replace(self, SliceDecoderState::Taken)
}
}
#[derive(Clone, Debug)]
pub struct AsyncSliceDecoder<T: AsyncRead + Unpin>(SliceDecoderState<T>);
impl<T: AsyncRead + Unpin> AsyncSliceDecoder<T> {
pub fn new(inner: T, hash: &Hash, slice_start: u64, slice_len: u64) -> Self {
let state = SliceDecoderInner {
shared: DecoderShared::new(inner, None, hash),
slice_start,
slice_remaining: slice_len,
need_fake_read: slice_len == 0,
};
Self(SliceDecoderState::Reading(Box::new(state)))
}
pub async fn read_size(&mut self) -> io::Result<u64> {
if let SliceDecoderState::Reading(state) = &mut self.0 {
future::poll_fn(|cx| state.shared.poll_read_header(cx)).await?;
}
Ok(match &self.0 {
SliceDecoderState::Reading(state) => state.content_len().unwrap(),
SliceDecoderState::Output(state) => state.content_len().unwrap(),
SliceDecoderState::Taken => unreachable!(),
})
}
pub fn into_inner(self) -> T {
match self.0 {
SliceDecoderState::Reading(state) => state.shared.input,
SliceDecoderState::Output(state) => state.shared.input,
SliceDecoderState::Taken => unreachable!(),
}
}
}
impl<T: AsyncRead + Unpin> AsyncRead for AsyncSliceDecoder<T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
loop {
match self.0.take() {
SliceDecoderState::Reading(mut state) => match state.poll_input(cx) {
Poll::Ready(Ok(())) => {
self.0 = SliceDecoderState::Output(state);
continue;
}
Poll::Ready(Err(e)) => {
self.0 = SliceDecoderState::Reading(state);
break Poll::Ready(Err(e));
}
Poll::Pending => {
self.0 = SliceDecoderState::Reading(state);
break Poll::Pending;
}
},
SliceDecoderState::Output(mut state) => {
state.shared.write_output(buf);
if state.shared.buf_len() == 0 {
state.shared.clear_buf();
self.0 = SliceDecoderState::Reading(state)
} else {
self.0 = SliceDecoderState::Output(state)
};
break Poll::Ready(Ok(()));
}
SliceDecoderState::Taken => {
unreachable!();
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::io::{Cursor, Read};
use tokio::io::AsyncReadExt;
use super::*;
use crate::{
decode::{make_test_input, SliceDecoder},
encode, GROUP_SIZE, HEADER_SIZE,
};
#[tokio::test]
async fn test_async_decode() {
for &case in crate::test::TEST_CASES {
use tokio::io::AsyncReadExt;
println!("case {case}");
let input = make_test_input(case);
let (encoded, hash) = { encode::encode(&input) };
let mut output = Vec::new();
let mut reader = AsyncDecoder::new(&encoded[..], &hash);
reader.read_to_end(&mut output).await.unwrap();
assert_eq!(input, output);
}
}
#[tokio::test]
async fn test_async_decode_outboard() {
for &case in crate::test::TEST_CASES {
use tokio::io::AsyncReadExt;
println!("case {case}");
let input = make_test_input(case);
let (outboard, hash) = { encode::outboard(&input) };
let mut output = Vec::new();
let mut reader = AsyncDecoder::new_outboard(&input[..], &outboard[..], &hash);
reader.read_to_end(&mut output).await.unwrap();
assert_eq!(input, output);
}
}
#[tokio::test]
async fn test_async_slices() {
for &case in crate::test::TEST_CASES {
let input = make_test_input(case);
let (encoded, hash) = encode::encode(&input);
let (outboard, outboard_hash) = encode::outboard(&input);
assert_eq!(hash, outboard_hash);
for &slice_start in crate::test::TEST_CASES {
let expected_start = cmp::min(input.len(), slice_start);
let slice_lens = [0, 1, 2, GROUP_SIZE - 1, GROUP_SIZE, GROUP_SIZE + 1];
for &slice_len in slice_lens.iter() {
println!("\ncase {case} start {slice_start} len {slice_len}");
let expected_end = cmp::min(input.len(), slice_start + slice_len);
let expected_output = &input[expected_start..expected_end];
let mut slice = Vec::new();
let mut extractor = encode::SliceExtractor::new(
Cursor::new(&encoded),
slice_start as u64,
slice_len as u64,
);
extractor.read_to_end(&mut slice).unwrap();
let mut slice_from_outboard = Vec::new();
let mut extractor = encode::SliceExtractor::new_outboard(
Cursor::new(&input),
Cursor::new(&outboard),
slice_start as u64,
slice_len as u64,
);
extractor.read_to_end(&mut slice_from_outboard).unwrap();
assert_eq!(slice, slice_from_outboard);
let mut output = Vec::new();
let mut reader = AsyncSliceDecoder::new(
&*slice,
&hash,
slice_start as u64,
slice_len as u64,
);
reader.read_to_end(&mut output).await.unwrap();
assert_eq!(expected_output, &*output);
}
}
}
}
#[tokio::test]
async fn test_async_corrupted_slice() {
let input = make_test_input(20_000);
let slice_start = 5_000;
let slice_len = 10_000;
let (encoded, hash) = encode::encode(&input);
let mut slice = Vec::new();
let mut extractor = encode::SliceExtractor::new(
Cursor::new(&encoded),
slice_start as u64,
slice_len as u64,
);
extractor.read_to_end(&mut slice).unwrap();
let mut output = Vec::new();
let mut reader =
SliceDecoder::new(&*slice, &hash, slice_start as u64, slice_len as u64);
reader.read_to_end(&mut output).unwrap();
assert_eq!(&input[slice_start..][..slice_len], &*output);
let (outboard, outboard_hash) = encode::outboard(&input);
assert_eq!(hash, outboard_hash);
let mut slice_from_outboard = Vec::new();
let mut extractor = encode::SliceExtractor::new_outboard(
Cursor::new(&input),
Cursor::new(&outboard),
slice_start as u64,
slice_len as u64,
);
extractor.read_to_end(&mut slice_from_outboard).unwrap();
assert_eq!(slice, slice_from_outboard);
let mut i = HEADER_SIZE;
while i < slice.len() {
let mut slice_clone = slice.clone();
slice_clone[i] ^= 1;
let mut reader = AsyncSliceDecoder::new(
&*slice_clone,
&hash,
slice_start as u64,
slice_len as u64,
);
output.clear();
let err = reader.read_to_end(&mut output).await.unwrap_err();
assert_eq!(io::ErrorKind::InvalidData, err.kind());
i += 32;
}
}
}
}
#[cfg(feature = "tokio_io")]
pub use tokio_io::{AsyncDecoder, AsyncSliceDecoder};
#[derive(Clone, Debug)]
pub struct Decoder<T: Read, O: Read> {
shared: DecoderShared<T, O>,
}
impl<T: Read> Decoder<T, T> {
pub fn new(inner: T, hash: &Hash) -> Self {
Self {
shared: DecoderShared::new(inner, None, hash),
}
}
}
impl<T: Read, O: Read> Decoder<T, O> {
pub fn new_outboard(inner: T, outboard: O, hash: &Hash) -> Self {
Self {
shared: DecoderShared::new(inner, Some(outboard), hash),
}
}
pub fn into_inner(self) -> (T, Option<O>) {
(self.shared.input, self.shared.outboard)
}
}
impl<T: Read, O: Read> Read for Decoder<T, O> {
fn read(&mut self, output: &mut [u8]) -> io::Result<usize> {
self.shared.read(output)
}
}
impl<T: Read + Seek, O: Read + Seek> Seek for Decoder<T, O> {
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
self.shared.clear_buf();
let seek_to = match pos {
SeekFrom::Start(offset) => offset,
SeekFrom::End(offset) => {
let content_len = loop {
match self.shared.state.len_next() {
encode::LenNext::Seek(bookkeeping) => {
let next_read = self.shared.handle_seek_bookkeeping(bookkeeping)?;
let done = self.shared.handle_seek_read(next_read)?;
debug_assert!(!done);
}
encode::LenNext::Len(len) => break len,
}
};
add_offset(content_len, offset)?
}
SeekFrom::Current(offset) => {
add_offset(self.shared.adjusted_content_position(), offset)?
}
};
loop {
let bookkeeping = self.shared.state.seek_next(seek_to);
let next_read = self.shared.handle_seek_bookkeeping(bookkeeping)?;
let done = self.shared.handle_seek_read(next_read)?;
if done {
return Ok(seek_to);
}
}
}
}
fn add_offset(position: u64, offset: i64) -> io::Result<u64> {
let sum = position as i128 + offset as i128;
if sum < 0 {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"seek before beginning",
))
} else if sum > u64::max_value() as i128 {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"seek target overflowed u64",
))
} else {
Ok(sum as u64)
}
}
#[derive(Debug)]
pub struct SliceDecoder<T: Read> {
shared: DecoderShared<T, T>,
slice_start: u64,
slice_remaining: u64,
need_fake_read: bool,
}
impl<T: Read> SliceDecoder<T> {
pub fn new(inner: T, hash: &Hash, slice_start: u64, slice_len: u64) -> Self {
Self {
shared: DecoderShared::new(inner, None, hash),
slice_start,
slice_remaining: slice_len,
need_fake_read: slice_len == 0,
}
}
pub fn into_inner(self) -> T {
self.shared.input
}
}
impl<T: Read> Read for SliceDecoder<T> {
fn read(&mut self, output: &mut [u8]) -> io::Result<usize> {
if self.shared.state.content_position() < self.slice_start {
loop {
let bookkeeping = self.shared.state.seek_next(self.slice_start);
let next = self.shared.state.seek_bookkeeping_done(bookkeeping);
let done = self.shared.handle_seek_read(next)?;
if done {
break;
}
}
debug_assert_eq!(0, self.shared.buf_len());
}
if self.need_fake_read {
self.shared.read(&mut [0])?;
self.need_fake_read = false;
Ok(0)
} else {
let cap = cmp::min(self.slice_remaining, output.len() as u64) as usize;
let capped_output = &mut output[..cap];
let n = self.shared.read(capped_output)?;
self.slice_remaining -= n as u64;
Ok(n)
}
}
}
#[cfg(test)]
pub(crate) fn make_test_input(len: usize) -> Vec<u8> {
let mut ret = Vec::new();
let mut counter = 0u64;
while ret.len() < len {
if counter < u8::max_value() as u64 {
ret.push(counter as u8);
} else if counter < u16::max_value() as u64 {
ret.extend_from_slice(&(counter as u16).to_be_bytes());
} else if counter < u32::max_value() as u64 {
ret.extend_from_slice(&(counter as u32).to_be_bytes());
} else {
ret.extend_from_slice(&counter.to_be_bytes());
}
counter += 1;
}
ret.truncate(len);
ret
}
#[cfg(test)]
mod test {
use rand::prelude::*;
use rand_chacha::ChaChaRng;
use std::io;
use std::io::prelude::*;
use std::io::Cursor;
use super::*;
use crate::encode;
#[test]
fn test_decode() {
for &case in crate::test::TEST_CASES {
println!("case {case}");
let input = make_test_input(case);
let (encoded, hash) = { encode::encode(&input) };
let output = decode(&encoded, &hash).unwrap();
assert_eq!(input, output);
assert_eq!(output.len(), output.capacity());
}
}
#[test]
fn test_decode_outboard() {
for &case in crate::test::TEST_CASES {
println!("case {case}");
let input = make_test_input(case);
let (outboard, hash) = { encode::outboard(&input) };
let mut output = Vec::new();
let mut reader = Decoder::new_outboard(&input[..], &outboard[..], &hash);
reader.read_to_end(&mut output).unwrap();
assert_eq!(input, output);
}
}
#[test]
fn test_decoders_corrupted() {
for &case in crate::test::TEST_CASES {
println!("case {case}");
let input = make_test_input(case);
let (encoded, hash) = encode::encode(&input);
let mut tweaks = Vec::new();
if encoded.len() > HEADER_SIZE {
tweaks.push(HEADER_SIZE);
}
if encoded.len() > HEADER_SIZE + PARENT_SIZE {
tweaks.push(HEADER_SIZE + PARENT_SIZE);
}
if encoded.len() > GROUP_SIZE {
tweaks.push(GROUP_SIZE);
}
for tweak in tweaks {
println!("tweak {tweak}");
let mut bad_encoded = encoded.clone();
bad_encoded[tweak] ^= 1;
let err = decode(&bad_encoded, &hash).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
}
}
}
#[test]
fn test_seek() {
for &input_len in crate::test::TEST_CASES {
println!();
println!("input_len {input_len}");
let input = make_test_input(input_len);
let (encoded, hash) = encode::encode(&input);
for &seek in crate::test::TEST_CASES {
println!("seek {seek}");
let mut seek_froms = Vec::new();
seek_froms.push(SeekFrom::Start(seek as u64));
seek_froms.push(SeekFrom::End(seek as i64 - input_len as i64));
seek_froms.push(SeekFrom::Current(seek as i64));
for seek_from in seek_froms {
println!("seek_from {seek_from:?}");
let mut decoder = Decoder::new(Cursor::new(&encoded), &hash);
let mut output = Vec::new();
decoder.seek(seek_from).expect("seek error");
decoder.read_to_end(&mut output).expect("decoder error");
let input_start = cmp::min(seek, input.len());
assert_eq!(
&input[input_start..],
&output[..],
"output doesn't match input"
);
}
}
}
}
#[test]
fn test_repeated_random_seeks() {
let input_len = 0b100101 * GROUP_SIZE;
println!("\n\ninput_len {input_len}");
let mut prng = ChaChaRng::from_seed([0; 32]);
let input = make_test_input(input_len);
let (encoded, hash) = encode::encode(&input);
let mut decoder = Decoder::new(Cursor::new(&encoded), &hash);
for _ in 0..1000 {
let seek = prng.gen_range(0..input_len + 1);
println!("\nseek {seek}");
decoder
.seek(SeekFrom::Start(seek as u64))
.expect("seek error");
let mut output = Vec::new();
decoder
.clone()
.take(GROUP_SIZE as u64)
.read_to_end(&mut output)
.expect("decoder error");
let input_start = cmp::min(seek, input_len);
let input_end = cmp::min(input_start + GROUP_SIZE, input_len);
assert_eq!(
&input[input_start..input_end],
&output[..],
"output doesn't match input"
);
}
}
#[test]
fn test_invalid_zero_length() {
let (zero_encoded, zero_hash) = encode::encode(b"");
let one_hash = blake3::hash(b"x");
let mut output = Vec::new();
let mut decoder = Decoder::new(&*zero_encoded, &zero_hash);
decoder.read_to_end(&mut output).unwrap();
assert_eq!(output.len(), 0);
let mut output = Vec::new();
let mut decoder = Decoder::new(&*zero_encoded, &one_hash);
let result = decoder.read_to_end(&mut output);
assert!(result.is_err(), "a bad hash is supposed to fail!");
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
}
#[test]
fn test_seeking_around_invalid_data() {
for &case in crate::test::TEST_CASES {
if case <= 2 * GROUP_SIZE {
continue;
}
println!("\ncase {case}");
let input = make_test_input(case);
let (mut encoded, hash) = encode::encode(&input);
println!("encoded len {}", encoded.len());
let tweak_chunk = encode::count_chunks(case as u64) / 2;
let tweak_position = tweak_chunk as usize * GROUP_SIZE;
println!("tweak position {tweak_position}");
let mut tweak_encoded_offset = HEADER_SIZE;
for chunk in 0..tweak_chunk {
tweak_encoded_offset +=
encode::pre_order_parent_nodes(chunk, case as u64) as usize * PARENT_SIZE;
tweak_encoded_offset += GROUP_SIZE;
}
tweak_encoded_offset +=
encode::pre_order_parent_nodes(tweak_chunk, case as u64) as usize * PARENT_SIZE;
println!("tweak encoded offset {tweak_encoded_offset}");
encoded[tweak_encoded_offset] ^= 1;
let mut decoder = Decoder::new(Cursor::new(&encoded), &hash);
let mut output = vec![0; tweak_position];
decoder.read_exact(&mut output).unwrap();
assert_eq!(&input[..tweak_position], &*output);
let mut buf = [0; GROUP_SIZE];
let res = decoder.read(&mut buf);
assert_eq!(res.unwrap_err().kind(), io::ErrorKind::InvalidData);
let new_start = tweak_position + GROUP_SIZE;
decoder.seek(SeekFrom::Start(new_start as u64)).unwrap();
let mut output = Vec::new();
decoder.read_to_end(&mut output).unwrap();
assert_eq!(&input[new_start..], &*output);
}
}
#[test]
fn test_invalid_eof_seek() {
for &case in crate::test::TEST_CASES {
let input = make_test_input(case);
let (encoded, hash) = encode::encode(&input);
let mut output = Vec::new();
let mut decoder = Decoder::new(Cursor::new(&encoded), &hash);
decoder.seek(SeekFrom::Start(case as u64)).unwrap();
decoder.read_to_end(&mut output).unwrap();
assert_eq!(output.len(), 0);
let mut bad_hash_bytes = *hash.as_bytes();
bad_hash_bytes[0] ^= 1;
let bad_hash = bad_hash_bytes.into();
let mut decoder = Decoder::new(Cursor::new(&encoded), &bad_hash);
let result = decoder.seek(SeekFrom::Start(case as u64));
assert!(result.is_err(), "a bad hash is supposed to fail!");
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
if case > 0 {
let mut bad_encoded = encoded.clone();
*bad_encoded.last_mut().unwrap() ^= 1;
let mut decoder = Decoder::new(Cursor::new(&bad_encoded), &hash);
let result = decoder.seek(SeekFrom::Start(case as u64));
assert!(result.is_err(), "a bad hash is supposed to fail!");
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
}
}
}
#[test]
fn test_slices() {
for &case in crate::test::TEST_CASES {
let input = make_test_input(case);
let (encoded, hash) = encode::encode(&input);
let (outboard, outboard_hash) = encode::outboard(&input);
assert_eq!(hash, outboard_hash);
for &slice_start in crate::test::TEST_CASES {
let expected_start = cmp::min(input.len(), slice_start);
let slice_lens = [0, 1, 2, GROUP_SIZE - 1, GROUP_SIZE, GROUP_SIZE + 1];
for &slice_len in slice_lens.iter() {
println!("\ncase {case} start {slice_start} len {slice_len}");
let expected_end = cmp::min(input.len(), slice_start + slice_len);
let expected_output = &input[expected_start..expected_end];
let mut slice = Vec::new();
let mut extractor = encode::SliceExtractor::new(
Cursor::new(&encoded),
slice_start as u64,
slice_len as u64,
);
extractor.read_to_end(&mut slice).unwrap();
let mut slice_from_outboard = Vec::new();
let mut extractor = encode::SliceExtractor::new_outboard(
Cursor::new(&input),
Cursor::new(&outboard),
slice_start as u64,
slice_len as u64,
);
extractor.read_to_end(&mut slice_from_outboard).unwrap();
assert_eq!(slice, slice_from_outboard);
let mut output = Vec::new();
let mut reader =
SliceDecoder::new(&*slice, &hash, slice_start as u64, slice_len as u64);
reader.read_to_end(&mut output).unwrap();
assert_eq!(expected_output, &*output);
}
}
}
}
#[test]
fn test_corrupted_slice() {
let input = make_test_input(20_000);
let slice_start = 5_000;
let slice_len = 10_000;
let (encoded, hash) = encode::encode(&input);
let mut slice = Vec::new();
let mut extractor = encode::SliceExtractor::new(
Cursor::new(&encoded),
slice_start as u64,
slice_len as u64,
);
extractor.read_to_end(&mut slice).unwrap();
let mut output = Vec::new();
let mut reader = SliceDecoder::new(&*slice, &hash, slice_start as u64, slice_len as u64);
reader.read_to_end(&mut output).unwrap();
assert_eq!(&input[slice_start..][..slice_len], &*output);
let (outboard, outboard_hash) = encode::outboard(&input);
assert_eq!(hash, outboard_hash);
let mut slice_from_outboard = Vec::new();
let mut extractor = encode::SliceExtractor::new_outboard(
Cursor::new(&input),
Cursor::new(&outboard),
slice_start as u64,
slice_len as u64,
);
extractor.read_to_end(&mut slice_from_outboard).unwrap();
assert_eq!(slice, slice_from_outboard);
let mut i = HEADER_SIZE;
while i < slice.len() {
let mut slice_clone = slice.clone();
slice_clone[i] ^= 1;
let mut reader =
SliceDecoder::new(&*slice_clone, &hash, slice_start as u64, slice_len as u64);
output.clear();
let err = reader.read_to_end(&mut output).unwrap_err();
assert_eq!(io::ErrorKind::InvalidData, err.kind());
i += 32;
}
}
#[test]
fn test_slice_entire() {
for &case in crate::test::TEST_CASES {
println!("case {case}");
let input = make_test_input(case);
let (encoded, _) = encode::encode(&input);
let (outboard, _) = encode::outboard(&input);
let mut slice = Vec::new();
let mut extractor = encode::SliceExtractor::new_outboard(
Cursor::new(&input),
Cursor::new(&outboard),
0,
case as u64,
);
extractor.read_to_end(&mut slice).unwrap();
assert_eq!(encoded, slice);
}
}
#[test]
fn test_into_inner() {
let v = vec![1u8, 2, 3];
let hash = [0; 32].into();
let decoder = Decoder::new(io::Cursor::new(v.clone()), &hash);
let (inner_reader, outboard_reader) = decoder.into_inner();
assert!(outboard_reader.is_none());
let slice_decoder = SliceDecoder::new(inner_reader, &hash, 0, 0);
assert_eq!(slice_decoder.into_inner().into_inner(), v);
let outboard_decoder = Decoder::new_outboard(&b""[..], &b""[..], &hash);
let (_, outboard_reader) = outboard_decoder.into_inner();
assert!(outboard_reader.is_some());
}
}