Skip to main content

nectar_primitives/file/
sync_joiner.rs

1//! Sync parallel file joiner with BFS fan-out, streaming Read, and Seek.
2//!
3//! Uses a pre-expanded subtree frontier to avoid redundant intermediate node
4//! decryption. Subtrees are processed in parallel via rayon, and `impl Read`
5//! provides bounded-memory streaming.
6
7use std::io::{self, SeekFrom};
8use std::marker::PhantomData;
9
10use bytes::Bytes;
11use rayon::prelude::*;
12
13use crate::bmt::DEFAULT_BODY_SIZE;
14use crate::chunk::ChunkAddress;
15
16use super::error::Result;
17use super::frontier::{SubtreeNode, expand_frontier, read_subtree_bodies};
18use super::mode::{JoinMode, PlainMode};
19use super::tree::{ChunkRange, TreeParams};
20use crate::store::SyncChunkGet;
21
22#[cfg(feature = "encryption")]
23use super::mode::EncryptedMode;
24
25/// Generic joiner parameterized by chunk mode.
26///
27/// Uses BFS fan-out to pre-expand intermediate nodes into a frontier of
28/// roughly equal-sized subtrees. Implements `std::io::Read` and `std::io::Seek`
29/// for bounded-memory streaming.
30pub struct GenericSyncJoiner<G, M: JoinMode, const BODY_SIZE: usize = DEFAULT_BODY_SIZE>
31where
32    G: SyncChunkGet<BODY_SIZE> + Clone + Send + Sync,
33{
34    getter: G,
35    root: ChunkAddress,
36    context: M::JoinerContext,
37    span: u64,
38    tree: TreeParams<BODY_SIZE>,
39
40    /// Pre-expanded frontier for streaming (computed at construction).
41    subtrees: Vec<SubtreeNode<M>>,
42
43    /// Streaming state for Read impl.
44    read_pos: u64,
45    buffer: Vec<u8>,
46    buffer_pos: usize,
47    subtree_idx: usize,
48
49    _mode: PhantomData<M>,
50}
51
52/// Plain (unencrypted) file joiner.
53pub type SyncJoiner<G, const BODY_SIZE: usize = DEFAULT_BODY_SIZE> =
54    GenericSyncJoiner<G, PlainMode, BODY_SIZE>;
55
56/// Encrypted file joiner.
57#[cfg(feature = "encryption")]
58pub type EncryptedSyncJoiner<G, const BODY_SIZE: usize = DEFAULT_BODY_SIZE> =
59    GenericSyncJoiner<G, EncryptedMode, BODY_SIZE>;
60
61impl<G, M, const BODY_SIZE: usize> std::fmt::Debug for GenericSyncJoiner<G, M, BODY_SIZE>
62where
63    G: SyncChunkGet<BODY_SIZE> + Clone + Send + Sync,
64    M: JoinMode,
65{
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        f.debug_struct("GenericSyncJoiner")
68            .field("root", &self.root)
69            .field("span", &self.span)
70            .field("read_pos", &self.read_pos)
71            .finish_non_exhaustive()
72    }
73}
74
75impl<G, M, const BODY_SIZE: usize> GenericSyncJoiner<G, M, BODY_SIZE>
76where
77    G: SyncChunkGet<BODY_SIZE> + Clone + Send + Sync,
78    M: JoinMode + Send + Sync,
79{
80    /// Create a joiner from a root reference.
81    pub fn new(getter: G, input: M::RootRef) -> Result<Self> {
82        const { super::constants::assert_valid_body_size::<BODY_SIZE>() };
83
84        let (root, span, context) = super::mode::joiner_init::<M, G, BODY_SIZE>(&getter, input)?;
85        let tree = TreeParams::<BODY_SIZE>::new(span);
86
87        // 2x thread count gives each thread >=2 subtrees for balanced work distribution.
88        let target = rayon::current_num_threads().max(1) * 2;
89        let full_range = tree.chunks_for_range(0, span);
90        let subtrees = expand_frontier::<G, M, BODY_SIZE>(
91            &getter,
92            &root,
93            &context,
94            span,
95            &full_range,
96            target,
97        )?;
98
99        Ok(Self {
100            getter,
101            root,
102            context,
103            span,
104            tree,
105            subtrees,
106            read_pos: 0,
107            buffer: Vec::new(),
108            buffer_pos: 0,
109            subtree_idx: 0,
110            _mode: PhantomData,
111        })
112    }
113
114    /// Total file size.
115    #[inline]
116    pub const fn size(&self) -> u64 {
117        self.span
118    }
119
120    /// Current read position.
121    #[inline]
122    pub const fn position(&self) -> u64 {
123        self.read_pos
124    }
125
126    /// Root address.
127    #[inline]
128    pub const fn root(&self) -> &ChunkAddress {
129        &self.root
130    }
131
132    /// Read a range of bytes, fetching required chunks in parallel.
133    pub fn read_range(&self, offset: u64, len: usize) -> Result<Vec<u8>> {
134        use super::helpers::{ReadRangeCheck, validate_read_range};
135
136        match validate_read_range::<BODY_SIZE>(offset, len, self.span) {
137            ReadRangeCheck::Empty => Ok(Vec::new()),
138            ReadRangeCheck::SingleChunk { offset, actual_len } => {
139                self.read_single_chunk(offset, actual_len)
140            }
141            ReadRangeCheck::MultiChunk { offset, actual_len } => {
142                let chunk_range = self.tree.chunks_for_range(offset, actual_len as u64);
143                let range_start_byte = chunk_range.start * BODY_SIZE as u64;
144                let range_end_byte = chunk_range.end * BODY_SIZE as u64;
145
146                let bodies = self.collect_bodies(&chunk_range, range_start_byte, range_end_byte)?;
147
148                Ok(super::tree::assemble_range(
149                    &self.tree,
150                    offset,
151                    actual_len,
152                    &chunk_range,
153                    &bodies,
154                ))
155            }
156        }
157    }
158
159    /// Read entire file into memory.
160    pub fn read_all(&self) -> Result<Vec<u8>> {
161        self.read_range(0, self.span as usize)
162    }
163
164    /// Filter pre-computed subtrees and collect leaf bodies in parallel.
165    fn collect_bodies(
166        &self,
167        chunk_range: &ChunkRange,
168        range_start_byte: u64,
169        range_end_byte: u64,
170    ) -> Result<Vec<Bytes>> {
171        let getter = &self.getter;
172        let nested: Vec<Vec<Bytes>> = self
173            .subtrees
174            .par_iter()
175            .filter(|st| {
176                st.byte_offset < range_end_byte && st.byte_offset + st.span > range_start_byte
177            })
178            .map(|st| {
179                let mut bodies = Vec::with_capacity((st.span as usize / BODY_SIZE).max(1));
180                read_subtree_bodies::<G, M, BODY_SIZE>(getter, st, chunk_range, &mut bodies)?;
181                Ok(bodies)
182            })
183            .collect::<Result<Vec<Vec<Bytes>>>>()?;
184
185        Ok(nested.into_iter().flat_map(|v| v.into_iter()).collect())
186    }
187
188    fn read_single_chunk(&self, offset: u64, len: usize) -> Result<Vec<u8>> {
189        let body = super::mode::read_chunk_body::<M, G, BODY_SIZE>(
190            &self.getter,
191            &self.root,
192            &self.context,
193            self.span,
194        )?;
195        let start = offset as usize;
196        let end = start + len;
197        Ok(body[start..end].to_vec())
198    }
199
200    /// Process the next batch of subtrees into the internal buffer.
201    fn fill_buffer(&mut self) -> Result<()> {
202        let batch_size = rayon::current_num_threads().max(1);
203        let start_idx = self.subtree_idx;
204        let end_idx = (start_idx + batch_size).min(self.subtrees.len());
205
206        let batch = &self.subtrees[start_idx..end_idx];
207        if batch.is_empty() {
208            return Ok(());
209        }
210
211        let batch_start_byte = batch[0].byte_offset;
212        let last = &batch[batch.len() - 1];
213        let batch_end_byte = (last.byte_offset + last.span).min(self.span);
214        let chunk_range = ChunkRange {
215            start: batch_start_byte / BODY_SIZE as u64,
216            end: batch_end_byte.div_ceil(BODY_SIZE as u64),
217        };
218
219        let getter = &self.getter;
220        let all_bodies = batch
221            .par_iter()
222            .map(|st| {
223                let mut bodies = Vec::with_capacity((st.span as usize / BODY_SIZE).max(1));
224                read_subtree_bodies::<G, M, BODY_SIZE>(getter, st, &chunk_range, &mut bodies)?;
225                Ok(bodies)
226            })
227            .collect::<Result<Vec<Vec<Bytes>>>>()?;
228
229        let estimated = (batch_end_byte - batch_start_byte) as usize;
230        self.buffer.clear();
231        self.buffer.reserve(estimated);
232        for bodies in all_bodies {
233            for body in bodies {
234                self.buffer.extend_from_slice(&body);
235            }
236        }
237        self.buffer_pos = 0;
238        self.subtree_idx = end_idx;
239
240        // After a seek, read_pos may be past the batch start — skip ahead.
241        if self.read_pos > batch_start_byte {
242            self.buffer_pos = (self.read_pos - batch_start_byte) as usize;
243        }
244
245        Ok(())
246    }
247
248    /// Copy bytes from the internal buffer to the caller's buffer.
249    fn drain_buffer(&mut self, buf: &mut [u8]) -> usize {
250        let available = self.buffer.len() - self.buffer_pos;
251        let to_copy = buf.len().min(available);
252        buf[..to_copy].copy_from_slice(&self.buffer[self.buffer_pos..self.buffer_pos + to_copy]);
253        self.buffer_pos += to_copy;
254        self.read_pos += to_copy as u64;
255        to_copy
256    }
257}
258
259impl<G, M, const BODY_SIZE: usize> io::Read for GenericSyncJoiner<G, M, BODY_SIZE>
260where
261    G: SyncChunkGet<BODY_SIZE> + Clone + Send + Sync,
262    M: JoinMode + Send + Sync,
263{
264    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
265        if buf.is_empty() || self.read_pos >= self.span {
266            return Ok(0);
267        }
268
269        if self.buffer_pos < self.buffer.len() {
270            return Ok(self.drain_buffer(buf));
271        }
272
273        if self.subtree_idx >= self.subtrees.len() {
274            return Ok(0);
275        }
276
277        self.fill_buffer().map_err(io::Error::other)?;
278
279        if self.buffer.is_empty() {
280            return Ok(0);
281        }
282
283        Ok(self.drain_buffer(buf))
284    }
285}
286
287impl<G, M, const BODY_SIZE: usize> io::Seek for GenericSyncJoiner<G, M, BODY_SIZE>
288where
289    G: SyncChunkGet<BODY_SIZE> + Clone + Send + Sync,
290    M: JoinMode + Send + Sync,
291{
292    fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
293        self.read_pos = super::resolve_seek_position(pos, self.read_pos, self.span)?;
294        self.buffer.clear();
295        self.buffer_pos = 0;
296        self.subtree_idx = self
297            .subtrees
298            .iter()
299            .position(|st| st.byte_offset + st.span > self.read_pos)
300            .unwrap_or(self.subtrees.len());
301        Ok(self.read_pos)
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use crate::chunk::AnyChunk;
309    use crate::file::sync_split;
310    use std::collections::HashMap;
311    use std::io::{Read, Seek};
312
313    fn split_and_store(data: &[u8]) -> (ChunkAddress, HashMap<ChunkAddress, AnyChunk>) {
314        let (root, store) = sync_split::<DEFAULT_BODY_SIZE>(data).unwrap();
315        (root, store.into_chunks())
316    }
317
318    // --- Generated shared tests (sync variants) ---
319    generate_plain_joiner_tests!(test, SyncJoiner, [], []);
320
321    // --- Sync-only tests: std::io::Read + Seek ---
322
323    #[test]
324    fn test_joiner_streaming() {
325        let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 3 + 500)
326            .map(|i| (i % 256) as u8)
327            .collect();
328        let (root, store) = split_and_store(&data);
329
330        let mut joiner = SyncJoiner::new(store, root).unwrap();
331        let mut result = vec![0u8; data.len()];
332        joiner.read_exact(&mut result).unwrap();
333        assert_eq!(result, data);
334    }
335
336    #[test]
337    fn test_joiner_small_buffer_streaming() {
338        let refs_per_chunk = DEFAULT_BODY_SIZE / super::super::constants::REF_SIZE;
339        let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * refs_per_chunk)
340            .map(|i| (i % 256) as u8)
341            .collect();
342        let (root, store) = split_and_store(&data);
343
344        let mut joiner = SyncJoiner::new(store, root).unwrap();
345        let mut result = Vec::new();
346        let mut buf = [0u8; 100];
347        loop {
348            let n = joiner.read(&mut buf).unwrap();
349            if n == 0 {
350                break;
351            }
352            result.extend_from_slice(&buf[..n]);
353        }
354        assert_eq!(result, data);
355    }
356
357    #[test]
358    fn test_joiner_seek_start() {
359        let data = b"hello world";
360        let (root, store) = split_and_store(data);
361        let mut joiner = SyncJoiner::new(store, root).unwrap();
362
363        joiner.seek(SeekFrom::Start(6)).unwrap();
364        let result = joiner.read_all().unwrap();
365        // read_all always reads from offset 0
366        assert_eq!(result, data);
367
368        // But seek + Read trait respects position
369        joiner.seek(SeekFrom::Start(6)).unwrap();
370        let mut buf = vec![0u8; 5];
371        joiner.read_exact(&mut buf).unwrap();
372        assert_eq!(&buf, b"world");
373    }
374
375    #[test]
376    fn test_joiner_seek_current() {
377        let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 3)
378            .map(|i| (i % 256) as u8)
379            .collect();
380        let (root, store) = split_and_store(&data);
381        let mut joiner = SyncJoiner::new(store, root).unwrap();
382
383        let offset = DEFAULT_BODY_SIZE + 100;
384        joiner.seek(SeekFrom::Start(offset as u64)).unwrap();
385        assert_eq!(joiner.position(), offset as u64);
386
387        let mut buf = vec![0u8; 50];
388        joiner.read_exact(&mut buf).unwrap();
389        assert_eq!(&buf, &data[offset..offset + 50]);
390
391        joiner.seek(SeekFrom::Current(-50)).unwrap();
392        let mut buf2 = vec![0u8; 50];
393        joiner.read_exact(&mut buf2).unwrap();
394        assert_eq!(buf, buf2);
395    }
396
397    #[test]
398    fn test_joiner_seek_negative() {
399        let data = b"test data";
400        let (root, store) = split_and_store(data);
401        let mut joiner = SyncJoiner::new(store, root).unwrap();
402
403        let result = joiner.seek(SeekFrom::Current(-100));
404        assert!(result.is_err());
405    }
406
407    #[test]
408    fn test_joiner_partial_reads() {
409        let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 2 + 500)
410            .map(|i| (i % 256) as u8)
411            .collect();
412        let (root, store) = split_and_store(&data);
413        let mut joiner = SyncJoiner::new(store, root).unwrap();
414
415        let mut recovered = Vec::new();
416        let mut buf = [0u8; 100];
417        loop {
418            let n = joiner.read(&mut buf).unwrap();
419            if n == 0 {
420                break;
421            }
422            recovered.extend_from_slice(&buf[..n]);
423        }
424        assert_eq!(recovered, data);
425    }
426
427    #[test]
428    fn test_joiner_read_at_eof() {
429        let data = b"test data";
430        let (root, store) = split_and_store(data);
431        let mut joiner = SyncJoiner::new(store, root).unwrap();
432
433        let mut buf = vec![0u8; data.len()];
434        joiner.read_exact(&mut buf).unwrap();
435
436        let mut buf2 = [0u8; 10];
437        let n = joiner.read(&mut buf2).unwrap();
438        assert_eq!(n, 0);
439    }
440
441    #[cfg(feature = "encryption")]
442    mod encrypted {
443        use super::*;
444        use crate::chunk::encryption::EncryptedChunkRef;
445        use crate::file::sync_split_encrypted;
446
447        fn encrypted_split_and_store(
448            data: &[u8],
449        ) -> (EncryptedChunkRef, HashMap<ChunkAddress, AnyChunk>) {
450            let (root_ref, store) = sync_split_encrypted::<DEFAULT_BODY_SIZE>(data).unwrap();
451            (root_ref, store.into_chunks())
452        }
453
454        // --- Generated shared tests (sync variants) ---
455        generate_encrypted_joiner_tests!(test, EncryptedSyncJoiner, [], []);
456
457        // --- Sync-only tests: std::io::Read + Seek ---
458
459        #[test]
460        fn test_encrypted_joiner_streaming() {
461            let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 65)
462                .map(|i| (i % 256) as u8)
463                .collect();
464            let (root_ref, store) = encrypted_split_and_store(&data);
465
466            let mut joiner = EncryptedSyncJoiner::new(store, root_ref).unwrap();
467            let mut result = vec![0u8; data.len()];
468            joiner.read_exact(&mut result).unwrap();
469            assert_eq!(result, data);
470        }
471
472        #[test]
473        fn test_encrypted_joiner_small_buffer_streaming() {
474            let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 128)
475                .map(|i| (i % 256) as u8)
476                .collect();
477            let (root_ref, store) = encrypted_split_and_store(&data);
478
479            let mut joiner = EncryptedSyncJoiner::new(store, root_ref).unwrap();
480            let mut result = Vec::new();
481            let mut buf = [0u8; 100];
482            loop {
483                let n = joiner.read(&mut buf).unwrap();
484                if n == 0 {
485                    break;
486                }
487                result.extend_from_slice(&buf[..n]);
488            }
489            assert_eq!(result, data);
490        }
491
492        #[test]
493        fn test_encrypted_joiner_seek_back_and_forth() {
494            let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 3)
495                .map(|i| (i % 256) as u8)
496                .collect();
497            let (root_ref, store) = encrypted_split_and_store(&data);
498            let mut joiner = EncryptedSyncJoiner::new(store, root_ref).unwrap();
499
500            // Read from middle
501            joiner
502                .seek(SeekFrom::Start(DEFAULT_BODY_SIZE as u64))
503                .unwrap();
504            let mut buf1 = vec![0u8; 100];
505            joiner.read_exact(&mut buf1).unwrap();
506            assert_eq!(&buf1, &data[DEFAULT_BODY_SIZE..DEFAULT_BODY_SIZE + 100]);
507
508            // Seek back to start
509            joiner.seek(SeekFrom::Start(0)).unwrap();
510            let mut buf2 = vec![0u8; 100];
511            joiner.read_exact(&mut buf2).unwrap();
512            assert_eq!(&buf2, &data[..100]);
513        }
514    }
515}