use std::{
io::{self, Read, Seek, SeekFrom, Write},
ops::Range,
result,
};
use blake3::guts::parent_cv;
use bytes::BytesMut;
use range_collections::{range_set::RangeSetRange, RangeSet2, RangeSetRef};
use smallvec::SmallVec;
use crate::{
hash_block, hash_chunk,
io::error::{DecodeError, EncodeError},
iter::{BaoChunk, PreOrderChunkIterRef},
outboard::{Outboard, OutboardMut},
range_ok, BaoTree, BlockSize, ByteNum, ChunkNum, TreeNode,
};
#[allow(clippy::len_without_is_empty)]
pub trait SliceReader {
fn read_at(&mut self, offset: u64, buf: &mut [u8]) -> io::Result<()>;
fn len(&mut self) -> io::Result<u64>;
}
impl<R: Read + Seek> SliceReader for R {
fn read_at(&mut self, offset: u64, buf: &mut [u8]) -> io::Result<()> {
self.seek(SeekFrom::Start(offset))?;
self.read_exact(buf)
}
fn len(&mut self) -> io::Result<u64> {
self.seek(SeekFrom::End(0))
}
}
pub trait SliceWriter {
fn write_at(&mut self, offset: u64, src: &[u8]) -> io::Result<()>;
}
impl<W: Write + Seek> SliceWriter for W {
fn write_at(&mut self, offset: u64, src: &[u8]) -> io::Result<()> {
self.seek(SeekFrom::Start(offset))?;
self.write_all(src)
}
}
use super::{DecodeResponseItem, Header, Leaf, Parent};
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
enum Position<'a> {
Header {
ranges: &'a RangeSetRef<ChunkNum>,
block_size: BlockSize,
},
Content { iter: PreOrderChunkIterRef<'a> },
}
#[derive(Debug)]
pub struct DecodeResponseIter<'a, R> {
inner: Position<'a>,
stack: SmallVec<[blake3::Hash; 10]>,
encoded: R,
buf: BytesMut,
}
impl<'a, R: Read> DecodeResponseIter<'a, R> {
pub fn new(
root: blake3::Hash,
block_size: BlockSize,
encoded: R,
ranges: &'a RangeSetRef<ChunkNum>,
buf: BytesMut,
) -> Self {
let mut stack = SmallVec::new();
stack.push(root);
Self {
stack,
inner: Position::Header { ranges, block_size },
encoded,
buf,
}
}
pub fn buffer(&self) -> &[u8] {
&self.buf
}
pub fn tree(&self) -> Option<&BaoTree> {
match &self.inner {
Position::Content { iter } => Some(iter.tree()),
Position::Header { .. } => None,
}
}
fn next0(&mut self) -> result::Result<Option<DecodeResponseItem>, DecodeError> {
let inner = match &mut self.inner {
Position::Content { ref mut iter } => iter,
Position::Header {
block_size,
ranges: range,
} => {
let size = read_len(&mut self.encoded)?;
if !range_ok(range, size.chunks()) {
return Err(DecodeError::InvalidQueryRange);
}
let tree = BaoTree::new(size, *block_size);
self.inner = Position::Content {
iter: tree.ranges_pre_order_chunks_iter_ref(range, 0),
};
return Ok(Some(Header { size }.into()));
}
};
match inner.next() {
Some(BaoChunk::Parent {
is_root,
left,
right,
node,
}) => {
let pair @ (l_hash, r_hash) = read_parent(&mut self.encoded)?;
let parent_hash = self.stack.pop().unwrap();
let actual = parent_cv(&l_hash, &r_hash, is_root);
if parent_hash != actual {
return Err(DecodeError::ParentHashMismatch(node));
}
if right {
self.stack.push(r_hash);
}
if left {
self.stack.push(l_hash);
}
Ok(Some(Parent { node, pair }.into()))
}
Some(BaoChunk::Leaf {
size,
is_root,
start_chunk,
}) => {
self.buf.resize(size, 0);
self.encoded.read_exact(&mut self.buf)?;
let actual = hash_block(start_chunk, &self.buf, is_root);
let leaf_hash = self.stack.pop().unwrap();
if leaf_hash != actual {
return Err(DecodeError::LeafHashMismatch(start_chunk));
}
Ok(Some(
Leaf {
offset: start_chunk.to_bytes(),
data: self.buf.split().freeze(),
}
.into(),
))
}
None => Ok(None),
}
}
}
impl<'a, R: Read> Iterator for DecodeResponseIter<'a, R> {
type Item = result::Result<DecodeResponseItem, DecodeError>;
fn next(&mut self) -> Option<Self::Item> {
self.next0().transpose()
}
}
pub fn encode_ranges<D: SliceReader, O: Outboard, W: Write>(
data: D,
outboard: O,
ranges: &RangeSetRef<ChunkNum>,
encoded: W,
) -> result::Result<(), EncodeError> {
let mut data = data;
let mut encoded = encoded;
let file_len = ByteNum(data.len()?);
let tree = outboard.tree();
let ob_len = tree.size;
if file_len != ob_len {
return Err(EncodeError::SizeMismatch);
}
if !range_ok(ranges, tree.chunks()) {
return Err(EncodeError::InvalidQueryRange);
}
let mut buffer = vec![0u8; tree.chunk_group_bytes().to_usize()];
encoded.write_all(tree.size.0.to_le_bytes().as_slice())?;
for item in tree.ranges_pre_order_chunks_iter_ref(ranges, 0) {
match item {
BaoChunk::Parent { node, .. } => {
let (l_hash, r_hash) = outboard.load(node)?.unwrap();
encoded.write_all(l_hash.as_bytes())?;
encoded.write_all(r_hash.as_bytes())?;
}
BaoChunk::Leaf {
start_chunk, size, ..
} => {
let start = start_chunk.to_bytes();
let buf = &mut buffer[..size];
data.read_at(start.0, buf)?;
encoded.write_all(buf)?;
}
}
}
Ok(())
}
pub fn encode_ranges_validated<D: SliceReader, O: Outboard, W: Write>(
data: D,
outboard: O,
ranges: &RangeSetRef<ChunkNum>,
encoded: W,
) -> result::Result<(), EncodeError> {
let mut stack = SmallVec::<[blake3::Hash; 10]>::new();
stack.push(outboard.root());
let mut data = data;
let mut encoded = encoded;
let file_len = ByteNum(data.len()?);
let tree = outboard.tree();
let ob_len = tree.size;
if file_len != ob_len {
return Err(EncodeError::SizeMismatch);
}
if !range_ok(ranges, tree.chunks()) {
return Err(EncodeError::InvalidQueryRange);
}
let mut buffer = vec![0u8; tree.chunk_group_bytes().to_usize()];
encoded.write_all(tree.size.0.to_le_bytes().as_slice())?;
for item in tree.ranges_pre_order_chunks_iter_ref(ranges, 0) {
match item {
BaoChunk::Parent {
is_root,
left,
right,
node,
} => {
let (l_hash, r_hash) = outboard.load(node)?.unwrap();
let actual = parent_cv(&l_hash, &r_hash, is_root);
let expected = stack.pop().unwrap();
if actual != expected {
return Err(EncodeError::ParentHashMismatch(node));
}
if right {
stack.push(r_hash);
}
if left {
stack.push(l_hash);
}
encoded.write_all(l_hash.as_bytes())?;
encoded.write_all(r_hash.as_bytes())?;
}
BaoChunk::Leaf {
start_chunk,
size,
is_root,
} => {
let expected = stack.pop().unwrap();
let start = start_chunk.to_bytes();
let buf = &mut buffer[..size];
data.read_at(start.0, buf)?;
let actual = hash_block(start_chunk, buf, is_root);
if actual != expected {
return Err(EncodeError::LeafHashMismatch(start_chunk));
}
encoded.write_all(buf)?;
}
}
}
Ok(())
}
pub async fn decode_response_into<R, O, W>(
ranges: &RangeSetRef<ChunkNum>,
encoded: R,
mut outboard: O,
mut target: W,
) -> io::Result<()>
where
O: OutboardMut,
R: Read,
W: SliceWriter,
{
let block_size = outboard.tree().block_size;
let buffer = BytesMut::with_capacity(block_size.bytes());
let iter = DecodeResponseIter::new(outboard.root(), block_size, encoded, ranges, buffer);
for item in iter {
match item? {
DecodeResponseItem::Header(Header { size }) => {
outboard.set_size(size)?;
}
DecodeResponseItem::Parent(Parent { node, pair }) => {
outboard.save(node, &pair)?;
}
DecodeResponseItem::Leaf(Leaf { offset, data }) => {
target.write_at(offset.0, &data)?;
}
}
}
Ok(())
}
pub fn write_ranges(
from: impl AsRef<[u8]>,
mut to: impl Write + Seek,
ranges: &RangeSetRef<u64>,
) -> io::Result<()> {
let from = from.as_ref();
let end = from.len() as u64;
for range in ranges.iter() {
let range = match range {
RangeSetRange::RangeFrom(x) => *x.start..end,
RangeSetRange::Range(x) => *x.start..*x.end,
};
let start = usize::try_from(range.start).unwrap();
let end = usize::try_from(range.end).unwrap();
to.seek(SeekFrom::Start(range.start))?;
to.write_all(&from[start..end])?;
}
Ok(())
}
pub fn outboard_post_order(
data: &mut impl Read,
size: u64,
block_size: BlockSize,
outboard: &mut impl Write,
) -> io::Result<blake3::Hash> {
let tree = BaoTree::new_with_start_chunk(ByteNum(size), block_size, ChunkNum(0));
let mut buffer = vec![0; tree.chunk_group_bytes().to_usize()];
let hash = outboard_post_order_impl(tree, data, outboard, &mut buffer)?;
outboard.write_all(&size.to_le_bytes())?;
Ok(hash)
}
pub(crate) fn outboard_post_order_impl(
tree: BaoTree,
data: &mut impl Read,
outboard: &mut impl Write,
buffer: &mut [u8],
) -> io::Result<blake3::Hash> {
let mut stack = SmallVec::<[blake3::Hash; 10]>::new();
debug_assert!(buffer.len() == tree.chunk_group_bytes().to_usize());
for item in tree.post_order_chunks_iter() {
match item {
BaoChunk::Parent { is_root, .. } => {
let right_hash = stack.pop().unwrap();
let left_hash = stack.pop().unwrap();
outboard.write_all(left_hash.as_bytes())?;
outboard.write_all(right_hash.as_bytes())?;
let parent = parent_cv(&left_hash, &right_hash, is_root);
stack.push(parent);
}
BaoChunk::Leaf {
size,
is_root,
start_chunk,
} => {
let buf = &mut buffer[..size];
data.read_exact(buf)?;
let hash = hash_block(start_chunk, buf, is_root);
stack.push(hash);
}
}
}
debug_assert_eq!(stack.len(), 1);
let hash = stack.pop().unwrap();
Ok(hash)
}
pub(crate) fn blake3_hash_inner(
mut data: impl Read,
data_len: ByteNum,
start_chunk: ChunkNum,
is_root: bool,
buf: &mut [u8],
) -> std::io::Result<blake3::Hash> {
let can_be_root = is_root;
let mut stack = SmallVec::<[blake3::Hash; 10]>::new();
let tree = BaoTree::new_with_start_chunk(data_len, BlockSize(0), start_chunk);
for item in tree.post_order_chunks_iter() {
match item {
BaoChunk::Leaf {
size,
is_root,
start_chunk,
} => {
let buf = &mut buf[..size];
data.read_exact(buf)?;
let hash = hash_chunk(start_chunk, buf, can_be_root && is_root);
stack.push(hash);
}
BaoChunk::Parent { is_root, .. } => {
let right_hash = stack.pop().unwrap();
let left_hash = stack.pop().unwrap();
let hash = parent_cv(&left_hash, &right_hash, can_be_root && is_root);
stack.push(hash);
}
}
}
debug_assert_eq!(stack.len(), 1);
Ok(stack.pop().unwrap())
}
fn read_len(from: &mut impl Read) -> std::io::Result<ByteNum> {
let mut buf = [0; 8];
from.read_exact(&mut buf)?;
let len = ByteNum(u64::from_le_bytes(buf));
Ok(len)
}
fn read_parent(from: &mut impl Read) -> std::io::Result<(blake3::Hash, blake3::Hash)> {
let mut buf = [0; 64];
from.read_exact(&mut buf)?;
let l_hash = blake3::Hash::from(<[u8; 32]>::try_from(&buf[..32]).unwrap());
let r_hash = blake3::Hash::from(<[u8; 32]>::try_from(&buf[32..]).unwrap());
Ok((l_hash, r_hash))
}
fn read_range<'a>(
from: &mut (impl Read + Seek),
range: Range<ByteNum>,
buf: &'a mut [u8],
) -> std::io::Result<&'a [u8]> {
let len = (range.end - range.start).to_usize();
from.seek(std::io::SeekFrom::Start(range.start.0))?;
let buf = &mut buf[..len];
from.read_exact(buf)?;
Ok(buf)
}
pub fn valid_file_ranges<O, R>(outboard: &O, reader: R) -> io::Result<RangeSet2<ChunkNum>>
where
O: Outboard,
R: Read + Seek,
{
struct RecursiveValidator<'a, O: Outboard, R: Read + Seek> {
tree: BaoTree,
valid_nodes: TreeNode,
res: RangeSet2<ChunkNum>,
outboard: &'a O,
reader: R,
buffer: Vec<u8>,
}
impl<'a, O: Outboard, R: Read + Seek> RecursiveValidator<'a, O, R> {
fn validate_rec(
&mut self,
parent_hash: &blake3::Hash,
node: TreeNode,
is_root: bool,
) -> io::Result<()> {
if let Some((l_hash, r_hash)) = self.outboard.load(node)? {
let actual = parent_cv(&l_hash, &r_hash, is_root);
if &actual != parent_hash {
return Ok(());
}
if let Some(leaf) = node.as_leaf() {
let (s, m, e) = self.tree.leaf_byte_ranges3(leaf);
let l_data = read_range(&mut self.reader, s..m, &mut self.buffer)?;
let actual = hash_block(s.chunks(), l_data, false);
if actual == l_hash {
self.res |= RangeSet2::from(s.chunks()..m.chunks());
}
let r_data = read_range(&mut self.reader, m..e, &mut self.buffer)?;
let actual = hash_block(m.chunks(), r_data, false);
if actual == r_hash {
self.res |= RangeSet2::from(m.chunks()..e.chunks());
}
} else {
let left = node.left_child().unwrap();
self.validate_rec(&l_hash, left, false)?;
let right = node.right_descendant(self.valid_nodes).unwrap();
self.validate_rec(&r_hash, right, false)?;
}
} else if let Some(leaf) = node.as_leaf() {
let (s, m, _) = self.tree.leaf_byte_ranges3(leaf);
let l_data = read_range(&mut self.reader, s..m, &mut self.buffer)?;
let actual = hash_block(s.chunks(), l_data, is_root);
if actual == *parent_hash {
self.res |= RangeSet2::from(s.chunks()..m.chunks());
}
};
Ok(())
}
}
let tree = outboard.tree();
let root_hash = outboard.root();
let mut validator = RecursiveValidator {
tree,
valid_nodes: tree.filled_size(),
res: RangeSet2::empty(),
outboard,
reader,
buffer: vec![0; tree.block_size.bytes()],
};
validator.validate_rec(&root_hash, tree.root(), true)?;
Ok(validator.res)
}