1use std::io::SeekFrom;
4use std::marker::PhantomData;
5use std::sync::Arc;
6
7const DEFAULT_ASYNC_CONCURRENCY: usize = 8;
9
10#[cfg(feature = "tokio")]
11use bytes::Buf;
12use bytes::Bytes;
13use futures::stream::{self, Stream, StreamExt};
14
15use crate::bmt::DEFAULT_BODY_SIZE;
16use crate::chunk::ChunkAddress;
17
18use super::error::{FileError, Result};
19use super::frontier::{SubtreeNode, expand_frontier_async, read_subtree_bodies_async};
20use super::mode::{JoinMode, PlainMode};
21use super::tree::{ChunkRange, TreeParams};
22use crate::store::ChunkGet;
23
24#[cfg(feature = "encryption")]
25use super::mode::EncryptedMode;
26
27pub struct GenericJoiner<G, M: JoinMode, const BODY_SIZE: usize = DEFAULT_BODY_SIZE>
29where
30 G: ChunkGet<BODY_SIZE>,
31{
32 getter: Arc<G>,
33 root: ChunkAddress,
34 context: M::JoinerContext,
35 span: u64,
36 tree: TreeParams<BODY_SIZE>,
37 subtrees: Vec<SubtreeNode<M>>,
39 position: u64,
40 concurrency: usize,
41 _mode: PhantomData<M>,
42}
43
44pub type Joiner<G, const BODY_SIZE: usize = DEFAULT_BODY_SIZE> =
46 GenericJoiner<G, PlainMode, BODY_SIZE>;
47
48#[cfg(feature = "encryption")]
50pub type EncryptedJoiner<G, const BODY_SIZE: usize = DEFAULT_BODY_SIZE> =
51 GenericJoiner<G, EncryptedMode, BODY_SIZE>;
52
53impl<G, M, const BODY_SIZE: usize> std::fmt::Debug for GenericJoiner<G, M, BODY_SIZE>
54where
55 G: ChunkGet<BODY_SIZE>,
56 M: JoinMode,
57{
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("GenericJoiner")
60 .field("root", &self.root)
61 .field("span", &self.span)
62 .field("position", &self.position)
63 .field("concurrency", &self.concurrency)
64 .finish_non_exhaustive()
65 }
66}
67
68async fn collect_subtree_bodies_async<G, M, const BODY_SIZE: usize>(
70 getter: &Arc<G>,
71 subtrees: Vec<SubtreeNode<M>>,
72 chunk_range: ChunkRange,
73 concurrency: usize,
74) -> Result<Vec<Bytes>>
75where
76 G: ChunkGet<BODY_SIZE>,
77 M: JoinMode + Send + Sync,
78{
79 let bodies: Vec<Bytes> = stream::iter(subtrees)
80 .map(|st| {
81 let getter = Arc::clone(getter);
82 async move {
83 read_subtree_bodies_async::<G, M, BODY_SIZE>(&*getter, &st, &chunk_range).await
84 }
85 })
86 .buffered(concurrency)
87 .collect::<Vec<_>>()
88 .await
89 .into_iter()
90 .collect::<Result<Vec<Vec<Bytes>>>>()?
91 .into_iter()
92 .flatten()
93 .collect();
94 Ok(bodies)
95}
96
97impl<G, M, const BODY_SIZE: usize> GenericJoiner<G, M, BODY_SIZE>
98where
99 G: ChunkGet<BODY_SIZE>,
100 M: JoinMode + Send + Sync,
101{
102 pub async fn new(getter: G, input: M::RootRef) -> Result<Self> {
104 const { super::constants::assert_valid_body_size::<BODY_SIZE>() };
105
106 let (root, span, context) =
107 super::mode::joiner_init_async::<M, G, BODY_SIZE>(&getter, input).await?;
108 let tree = TreeParams::<BODY_SIZE>::new(span);
109
110 let target = DEFAULT_ASYNC_CONCURRENCY * 2;
111 let full_range = tree.chunks_for_range(0, span);
112 let subtrees = expand_frontier_async::<G, M, BODY_SIZE>(
113 &getter,
114 &root,
115 &context,
116 span,
117 &full_range,
118 target,
119 )
120 .await?;
121
122 Ok(Self {
123 getter: Arc::new(getter),
124 root,
125 context,
126 span,
127 tree,
128 subtrees,
129 position: 0,
130 concurrency: DEFAULT_ASYNC_CONCURRENCY,
131 _mode: PhantomData,
132 })
133 }
134
135 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
137 self.concurrency = concurrency.max(1);
138 self
139 }
140
141 #[inline]
143 pub const fn size(&self) -> u64 {
144 self.span
145 }
146
147 #[inline]
149 pub const fn position(&self) -> u64 {
150 self.position
151 }
152
153 #[inline]
155 pub const fn root(&self) -> &ChunkAddress {
156 &self.root
157 }
158
159 pub async fn read_range(&self, offset: u64, len: usize) -> Result<Vec<u8>> {
161 Self::read_range_with(
162 &self.getter,
163 &self.subtrees,
164 &self.root,
165 &self.context,
166 self.span,
167 self.tree,
168 self.concurrency,
169 offset,
170 len,
171 )
172 .await
173 }
174
175 pub async fn read_all(&self) -> Result<Vec<u8>> {
177 self.read_range(0, self.span as usize).await
178 }
179
180 #[allow(
182 clippy::too_many_arguments,
183 reason = "internal helper threading already-decomposed reader state from two call sites"
184 )]
185 async fn read_range_with(
186 getter: &Arc<G>,
187 subtrees: &[SubtreeNode<M>],
188 root: &ChunkAddress,
189 context: &M::JoinerContext,
190 span: u64,
191 tree: TreeParams<BODY_SIZE>,
192 concurrency: usize,
193 offset: u64,
194 len: usize,
195 ) -> Result<Vec<u8>> {
196 use super::helpers::{ReadRangeCheck, validate_read_range};
197
198 let (offset, actual_len) = match validate_read_range::<BODY_SIZE>(offset, len, span) {
199 ReadRangeCheck::Empty => return Ok(Vec::new()),
200 ReadRangeCheck::SingleChunk { offset, actual_len } => {
201 let chunk = getter.get(root).await.map_err(FileError::getter)?;
202 let chunk = chunk.into_content().ok_or(FileError::InvalidChunkType {
203 type_name: "non-content",
204 })?;
205 let body = M::decode_body::<BODY_SIZE>(chunk, context, span)?;
206 let start = offset as usize;
207 let end = start + actual_len;
208 return Ok(body[start..end].to_vec());
209 }
210 ReadRangeCheck::MultiChunk { offset, actual_len } => (offset, actual_len),
211 };
212
213 let chunk_range = tree.chunks_for_range(offset, actual_len as u64);
214 let range_start_byte = chunk_range.start * BODY_SIZE as u64;
215 let range_end_byte = chunk_range.end * BODY_SIZE as u64;
216
217 let relevant: Vec<_> = subtrees
218 .iter()
219 .filter(|st| {
220 st.byte_offset < range_end_byte && st.byte_offset + st.span > range_start_byte
221 })
222 .cloned()
223 .collect();
224
225 let bodies = collect_subtree_bodies_async::<G, M, BODY_SIZE>(
226 getter,
227 relevant,
228 chunk_range,
229 concurrency,
230 )
231 .await?;
232
233 Ok(super::tree::assemble_range(
234 &tree,
235 offset,
236 actual_len,
237 &chunk_range,
238 &bodies,
239 ))
240 }
241
242 pub fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
244 self.position = super::resolve_seek_position(pos, self.position, self.span)?;
245 Ok(self.position)
246 }
247
248 pub fn into_stream(self) -> impl Stream<Item = Result<Bytes>> {
250 let getter = self.getter;
251 let chunk_range = self.tree.chunks_for_range(0, self.span);
252
253 struct State<M: JoinMode> {
254 subtrees: std::vec::IntoIter<SubtreeNode<M>>,
255 pending: std::vec::IntoIter<Bytes>,
256 }
257
258 let state = State {
259 subtrees: self.subtrees.into_iter(),
260 pending: Vec::new().into_iter(),
261 };
262
263 stream::unfold(state, move |mut state| {
264 let getter = Arc::clone(&getter);
265 async move {
266 if let Some(body) = state.pending.next() {
268 return Some((Ok(body), state));
269 }
270
271 let st = state.subtrees.next()?;
273 match read_subtree_bodies_async::<G, M, BODY_SIZE>(&*getter, &st, &chunk_range)
274 .await
275 {
276 Ok(bodies) => {
277 let mut iter = bodies.into_iter();
278 match iter.next() {
279 Some(first) => {
280 state.pending = iter;
281 Some((Ok(first), state))
282 }
283 None => Some((Ok(Bytes::new()), state)),
284 }
285 }
286 Err(e) => Some((Err(e), state)),
287 }
288 }
289 })
290 }
291
292 #[cfg(feature = "tokio")]
294 pub fn into_reader(self) -> JoinerReader<G, M, BODY_SIZE> {
295 JoinerReader {
296 joiner: self,
297 buffer: Bytes::new(),
298 future: None,
299 }
300 }
301}
302
303#[cfg(feature = "tokio")]
307pub struct JoinerReader<G, M: JoinMode, const BODY_SIZE: usize = DEFAULT_BODY_SIZE>
308where
309 G: ChunkGet<BODY_SIZE>,
310{
311 joiner: GenericJoiner<G, M, BODY_SIZE>,
312 buffer: Bytes,
313 #[allow(clippy::type_complexity)]
314 future: Option<std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<u8>>> + Send>>>,
315}
316
317#[cfg(feature = "tokio")]
318impl<G, M, const BODY_SIZE: usize> std::fmt::Debug for JoinerReader<G, M, BODY_SIZE>
319where
320 G: ChunkGet<BODY_SIZE>,
321 M: JoinMode,
322{
323 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324 f.debug_struct("JoinerReader")
325 .field("joiner", &self.joiner)
326 .field("buffer_len", &self.buffer.len())
327 .field("has_pending_future", &self.future.is_some())
328 .finish()
329 }
330}
331
332#[cfg(feature = "tokio")]
335impl<G: ChunkGet<BODY_SIZE>, M: JoinMode, const BODY_SIZE: usize> Unpin
336 for JoinerReader<G, M, BODY_SIZE>
337{
338}
339
340#[cfg(feature = "tokio")]
341impl<G, M, const BODY_SIZE: usize> tokio::io::AsyncRead for JoinerReader<G, M, BODY_SIZE>
342where
343 G: ChunkGet<BODY_SIZE> + 'static,
344 M: JoinMode + Send + Sync + 'static,
345{
346 fn poll_read(
347 self: std::pin::Pin<&mut Self>,
348 cx: &mut std::task::Context<'_>,
349 buf: &mut tokio::io::ReadBuf<'_>,
350 ) -> std::task::Poll<std::io::Result<()>> {
351 use std::task::Poll;
352
353 let this = self.get_mut();
354
355 if !this.buffer.is_empty() {
357 let to_copy = this.buffer.len().min(buf.remaining());
358 buf.put_slice(&this.buffer[..to_copy]);
359 this.buffer.advance(to_copy);
360 return Poll::Ready(Ok(()));
361 }
362
363 if this.joiner.position >= this.joiner.span {
365 return Poll::Ready(Ok(()));
366 }
367
368 if this.future.is_none() {
370 let position = this.joiner.position;
371 let remaining = (this.joiner.span - position) as usize;
372 let read_len = remaining.min(BODY_SIZE);
373 let getter = Arc::clone(&this.joiner.getter);
374 let root = this.joiner.root;
375 let context = this.joiner.context.clone();
376 let span = this.joiner.span;
377 let tree = this.joiner.tree;
378 let concurrency = this.joiner.concurrency;
379 let subtrees: Vec<SubtreeNode<M>> = this.joiner.subtrees.clone();
380
381 let fut = async move {
382 GenericJoiner::<G, M, BODY_SIZE>::read_range_with(
383 &getter,
384 &subtrees,
385 &root,
386 &context,
387 span,
388 tree,
389 concurrency,
390 position,
391 read_len,
392 )
393 .await
394 };
395 this.future = Some(Box::pin(fut));
396 }
397
398 let fut = this.future.as_mut().unwrap();
400 match fut.as_mut().poll(cx) {
401 Poll::Ready(Ok(data)) => {
402 this.future = None;
403 let bytes = Bytes::from(data);
404 this.joiner.position += bytes.len() as u64;
405 let to_copy = bytes.len().min(buf.remaining());
406 buf.put_slice(&bytes[..to_copy]);
407 if to_copy < bytes.len() {
408 this.buffer = bytes.slice(to_copy..);
409 }
410 Poll::Ready(Ok(()))
411 }
412 Poll::Ready(Err(e)) => {
413 this.future = None;
414 Poll::Ready(Err(std::io::Error::other(e)))
415 }
416 Poll::Pending => Poll::Pending,
417 }
418 }
419}
420
421#[cfg(feature = "tokio")]
422impl<G, M, const BODY_SIZE: usize> tokio::io::AsyncSeek for JoinerReader<G, M, BODY_SIZE>
423where
424 G: ChunkGet<BODY_SIZE> + 'static,
425 M: JoinMode + Send + Sync + 'static,
426{
427 fn start_seek(self: std::pin::Pin<&mut Self>, pos: SeekFrom) -> std::io::Result<()> {
428 let this = self.get_mut();
429 this.joiner.position =
430 super::resolve_seek_position(pos, this.joiner.position, this.joiner.span)?;
431 this.buffer = Bytes::new();
432 this.future = None;
433 Ok(())
434 }
435
436 fn poll_complete(
437 self: std::pin::Pin<&mut Self>,
438 _cx: &mut std::task::Context<'_>,
439 ) -> std::task::Poll<std::io::Result<u64>> {
440 std::task::Poll::Ready(Ok(self.get_mut().joiner.position))
441 }
442}
443
444#[cfg(all(test, feature = "tokio"))]
445mod tests {
446 use super::*;
447 use crate::chunk::AnyChunk;
448 use crate::file::sync_split;
449 use std::collections::HashMap;
450
451 fn split_and_store(data: &[u8]) -> (ChunkAddress, HashMap<ChunkAddress, AnyChunk>) {
452 let (root, store) = sync_split::<DEFAULT_BODY_SIZE>(data).unwrap();
453 (root, store.into_chunks())
454 }
455
456 generate_plain_joiner_tests!(tokio::test, Joiner, [async], [await]);
458
459 #[tokio::test]
462 async fn test_async_joiner_stream() {
463 let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 3)
464 .map(|i| (i % 256) as u8)
465 .collect();
466 let (root, store) = split_and_store(&data);
467
468 let joiner = Joiner::new(store, root).await.unwrap();
469 let chunks: Vec<Result<Bytes>> = joiner.into_stream().collect().await;
470
471 let mut recovered = Vec::new();
472 for chunk in chunks {
473 recovered.extend_from_slice(&chunk.unwrap());
474 }
475 assert_eq!(recovered, data);
476 }
477
478 #[cfg(feature = "tokio")]
479 #[tokio::test]
480 async fn test_async_reader_small() {
481 use tokio::io::AsyncReadExt;
482
483 let data = b"hello world";
484 let (root, store) = split_and_store(data);
485
486 let joiner = Joiner::new(store, root).await.unwrap();
487 let mut reader = joiner.into_reader();
488 let mut result = Vec::new();
489 reader.read_to_end(&mut result).await.unwrap();
490 assert_eq!(result, data);
491 }
492
493 #[cfg(feature = "tokio")]
494 #[tokio::test]
495 async fn test_async_reader_multi_chunk() {
496 use tokio::io::AsyncReadExt;
497
498 let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 3 + 123)
499 .map(|i| (i % 256) as u8)
500 .collect();
501 let (root, store) = split_and_store(&data);
502
503 let joiner = Joiner::new(store, root).await.unwrap();
504 let mut reader = joiner.into_reader();
505 let mut result = Vec::new();
506 reader.read_to_end(&mut result).await.unwrap();
507 assert_eq!(result, data);
508 }
509
510 #[cfg(feature = "tokio")]
511 #[tokio::test]
512 async fn test_async_reader_seek() {
513 use tokio::io::{AsyncReadExt, AsyncSeekExt};
514
515 let data = b"hello world";
516 let (root, store) = split_and_store(data);
517
518 let joiner = Joiner::new(store, root).await.unwrap();
519 let mut reader = joiner.into_reader();
520
521 reader.seek(SeekFrom::Start(6)).await.unwrap();
522 let mut buf = vec![0u8; 5];
523 reader.read_exact(&mut buf).await.unwrap();
524 assert_eq!(&buf, b"world");
525 }
526
527 #[cfg(feature = "tokio")]
528 #[tokio::test]
529 async fn test_async_reader_seek_back_and_forth() {
530 use tokio::io::{AsyncReadExt, AsyncSeekExt};
531
532 let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 3)
533 .map(|i| (i % 256) as u8)
534 .collect();
535 let (root, store) = split_and_store(&data);
536
537 let joiner = Joiner::new(store, root).await.unwrap();
538 let mut reader = joiner.into_reader();
539
540 reader
542 .seek(SeekFrom::Start(DEFAULT_BODY_SIZE as u64))
543 .await
544 .unwrap();
545 let mut buf1 = vec![0u8; 100];
546 reader.read_exact(&mut buf1).await.unwrap();
547 assert_eq!(&buf1, &data[DEFAULT_BODY_SIZE..DEFAULT_BODY_SIZE + 100]);
548
549 reader.seek(SeekFrom::Start(0)).await.unwrap();
551 let mut buf2 = vec![0u8; 100];
552 reader.read_exact(&mut buf2).await.unwrap();
553 assert_eq!(&buf2, &data[..100]);
554
555 reader.seek(SeekFrom::End(-50)).await.unwrap();
557 let mut buf3 = vec![0u8; 50];
558 reader.read_exact(&mut buf3).await.unwrap();
559 assert_eq!(&buf3, &data[data.len() - 50..]);
560 }
561
562 #[cfg(feature = "encryption")]
563 mod encrypted {
564 use super::*;
565 use crate::file::sync_split_encrypted;
566
567 fn encrypted_split_and_store(
568 data: &[u8],
569 ) -> (
570 crate::chunk::encryption::EncryptedChunkRef,
571 HashMap<ChunkAddress, AnyChunk>,
572 ) {
573 let (root_ref, store) = sync_split_encrypted::<DEFAULT_BODY_SIZE>(data).unwrap();
574 (root_ref, store.into_chunks())
575 }
576
577 generate_encrypted_joiner_tests!(tokio::test, EncryptedJoiner, [async], [await]);
579
580 #[tokio::test]
583 async fn test_encrypted_async_joiner_stream() {
584 let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 3)
585 .map(|i| (i % 256) as u8)
586 .collect();
587 let (root_ref, store) = encrypted_split_and_store(&data);
588
589 let joiner = EncryptedJoiner::new(store, root_ref).await.unwrap();
590 let chunks: Vec<Result<Bytes>> = joiner.into_stream().collect().await;
591
592 let mut recovered = Vec::new();
593 for chunk in chunks {
594 recovered.extend_from_slice(&chunk.unwrap());
595 }
596 assert_eq!(recovered, data);
597 }
598 }
599}