bao_tree/io/
mixed.rs

1//! Read from sync, send to tokio sender
2use std::{future::Future, result};
3
4use blake3;
5use bytes::Bytes;
6use serde::{Deserialize, Serialize};
7use smallvec::SmallVec;
8
9use super::{sync::Outboard, EncodeError, Leaf, Parent};
10use crate::{
11    hash_subtree, iter::BaoChunk, parent_cv, rec::truncate_ranges, split_inner, ChunkNum,
12    ChunkRangesRef, TreeNode,
13};
14
15/// A content item for the bao streaming protocol.
16#[derive(Debug, Serialize, Deserialize)]
17pub enum EncodedItem {
18    /// total data size, will be the first item
19    Size(u64),
20    /// a parent node
21    Parent(Parent),
22    /// a leaf node
23    Leaf(Leaf),
24    /// an error, will be the last item
25    Error(EncodeError),
26    /// done, will be the last item
27    Done,
28}
29
30impl From<Leaf> for EncodedItem {
31    fn from(l: Leaf) -> Self {
32        Self::Leaf(l)
33    }
34}
35
36impl From<Parent> for EncodedItem {
37    fn from(p: Parent) -> Self {
38        Self::Parent(p)
39    }
40}
41
42impl From<EncodeError> for EncodedItem {
43    fn from(e: EncodeError) -> Self {
44        Self::Error(e)
45    }
46}
47
48/// Abstract sender trait for sending encoded items
49pub trait Sender {
50    /// Error type
51    type Error;
52    /// Send an item
53    fn send(
54        &mut self,
55        item: EncodedItem,
56    ) -> impl Future<Output = std::result::Result<(), Self::Error>> + '_;
57}
58
59impl Sender for tokio::sync::mpsc::Sender<EncodedItem> {
60    type Error = tokio::sync::mpsc::error::SendError<EncodedItem>;
61    fn send(
62        &mut self,
63        item: EncodedItem,
64    ) -> impl Future<Output = std::result::Result<(), Self::Error>> + '_ {
65        tokio::sync::mpsc::Sender::send(self, item)
66    }
67}
68
69/// Traverse ranges relevant to a query from a reader and outboard to a stream
70///
71/// This function validates the data before writing.
72///
73/// It is possible to encode ranges from a partial file and outboard.
74/// This will either succeed if the requested ranges are all present, or fail
75/// as soon as a range is missing.
76pub async fn traverse_ranges_validated<D, O, F>(
77    data: D,
78    outboard: O,
79    ranges: &ChunkRangesRef,
80    send: &mut F,
81) -> std::result::Result<(), F::Error>
82where
83    D: ReadBytesAt,
84    O: Outboard,
85    F: Sender,
86{
87    send.send(EncodedItem::Size(outboard.tree().size())).await?;
88    let res = match traverse_ranges_validated_impl(data, outboard, ranges, send).await {
89        Ok(Ok(())) => EncodedItem::Done,
90        Err(cause) => EncodedItem::Error(cause),
91        Ok(Err(err)) => return Err(err),
92    };
93    send.send(res).await
94}
95
96/// Encode ranges relevant to a query from a reader and outboard to a writer
97///
98/// This function validates the data before writing.
99///
100/// It is possible to encode ranges from a partial file and outboard.
101/// This will either succeed if the requested ranges are all present, or fail
102/// as soon as a range is missing.
103async fn traverse_ranges_validated_impl<D, O, F>(
104    data: D,
105    outboard: O,
106    ranges: &ChunkRangesRef,
107    send: &mut F,
108) -> result::Result<std::result::Result<(), F::Error>, EncodeError>
109where
110    D: ReadBytesAt,
111    O: Outboard,
112    F: Sender,
113{
114    if ranges.is_empty() {
115        return Ok(Ok(()));
116    }
117    let mut stack: SmallVec<[_; 10]> = SmallVec::<[blake3::Hash; 10]>::new();
118    stack.push(outboard.root());
119    let data = data;
120    let tree = outboard.tree();
121    // canonicalize ranges
122    let ranges = truncate_ranges(ranges, tree.size());
123    for item in tree.ranges_pre_order_chunks_iter_ref(ranges, 0) {
124        match item {
125            BaoChunk::Parent {
126                is_root,
127                left,
128                right,
129                node,
130                ..
131            } => {
132                let (l_hash, r_hash) = outboard.load(node)?.unwrap();
133                let actual = parent_cv(&l_hash, &r_hash, is_root);
134                let expected = stack.pop().unwrap();
135                if actual != expected {
136                    return Err(EncodeError::ParentHashMismatch(node));
137                }
138                if right {
139                    stack.push(r_hash);
140                }
141                if left {
142                    stack.push(l_hash);
143                }
144                let item = Parent {
145                    node,
146                    pair: (l_hash, r_hash),
147                };
148                if let Err(e) = send.send(item.into()).await {
149                    return Ok(Err(e));
150                }
151            }
152            BaoChunk::Leaf {
153                start_chunk,
154                size,
155                is_root,
156                ranges,
157                ..
158            } => {
159                let expected = stack.pop().unwrap();
160                let start = start_chunk.to_bytes();
161                let buffer = data.read_bytes_at(start, size)?;
162                if !ranges.is_all() {
163                    // we need to encode just a part of the data
164                    //
165                    // write into an out buffer to ensure we detect mismatches
166                    // before writing to the output.
167                    let mut out_buf = Vec::new();
168                    let actual = traverse_selected_rec(
169                        start_chunk,
170                        buffer,
171                        is_root,
172                        ranges,
173                        tree.block_size.to_u32(),
174                        true,
175                        &mut out_buf,
176                    );
177                    if actual != expected {
178                        return Err(EncodeError::LeafHashMismatch(start_chunk));
179                    }
180                    for item in out_buf.into_iter() {
181                        if let Err(e) = send.send(item).await {
182                            return Ok(Err(e));
183                        }
184                    }
185                } else {
186                    let actual = hash_subtree(start_chunk.0, &buffer, is_root);
187                    #[allow(clippy::redundant_slicing)]
188                    if actual != expected {
189                        return Err(EncodeError::LeafHashMismatch(start_chunk));
190                    }
191                    let item = Leaf {
192                        data: buffer,
193                        offset: start_chunk.to_bytes(),
194                    };
195                    if let Err(e) = send.send(item.into()).await {
196                        return Ok(Err(e));
197                    }
198                };
199            }
200        }
201    }
202    Ok(Ok(()))
203}
204
205/// Encode ranges relevant to a query from a slice and outboard to a buffer.
206///
207/// This will compute the root hash, so it will have to traverse the entire tree.
208/// The `ranges` parameter just controls which parts of the data are written.
209///
210/// Except for writing to a buffer, this is the same as [hash_subtree].
211/// The `min_level` parameter controls the minimum level that will be emitted as a leaf.
212/// Set this to 0 to disable chunk groups entirely.
213/// The `emit_data` parameter controls whether the data is written to the buffer.
214/// When setting this to false and setting query to `RangeSet::all()`, this can be used
215/// to write an outboard.
216///
217/// `res` will not contain the length prefix, so if you want a bao compatible format,
218/// you need to prepend it yourself.
219///
220/// This is used as a reference implementation in tests, but also to compute hashes
221/// below the chunk group size when creating responses for outboards with a chunk group
222/// size of >0.
223pub fn traverse_selected_rec(
224    start_chunk: ChunkNum,
225    data: Bytes,
226    is_root: bool,
227    query: &ChunkRangesRef,
228    min_level: u32,
229    emit_data: bool,
230    res: &mut Vec<EncodedItem>,
231) -> blake3::Hash {
232    use blake3::CHUNK_LEN;
233    if data.len() <= CHUNK_LEN {
234        if emit_data && !query.is_empty() {
235            res.push(
236                Leaf {
237                    data: data.clone(),
238                    offset: start_chunk.to_bytes(),
239                }
240                .into(),
241            );
242        }
243        hash_subtree(start_chunk.0, &data, is_root)
244    } else {
245        let chunks = data.len() / CHUNK_LEN + (data.len() % CHUNK_LEN != 0) as usize;
246        let chunks = chunks.next_power_of_two();
247        let level = chunks.trailing_zeros() - 1;
248        let mid = chunks / 2;
249        let mid_bytes = mid * CHUNK_LEN;
250        let mid_chunk = start_chunk + (mid as u64);
251        let (l_ranges, r_ranges) = split_inner(query, start_chunk, mid_chunk);
252        // for empty ranges, we don't want to emit anything.
253        // for full ranges where the level is below min_level, we want to emit
254        // just the data.
255        //
256        // todo: maybe call into blake3::hazmat::hash_subtree directly for this case? it would be faster.
257        let full = query.is_all();
258        let emit_parent = !query.is_empty() && (!full || level >= min_level);
259        let hash_offset = if emit_parent {
260            // make some room for the hash pair
261            let pair = Parent {
262                node: TreeNode(0),
263                pair: ([0; 32].into(), [0; 32].into()),
264            };
265            res.push(pair.into());
266            Some(res.len() - 1)
267        } else {
268            None
269        };
270        // recurse to the left and right to compute the hashes and emit data
271        let left = traverse_selected_rec(
272            start_chunk,
273            data.slice(..mid_bytes),
274            false,
275            l_ranges,
276            min_level,
277            emit_data,
278            res,
279        );
280        let right = traverse_selected_rec(
281            mid_chunk,
282            data.slice(mid_bytes..),
283            false,
284            r_ranges,
285            min_level,
286            emit_data,
287            res,
288        );
289        // backfill the hashes if needed
290        if let Some(o) = hash_offset {
291            // todo: figure out how to get the tree node from the start chunk!
292            let node = TreeNode(0);
293            res[o] = Parent {
294                node,
295                pair: (left, right),
296            }
297            .into();
298        }
299        parent_cv(&left, &right, is_root)
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306    use crate::{
307        io::{outboard::PreOrderMemOutboard, sync::encode_ranges_validated},
308        BlockSize, ChunkRanges,
309    };
310
311    fn flatten(items: Vec<EncodedItem>) -> Vec<u8> {
312        let mut res = Vec::new();
313        for item in items {
314            match item {
315                EncodedItem::Leaf(Leaf { data, .. }) => res.extend_from_slice(&data),
316                EncodedItem::Parent(Parent { pair: (l, r), .. }) => {
317                    res.extend_from_slice(l.as_bytes());
318                    res.extend_from_slice(r.as_bytes());
319                }
320                _ => {}
321            }
322        }
323        res
324    }
325
326    #[tokio::test]
327    async fn smoke() {
328        let data = [0u8; 100000];
329        let outboard = PreOrderMemOutboard::create(data, BlockSize::from_chunk_log(4));
330        let (mut tx, mut rx) = tokio::sync::mpsc::channel(10);
331        let mut encoded = Vec::new();
332        encode_ranges_validated(&data[..], &outboard, &ChunkRanges::empty(), &mut encoded).unwrap();
333        tokio::spawn(async move {
334            traverse_ranges_validated(&data[..], &outboard, &ChunkRanges::empty(), &mut tx)
335                .await
336                .unwrap();
337        });
338        let mut res = Vec::new();
339        while let Some(item) = rx.recv().await {
340            res.push(item);
341        }
342        println!("{:?}", res);
343        let encoded2 = flatten(res);
344        assert_eq!(encoded, encoded2);
345    }
346}
347
348/// Trait identical to `ReadAt` but returning `Bytes` instead of reading into a buffer.
349///
350/// This forwards to the underlying `ReadAt` implementation except for `Bytes`, `&Bytes`, `&mut Bytes`.
351pub trait ReadBytesAt {
352    /// Version of `ReadAt::read_exact_at` that returns a `Bytes` instead of reading into a buffer.
353    fn read_bytes_at(&self, offset: u64, size: usize) -> std::io::Result<Bytes>;
354}
355
356mod impls {
357    use std::io;
358
359    use bytes::Bytes;
360
361    use super::ReadBytesAt;
362
363    // Macro for generic implementations (allocating with copy_from_slice)
364    macro_rules! impl_read_bytes_at_generic {
365    ($($t:ty),*) => {
366        $(
367            impl ReadBytesAt for $t {
368                fn read_bytes_at(&self, offset: u64, size: usize) -> io::Result<Bytes> {
369                    let mut buf = vec![0; size];
370                    ::positioned_io::ReadAt::read_exact_at(self, offset, &mut buf)?;
371                    Ok(buf.into())
372                }
373            }
374        )*
375    };
376}
377
378    // Macro for special implementations (non-allocating with slice)
379    macro_rules! impl_read_bytes_at_special {
380    ($($t:ty),*) => {
381        $(
382            impl ReadBytesAt for $t {
383                fn read_bytes_at(&self, offset: u64, size: usize) -> io::Result<Bytes> {
384                    let offset = offset as usize;
385                    if offset + size > self.len() {
386                        return Err(io::Error::new(
387                            io::ErrorKind::UnexpectedEof,
388                            "Read past end of buffer",
389                        ));
390                    }
391                    Ok(self.slice(offset..offset + size))
392                }
393            }
394        )*
395    };
396}
397
398    // Apply the macros
399    impl_read_bytes_at_generic!(&[u8], Vec<u8>, std::fs::File);
400    impl_read_bytes_at_special!(Bytes, &Bytes, &mut Bytes);
401}