use std::io::{self, SeekFrom};
use std::marker::PhantomData;
use bytes::Bytes;
use rayon::prelude::*;
use crate::bmt::DEFAULT_BODY_SIZE;
use crate::chunk::ChunkAddress;
use super::error::Result;
use super::frontier::{SubtreeNode, expand_frontier, read_subtree_bodies};
use super::mode::{JoinMode, PlainMode};
use super::tree::{ChunkRange, TreeParams};
use crate::store::SyncChunkGet;
#[cfg(feature = "encryption")]
use super::mode::EncryptedMode;
pub struct GenericSyncJoiner<G, M: JoinMode, const BODY_SIZE: usize = DEFAULT_BODY_SIZE>
where
G: SyncChunkGet<BODY_SIZE> + Clone + Send + Sync,
{
getter: G,
root: ChunkAddress,
context: M::JoinerContext,
span: u64,
tree: TreeParams<BODY_SIZE>,
subtrees: Vec<SubtreeNode<M>>,
read_pos: u64,
buffer: Vec<u8>,
buffer_pos: usize,
subtree_idx: usize,
_mode: PhantomData<M>,
}
pub type SyncJoiner<G, const BODY_SIZE: usize = DEFAULT_BODY_SIZE> =
GenericSyncJoiner<G, PlainMode, BODY_SIZE>;
#[cfg(feature = "encryption")]
pub type EncryptedSyncJoiner<G, const BODY_SIZE: usize = DEFAULT_BODY_SIZE> =
GenericSyncJoiner<G, EncryptedMode, BODY_SIZE>;
impl<G, M, const BODY_SIZE: usize> std::fmt::Debug for GenericSyncJoiner<G, M, BODY_SIZE>
where
G: SyncChunkGet<BODY_SIZE> + Clone + Send + Sync,
M: JoinMode,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GenericSyncJoiner")
.field("root", &self.root)
.field("span", &self.span)
.field("read_pos", &self.read_pos)
.finish_non_exhaustive()
}
}
impl<G, M, const BODY_SIZE: usize> GenericSyncJoiner<G, M, BODY_SIZE>
where
G: SyncChunkGet<BODY_SIZE> + Clone + Send + Sync,
M: JoinMode + Send + Sync,
{
pub fn new(getter: G, input: M::RootRef) -> Result<Self> {
const { super::constants::assert_valid_body_size::<BODY_SIZE>() };
let (root, span, context) = super::mode::joiner_init::<M, G, BODY_SIZE>(&getter, input)?;
let tree = TreeParams::<BODY_SIZE>::new(span);
let target = rayon::current_num_threads().max(1) * 2;
let full_range = tree.chunks_for_range(0, span);
let subtrees = expand_frontier::<G, M, BODY_SIZE>(
&getter,
&root,
&context,
span,
&full_range,
target,
)?;
Ok(Self {
getter,
root,
context,
span,
tree,
subtrees,
read_pos: 0,
buffer: Vec::new(),
buffer_pos: 0,
subtree_idx: 0,
_mode: PhantomData,
})
}
#[inline]
pub const fn size(&self) -> u64 {
self.span
}
#[inline]
pub const fn position(&self) -> u64 {
self.read_pos
}
#[inline]
pub const fn root(&self) -> &ChunkAddress {
&self.root
}
pub fn read_range(&self, offset: u64, len: usize) -> Result<Vec<u8>> {
use super::helpers::{ReadRangeCheck, validate_read_range};
match validate_read_range::<BODY_SIZE>(offset, len, self.span) {
ReadRangeCheck::Empty => Ok(Vec::new()),
ReadRangeCheck::SingleChunk { offset, actual_len } => {
self.read_single_chunk(offset, actual_len)
}
ReadRangeCheck::MultiChunk { offset, actual_len } => {
let chunk_range = self.tree.chunks_for_range(offset, actual_len as u64);
let range_start_byte = chunk_range.start * BODY_SIZE as u64;
let range_end_byte = chunk_range.end * BODY_SIZE as u64;
let bodies = self.collect_bodies(&chunk_range, range_start_byte, range_end_byte)?;
Ok(super::tree::assemble_range(
&self.tree,
offset,
actual_len,
&chunk_range,
&bodies,
))
}
}
}
pub fn read_all(&self) -> Result<Vec<u8>> {
self.read_range(0, self.span as usize)
}
fn collect_bodies(
&self,
chunk_range: &ChunkRange,
range_start_byte: u64,
range_end_byte: u64,
) -> Result<Vec<Bytes>> {
let getter = &self.getter;
let nested: Vec<Vec<Bytes>> = self
.subtrees
.par_iter()
.filter(|st| {
st.byte_offset < range_end_byte && st.byte_offset + st.span > range_start_byte
})
.map(|st| {
let mut bodies = Vec::with_capacity((st.span as usize / BODY_SIZE).max(1));
read_subtree_bodies::<G, M, BODY_SIZE>(getter, st, chunk_range, &mut bodies)?;
Ok(bodies)
})
.collect::<Result<Vec<Vec<Bytes>>>>()?;
Ok(nested.into_iter().flat_map(|v| v.into_iter()).collect())
}
fn read_single_chunk(&self, offset: u64, len: usize) -> Result<Vec<u8>> {
let body = super::mode::read_chunk_body::<M, G, BODY_SIZE>(
&self.getter,
&self.root,
&self.context,
self.span,
)?;
let start = offset as usize;
let end = start + len;
Ok(body[start..end].to_vec())
}
fn fill_buffer(&mut self) -> Result<()> {
let batch_size = rayon::current_num_threads().max(1);
let start_idx = self.subtree_idx;
let end_idx = (start_idx + batch_size).min(self.subtrees.len());
let batch = &self.subtrees[start_idx..end_idx];
if batch.is_empty() {
return Ok(());
}
let batch_start_byte = batch[0].byte_offset;
let last = &batch[batch.len() - 1];
let batch_end_byte = (last.byte_offset + last.span).min(self.span);
let chunk_range = ChunkRange {
start: batch_start_byte / BODY_SIZE as u64,
end: batch_end_byte.div_ceil(BODY_SIZE as u64),
};
let getter = &self.getter;
let all_bodies = batch
.par_iter()
.map(|st| {
let mut bodies = Vec::with_capacity((st.span as usize / BODY_SIZE).max(1));
read_subtree_bodies::<G, M, BODY_SIZE>(getter, st, &chunk_range, &mut bodies)?;
Ok(bodies)
})
.collect::<Result<Vec<Vec<Bytes>>>>()?;
let estimated = (batch_end_byte - batch_start_byte) as usize;
self.buffer.clear();
self.buffer.reserve(estimated);
for bodies in all_bodies {
for body in bodies {
self.buffer.extend_from_slice(&body);
}
}
self.buffer_pos = 0;
self.subtree_idx = end_idx;
if self.read_pos > batch_start_byte {
self.buffer_pos = (self.read_pos - batch_start_byte) as usize;
}
Ok(())
}
fn drain_buffer(&mut self, buf: &mut [u8]) -> usize {
let available = self.buffer.len() - self.buffer_pos;
let to_copy = buf.len().min(available);
buf[..to_copy].copy_from_slice(&self.buffer[self.buffer_pos..self.buffer_pos + to_copy]);
self.buffer_pos += to_copy;
self.read_pos += to_copy as u64;
to_copy
}
}
impl<G, M, const BODY_SIZE: usize> io::Read for GenericSyncJoiner<G, M, BODY_SIZE>
where
G: SyncChunkGet<BODY_SIZE> + Clone + Send + Sync,
M: JoinMode + Send + Sync,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if buf.is_empty() || self.read_pos >= self.span {
return Ok(0);
}
if self.buffer_pos < self.buffer.len() {
return Ok(self.drain_buffer(buf));
}
if self.subtree_idx >= self.subtrees.len() {
return Ok(0);
}
self.fill_buffer().map_err(io::Error::other)?;
if self.buffer.is_empty() {
return Ok(0);
}
Ok(self.drain_buffer(buf))
}
}
impl<G, M, const BODY_SIZE: usize> io::Seek for GenericSyncJoiner<G, M, BODY_SIZE>
where
G: SyncChunkGet<BODY_SIZE> + Clone + Send + Sync,
M: JoinMode + Send + Sync,
{
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
self.read_pos = super::resolve_seek_position(pos, self.read_pos, self.span)?;
self.buffer.clear();
self.buffer_pos = 0;
self.subtree_idx = self
.subtrees
.iter()
.position(|st| st.byte_offset + st.span > self.read_pos)
.unwrap_or(self.subtrees.len());
Ok(self.read_pos)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::chunk::AnyChunk;
use crate::file::sync_split;
use std::collections::HashMap;
use std::io::{Read, Seek};
fn split_and_store(data: &[u8]) -> (ChunkAddress, HashMap<ChunkAddress, AnyChunk>) {
let (root, store) = sync_split::<DEFAULT_BODY_SIZE>(data).unwrap();
(root, store.into_chunks())
}
generate_plain_joiner_tests!(test, SyncJoiner, [], []);
#[test]
fn test_joiner_streaming() {
let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 3 + 500)
.map(|i| (i % 256) as u8)
.collect();
let (root, store) = split_and_store(&data);
let mut joiner = SyncJoiner::new(store, root).unwrap();
let mut result = vec![0u8; data.len()];
joiner.read_exact(&mut result).unwrap();
assert_eq!(result, data);
}
#[test]
fn test_joiner_small_buffer_streaming() {
let refs_per_chunk = DEFAULT_BODY_SIZE / super::super::constants::REF_SIZE;
let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * refs_per_chunk)
.map(|i| (i % 256) as u8)
.collect();
let (root, store) = split_and_store(&data);
let mut joiner = SyncJoiner::new(store, root).unwrap();
let mut result = Vec::new();
let mut buf = [0u8; 100];
loop {
let n = joiner.read(&mut buf).unwrap();
if n == 0 {
break;
}
result.extend_from_slice(&buf[..n]);
}
assert_eq!(result, data);
}
#[test]
fn test_joiner_seek_start() {
let data = b"hello world";
let (root, store) = split_and_store(data);
let mut joiner = SyncJoiner::new(store, root).unwrap();
joiner.seek(SeekFrom::Start(6)).unwrap();
let result = joiner.read_all().unwrap();
assert_eq!(result, data);
joiner.seek(SeekFrom::Start(6)).unwrap();
let mut buf = vec![0u8; 5];
joiner.read_exact(&mut buf).unwrap();
assert_eq!(&buf, b"world");
}
#[test]
fn test_joiner_seek_current() {
let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 3)
.map(|i| (i % 256) as u8)
.collect();
let (root, store) = split_and_store(&data);
let mut joiner = SyncJoiner::new(store, root).unwrap();
let offset = DEFAULT_BODY_SIZE + 100;
joiner.seek(SeekFrom::Start(offset as u64)).unwrap();
assert_eq!(joiner.position(), offset as u64);
let mut buf = vec![0u8; 50];
joiner.read_exact(&mut buf).unwrap();
assert_eq!(&buf, &data[offset..offset + 50]);
joiner.seek(SeekFrom::Current(-50)).unwrap();
let mut buf2 = vec![0u8; 50];
joiner.read_exact(&mut buf2).unwrap();
assert_eq!(buf, buf2);
}
#[test]
fn test_joiner_seek_negative() {
let data = b"test data";
let (root, store) = split_and_store(data);
let mut joiner = SyncJoiner::new(store, root).unwrap();
let result = joiner.seek(SeekFrom::Current(-100));
assert!(result.is_err());
}
#[test]
fn test_joiner_partial_reads() {
let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 2 + 500)
.map(|i| (i % 256) as u8)
.collect();
let (root, store) = split_and_store(&data);
let mut joiner = SyncJoiner::new(store, root).unwrap();
let mut recovered = Vec::new();
let mut buf = [0u8; 100];
loop {
let n = joiner.read(&mut buf).unwrap();
if n == 0 {
break;
}
recovered.extend_from_slice(&buf[..n]);
}
assert_eq!(recovered, data);
}
#[test]
fn test_joiner_read_at_eof() {
let data = b"test data";
let (root, store) = split_and_store(data);
let mut joiner = SyncJoiner::new(store, root).unwrap();
let mut buf = vec![0u8; data.len()];
joiner.read_exact(&mut buf).unwrap();
let mut buf2 = [0u8; 10];
let n = joiner.read(&mut buf2).unwrap();
assert_eq!(n, 0);
}
#[cfg(feature = "encryption")]
mod encrypted {
use super::*;
use crate::chunk::encryption::EncryptedChunkRef;
use crate::file::sync_split_encrypted;
fn encrypted_split_and_store(
data: &[u8],
) -> (EncryptedChunkRef, HashMap<ChunkAddress, AnyChunk>) {
let (root_ref, store) = sync_split_encrypted::<DEFAULT_BODY_SIZE>(data).unwrap();
(root_ref, store.into_chunks())
}
generate_encrypted_joiner_tests!(test, EncryptedSyncJoiner, [], []);
#[test]
fn test_encrypted_joiner_streaming() {
let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 65)
.map(|i| (i % 256) as u8)
.collect();
let (root_ref, store) = encrypted_split_and_store(&data);
let mut joiner = EncryptedSyncJoiner::new(store, root_ref).unwrap();
let mut result = vec![0u8; data.len()];
joiner.read_exact(&mut result).unwrap();
assert_eq!(result, data);
}
#[test]
fn test_encrypted_joiner_small_buffer_streaming() {
let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 128)
.map(|i| (i % 256) as u8)
.collect();
let (root_ref, store) = encrypted_split_and_store(&data);
let mut joiner = EncryptedSyncJoiner::new(store, root_ref).unwrap();
let mut result = Vec::new();
let mut buf = [0u8; 100];
loop {
let n = joiner.read(&mut buf).unwrap();
if n == 0 {
break;
}
result.extend_from_slice(&buf[..n]);
}
assert_eq!(result, data);
}
#[test]
fn test_encrypted_joiner_seek_back_and_forth() {
let data: Vec<u8> = (0..DEFAULT_BODY_SIZE * 3)
.map(|i| (i % 256) as u8)
.collect();
let (root_ref, store) = encrypted_split_and_store(&data);
let mut joiner = EncryptedSyncJoiner::new(store, root_ref).unwrap();
joiner
.seek(SeekFrom::Start(DEFAULT_BODY_SIZE as u64))
.unwrap();
let mut buf1 = vec![0u8; 100];
joiner.read_exact(&mut buf1).unwrap();
assert_eq!(&buf1, &data[DEFAULT_BODY_SIZE..DEFAULT_BODY_SIZE + 100]);
joiner.seek(SeekFrom::Start(0)).unwrap();
let mut buf2 = vec![0u8; 100];
joiner.read_exact(&mut buf2).unwrap();
assert_eq!(&buf2, &data[..100]);
}
}
}