use crate::{
decrypt::decrypt_chunk, get_root_data_map_parallel, utils::extract_hashes, ChunkInfo, DataMap,
Result, STREAM_DECRYPT_BATCH_SIZE,
};
use bytes::Bytes;
use std::collections::HashMap;
use std::ops::Range;
use xor_name::XorName;
pub struct DecryptionStream<F> {
chunk_infos: Vec<ChunkInfo>,
src_hashes: Vec<XorName>,
child_level: usize,
get_chunk_parallel: F,
current_batch_start: usize,
current_batch_chunks: Vec<Bytes>,
current_batch_index: usize,
}
impl<F> DecryptionStream<F>
where
F: Fn(&[(usize, XorName)]) -> Result<Vec<(usize, Bytes)>>,
{
pub fn new(data_map: &DataMap, get_chunk_parallel: F) -> Result<Self> {
let root_map = if data_map.is_child() {
get_root_data_map_parallel(data_map.clone(), &get_chunk_parallel)?
} else {
data_map.clone()
};
let child_level = root_map.child().unwrap_or(0);
let mut chunk_infos = root_map.infos().to_vec();
chunk_infos.sort_by_key(|info| info.index);
let src_hashes = extract_hashes(&root_map);
Ok(Self {
chunk_infos,
src_hashes,
child_level,
get_chunk_parallel,
current_batch_start: 0,
current_batch_chunks: Vec::new(),
current_batch_index: 0,
})
}
fn fetch_next_batch(&mut self) -> Result<bool> {
if self.current_batch_start >= self.chunk_infos.len() {
return Ok(false); }
let batch_end =
(self.current_batch_start + *STREAM_DECRYPT_BATCH_SIZE).min(self.chunk_infos.len());
let batch_infos = self
.chunk_infos
.get(self.current_batch_start..batch_end)
.ok_or_else(|| {
crate::Error::Generic(format!(
"batch range {}..{} out of bounds for chunk_infos (len {})",
self.current_batch_start,
batch_end,
self.chunk_infos.len()
))
})?;
let batch_hashes: Vec<_> = batch_infos
.iter()
.map(|info| (info.index, info.dst_hash))
.collect();
let mut fetched_chunks = (self.get_chunk_parallel)(&batch_hashes)?;
fetched_chunks.sort_by_key(|(index, _content)| *index);
self.current_batch_chunks.clear();
for (info, (_index, encrypted_content)) in
batch_infos.iter().zip(fetched_chunks.into_iter())
{
let decrypted_chunk = decrypt_chunk(
info.index,
&encrypted_content,
&self.src_hashes,
self.child_level,
)?;
self.current_batch_chunks.push(decrypted_chunk);
}
self.current_batch_start = batch_end;
self.current_batch_index = 0;
Ok(true)
}
pub fn file_size(&self) -> usize {
self.chunk_infos
.iter()
.fold(0, |acc, chunk| acc + chunk.src_size)
}
pub fn get_range(&self, start: usize, len: usize) -> Result<Bytes> {
let file_size = self.file_size();
if start >= file_size {
return Ok(Bytes::new());
}
let end_pos = std::cmp::min(start + len, file_size);
let actual_len = end_pos - start;
if actual_len == 0 {
return Ok(Bytes::new());
}
let start_chunk = self.get_chunk_index_from_infos(start);
let end_chunk = self.get_chunk_index_from_infos(end_pos.saturating_sub(1));
let mut required_hashes = Vec::new();
for chunk_info in &self.chunk_infos {
if chunk_info.index >= start_chunk && chunk_info.index <= end_chunk {
required_hashes.push((chunk_info.index, chunk_info.dst_hash));
}
}
required_hashes.sort_by_key(|(index, _)| *index);
let fetched_chunks = (self.get_chunk_parallel)(&required_hashes)?;
let chunk_map: HashMap<usize, Bytes> = fetched_chunks.into_iter().collect();
let mut all_bytes = Vec::new();
for chunk_index in start_chunk..=end_chunk {
if let Some(encrypted_content) = chunk_map.get(&chunk_index) {
let decrypted = decrypt_chunk(
chunk_index,
encrypted_content,
&self.src_hashes,
self.child_level,
)?;
all_bytes.extend_from_slice(&decrypted);
}
}
let bytes = Bytes::from(all_bytes);
let start_chunk_pos = self.get_chunk_start_position(start_chunk);
let internal_offset = start - start_chunk_pos;
if internal_offset >= bytes.len() {
return Ok(Bytes::new());
}
let available_len = bytes.len() - internal_offset;
let range_len = std::cmp::min(actual_len, available_len);
let result = bytes.slice(internal_offset..internal_offset + range_len);
Ok(result)
}
fn get_chunk_start_position(&self, chunk_index: usize) -> usize {
self.chunk_infos
.iter()
.filter(|info| info.index < chunk_index)
.fold(0, |acc, chunk| acc + chunk.src_size)
}
fn get_chunk_index_from_infos(&self, position: usize) -> usize {
let mut accumulated_size = 0;
for chunk_info in &self.chunk_infos {
if position >= accumulated_size && position < accumulated_size + chunk_info.src_size {
return chunk_info.index;
}
accumulated_size += chunk_info.src_size;
}
if let Some(last_chunk) = self.chunk_infos.last() {
last_chunk.index
} else {
0 }
}
}
impl<F> Iterator for DecryptionStream<F>
where
F: Fn(&[(usize, XorName)]) -> Result<Vec<(usize, Bytes)>>,
{
type Item = Result<Bytes>;
fn next(&mut self) -> Option<Self::Item> {
if self.current_batch_index >= self.current_batch_chunks.len() {
match self.fetch_next_batch() {
Ok(has_more) => {
if !has_more {
return None; }
}
Err(e) => return Some(Err(e)),
}
}
match self.current_batch_chunks.get(self.current_batch_index) {
Some(chunk) => {
let chunk = chunk.clone();
self.current_batch_index += 1;
Some(Ok(chunk))
}
None => None,
}
}
}
impl<F> DecryptionStream<F>
where
F: Fn(&[(usize, XorName)]) -> Result<Vec<(usize, Bytes)>>,
{
pub fn range(&self, range: Range<usize>) -> Result<Bytes> {
let len = range.end.saturating_sub(range.start);
self.get_range(range.start, len)
}
pub fn range_from(&self, start: usize) -> Result<Bytes> {
let file_size = self.file_size();
let len = file_size.saturating_sub(start);
self.get_range(start, len)
}
pub fn range_to(&self, end: usize) -> Result<Bytes> {
self.get_range(0, end)
}
pub fn range_full(&self) -> Result<Bytes> {
let file_size = self.file_size();
self.get_range(0, file_size)
}
pub fn range_inclusive(&self, start: usize, end: usize) -> Result<Bytes> {
let len = end.saturating_sub(start) + 1; self.get_range(start, len)
}
}
pub fn streaming_decrypt<F>(
data_map: &DataMap,
get_chunk_parallel: F,
) -> Result<DecryptionStream<F>>
where
F: Fn(&[(usize, XorName)]) -> Result<Vec<(usize, Bytes)>>,
{
DecryptionStream::new(data_map, get_chunk_parallel)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{encrypt, test_helpers::random_bytes, Error};
use std::collections::HashMap;
#[test]
fn test_streaming_decrypt_basic() -> Result<()> {
let original_data = random_bytes(50_000); let (data_map, encrypted_chunks) = encrypt(original_data.clone())?;
let mut storage = HashMap::new();
for chunk in encrypted_chunks {
let hash = crate::hash::content_hash(&chunk.content);
let _ = storage.insert(hash, chunk.content.to_vec());
}
let get_chunks = |hashes: &[(usize, XorName)]| -> Result<Vec<(usize, Bytes)>> {
let mut results = Vec::new();
for &(index, hash) in hashes {
if let Some(data) = storage.get(&hash) {
results.push((index, Bytes::from(data.clone())));
} else {
return Err(Error::Generic(format!(
"Chunk not found: {}",
hex::encode(hash)
)));
}
}
Ok(results)
};
let stream = streaming_decrypt(&data_map, get_chunks)?;
let mut decrypted_data = Vec::new();
for chunk_result in stream {
let chunk = chunk_result?;
decrypted_data.extend_from_slice(&chunk);
}
assert_eq!(decrypted_data, original_data.to_vec());
Ok(())
}
#[test]
fn test_streaming_decrypt_large_file() -> Result<()> {
let original_data = random_bytes(5_000_000); let (data_map, encrypted_chunks) = encrypt(original_data.clone())?;
let mut storage = HashMap::new();
for chunk in encrypted_chunks {
let hash = crate::hash::content_hash(&chunk.content);
let _ = storage.insert(hash, chunk.content.to_vec());
}
let get_chunks = |hashes: &[(usize, XorName)]| -> Result<Vec<(usize, Bytes)>> {
let mut results = Vec::new();
for &(index, hash) in hashes {
if let Some(data) = storage.get(&hash) {
results.push((index, Bytes::from(data.clone())));
} else {
return Err(Error::Generic(format!(
"Chunk not found: {}",
hex::encode(hash)
)));
}
}
Ok(results)
};
let stream = streaming_decrypt(&data_map, get_chunks)?;
let mut decrypted_data = Vec::new();
let mut chunk_count = 0;
for chunk_result in stream {
let chunk = chunk_result?;
decrypted_data.extend_from_slice(&chunk);
chunk_count += 1;
}
assert_eq!(decrypted_data, original_data.to_vec());
assert!(chunk_count > 1, "Should have processed multiple chunks");
Ok(())
}
#[test]
fn test_streaming_decrypt_error_handling() -> Result<()> {
let original_data = random_bytes(10_000);
let (data_map, encrypted_chunks) = encrypt(original_data)?;
let mut storage = HashMap::new();
for (i, chunk) in encrypted_chunks.iter().enumerate() {
if i < encrypted_chunks.len() - 1 {
let hash = crate::hash::content_hash(&chunk.content);
let _ = storage.insert(hash, chunk.content.to_vec());
}
}
let get_chunks = |hashes: &[(usize, XorName)]| -> Result<Vec<(usize, Bytes)>> {
let mut results = Vec::new();
for &(index, hash) in hashes {
if let Some(data) = storage.get(&hash) {
results.push((index, Bytes::from(data.clone())));
} else {
return Err(Error::Generic(format!(
"Chunk not found: {}",
hex::encode(hash)
)));
}
}
Ok(results)
};
let stream = streaming_decrypt(&data_map, get_chunks)?;
let mut found_error = false;
for chunk_result in stream {
match chunk_result {
Ok(_chunk) => {
}
Err(_e) => {
found_error = true;
break;
}
}
}
assert!(
found_error,
"Should have encountered an error for missing chunk"
);
Ok(())
}
#[test]
fn test_random_access_basic() -> Result<()> {
let original_data = random_bytes(10_000);
let (data_map, encrypted_chunks) = encrypt(original_data.clone())?;
let mut storage = HashMap::new();
for chunk in encrypted_chunks {
let hash = crate::hash::content_hash(&chunk.content);
let _ = storage.insert(hash, chunk.content.to_vec());
}
let get_chunks = |hashes: &[(usize, XorName)]| -> Result<Vec<(usize, Bytes)>> {
let mut results = Vec::new();
for &(index, hash) in hashes {
if let Some(data) = storage.get(&hash) {
results.push((index, Bytes::from(data.clone())));
} else {
return Err(Error::Generic(format!(
"Chunk not found: {}",
hex::encode(hash)
)));
}
}
Ok(results)
};
let stream = streaming_decrypt(&data_map, get_chunks)?;
let range_start = 1000;
let range_len = 500;
let range_data = stream.get_range(range_start, range_len)?;
assert_eq!(range_data.len(), range_len);
assert_eq!(
range_data.as_ref(),
&original_data[range_start..range_start + range_len]
);
Ok(())
}
#[test]
fn test_random_access_convenience_methods() -> Result<()> {
let original_data = random_bytes(5_000);
let (data_map, encrypted_chunks) = encrypt(original_data.clone())?;
let mut storage = HashMap::new();
for chunk in encrypted_chunks {
let hash = crate::hash::content_hash(&chunk.content);
let _ = storage.insert(hash, chunk.content.to_vec());
}
let get_chunks = |hashes: &[(usize, XorName)]| -> Result<Vec<(usize, Bytes)>> {
let mut results = Vec::new();
for &(index, hash) in hashes {
if let Some(data) = storage.get(&hash) {
results.push((index, Bytes::from(data.clone())));
}
}
Ok(results)
};
let stream = streaming_decrypt(&data_map, get_chunks)?;
let range_data = stream.range(1000..2000)?;
assert_eq!(range_data.as_ref(), &original_data[1000..2000]);
let from_data = stream.range_from(3000)?;
assert_eq!(from_data.as_ref(), &original_data[3000..]);
let to_data = stream.range_to(1500)?;
assert_eq!(to_data.as_ref(), &original_data[..1500]);
let full_data = stream.range_full()?;
assert_eq!(full_data.as_ref(), &original_data[..]);
let inclusive_data = stream.range_inclusive(500, 999)?;
assert_eq!(inclusive_data.as_ref(), &original_data[500..1000]);
Ok(())
}
#[test]
fn test_random_access_edge_cases() -> Result<()> {
let original_data = random_bytes(1_000);
let (data_map, encrypted_chunks) = encrypt(original_data.clone())?;
let mut storage = HashMap::new();
for chunk in encrypted_chunks {
let hash = crate::hash::content_hash(&chunk.content);
let _ = storage.insert(hash, chunk.content.to_vec());
}
let get_chunks = |hashes: &[(usize, XorName)]| -> Result<Vec<(usize, Bytes)>> {
let mut results = Vec::new();
for &(index, hash) in hashes {
if let Some(data) = storage.get(&hash) {
results.push((index, Bytes::from(data.clone())));
}
}
Ok(results)
};
let stream = streaming_decrypt(&data_map, get_chunks)?;
let beyond_range = stream.get_range(2000, 500)?;
assert_eq!(beyond_range.len(), 0);
let at_end = stream.get_range(1000, 100)?;
assert_eq!(at_end.len(), 0);
let partial_exceed = stream.get_range(950, 100)?;
assert_eq!(partial_exceed.len(), 50); assert_eq!(partial_exceed.as_ref(), &original_data[950..]);
let zero_len = stream.get_range(500, 0)?;
assert_eq!(zero_len.len(), 0);
let at_start = stream.get_range(0, 100)?;
assert_eq!(at_start.as_ref(), &original_data[0..100]);
Ok(())
}
#[test]
fn test_random_access_chunk_boundaries() -> Result<()> {
let original_data = random_bytes(5_000_000); let (data_map, encrypted_chunks) = encrypt(original_data.clone())?;
let mut storage = HashMap::new();
for chunk in encrypted_chunks {
let hash = crate::hash::content_hash(&chunk.content);
let _ = storage.insert(hash, chunk.content.to_vec());
}
let get_chunks = |hashes: &[(usize, XorName)]| -> Result<Vec<(usize, Bytes)>> {
let mut results = Vec::new();
for &(index, hash) in hashes {
if let Some(data) = storage.get(&hash) {
results.push((index, Bytes::from(data.clone())));
}
}
Ok(results)
};
let stream = streaming_decrypt(&data_map, get_chunks)?;
let cross_boundary = stream.get_range(1_000_000, 2_000_000)?; assert_eq!(cross_boundary.len(), 2_000_000);
assert_eq!(
cross_boundary.as_ref(),
&original_data[1_000_000..3_000_000]
);
let within_chunk = stream.get_range(500_000, 1000)?;
assert_eq!(within_chunk.len(), 1000);
assert_eq!(within_chunk.as_ref(), &original_data[500_000..501_000]);
Ok(())
}
#[test]
fn test_random_access_file_size() -> Result<()> {
let original_data = random_bytes(1234); let (data_map, encrypted_chunks) = encrypt(original_data.clone())?;
let mut storage = HashMap::new();
for chunk in encrypted_chunks {
let hash = crate::hash::content_hash(&chunk.content);
let _ = storage.insert(hash, chunk.content.to_vec());
}
let get_chunks = |hashes: &[(usize, XorName)]| -> Result<Vec<(usize, Bytes)>> {
let mut results = Vec::new();
for &(index, hash) in hashes {
if let Some(data) = storage.get(&hash) {
results.push((index, Bytes::from(data.clone())));
}
}
Ok(results)
};
let stream = streaming_decrypt(&data_map, get_chunks)?;
assert_eq!(stream.file_size(), 1234);
let full_file = stream.get_range(0, 1234)?;
assert_eq!(full_file.len(), 1234);
assert_eq!(full_file.as_ref(), &original_data[..]);
Ok(())
}
#[test]
fn test_random_access_vs_sequential() -> Result<()> {
let original_data = random_bytes(100_000);
let (data_map, encrypted_chunks) = encrypt(original_data.clone())?;
let mut storage = HashMap::new();
for chunk in encrypted_chunks {
let hash = crate::hash::content_hash(&chunk.content);
let _ = storage.insert(hash, chunk.content.to_vec());
}
let get_chunks = |hashes: &[(usize, XorName)]| -> Result<Vec<(usize, Bytes)>> {
let mut results = Vec::new();
for &(index, hash) in hashes {
if let Some(data) = storage.get(&hash) {
results.push((index, Bytes::from(data.clone())));
}
}
Ok(results)
};
let stream = streaming_decrypt(&data_map, &get_chunks)?;
let random_access_data = stream.range_full()?;
let stream2 = streaming_decrypt(&data_map, get_chunks)?;
let mut sequential_data = Vec::new();
for chunk_result in stream2 {
sequential_data.extend_from_slice(&chunk_result?);
}
assert_eq!(random_access_data.as_ref(), &original_data[..]);
assert_eq!(sequential_data, original_data.to_vec());
assert_eq!(random_access_data.as_ref(), &sequential_data[..]);
Ok(())
}
#[test]
fn test_chunk_boundary_underflow_reproduction() -> Result<()> {
let file_size = 16404310194u64 as usize; let start_position = 4194304u64 as usize;
let max_chunk_size = crate::MAX_CHUNK_SIZE * 2;
println!("Testing with file_size: {file_size}, start_position: {start_position}");
let num_chunks = crate::utils::get_num_chunks_with_variable_max(file_size, max_chunk_size);
println!("Total number of chunks: {num_chunks}");
let mut chunk_infos = Vec::new();
let mut accumulated_size = 0;
for chunk_index in 0..num_chunks {
let chunk_size = crate::utils::get_chunk_size_with_variable_max(
file_size,
chunk_index,
max_chunk_size,
);
let chunk_info = ChunkInfo {
index: chunk_index,
dst_hash: crate::hash::content_hash(&[chunk_index as u8]), src_hash: crate::hash::content_hash(&[(chunk_index + 1) as u8]), src_size: chunk_size,
};
chunk_infos.push(chunk_info);
accumulated_size += chunk_size;
}
assert_eq!(
accumulated_size, file_size,
"Mock data map total size should match file size"
);
let data_map = DataMap::new(chunk_infos);
let get_chunk_parallel =
|_hashes: &[(usize, XorName)]| -> Result<Vec<(usize, Bytes)>> { Ok(Vec::new()) };
let mock_stream = DecryptionStream {
chunk_infos: data_map.infos().to_vec(),
src_hashes: vec![crate::hash::content_hash(&[0u8]); num_chunks], child_level: 0,
get_chunk_parallel,
current_batch_start: 0,
current_batch_chunks: Vec::new(),
current_batch_index: 0,
};
let start_chunk_index = mock_stream.get_chunk_index_from_infos(start_position);
println!(
"Calculated start_chunk_index using get_chunk_index_from_infos: {start_chunk_index}"
);
let start_chunk_pos = mock_stream.get_chunk_start_position(start_chunk_index);
println!("start_chunk_pos: {start_chunk_pos}");
println!("start_position: {start_position}");
if start_chunk_pos <= start_position {
println!("✓ start_chunk_pos <= start_position (as expected)");
if start_chunk_pos < start_position {
let diff = start_position - start_chunk_pos;
println!("Difference: {diff}");
assert!(
diff <= max_chunk_size,
"Difference {} should be less than {}, but got {}",
diff,
max_chunk_size,
diff
);
println!("✓ Difference {diff} is less than max_chunk_size");
} else {
println!("start_chunk_pos exactly equals start_position");
}
} else {
let would_underflow = start_chunk_pos - start_position;
panic!(
"❌ start_chunk_pos ({}) > start_position ({}) by {}, this would cause underflow!",
start_chunk_pos, start_position, would_underflow
);
}
let internal_offset = start_position - start_chunk_pos;
println!("Calculated internal_offset: {internal_offset}");
let chunk_size = mock_stream
.chunk_infos
.iter()
.find(|info| info.index == start_chunk_index)
.map(|info| info.src_size)
.unwrap_or(0);
println!("Chunk {start_chunk_index} size: {chunk_size}");
assert!(
internal_offset < chunk_size,
"internal_offset {} should be less than chunk size {}",
internal_offset,
chunk_size
);
println!("✓ Test passed: No underflow condition detected");
Ok(())
}
}