use std::{future::Future, result};
use blake3;
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
use super::{sync::Outboard, EncodeError, Leaf, Parent};
use crate::{
hash_subtree, iter::BaoChunk, parent_cv, rec::truncate_ranges, split_inner, ChunkNum,
ChunkRangesRef, TreeNode,
};
#[derive(Debug, Serialize, Deserialize)]
pub enum EncodedItem {
Size(u64),
Parent(Parent),
Leaf(Leaf),
Error(EncodeError),
Done,
}
impl From<Leaf> for EncodedItem {
fn from(l: Leaf) -> Self {
Self::Leaf(l)
}
}
impl From<Parent> for EncodedItem {
fn from(p: Parent) -> Self {
Self::Parent(p)
}
}
impl From<EncodeError> for EncodedItem {
fn from(e: EncodeError) -> Self {
Self::Error(e)
}
}
pub trait Sender {
type Error;
fn send(
&mut self,
item: EncodedItem,
) -> impl Future<Output = std::result::Result<(), Self::Error>> + '_;
}
impl Sender for tokio::sync::mpsc::Sender<EncodedItem> {
type Error = tokio::sync::mpsc::error::SendError<EncodedItem>;
fn send(
&mut self,
item: EncodedItem,
) -> impl Future<Output = std::result::Result<(), Self::Error>> + '_ {
tokio::sync::mpsc::Sender::send(self, item)
}
}
pub async fn traverse_ranges_validated<D, O, F>(
data: D,
outboard: O,
ranges: &ChunkRangesRef,
send: &mut F,
) -> std::result::Result<(), F::Error>
where
D: ReadBytesAt,
O: Outboard,
F: Sender,
{
send.send(EncodedItem::Size(outboard.tree().size())).await?;
let res = match traverse_ranges_validated_impl(data, outboard, ranges, send).await {
Ok(Ok(())) => EncodedItem::Done,
Err(cause) => EncodedItem::Error(cause),
Ok(Err(err)) => return Err(err),
};
send.send(res).await
}
async fn traverse_ranges_validated_impl<D, O, F>(
data: D,
outboard: O,
ranges: &ChunkRangesRef,
send: &mut F,
) -> result::Result<std::result::Result<(), F::Error>, EncodeError>
where
D: ReadBytesAt,
O: Outboard,
F: Sender,
{
if ranges.is_empty() {
return Ok(Ok(()));
}
let mut stack: SmallVec<[_; 10]> = SmallVec::<[blake3::Hash; 10]>::new();
stack.push(outboard.root());
let data = data;
let tree = outboard.tree();
let ranges = truncate_ranges(ranges, tree.size());
for item in tree.ranges_pre_order_chunks_iter_ref(ranges, 0) {
match item {
BaoChunk::Parent {
is_root,
left,
right,
node,
..
} => {
let (l_hash, r_hash) = outboard.load(node)?.unwrap();
let actual = parent_cv(&l_hash, &r_hash, is_root);
let expected = stack.pop().unwrap();
if actual != expected {
return Err(EncodeError::ParentHashMismatch(node));
}
if right {
stack.push(r_hash);
}
if left {
stack.push(l_hash);
}
let item = Parent {
node,
pair: (l_hash, r_hash),
};
if let Err(e) = send.send(item.into()).await {
return Ok(Err(e));
}
}
BaoChunk::Leaf {
start_chunk,
size,
is_root,
ranges,
..
} => {
let expected = stack.pop().unwrap();
let start = start_chunk.to_bytes();
let buffer = data.read_bytes_at(start, size)?;
if !ranges.is_all() {
let mut out_buf = Vec::new();
let actual = traverse_selected_rec(
start_chunk,
buffer,
is_root,
ranges,
tree.block_size.to_u32(),
true,
&mut out_buf,
);
if actual != expected {
return Err(EncodeError::LeafHashMismatch(start_chunk));
}
for item in out_buf.into_iter() {
if let Err(e) = send.send(item).await {
return Ok(Err(e));
}
}
} else {
let actual = hash_subtree(start_chunk.0, &buffer, is_root);
#[allow(clippy::redundant_slicing)]
if actual != expected {
return Err(EncodeError::LeafHashMismatch(start_chunk));
}
let item = Leaf {
data: buffer,
offset: start_chunk.to_bytes(),
};
if let Err(e) = send.send(item.into()).await {
return Ok(Err(e));
}
};
}
}
}
Ok(Ok(()))
}
pub fn traverse_selected_rec(
start_chunk: ChunkNum,
data: Bytes,
is_root: bool,
query: &ChunkRangesRef,
min_level: u32,
emit_data: bool,
res: &mut Vec<EncodedItem>,
) -> blake3::Hash {
use blake3::CHUNK_LEN;
if data.len() <= CHUNK_LEN {
if emit_data && !query.is_empty() {
res.push(
Leaf {
data: data.clone(),
offset: start_chunk.to_bytes(),
}
.into(),
);
}
hash_subtree(start_chunk.0, &data, is_root)
} else {
let chunks = data.len() / CHUNK_LEN + (data.len() % CHUNK_LEN != 0) as usize;
let chunks = chunks.next_power_of_two();
let level = chunks.trailing_zeros() - 1;
let mid = chunks / 2;
let mid_bytes = mid * CHUNK_LEN;
let mid_chunk = start_chunk + (mid as u64);
let (l_ranges, r_ranges) = split_inner(query, start_chunk, mid_chunk);
let full = query.is_all();
let emit_parent = !query.is_empty() && (!full || level >= min_level);
let hash_offset = if emit_parent {
let pair = Parent {
node: TreeNode(0),
pair: ([0; 32].into(), [0; 32].into()),
};
res.push(pair.into());
Some(res.len() - 1)
} else {
None
};
let left = traverse_selected_rec(
start_chunk,
data.slice(..mid_bytes),
false,
l_ranges,
min_level,
emit_data,
res,
);
let right = traverse_selected_rec(
mid_chunk,
data.slice(mid_bytes..),
false,
r_ranges,
min_level,
emit_data,
res,
);
if let Some(o) = hash_offset {
let node = TreeNode(0);
res[o] = Parent {
node,
pair: (left, right),
}
.into();
}
parent_cv(&left, &right, is_root)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
io::{outboard::PreOrderMemOutboard, sync::encode_ranges_validated},
BlockSize, ChunkRanges,
};
fn flatten(items: Vec<EncodedItem>) -> Vec<u8> {
let mut res = Vec::new();
for item in items {
match item {
EncodedItem::Leaf(Leaf { data, .. }) => res.extend_from_slice(&data),
EncodedItem::Parent(Parent { pair: (l, r), .. }) => {
res.extend_from_slice(l.as_bytes());
res.extend_from_slice(r.as_bytes());
}
_ => {}
}
}
res
}
#[tokio::test]
async fn smoke() {
let data = [0u8; 100000];
let outboard = PreOrderMemOutboard::create(data, BlockSize::from_chunk_log(4));
let (mut tx, mut rx) = tokio::sync::mpsc::channel(10);
let mut encoded = Vec::new();
encode_ranges_validated(&data[..], &outboard, &ChunkRanges::empty(), &mut encoded).unwrap();
tokio::spawn(async move {
traverse_ranges_validated(&data[..], &outboard, &ChunkRanges::empty(), &mut tx)
.await
.unwrap();
});
let mut res = Vec::new();
while let Some(item) = rx.recv().await {
res.push(item);
}
println!("{res:?}");
let encoded2 = flatten(res);
assert_eq!(encoded, encoded2);
}
}
pub trait ReadBytesAt {
fn read_bytes_at(&self, offset: u64, size: usize) -> std::io::Result<Bytes>;
}
mod impls {
use std::io;
use bytes::Bytes;
use super::ReadBytesAt;
macro_rules! impl_read_bytes_at_generic {
($($t:ty),*) => {
$(
impl ReadBytesAt for $t {
fn read_bytes_at(&self, offset: u64, size: usize) -> io::Result<Bytes> {
let mut buf = vec![0; size];
::positioned_io::ReadAt::read_exact_at(self, offset, &mut buf)?;
Ok(buf.into())
}
}
)*
};
}
macro_rules! impl_read_bytes_at_special {
($($t:ty),*) => {
$(
impl ReadBytesAt for $t {
fn read_bytes_at(&self, offset: u64, size: usize) -> io::Result<Bytes> {
let offset = offset as usize;
if offset + size > self.len() {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"Read past end of buffer",
));
}
Ok(self.slice(offset..offset + size))
}
}
)*
};
}
impl_read_bytes_at_generic!(&[u8], Vec<u8>);
impl_read_bytes_at_special!(Bytes, &Bytes, &mut Bytes);
#[cfg(feature = "fs")]
impl_read_bytes_at_generic!(std::fs::File);
}