1use std::io::{self, SeekFrom};
8use std::marker::PhantomData;
9
10use bytes::Bytes;
11use rayon::prelude::*;
12
13use crate::bmt::DEFAULT_BODY_SIZE;
14use crate::chunk::ChunkAddress;
15
16use super::error::Result;
17use super::frontier::{SubtreeNode, expand_frontier, read_subtree_bodies};
18use super::mode::{JoinMode, PlainMode};
19use super::tree::{ChunkRange, TreeParams};
20use crate::store::SyncChunkGet;
21
22#[cfg(feature = "encryption")]
23use super::mode::EncryptedMode;
24
25pub struct GenericSyncJoiner<G, M: JoinMode, const BODY_SIZE: usize = DEFAULT_BODY_SIZE>
31where
32 G: SyncChunkGet<BODY_SIZE> + Clone + Send + Sync,
33{
34 getter: G,
35 root: ChunkAddress,
36 context: M::JoinerContext,
37 span: u64,
38 tree: TreeParams<BODY_SIZE>,
39
40 subtrees: Vec<SubtreeNode<M>>,
42
43 read_pos: u64,
45 buffer: Vec<u8>,
46 buffer_pos: usize,
47 subtree_idx: usize,
48
49 _mode: PhantomData<M>,
50}
51
52pub type SyncJoiner<G, const BODY_SIZE: usize = DEFAULT_BODY_SIZE> =
54 GenericSyncJoiner<G, PlainMode, BODY_SIZE>;
55
56#[cfg(feature = "encryption")]
58pub type EncryptedSyncJoiner<G, const BODY_SIZE: usize = DEFAULT_BODY_SIZE> =
59 GenericSyncJoiner<G, EncryptedMode, BODY_SIZE>;
60
61impl<G, M, const BODY_SIZE: usize> std::fmt::Debug for GenericSyncJoiner<G, M, BODY_SIZE>
62where
63 G: SyncChunkGet<BODY_SIZE> + Clone + Send + Sync,
64 M: JoinMode,
65{
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 f.debug_struct("GenericSyncJoiner")
68 .field("root", &self.root)
69 .field("span", &self.span)
70 .field("read_pos", &self.read_pos)
71 .finish_non_exhaustive()
72 }
73}
74
75impl<G, M, const BODY_SIZE: usize> GenericSyncJoiner<G, M, BODY_SIZE>
76where
77 G: SyncChunkGet<BODY_SIZE> + Clone + Send + Sync,
78 M: JoinMode + Send + Sync,
79{
80 pub fn new(getter: G, input: M::RootRef) -> Result<Self> {
82 const { super::constants::assert_valid_body_size::<BODY_SIZE>() };
83
84 let (root, span, context) = super::mode::joiner_init::<M, G, BODY_SIZE>(&getter, input)?;
85 let tree = TreeParams::<BODY_SIZE>::new(span);
86
87 let target = rayon::current_num_threads().max(1) * 2;
89 let full_range = tree.chunks_for_range(0, span);
90 let subtrees = expand_frontier::<G, M, BODY_SIZE>(
91 &getter,
92 &root,
93 &context,
94 span,
95 &full_range,
96 target,
97 )?;
98
99 Ok(Self {
100 getter,
101 root,
102 context,
103 span,
104 tree,
105 subtrees,
106 read_pos: 0,
107 buffer: Vec::new(),
108 buffer_pos: 0,
109 subtree_idx: 0,
110 _mode: PhantomData,
111 })
112 }
113
114 #[inline]
116 pub const fn size(&self) -> u64 {
117 self.span
118 }
119
120 #[inline]
122 pub const fn position(&self) -> u64 {
123 self.read_pos
124 }
125
126 #[inline]
128 pub const fn root(&self) -> &ChunkAddress {
129 &self.root
130 }
131
132 pub fn read_range(&self, offset: u64, len: usize) -> Result<Vec<u8>> {
134 use super::helpers::{ReadRangeCheck, validate_read_range};
135
136 match validate_read_range::<BODY_SIZE>(offset, len, self.span) {
137 ReadRangeCheck::Empty => Ok(Vec::new()),
138 ReadRangeCheck::SingleChunk { offset, actual_len } => {
139 self.read_single_chunk(offset, actual_len)
140 }
141 ReadRangeCheck::MultiChunk { offset, actual_len } => {
142 let chunk_range = self.tree.chunks_for_range(offset, actual_len as u64);
143 let range_start_byte = chunk_range.start * BODY_SIZE as u64;
144 let range_end_byte = chunk_range.end * BODY_SIZE as u64;
145
146 let bodies = self.collect_bodies(&chunk_range, range_start_byte, range_end_byte)?;
147
148 Ok(super::tree::assemble_range(
149 &self.tree,
150 offset,
151 actual_len,
152 &chunk_range,
153 &bodies,
154 ))
155 }
156 }
157 }
158
159 pub fn read_all(&self) -> Result<Vec<u8>> {
161 self.read_range(0, self.span as usize)
162 }
163
164 fn collect_bodies(
166 &self,
167 chunk_range: &ChunkRange,
168 range_start_byte: u64,
169 range_end_byte: u64,
170 ) -> Result<Vec<Bytes>> {
171 let getter = &self.getter;
172 let nested: Vec<Vec<Bytes>> = self
173 .subtrees
174 .par_iter()
175 .filter(|st| {
176 st.byte_offset < range_end_byte && st.byte_offset + st.span > range_start_byte
177 })
178 .map(|st| {
179 let mut bodies = Vec::with_capacity((st.span as usize / BODY_SIZE).max(1));
180 read_subtree_bodies::<G, M, BODY_SIZE>(getter, st, chunk_range, &mut bodies)?;
181 Ok(bodies)
182 })
183 .collect::<Result<Vec<Vec<Bytes>>>>()?;
184
185 Ok(nested.into_iter().flat_map(|v| v.into_iter()).collect())
186 }
187
188 fn read_single_chunk(&self, offset: u64, len: usize) -> Result<Vec<u8>> {
189 let body = super::mode::read_chunk_body::<M, G, BODY_SIZE>(
190 &self.getter,
191 &self.root,
192 &self.context,
193 self.span,
194 )?;
195 let start = offset as usize;
196 let end = start + len;
197 Ok(body[start..end].to_vec())
198 }
199
200 fn fill_buffer(&mut self) -> Result<()> {
202 let batch_size = rayon::current_num_threads().max(1);
203 let start_idx = self.subtree_idx;
204 let end_idx = (start_idx + batch_size).min(self.subtrees.len());
205
206 let batch = &self.subtrees[start_idx..end_idx];
207 if batch.is_empty() {
208 return Ok(());
209 }
210
211 let batch_start_byte = batch[0].byte_offset;
212 let last = &batch[batch.len() - 1];
213 let batch_end_byte = (last.byte_offset + last.span).min(self.span);
214 let chunk_range = ChunkRange {
215 start: batch_start_byte / BODY_SIZE as u64,
216 end: batch_end_byte.div_ceil(BODY_SIZE as u64),
217 };
218
219 let getter = &self.getter;
220 let all_bodies = batch
221 .par_iter()
222 .map(|st| {
223 let mut bodies = Vec::with_capacity((st.span as usize / BODY_SIZE).max(1));
224 read_subtree_bodies::<G, M, BODY_SIZE>(getter, st, &chunk_range, &mut bodies)?;
225 Ok(bodies)
226 })
227 .collect::<Result<Vec<Vec<Bytes>>>>()?;
228
229 let estimated = (batch_end_byte - batch_start_byte) as usize;
230 self.buffer.clear();
231 self.buffer.reserve(estimated);
232 for bodies in all_bodies {
233 for body in bodies {
234 self.buffer.extend_from_slice(&body);
235 }
236 }
237 self.buffer_pos = 0;
238 self.subtree_idx = end_idx;
239
240 if self.read_pos > batch_start_byte {
242 self.buffer_pos = (self.read_pos - batch_start_byte) as usize;
243 }
244
245 Ok(())
246 }
247
248 fn drain_buffer(&mut self, buf: &mut [u8]) -> usize {
250 let available = self.buffer.len() - self.buffer_pos;
251 let to_copy = buf.len().min(available);
252 buf[..to_copy].copy_from_slice(&self.buffer[self.buffer_pos..self.buffer_pos + to_copy]);
253 self.buffer_pos += to_copy;
254 self.read_pos += to_copy as u64;
255 to_copy
256 }
257}
258
259impl<G, M, const BODY_SIZE: usize> io::Read for GenericSyncJoiner<G, M, BODY_SIZE>
260where
261 G: SyncChunkGet<BODY_SIZE> + Clone + Send + Sync,
262 M: JoinMode + Send + Sync,
263{
264 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
265 if buf.is_empty() || self.read_pos >= self.span {
266 return Ok(0);
267 }
268
269 if self.buffer_pos < self.buffer.len() {
270 return Ok(self.drain_buffer(buf));
271 }
272
273 if self.subtree_idx >= self.subtrees.len() {
274 return Ok(0);
275 }
276
277 self.fill_buffer().map_err(io::Error::other)?;
278
279 if self.buffer.is_empty() {
280 return Ok(0);
281 }
282
283 Ok(self.drain_buffer(buf))
284 }
285}
286
287impl<G, M, const BODY_SIZE: usize> io::Seek for GenericSyncJoiner<G, M, BODY_SIZE>
288where
289 G: SyncChunkGet<BODY_SIZE> + Clone + Send + Sync,
290 M: JoinMode + Send + Sync,
291{
292 fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
293 self.read_pos = super::resolve_seek_position(pos, self.read_pos, self.span)?;
294 self.buffer.clear();
295 self.buffer_pos = 0;
296 self.subtree_idx = self
297 .subtrees
298 .iter()
299 .position(|st| st.byte_offset + st.span > self.read_pos)
300 .unwrap_or(self.subtrees.len());
301 Ok(self.read_pos)
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use crate::chunk::AnyChunk;
309 use crate::file::sync_split;
310 use std::collections::HashMap;
311 use std::io::{Read, Seek};
312
313 fn split_and_store(data: &[u8]) -> (ChunkAddress, HashMap<ChunkAddress, AnyChunk>) {
314 let (root, store) = sync_split::<DEFAULT_BODY_SIZE>(data).unwrap();
315 (root, store.into_chunks())
316 }
317
318 generate_plain_joiner_tests!(test, SyncJoiner, [], []);
320
321 #[test]
324 fn test_joiner_streaming() {
325 let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 3 + 500)
326 .map(|i| (i % 256) as u8)
327 .collect();
328 let (root, store) = split_and_store(&data);
329
330 let mut joiner = SyncJoiner::new(store, root).unwrap();
331 let mut result = vec![0u8; data.len()];
332 joiner.read_exact(&mut result).unwrap();
333 assert_eq!(result, data);
334 }
335
336 #[test]
337 fn test_joiner_small_buffer_streaming() {
338 let refs_per_chunk = DEFAULT_BODY_SIZE / super::super::constants::REF_SIZE;
339 let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * refs_per_chunk)
340 .map(|i| (i % 256) as u8)
341 .collect();
342 let (root, store) = split_and_store(&data);
343
344 let mut joiner = SyncJoiner::new(store, root).unwrap();
345 let mut result = Vec::new();
346 let mut buf = [0u8; 100];
347 loop {
348 let n = joiner.read(&mut buf).unwrap();
349 if n == 0 {
350 break;
351 }
352 result.extend_from_slice(&buf[..n]);
353 }
354 assert_eq!(result, data);
355 }
356
357 #[test]
358 fn test_joiner_seek_start() {
359 let data = b"hello world";
360 let (root, store) = split_and_store(data);
361 let mut joiner = SyncJoiner::new(store, root).unwrap();
362
363 joiner.seek(SeekFrom::Start(6)).unwrap();
364 let result = joiner.read_all().unwrap();
365 assert_eq!(result, data);
367
368 joiner.seek(SeekFrom::Start(6)).unwrap();
370 let mut buf = vec![0u8; 5];
371 joiner.read_exact(&mut buf).unwrap();
372 assert_eq!(&buf, b"world");
373 }
374
375 #[test]
376 fn test_joiner_seek_current() {
377 let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 3)
378 .map(|i| (i % 256) as u8)
379 .collect();
380 let (root, store) = split_and_store(&data);
381 let mut joiner = SyncJoiner::new(store, root).unwrap();
382
383 let offset = DEFAULT_BODY_SIZE + 100;
384 joiner.seek(SeekFrom::Start(offset as u64)).unwrap();
385 assert_eq!(joiner.position(), offset as u64);
386
387 let mut buf = vec![0u8; 50];
388 joiner.read_exact(&mut buf).unwrap();
389 assert_eq!(&buf, &data[offset..offset + 50]);
390
391 joiner.seek(SeekFrom::Current(-50)).unwrap();
392 let mut buf2 = vec![0u8; 50];
393 joiner.read_exact(&mut buf2).unwrap();
394 assert_eq!(buf, buf2);
395 }
396
397 #[test]
398 fn test_joiner_seek_negative() {
399 let data = b"test data";
400 let (root, store) = split_and_store(data);
401 let mut joiner = SyncJoiner::new(store, root).unwrap();
402
403 let result = joiner.seek(SeekFrom::Current(-100));
404 assert!(result.is_err());
405 }
406
407 #[test]
408 fn test_joiner_partial_reads() {
409 let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 2 + 500)
410 .map(|i| (i % 256) as u8)
411 .collect();
412 let (root, store) = split_and_store(&data);
413 let mut joiner = SyncJoiner::new(store, root).unwrap();
414
415 let mut recovered = Vec::new();
416 let mut buf = [0u8; 100];
417 loop {
418 let n = joiner.read(&mut buf).unwrap();
419 if n == 0 {
420 break;
421 }
422 recovered.extend_from_slice(&buf[..n]);
423 }
424 assert_eq!(recovered, data);
425 }
426
427 #[test]
428 fn test_joiner_read_at_eof() {
429 let data = b"test data";
430 let (root, store) = split_and_store(data);
431 let mut joiner = SyncJoiner::new(store, root).unwrap();
432
433 let mut buf = vec![0u8; data.len()];
434 joiner.read_exact(&mut buf).unwrap();
435
436 let mut buf2 = [0u8; 10];
437 let n = joiner.read(&mut buf2).unwrap();
438 assert_eq!(n, 0);
439 }
440
441 #[cfg(feature = "encryption")]
442 mod encrypted {
443 use super::*;
444 use crate::chunk::encryption::EncryptedChunkRef;
445 use crate::file::sync_split_encrypted;
446
447 fn encrypted_split_and_store(
448 data: &[u8],
449 ) -> (EncryptedChunkRef, HashMap<ChunkAddress, AnyChunk>) {
450 let (root_ref, store) = sync_split_encrypted::<DEFAULT_BODY_SIZE>(data).unwrap();
451 (root_ref, store.into_chunks())
452 }
453
454 generate_encrypted_joiner_tests!(test, EncryptedSyncJoiner, [], []);
456
457 #[test]
460 fn test_encrypted_joiner_streaming() {
461 let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 65)
462 .map(|i| (i % 256) as u8)
463 .collect();
464 let (root_ref, store) = encrypted_split_and_store(&data);
465
466 let mut joiner = EncryptedSyncJoiner::new(store, root_ref).unwrap();
467 let mut result = vec![0u8; data.len()];
468 joiner.read_exact(&mut result).unwrap();
469 assert_eq!(result, data);
470 }
471
472 #[test]
473 fn test_encrypted_joiner_small_buffer_streaming() {
474 let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 128)
475 .map(|i| (i % 256) as u8)
476 .collect();
477 let (root_ref, store) = encrypted_split_and_store(&data);
478
479 let mut joiner = EncryptedSyncJoiner::new(store, root_ref).unwrap();
480 let mut result = Vec::new();
481 let mut buf = [0u8; 100];
482 loop {
483 let n = joiner.read(&mut buf).unwrap();
484 if n == 0 {
485 break;
486 }
487 result.extend_from_slice(&buf[..n]);
488 }
489 assert_eq!(result, data);
490 }
491
492 #[test]
493 fn test_encrypted_joiner_seek_back_and_forth() {
494 let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 3)
495 .map(|i| (i % 256) as u8)
496 .collect();
497 let (root_ref, store) = encrypted_split_and_store(&data);
498 let mut joiner = EncryptedSyncJoiner::new(store, root_ref).unwrap();
499
500 joiner
502 .seek(SeekFrom::Start(DEFAULT_BODY_SIZE as u64))
503 .unwrap();
504 let mut buf1 = vec![0u8; 100];
505 joiner.read_exact(&mut buf1).unwrap();
506 assert_eq!(&buf1, &data[DEFAULT_BODY_SIZE..DEFAULT_BODY_SIZE + 100]);
507
508 joiner.seek(SeekFrom::Start(0)).unwrap();
510 let mut buf2 = vec![0u8; 100];
511 joiner.read_exact(&mut buf2).unwrap();
512 assert_eq!(&buf2, &data[..100]);
513 }
514 }
515}