Skip to main content

nectar_primitives/file/
joiner.rs

1//! Async joiner with BFS expansion and concurrent chunk fetching.
2
3use std::io::SeekFrom;
4use std::marker::PhantomData;
5use std::sync::Arc;
6
7/// Default number of concurrent chunk fetches for async operations.
8const DEFAULT_ASYNC_CONCURRENCY: usize = 8;
9
10#[cfg(feature = "tokio")]
11use bytes::Buf;
12use bytes::Bytes;
13use futures::stream::{self, Stream, StreamExt};
14
15use crate::bmt::DEFAULT_BODY_SIZE;
16use crate::chunk::ChunkAddress;
17
18use super::error::{FileError, Result};
19use super::frontier::{SubtreeNode, expand_frontier_async, read_subtree_bodies_async};
20use super::mode::{JoinMode, PlainMode};
21use super::tree::{ChunkRange, TreeParams};
22use crate::store::ChunkGet;
23
24#[cfg(feature = "encryption")]
25use super::mode::EncryptedMode;
26
27/// Generic async joiner parameterized by chunk mode.
28pub struct GenericJoiner<G, M: JoinMode, const BODY_SIZE: usize = DEFAULT_BODY_SIZE>
29where
30    G: ChunkGet<BODY_SIZE>,
31{
32    getter: Arc<G>,
33    root: ChunkAddress,
34    context: M::JoinerContext,
35    span: u64,
36    tree: TreeParams<BODY_SIZE>,
37    /// Pre-expanded frontier for parallel work distribution (computed once at construction).
38    subtrees: Vec<SubtreeNode<M>>,
39    position: u64,
40    concurrency: usize,
41    _mode: PhantomData<M>,
42}
43
44/// Plain (unencrypted) async joiner.
45pub type Joiner<G, const BODY_SIZE: usize = DEFAULT_BODY_SIZE> =
46    GenericJoiner<G, PlainMode, BODY_SIZE>;
47
48/// Encrypted async joiner.
49#[cfg(feature = "encryption")]
50pub type EncryptedJoiner<G, const BODY_SIZE: usize = DEFAULT_BODY_SIZE> =
51    GenericJoiner<G, EncryptedMode, BODY_SIZE>;
52
53impl<G, M, const BODY_SIZE: usize> std::fmt::Debug for GenericJoiner<G, M, BODY_SIZE>
54where
55    G: ChunkGet<BODY_SIZE>,
56    M: JoinMode,
57{
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        f.debug_struct("GenericJoiner")
60            .field("root", &self.root)
61            .field("span", &self.span)
62            .field("position", &self.position)
63            .field("concurrency", &self.concurrency)
64            .finish_non_exhaustive()
65    }
66}
67
68/// Collect leaf bodies for a set of subtrees with concurrent fetching.
69async fn collect_subtree_bodies_async<G, M, const BODY_SIZE: usize>(
70    getter: &Arc<G>,
71    subtrees: Vec<SubtreeNode<M>>,
72    chunk_range: ChunkRange,
73    concurrency: usize,
74) -> Result<Vec<Bytes>>
75where
76    G: ChunkGet<BODY_SIZE>,
77    M: JoinMode + Send + Sync,
78{
79    let bodies: Vec<Bytes> = stream::iter(subtrees)
80        .map(|st| {
81            let getter = Arc::clone(getter);
82            async move {
83                read_subtree_bodies_async::<G, M, BODY_SIZE>(&*getter, &st, &chunk_range).await
84            }
85        })
86        .buffered(concurrency)
87        .collect::<Vec<_>>()
88        .await
89        .into_iter()
90        .collect::<Result<Vec<Vec<Bytes>>>>()?
91        .into_iter()
92        .flatten()
93        .collect();
94    Ok(bodies)
95}
96
97impl<G, M, const BODY_SIZE: usize> GenericJoiner<G, M, BODY_SIZE>
98where
99    G: ChunkGet<BODY_SIZE>,
100    M: JoinMode + Send + Sync,
101{
102    /// Create an async joiner from a root reference.
103    pub async fn new(getter: G, input: M::RootRef) -> Result<Self> {
104        const { super::constants::assert_valid_body_size::<BODY_SIZE>() };
105
106        let (root, span, context) =
107            super::mode::joiner_init_async::<M, G, BODY_SIZE>(&getter, input).await?;
108        let tree = TreeParams::<BODY_SIZE>::new(span);
109
110        let target = DEFAULT_ASYNC_CONCURRENCY * 2;
111        let full_range = tree.chunks_for_range(0, span);
112        let subtrees = expand_frontier_async::<G, M, BODY_SIZE>(
113            &getter,
114            &root,
115            &context,
116            span,
117            &full_range,
118            target,
119        )
120        .await?;
121
122        Ok(Self {
123            getter: Arc::new(getter),
124            root,
125            context,
126            span,
127            tree,
128            subtrees,
129            position: 0,
130            concurrency: DEFAULT_ASYNC_CONCURRENCY,
131            _mode: PhantomData,
132        })
133    }
134
135    /// Set concurrency level for prefetching.
136    pub fn with_concurrency(mut self, concurrency: usize) -> Self {
137        self.concurrency = concurrency.max(1);
138        self
139    }
140
141    /// Total file size.
142    #[inline]
143    pub const fn size(&self) -> u64 {
144        self.span
145    }
146
147    /// Current read position.
148    #[inline]
149    pub const fn position(&self) -> u64 {
150        self.position
151    }
152
153    /// Root address.
154    #[inline]
155    pub const fn root(&self) -> &ChunkAddress {
156        &self.root
157    }
158
159    /// Read a range of bytes with concurrent fetching using the cached frontier.
160    pub async fn read_range(&self, offset: u64, len: usize) -> Result<Vec<u8>> {
161        Self::read_range_with(
162            &self.getter,
163            &self.subtrees,
164            &self.root,
165            &self.context,
166            self.span,
167            self.tree,
168            self.concurrency,
169            offset,
170            len,
171        )
172        .await
173    }
174
175    /// Read entire file into memory.
176    pub async fn read_all(&self) -> Result<Vec<u8>> {
177        self.read_range(0, self.span as usize).await
178    }
179
180    /// Shared read-range implementation used by both `read_range` and `poll_read`.
181    #[allow(
182        clippy::too_many_arguments,
183        reason = "internal helper threading already-decomposed reader state from two call sites"
184    )]
185    async fn read_range_with(
186        getter: &Arc<G>,
187        subtrees: &[SubtreeNode<M>],
188        root: &ChunkAddress,
189        context: &M::JoinerContext,
190        span: u64,
191        tree: TreeParams<BODY_SIZE>,
192        concurrency: usize,
193        offset: u64,
194        len: usize,
195    ) -> Result<Vec<u8>> {
196        use super::helpers::{ReadRangeCheck, validate_read_range};
197
198        let (offset, actual_len) = match validate_read_range::<BODY_SIZE>(offset, len, span) {
199            ReadRangeCheck::Empty => return Ok(Vec::new()),
200            ReadRangeCheck::SingleChunk { offset, actual_len } => {
201                let chunk = getter.get(root).await.map_err(FileError::getter)?;
202                let chunk = chunk.into_content().ok_or(FileError::InvalidChunkType {
203                    type_name: "non-content",
204                })?;
205                let body = M::decode_body::<BODY_SIZE>(chunk, context, span)?;
206                let start = offset as usize;
207                let end = start + actual_len;
208                return Ok(body[start..end].to_vec());
209            }
210            ReadRangeCheck::MultiChunk { offset, actual_len } => (offset, actual_len),
211        };
212
213        let chunk_range = tree.chunks_for_range(offset, actual_len as u64);
214        let range_start_byte = chunk_range.start * BODY_SIZE as u64;
215        let range_end_byte = chunk_range.end * BODY_SIZE as u64;
216
217        let relevant: Vec<_> = subtrees
218            .iter()
219            .filter(|st| {
220                st.byte_offset < range_end_byte && st.byte_offset + st.span > range_start_byte
221            })
222            .cloned()
223            .collect();
224
225        let bodies = collect_subtree_bodies_async::<G, M, BODY_SIZE>(
226            getter,
227            relevant,
228            chunk_range,
229            concurrency,
230        )
231        .await?;
232
233        Ok(super::tree::assemble_range(
234            &tree,
235            offset,
236            actual_len,
237            &chunk_range,
238            &bodies,
239        ))
240    }
241
242    /// Update read position (synchronous — just updates internal state).
243    pub fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
244        self.position = super::resolve_seek_position(pos, self.position, self.span)?;
245        Ok(self.position)
246    }
247
248    /// Convert into a stream of leaf chunk bodies.
249    pub fn into_stream(self) -> impl Stream<Item = Result<Bytes>> {
250        let getter = self.getter;
251        let chunk_range = self.tree.chunks_for_range(0, self.span);
252
253        struct State<M: JoinMode> {
254            subtrees: std::vec::IntoIter<SubtreeNode<M>>,
255            pending: std::vec::IntoIter<Bytes>,
256        }
257
258        let state = State {
259            subtrees: self.subtrees.into_iter(),
260            pending: Vec::new().into_iter(),
261        };
262
263        stream::unfold(state, move |mut state| {
264            let getter = Arc::clone(&getter);
265            async move {
266                // Drain pending leaf bodies from the last subtree.
267                if let Some(body) = state.pending.next() {
268                    return Some((Ok(body), state));
269                }
270
271                // Fetch the next subtree's leaf bodies.
272                let st = state.subtrees.next()?;
273                match read_subtree_bodies_async::<G, M, BODY_SIZE>(&*getter, &st, &chunk_range)
274                    .await
275                {
276                    Ok(bodies) => {
277                        let mut iter = bodies.into_iter();
278                        match iter.next() {
279                            Some(first) => {
280                                state.pending = iter;
281                                Some((Ok(first), state))
282                            }
283                            None => Some((Ok(Bytes::new()), state)),
284                        }
285                    }
286                    Err(e) => Some((Err(e), state)),
287                }
288            }
289        })
290    }
291
292    /// Convert into an `AsyncRead` reader.
293    #[cfg(feature = "tokio")]
294    pub fn into_reader(self) -> JoinerReader<G, M, BODY_SIZE> {
295        JoinerReader {
296            joiner: self,
297            buffer: Bytes::new(),
298            future: None,
299        }
300    }
301}
302
303/// Wrapper providing `tokio::io::AsyncRead` over a [`GenericJoiner`].
304///
305/// Created via [`GenericJoiner::into_reader`].
306#[cfg(feature = "tokio")]
307pub struct JoinerReader<G, M: JoinMode, const BODY_SIZE: usize = DEFAULT_BODY_SIZE>
308where
309    G: ChunkGet<BODY_SIZE>,
310{
311    joiner: GenericJoiner<G, M, BODY_SIZE>,
312    buffer: Bytes,
313    #[allow(clippy::type_complexity)]
314    future: Option<std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<u8>>> + Send>>>,
315}
316
317#[cfg(feature = "tokio")]
318impl<G, M, const BODY_SIZE: usize> std::fmt::Debug for JoinerReader<G, M, BODY_SIZE>
319where
320    G: ChunkGet<BODY_SIZE>,
321    M: JoinMode,
322{
323    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324        f.debug_struct("JoinerReader")
325            .field("joiner", &self.joiner)
326            .field("buffer_len", &self.buffer.len())
327            .field("has_pending_future", &self.future.is_some())
328            .finish()
329    }
330}
331
332// Safety: JoinerReader contains no self-referential data.
333// The boxed future is heap-allocated and all other fields are plain data.
334#[cfg(feature = "tokio")]
335impl<G: ChunkGet<BODY_SIZE>, M: JoinMode, const BODY_SIZE: usize> Unpin
336    for JoinerReader<G, M, BODY_SIZE>
337{
338}
339
340#[cfg(feature = "tokio")]
341impl<G, M, const BODY_SIZE: usize> tokio::io::AsyncRead for JoinerReader<G, M, BODY_SIZE>
342where
343    G: ChunkGet<BODY_SIZE> + 'static,
344    M: JoinMode + Send + Sync + 'static,
345{
346    fn poll_read(
347        self: std::pin::Pin<&mut Self>,
348        cx: &mut std::task::Context<'_>,
349        buf: &mut tokio::io::ReadBuf<'_>,
350    ) -> std::task::Poll<std::io::Result<()>> {
351        use std::task::Poll;
352
353        let this = self.get_mut();
354
355        // Drain any leftover buffer first
356        if !this.buffer.is_empty() {
357            let to_copy = this.buffer.len().min(buf.remaining());
358            buf.put_slice(&this.buffer[..to_copy]);
359            this.buffer.advance(to_copy);
360            return Poll::Ready(Ok(()));
361        }
362
363        // EOF check
364        if this.joiner.position >= this.joiner.span {
365            return Poll::Ready(Ok(()));
366        }
367
368        // Create a future for the next read if we don't have one
369        if this.future.is_none() {
370            let position = this.joiner.position;
371            let remaining = (this.joiner.span - position) as usize;
372            let read_len = remaining.min(BODY_SIZE);
373            let getter = Arc::clone(&this.joiner.getter);
374            let root = this.joiner.root;
375            let context = this.joiner.context.clone();
376            let span = this.joiner.span;
377            let tree = this.joiner.tree;
378            let concurrency = this.joiner.concurrency;
379            let subtrees: Vec<SubtreeNode<M>> = this.joiner.subtrees.clone();
380
381            let fut = async move {
382                GenericJoiner::<G, M, BODY_SIZE>::read_range_with(
383                    &getter,
384                    &subtrees,
385                    &root,
386                    &context,
387                    span,
388                    tree,
389                    concurrency,
390                    position,
391                    read_len,
392                )
393                .await
394            };
395            this.future = Some(Box::pin(fut));
396        }
397
398        // Poll the future
399        let fut = this.future.as_mut().unwrap();
400        match fut.as_mut().poll(cx) {
401            Poll::Ready(Ok(data)) => {
402                this.future = None;
403                let bytes = Bytes::from(data);
404                this.joiner.position += bytes.len() as u64;
405                let to_copy = bytes.len().min(buf.remaining());
406                buf.put_slice(&bytes[..to_copy]);
407                if to_copy < bytes.len() {
408                    this.buffer = bytes.slice(to_copy..);
409                }
410                Poll::Ready(Ok(()))
411            }
412            Poll::Ready(Err(e)) => {
413                this.future = None;
414                Poll::Ready(Err(std::io::Error::other(e)))
415            }
416            Poll::Pending => Poll::Pending,
417        }
418    }
419}
420
421#[cfg(feature = "tokio")]
422impl<G, M, const BODY_SIZE: usize> tokio::io::AsyncSeek for JoinerReader<G, M, BODY_SIZE>
423where
424    G: ChunkGet<BODY_SIZE> + 'static,
425    M: JoinMode + Send + Sync + 'static,
426{
427    fn start_seek(self: std::pin::Pin<&mut Self>, pos: SeekFrom) -> std::io::Result<()> {
428        let this = self.get_mut();
429        this.joiner.position =
430            super::resolve_seek_position(pos, this.joiner.position, this.joiner.span)?;
431        this.buffer = Bytes::new();
432        this.future = None;
433        Ok(())
434    }
435
436    fn poll_complete(
437        self: std::pin::Pin<&mut Self>,
438        _cx: &mut std::task::Context<'_>,
439    ) -> std::task::Poll<std::io::Result<u64>> {
440        std::task::Poll::Ready(Ok(self.get_mut().joiner.position))
441    }
442}
443
444#[cfg(all(test, feature = "tokio"))]
445mod tests {
446    use super::*;
447    use crate::chunk::AnyChunk;
448    use crate::file::sync_split;
449    use std::collections::HashMap;
450
451    fn split_and_store(data: &[u8]) -> (ChunkAddress, HashMap<ChunkAddress, AnyChunk>) {
452        let (root, store) = sync_split::<DEFAULT_BODY_SIZE>(data).unwrap();
453        (root, store.into_chunks())
454    }
455
456    // --- Generated shared tests (async variants) ---
457    generate_plain_joiner_tests!(tokio::test, Joiner, [async], [await]);
458
459    // --- Async-only tests: Stream, AsyncRead, AsyncSeek ---
460
461    #[tokio::test]
462    async fn test_async_joiner_stream() {
463        let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 3)
464            .map(|i| (i % 256) as u8)
465            .collect();
466        let (root, store) = split_and_store(&data);
467
468        let joiner = Joiner::new(store, root).await.unwrap();
469        let chunks: Vec<Result<Bytes>> = joiner.into_stream().collect().await;
470
471        let mut recovered = Vec::new();
472        for chunk in chunks {
473            recovered.extend_from_slice(&chunk.unwrap());
474        }
475        assert_eq!(recovered, data);
476    }
477
478    #[cfg(feature = "tokio")]
479    #[tokio::test]
480    async fn test_async_reader_small() {
481        use tokio::io::AsyncReadExt;
482
483        let data = b"hello world";
484        let (root, store) = split_and_store(data);
485
486        let joiner = Joiner::new(store, root).await.unwrap();
487        let mut reader = joiner.into_reader();
488        let mut result = Vec::new();
489        reader.read_to_end(&mut result).await.unwrap();
490        assert_eq!(result, data);
491    }
492
493    #[cfg(feature = "tokio")]
494    #[tokio::test]
495    async fn test_async_reader_multi_chunk() {
496        use tokio::io::AsyncReadExt;
497
498        let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 3 + 123)
499            .map(|i| (i % 256) as u8)
500            .collect();
501        let (root, store) = split_and_store(&data);
502
503        let joiner = Joiner::new(store, root).await.unwrap();
504        let mut reader = joiner.into_reader();
505        let mut result = Vec::new();
506        reader.read_to_end(&mut result).await.unwrap();
507        assert_eq!(result, data);
508    }
509
510    #[cfg(feature = "tokio")]
511    #[tokio::test]
512    async fn test_async_reader_seek() {
513        use tokio::io::{AsyncReadExt, AsyncSeekExt};
514
515        let data = b"hello world";
516        let (root, store) = split_and_store(data);
517
518        let joiner = Joiner::new(store, root).await.unwrap();
519        let mut reader = joiner.into_reader();
520
521        reader.seek(SeekFrom::Start(6)).await.unwrap();
522        let mut buf = vec![0u8; 5];
523        reader.read_exact(&mut buf).await.unwrap();
524        assert_eq!(&buf, b"world");
525    }
526
527    #[cfg(feature = "tokio")]
528    #[tokio::test]
529    async fn test_async_reader_seek_back_and_forth() {
530        use tokio::io::{AsyncReadExt, AsyncSeekExt};
531
532        let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 3)
533            .map(|i| (i % 256) as u8)
534            .collect();
535        let (root, store) = split_and_store(&data);
536
537        let joiner = Joiner::new(store, root).await.unwrap();
538        let mut reader = joiner.into_reader();
539
540        // Read from middle
541        reader
542            .seek(SeekFrom::Start(DEFAULT_BODY_SIZE as u64))
543            .await
544            .unwrap();
545        let mut buf1 = vec![0u8; 100];
546        reader.read_exact(&mut buf1).await.unwrap();
547        assert_eq!(&buf1, &data[DEFAULT_BODY_SIZE..DEFAULT_BODY_SIZE + 100]);
548
549        // Seek back to start
550        reader.seek(SeekFrom::Start(0)).await.unwrap();
551        let mut buf2 = vec![0u8; 100];
552        reader.read_exact(&mut buf2).await.unwrap();
553        assert_eq!(&buf2, &data[..100]);
554
555        // Seek to near-end
556        reader.seek(SeekFrom::End(-50)).await.unwrap();
557        let mut buf3 = vec![0u8; 50];
558        reader.read_exact(&mut buf3).await.unwrap();
559        assert_eq!(&buf3, &data[data.len() - 50..]);
560    }
561
562    #[cfg(feature = "encryption")]
563    mod encrypted {
564        use super::*;
565        use crate::file::sync_split_encrypted;
566
567        fn encrypted_split_and_store(
568            data: &[u8],
569        ) -> (
570            crate::chunk::encryption::EncryptedChunkRef,
571            HashMap<ChunkAddress, AnyChunk>,
572        ) {
573            let (root_ref, store) = sync_split_encrypted::<DEFAULT_BODY_SIZE>(data).unwrap();
574            (root_ref, store.into_chunks())
575        }
576
577        // --- Generated shared tests (async variants) ---
578        generate_encrypted_joiner_tests!(tokio::test, EncryptedJoiner, [async], [await]);
579
580        // --- Async-only tests: Stream ---
581
582        #[tokio::test]
583        async fn test_encrypted_async_joiner_stream() {
584            let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 3)
585                .map(|i| (i % 256) as u8)
586                .collect();
587            let (root_ref, store) = encrypted_split_and_store(&data);
588
589            let joiner = EncryptedJoiner::new(store, root_ref).await.unwrap();
590            let chunks: Vec<Result<Bytes>> = joiner.into_stream().collect().await;
591
592            let mut recovered = Vec::new();
593            for chunk in chunks {
594                recovered.extend_from_slice(&chunk.unwrap());
595            }
596            assert_eq!(recovered, data);
597        }
598    }
599}