use std::fs::OpenOptions;
use std::io::{Seek, SeekFrom, Write};
use std::path::Path;
use std::sync::Arc;
use std::sync::atomic::Ordering;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info};
use xet_client::cas_client::Client;
use xet_client::cas_types::FileRange;
use xet_client::chunk_cache::ChunkCache;
use xet_core_structures::merklehash::MerkleHash;
use xet_runtime::config::ReconstructionConfig;
use xet_runtime::core::{XetRuntime, xet_config};
use xet_runtime::utils::ClosureGuard;
use xet_runtime::utils::adjustable_semaphore::AdjustableSemaphore;
use super::data_writer::{DataWriter, DownloadStream, SequentialWriter, UnorderedDownloadStream};
use super::error::{FileReconstructionError, Result};
use super::reconstruction_terms::ReconstructionTermManager;
use super::run_state::{RunError, RunState};
use crate::progress_tracking::ItemProgressUpdater;
pub struct FileReconstructor {
client: Arc<dyn Client>,
file_hash: MerkleHash,
byte_range: Option<FileRange>,
progress_updater: Option<Arc<ItemProgressUpdater>>,
config: Arc<ReconstructionConfig>,
chunk_cache: Option<Arc<dyn ChunkCache>>,
custom_buffer_semaphore: Option<Arc<AdjustableSemaphore>>,
cancellation_token: CancellationToken,
}
impl FileReconstructor {
pub fn new(client: &Arc<dyn Client>, file_hash: MerkleHash) -> Self {
Self {
client: client.clone(),
file_hash,
byte_range: None,
progress_updater: default_progress_updater(),
config: Arc::new(xet_config().reconstruction.clone()),
chunk_cache: None,
custom_buffer_semaphore: None,
cancellation_token: CancellationToken::new(),
}
}
pub fn with_byte_range(self, byte_range: FileRange) -> Self {
Self {
byte_range: Some(byte_range),
..self
}
}
pub fn with_progress_updater(self, progress_updater: Arc<ItemProgressUpdater>) -> Self {
Self {
progress_updater: Some(progress_updater),
..self
}
}
pub fn with_chunk_cache(self, cache: Arc<dyn ChunkCache>) -> Self {
Self {
chunk_cache: Some(cache),
..self
}
}
pub fn with_config(self, config: impl AsRef<ReconstructionConfig>) -> Self {
Self {
config: Arc::new(config.as_ref().clone()),
..self
}
}
pub fn with_buffer_semaphore(self, semaphore: Arc<AdjustableSemaphore>) -> Self {
Self {
custom_buffer_semaphore: Some(semaphore),
..self
}
}
pub fn with_cancellation_token(self, token: CancellationToken) -> Self {
Self {
cancellation_token: token,
..self
}
}
pub async fn reconstruct_to_file(self, path: &Path, write_offset: Option<u64>, truncate_file: bool) -> Result<u64> {
info!(
file_hash = %self.file_hash,
byte_range = ?self.byte_range,
path = %path.display(),
write_offset = ?write_offset,
"Reconstructing file to disk"
);
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let mut file = OpenOptions::new().write(true).create(true).truncate(truncate_file).open(path)?;
let default_write_position = self.byte_range.map_or(0, |r| r.start);
let seek_position = write_offset.unwrap_or(default_write_position);
if seek_position > 0 {
file.seek(SeekFrom::Start(seek_position))?;
}
let run_state = RunState::new(self.cancellation_token.clone(), self.file_hash, self.progress_updater.clone());
let data_writer = SequentialWriter::new(file, self.config.use_vectored_write, run_state.clone());
self.run(data_writer, run_state, false).await
}
pub async fn reconstruct_to_writer<W: Write + Send + 'static>(self, writer: W) -> Result<u64> {
info!(
file_hash = %self.file_hash,
byte_range = ?self.byte_range,
"Reconstructing file to writer"
);
let run_state = RunState::new(self.cancellation_token.clone(), self.file_hash, self.progress_updater.clone());
let data_writer = SequentialWriter::new(writer, self.config.use_vectored_write, run_state.clone());
self.run(data_writer, run_state, false).await
}
pub fn reconstruct_to_stream(self) -> DownloadStream {
let run_state = RunState::new(self.cancellation_token.clone(), self.file_hash, self.progress_updater.clone());
DownloadStream::new(self, run_state)
}
pub fn reconstruct_to_unordered_stream(self) -> UnorderedDownloadStream {
let run_state = RunState::new(self.cancellation_token.clone(), self.file_hash, self.progress_updater.clone());
UnorderedDownloadStream::new(self, run_state)
}
pub(crate) async fn run(
self,
data_writer: Box<dyn DataWriter>,
run_state: Arc<RunState>,
is_streaming: bool,
) -> Result<u64> {
match self.run_impl(data_writer, &run_state, is_streaming).await {
Ok(v) => Ok(v),
Err(RunError::Cancelled) => {
run_state.check_error()?;
Ok(0)
},
Err(RunError::Error(e)) => {
run_state.set_error(e.clone());
Err(e)
},
}
}
async fn run_impl(
self,
mut data_writer: Box<dyn DataWriter>,
run_state: &RunState,
_is_streaming: bool,
) -> std::result::Result<u64, RunError> {
let Self {
client,
byte_range,
config,
chunk_cache,
custom_buffer_semaphore,
..
} = self;
run_state.check_run_state()?;
let file_hash = *run_state.file_hash();
let requested_range = byte_range.unwrap_or_else(FileRange::full);
let mut term_manager = ReconstructionTermManager::new(
config.clone(),
client.clone(),
file_hash,
requested_range,
run_state.progress_updater().cloned(),
)
.await?;
let using_global_memory_limit = custom_buffer_semaphore.is_none();
let download_buffer_semaphore = custom_buffer_semaphore
.unwrap_or_else(|| XetRuntime::current().common().reconstruction_download_buffer.clone());
let mut seed_buffer_permit;
let _download_count_decrement_guard;
if using_global_memory_limit {
let active_downloads = XetRuntime::current().common().active_downloads.clone();
let n = active_downloads.fetch_add(1, Ordering::Relaxed) + 1;
let base = config.download_buffer_size.as_u64();
let perfile = config.download_buffer_perfile_size.as_u64();
let limit = config.download_buffer_limit.as_u64();
let target = base.saturating_add(n.saturating_mul(perfile)).min(limit);
seed_buffer_permit = download_buffer_semaphore.increment_permits_to_target(target);
let buffer_sem = download_buffer_semaphore.clone();
_download_count_decrement_guard = Some(ClosureGuard::new(move || {
let n = active_downloads.fetch_sub(1, Ordering::Relaxed).saturating_sub(1);
let target = base.saturating_add(n.saturating_mul(perfile)).min(limit);
buffer_sem.decrement_permits_to_target(target);
}));
} else {
seed_buffer_permit = None;
_download_count_decrement_guard = None;
}
let range_start_offset = requested_range.start;
loop {
let maybe_file_terms = tokio::select! {
biased;
_ = run_state.cancelled() => {
return run_state.check_run_state().map(|_| 0);
}
result = term_manager.next_file_terms() => result?
};
let Some(file_terms) = maybe_file_terms else {
break;
};
run_state.check_run_state()?;
run_state.record_new_block();
for file_term in file_terms {
run_state.check_run_state()?;
let term_size = file_term.byte_range.end - file_term.byte_range.start;
debug!(
file_hash = %file_hash,
xorb_hash = %file_term.xorb_block.xorb_hash,
term_byte_range = ?(file_term.byte_range.start, file_term.byte_range.end),
term_size,
"Processing file term"
);
let buffer_permit = match seed_buffer_permit.as_mut().and_then(|rp| rp.split(term_size)) {
Some(split) => split,
None => {
seed_buffer_permit = None;
tokio::select! {
biased;
_ = run_state.cancelled() => {
return run_state.check_run_state().map(|_| 0);
}
result = download_buffer_semaphore.acquire_many(term_size) => {
result.map_err(|e| {
FileReconstructionError::InternalError(format!(
"Error acquiring download buffer permit: {e}"
))
})?
}
}
},
};
let data_future = file_term
.get_data_task(client.clone(), run_state.progress_updater().cloned(), chunk_cache.clone())
.await?;
#[cfg(debug_assertions)]
{
let refs = &file_term.xorb_block.references;
assert!(refs.iter().any(|r| r.term_chunks == file_term.xorb_chunk_range));
}
let relative_byte_range = FileRange::new(
file_term.byte_range.start - range_start_offset,
file_term.byte_range.end - range_start_offset,
);
data_writer
.set_next_term_data_source(relative_byte_range, Some(buffer_permit), data_future)
.await?;
run_state.record_new_term(term_size);
}
}
run_state.log_progress("All term blocks received and scheduled for writing");
let bytes_written = data_writer.finish().await?;
let total_bytes_scheduled = run_state.total_bytes_scheduled();
debug_assert_eq!(
bytes_written, total_bytes_scheduled,
"Bytes written ({bytes_written}) should match total bytes scheduled ({total_bytes_scheduled})"
);
run_state.log_progress("File reconstruction completed successfully");
#[cfg(debug_assertions)]
if !_is_streaming && let Some(updater) = run_state.progress_updater() {
updater.assert_complete();
if let Some(byte_range) = byte_range
&& byte_range.end < u64::MAX
{
assert_eq!(updater.total_bytes_completed(), byte_range.end - byte_range.start);
}
}
Ok(total_bytes_scheduled)
}
}
#[cfg(test)]
fn default_progress_updater() -> Option<Arc<ItemProgressUpdater>> {
Some(ItemProgressUpdater::new_standalone("test"))
}
#[cfg(not(test))]
fn default_progress_updater() -> Option<Arc<ItemProgressUpdater>> {
None
}
#[cfg(test)]
mod tests {
use std::io::{Cursor, Write};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use xet_client::cas_client::{ClientTestingUtils, DirectAccessClient, LocalClient, RandomFileContents};
use xet_client::cas_types::FileRange;
use xet_runtime::core::XetRuntime;
use super::*;
use crate::progress_tracking::ItemProgressUpdater;
const TEST_CHUNK_SIZE: usize = 101;
fn test_config() -> ReconstructionConfig {
let mut config = ReconstructionConfig::default();
config.min_reconstruction_fetch_size = xet_runtime::utils::ByteSize::from("100");
config.max_reconstruction_fetch_size = xet_runtime::utils::ByteSize::from("400");
config.min_prefetch_buffer = xet_runtime::utils::ByteSize::from("800");
config
}
async fn setup_test_file(term_spec: &[(u64, (u64, u64))]) -> (Arc<LocalClient>, RandomFileContents) {
let client = LocalClient::temporary().await.unwrap();
let file_contents = client.upload_random_file(term_spec, TEST_CHUNK_SIZE).await.unwrap();
(client, file_contents)
}
async fn reconstruct_to_vec(
client: &Arc<LocalClient>,
file_hash: MerkleHash,
byte_range: Option<FileRange>,
config: &ReconstructionConfig,
semaphore: Option<Arc<AdjustableSemaphore>>,
) -> Result<Vec<u8>> {
let buffer = Arc::new(std::sync::Mutex::new(Cursor::new(Vec::new())));
let writer = StaticCursorWriter(buffer.clone());
let mut reconstructor =
FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_hash).with_config(config);
if let Some(range) = byte_range {
reconstructor = reconstructor.with_byte_range(range);
}
if let Some(sem) = semaphore {
reconstructor = reconstructor.with_buffer_semaphore(sem);
}
reconstructor.reconstruct_to_writer(writer).await?;
let data = buffer.lock().unwrap().get_ref().clone();
Ok(data)
}
async fn reconstruct_to_file(
client: &Arc<LocalClient>,
file_hash: MerkleHash,
byte_range: Option<FileRange>,
config: &ReconstructionConfig,
) -> Result<Vec<u8>> {
let temp_dir = tempfile::tempdir().unwrap();
let file_path = temp_dir.path().join("output.bin");
let mut reconstructor =
FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_hash).with_config(config);
if let Some(range) = byte_range {
reconstructor = reconstructor.with_byte_range(range);
}
reconstructor.reconstruct_to_file(&file_path, None, false).await?;
let file_data = std::fs::read(&file_path)?;
let start = byte_range.map(|r| r.start as usize).unwrap_or(0);
Ok(file_data[start..].to_vec())
}
async fn reconstruct_to_file_at_specific_offset(
client: &Arc<LocalClient>,
file_hash: MerkleHash,
byte_range: Option<FileRange>,
config: &ReconstructionConfig,
) -> Result<Vec<u8>> {
let offset = 9u64;
let temp_dir = tempfile::tempdir().unwrap();
let file_path = temp_dir.path().join("output.bin");
let mut reconstructor =
FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_hash).with_config(config);
if let Some(range) = byte_range {
reconstructor = reconstructor.with_byte_range(range);
}
reconstructor.reconstruct_to_file(&file_path, Some(offset), false).await?;
let file_data = std::fs::read(&file_path)?;
Ok(file_data[offset as usize..].to_vec())
}
async fn reconstruct_to_file_at_offset_zero(
client: &Arc<LocalClient>,
file_hash: MerkleHash,
byte_range: Option<FileRange>,
config: &ReconstructionConfig,
) -> Result<Vec<u8>> {
let temp_dir = tempfile::tempdir().unwrap();
let file_path = temp_dir.path().join("output.bin");
let mut reconstructor =
FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_hash).with_config(config);
if let Some(range) = byte_range {
reconstructor = reconstructor.with_byte_range(range);
}
reconstructor.reconstruct_to_file(&file_path, Some(0), false).await?;
let file_data = std::fs::read(&file_path)?;
Ok(file_data)
}
struct StaticCursorWriter(Arc<std::sync::Mutex<Cursor<Vec<u8>>>>);
impl std::io::Write for StaticCursorWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.0.lock().unwrap().write(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
self.0.lock().unwrap().flush()
}
}
async fn reconstruct_and_verify_full(
client: &Arc<LocalClient>,
file_contents: &RandomFileContents,
base_config: ReconstructionConfig,
) {
let expected = &file_contents.data;
let h = file_contents.file_hash;
for use_vectored in [false, true] {
let mut config = base_config.clone();
config.use_vectored_write = use_vectored;
let vec_result = reconstruct_to_vec(client, h, None, &config, None).await.unwrap();
assert_eq!(vec_result, *expected, "vec failed (vectored={use_vectored})");
let file_result = reconstruct_to_file(client, h, None, &config).await.unwrap();
assert_eq!(file_result, *expected, "file failed (vectored={use_vectored})");
let file_offset_result = reconstruct_to_file_at_offset_zero(client, h, None, &config).await.unwrap();
assert_eq!(file_offset_result, *expected, "file_at_offset_zero failed (vectored={use_vectored})");
let file_specific_result = reconstruct_to_file_at_specific_offset(client, h, None, &config).await.unwrap();
assert_eq!(file_specific_result, *expected, "file_at_specific_offset failed (vectored={use_vectored})");
}
}
async fn reconstruct_and_verify_range(
client: &Arc<LocalClient>,
file_contents: &RandomFileContents,
range: FileRange,
base_config: ReconstructionConfig,
) {
let expected = &file_contents.data[range.start as usize..range.end as usize];
for use_vectored in [false, true] {
let mut config = base_config.clone();
config.use_vectored_write = use_vectored;
let vec_result = reconstruct_to_vec(client, file_contents.file_hash, Some(range), &config, None)
.await
.expect("reconstruct_to_vec should succeed");
assert_eq!(vec_result, expected, "vec failed (vectored={use_vectored})");
let file_result = reconstruct_to_file(client, file_contents.file_hash, Some(range), &config)
.await
.expect("reconstruct_to_file should succeed");
assert_eq!(file_result, expected, "file failed (vectored={use_vectored})");
let file_offset_result =
reconstruct_to_file_at_offset_zero(client, file_contents.file_hash, Some(range), &config)
.await
.expect("reconstruct_to_file_at_offset_zero should succeed");
assert_eq!(file_offset_result, expected, "file_at_offset failed (vectored={use_vectored})");
}
}
#[tokio::test]
async fn test_single_term_full_reconstruction() {
let (client, file_contents) = setup_test_file(&[(1, (0, 3))]).await;
reconstruct_and_verify_full(&client, &file_contents, test_config()).await;
}
#[tokio::test]
async fn test_multiple_terms_same_xorb_full_reconstruction() {
let (client, file_contents) = setup_test_file(&[(1, (0, 2)), (1, (2, 4)), (1, (4, 6))]).await;
reconstruct_and_verify_full(&client, &file_contents, test_config()).await;
}
#[tokio::test]
async fn test_multiple_xorbs_full_reconstruction() {
let (client, file_contents) = setup_test_file(&[(1, (0, 3)), (2, (0, 2)), (3, (0, 4))]).await;
reconstruct_and_verify_full(&client, &file_contents, test_config()).await;
}
#[tokio::test]
async fn test_large_file_many_terms_full_reconstruction() {
let term_spec: Vec<(u64, (u64, u64))> = (1..=10).map(|i| (i, (0, 5))).collect();
let (client, file_contents) = setup_test_file(&term_spec).await;
reconstruct_and_verify_full(&client, &file_contents, test_config()).await;
}
#[tokio::test]
async fn test_interleaved_xorbs_full_reconstruction() {
let (client, file_contents) = setup_test_file(&[(1, (0, 2)), (2, (0, 2)), (1, (2, 4)), (2, (2, 4))]).await;
reconstruct_and_verify_full(&client, &file_contents, test_config()).await;
}
#[tokio::test]
async fn test_single_chunk_file() {
let (client, file_contents) = setup_test_file(&[(1, (0, 1))]).await;
reconstruct_and_verify_full(&client, &file_contents, test_config()).await;
}
#[tokio::test]
async fn test_many_small_terms_different_xorbs() {
let term_spec: Vec<(u64, (u64, u64))> = (1..=20).map(|i| (i, (0, 1))).collect();
let (client, file_contents) = setup_test_file(&term_spec).await;
reconstruct_and_verify_full(&client, &file_contents, test_config()).await;
}
#[tokio::test]
async fn test_progress_tracker_records_full_reconstruction_bytes() {
let (client, file_contents) = setup_test_file(&[(1, (0, 3)), (2, (0, 2))]).await;
let config = test_config();
let buffer = Arc::new(std::sync::Mutex::new(Cursor::new(Vec::new())));
let writer = StaticCursorWriter(buffer.clone());
let progress_updater = ItemProgressUpdater::new_standalone("file");
let bytes_written = FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
.with_config(&config)
.with_progress_updater(progress_updater.clone())
.reconstruct_to_writer(writer)
.await
.unwrap();
assert_eq!(bytes_written, file_contents.data.len() as u64);
}
#[tokio::test]
async fn test_progress_tracker_records_partial_range_bytes() {
let (client, file_contents) = setup_test_file(&[(1, (0, 10))]).await;
let config = test_config();
let file_len = file_contents.data.len() as u64;
let range = FileRange::new(file_len / 4, file_len * 3 / 4);
let expected_bytes = range.end - range.start;
let buffer = Arc::new(std::sync::Mutex::new(Cursor::new(Vec::new())));
let writer = StaticCursorWriter(buffer.clone());
let progress_updater = ItemProgressUpdater::new_standalone("file");
let bytes_written = FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
.with_config(&config)
.with_byte_range(range)
.with_progress_updater(progress_updater.clone())
.reconstruct_to_writer(writer)
.await
.unwrap();
assert_eq!(bytes_written, expected_bytes);
}
#[tokio::test]
async fn test_external_progress_tracker_incremental_discovery() {
let term_spec: Vec<(u64, (u64, u64))> = (1..=5).map(|i| (i, (0, 3))).collect();
let (client, file_contents) = setup_test_file(&term_spec).await;
let config = test_config();
let task = ItemProgressUpdater::new_standalone("test_file.bin");
let buffer = Arc::new(std::sync::Mutex::new(Cursor::new(Vec::new())));
let writer = StaticCursorWriter(buffer.clone());
let bytes_written = FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
.with_config(&config)
.with_progress_updater(task.clone())
.reconstruct_to_writer(writer)
.await
.unwrap();
assert_eq!(bytes_written, file_contents.data.len() as u64);
task.assert_complete();
assert_eq!(task.total_bytes_completed(), file_contents.data.len() as u64);
}
#[tokio::test]
async fn test_external_progress_tracker_final_size_upfront() {
let term_spec: Vec<(u64, (u64, u64))> = (1..=5).map(|i| (i, (0, 3))).collect();
let (client, file_contents) = setup_test_file(&term_spec).await;
let config = test_config();
let file_size = file_contents.data.len() as u64;
let task = ItemProgressUpdater::new_standalone("test_file.bin");
task.update_item_size(file_size, true);
let buffer = Arc::new(std::sync::Mutex::new(Cursor::new(Vec::new())));
let writer = StaticCursorWriter(buffer.clone());
let bytes_written = FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
.with_config(&config)
.with_progress_updater(task.clone())
.reconstruct_to_writer(writer)
.await
.unwrap();
assert_eq!(bytes_written, file_size);
assert_eq!(task.total_bytes_completed(), file_size);
task.assert_complete();
}
#[tokio::test]
async fn test_range_first_half() {
let (client, file_contents) = setup_test_file(&[(1, (0, 10))]).await;
let file_len = file_contents.data.len() as u64;
let range = FileRange::new(0, file_len / 2);
reconstruct_and_verify_range(&client, &file_contents, range, test_config()).await;
}
#[tokio::test]
async fn test_range_second_half() {
let (client, file_contents) = setup_test_file(&[(1, (0, 10))]).await;
let file_len = file_contents.data.len() as u64;
let range = FileRange::new(file_len / 2, file_len);
reconstruct_and_verify_range(&client, &file_contents, range, test_config()).await;
}
#[tokio::test]
async fn test_range_middle() {
let (client, file_contents) = setup_test_file(&[(1, (0, 10))]).await;
let file_len = file_contents.data.len() as u64;
let range = FileRange::new(file_len / 4, file_len * 3 / 4);
reconstruct_and_verify_range(&client, &file_contents, range, test_config()).await;
}
#[tokio::test]
async fn test_range_single_byte_start() {
let (client, file_contents) = setup_test_file(&[(1, (0, 5))]).await;
let range = FileRange::new(0, 1);
reconstruct_and_verify_range(&client, &file_contents, range, test_config()).await;
}
#[tokio::test]
async fn test_range_single_byte_end() {
let (client, file_contents) = setup_test_file(&[(1, (0, 5))]).await;
let file_len = file_contents.data.len() as u64;
let range = FileRange::new(file_len - 1, file_len);
reconstruct_and_verify_range(&client, &file_contents, range, test_config()).await;
}
#[tokio::test]
async fn test_range_single_byte_middle() {
let (client, file_contents) = setup_test_file(&[(1, (0, 5))]).await;
let file_len = file_contents.data.len() as u64;
let mid = file_len / 2;
let range = FileRange::new(mid, mid + 1);
reconstruct_and_verify_range(&client, &file_contents, range, test_config()).await;
}
#[tokio::test]
async fn test_range_few_bytes_from_start() {
let (client, file_contents) = setup_test_file(&[(1, (0, 5))]).await;
let file_len = file_contents.data.len() as u64;
let range = FileRange::new(3, file_len);
reconstruct_and_verify_range(&client, &file_contents, range, test_config()).await;
}
#[tokio::test]
async fn test_range_few_bytes_before_end() {
let (client, file_contents) = setup_test_file(&[(1, (0, 5))]).await;
let file_len = file_contents.data.len() as u64;
let range = FileRange::new(0, file_len - 3);
reconstruct_and_verify_range(&client, &file_contents, range, test_config()).await;
}
#[tokio::test]
async fn test_range_small_slice_in_middle() {
let (client, file_contents) = setup_test_file(&[(1, (0, 10))]).await;
let file_len = file_contents.data.len() as u64;
let range = FileRange::new(file_len / 3, file_len / 3 + 10);
reconstruct_and_verify_range(&client, &file_contents, range, test_config()).await;
}
#[tokio::test]
async fn test_range_spanning_multiple_terms() {
let (client, file_contents) = setup_test_file(&[(1, (0, 3)), (2, (0, 3)), (3, (0, 3))]).await;
let file_len = file_contents.data.len() as u64;
let range = FileRange::new(10, file_len - 10);
reconstruct_and_verify_range(&client, &file_contents, range, test_config()).await;
}
#[tokio::test]
async fn test_range_within_single_term() {
let (client, file_contents) = setup_test_file(&[(1, (0, 10)), (2, (0, 10))]).await;
let first_term_size = file_contents.terms[0].data.len() as u64;
let range = FileRange::new(5, first_term_size - 5);
reconstruct_and_verify_range(&client, &file_contents, range, test_config()).await;
}
#[tokio::test]
async fn test_range_crossing_term_boundary() {
let (client, file_contents) = setup_test_file(&[(1, (0, 5)), (2, (0, 5))]).await;
let first_term_size = file_contents.terms[0].data.len() as u64;
let range = FileRange::new(first_term_size - 10, first_term_size + 10);
reconstruct_and_verify_range(&client, &file_contents, range, test_config()).await;
}
#[tokio::test]
async fn test_large_file_range_first_portion() {
let term_spec: Vec<(u64, (u64, u64))> = (1..=15).map(|i| (i, (0, 4))).collect();
let (client, file_contents) = setup_test_file(&term_spec).await;
let file_len = file_contents.data.len() as u64;
let range = FileRange::new(0, file_len / 3);
reconstruct_and_verify_range(&client, &file_contents, range, test_config()).await;
}
#[tokio::test]
async fn test_large_file_range_last_portion() {
let term_spec: Vec<(u64, (u64, u64))> = (1..=15).map(|i| (i, (0, 4))).collect();
let (client, file_contents) = setup_test_file(&term_spec).await;
let file_len = file_contents.data.len() as u64;
let range = FileRange::new(file_len * 2 / 3, file_len);
reconstruct_and_verify_range(&client, &file_contents, range, test_config()).await;
}
#[tokio::test]
async fn test_large_file_range_middle_portion() {
let term_spec: Vec<(u64, (u64, u64))> = (1..=15).map(|i| (i, (0, 4))).collect();
let (client, file_contents) = setup_test_file(&term_spec).await;
let file_len = file_contents.data.len() as u64;
let range = FileRange::new(file_len / 3, file_len * 2 / 3);
reconstruct_and_verify_range(&client, &file_contents, range, test_config()).await;
}
#[tokio::test]
async fn test_complex_mixed_pattern_full() {
let term_spec = &[
(1, (0, 3)),
(2, (0, 2)),
(1, (3, 5)),
(3, (1, 4)),
(2, (4, 6)),
(1, (0, 2)),
];
let (client, file_contents) = setup_test_file(term_spec).await;
reconstruct_and_verify_full(&client, &file_contents, test_config()).await;
}
#[tokio::test]
async fn test_complex_mixed_pattern_partial_range() {
let term_spec = &[
(1, (0, 3)),
(2, (0, 2)),
(1, (3, 5)),
(3, (1, 4)),
(2, (4, 6)),
(1, (0, 2)),
];
let (client, file_contents) = setup_test_file(term_spec).await;
let file_len = file_contents.data.len() as u64;
let range = FileRange::new(file_len / 4, file_len * 3 / 4);
reconstruct_and_verify_range(&client, &file_contents, range, test_config()).await;
}
#[tokio::test]
async fn test_overlapping_chunk_ranges() {
let (client, file_contents) = setup_test_file(&[(1, (0, 5)), (1, (1, 3)), (1, (2, 4))]).await;
reconstruct_and_verify_full(&client, &file_contents, test_config()).await;
}
#[tokio::test]
async fn test_non_contiguous_chunks() {
let (client, file_contents) = setup_test_file(&[(1, (0, 2)), (1, (4, 6))]).await;
let config = test_config();
let result = reconstruct_to_vec(&client, file_contents.file_hash, None, &config, None)
.await
.unwrap();
assert_eq!(result, file_contents.data);
}
#[tokio::test]
async fn test_default_config_full_reconstruction() {
let (client, file_contents) = setup_test_file(&[(1, (0, 5)), (2, (0, 3))]).await;
reconstruct_and_verify_full(&client, &file_contents, ReconstructionConfig::default()).await;
}
#[tokio::test]
async fn test_default_config_partial_range() {
let (client, file_contents) = setup_test_file(&[(1, (0, 5)), (2, (0, 3))]).await;
let file_len = file_contents.data.len() as u64;
let range = FileRange::new(file_len / 4, file_len * 3 / 4);
reconstruct_and_verify_range(&client, &file_contents, range, ReconstructionConfig::default()).await;
}
struct TimeAdvancingWriter {
buffer: Arc<std::sync::Mutex<Vec<u8>>>,
advance_duration: Duration,
write_count: Arc<AtomicUsize>,
}
impl TimeAdvancingWriter {
fn new(advance_duration: Duration) -> Self {
Self {
buffer: Arc::new(std::sync::Mutex::new(Vec::new())),
advance_duration,
write_count: Arc::new(AtomicUsize::new(0)),
}
}
}
impl Write for TimeAdvancingWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let bytes_written = self.buffer.lock().unwrap().write(buf)?;
self.write_count.fetch_add(1, Ordering::Relaxed);
let advance_duration = self.advance_duration;
tokio::runtime::Handle::current().block_on(async {
tokio::time::advance(advance_duration).await;
});
Ok(bytes_written)
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
fn url_refresh_test_config() -> ReconstructionConfig {
let mut config = ReconstructionConfig::default();
config.min_reconstruction_fetch_size = xet_runtime::utils::ByteSize::from("50");
config.max_reconstruction_fetch_size = xet_runtime::utils::ByteSize::from("100");
config.min_prefetch_buffer = xet_runtime::utils::ByteSize::from("50");
config
}
#[tokio::test(start_paused = true)]
async fn test_url_refresh_on_expiration() {
let term_spec = &[(1, (0, 2)), (2, (0, 2)), (3, (0, 2))];
let (client, file_contents) = setup_test_file(term_spec).await;
let url_expiration = Duration::from_secs(1);
client.set_fetch_term_url_expiration(url_expiration);
let time_advance = Duration::from_secs(2);
let writer = TimeAdvancingWriter::new(time_advance);
let writer_buffer = writer.buffer.clone();
let write_count = writer.write_count.clone();
let tiny_semaphore = AdjustableSemaphore::new(1, (1, 1));
FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
.with_config(url_refresh_test_config())
.with_buffer_semaphore(tiny_semaphore)
.reconstruct_to_writer(writer)
.await
.expect("Reconstruction should succeed with URL refresh");
let reconstructed = writer_buffer.lock().unwrap().clone();
assert_eq!(reconstructed.len(), file_contents.data.len());
assert_eq!(reconstructed, file_contents.data);
assert!(write_count.load(Ordering::Relaxed) >= term_spec.len());
}
#[tokio::test(start_paused = true)]
async fn test_url_refresh_same_xorb_multiple_terms() {
let term_spec = &[(1, (0, 2)), (1, (2, 4)), (1, (4, 6))];
let (client, file_contents) = setup_test_file(term_spec).await;
client.set_fetch_term_url_expiration(Duration::from_secs(1));
let writer = TimeAdvancingWriter::new(Duration::from_secs(2));
let writer_buffer = writer.buffer.clone();
let tiny_semaphore = AdjustableSemaphore::new(1, (1, 1));
FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
.with_config(url_refresh_test_config())
.with_buffer_semaphore(tiny_semaphore)
.reconstruct_to_writer(writer)
.await
.expect("Reconstruction should succeed");
let reconstructed = writer_buffer.lock().unwrap().clone();
assert_eq!(reconstructed, file_contents.data);
}
#[tokio::test(start_paused = true)]
async fn test_url_refresh_large_file_multiple_blocks() {
let term_spec: Vec<(u64, (u64, u64))> = (1..=5).map(|i| (i, (0, 3))).collect();
let (client, file_contents) = setup_test_file(&term_spec).await;
client.set_fetch_term_url_expiration(Duration::from_secs(1));
let writer = TimeAdvancingWriter::new(Duration::from_secs(2));
let writer_buffer = writer.buffer.clone();
let tiny_semaphore = AdjustableSemaphore::new(1, (1, 1));
FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
.with_config(url_refresh_test_config())
.with_buffer_semaphore(tiny_semaphore)
.reconstruct_to_writer(writer)
.await
.expect("Reconstruction should succeed");
let reconstructed = writer_buffer.lock().unwrap().clone();
assert_eq!(reconstructed, file_contents.data);
}
#[tokio::test(start_paused = true)]
async fn test_no_url_expiration_control() {
let term_spec = &[(1, (0, 2)), (2, (0, 2)), (3, (0, 2))];
let (client, file_contents) = setup_test_file(term_spec).await;
client.set_fetch_term_url_expiration(Duration::from_secs(3600));
let writer = TimeAdvancingWriter::new(Duration::from_millis(100));
let writer_buffer = writer.buffer.clone();
let tiny_semaphore = AdjustableSemaphore::new(1, (1, 1));
FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
.with_config(url_refresh_test_config())
.with_buffer_semaphore(tiny_semaphore)
.reconstruct_to_writer(writer)
.await
.expect("Reconstruction should succeed");
let reconstructed = writer_buffer.lock().unwrap().clone();
assert_eq!(reconstructed, file_contents.data);
}
#[tokio::test(start_paused = true)]
async fn test_url_refresh_partial_range() {
let term_spec = &[(1, (0, 5)), (2, (0, 5))];
let (client, file_contents) = setup_test_file(term_spec).await;
let file_len = file_contents.data.len() as u64;
client.set_fetch_term_url_expiration(Duration::from_secs(1));
let writer = TimeAdvancingWriter::new(Duration::from_secs(2));
let writer_buffer = writer.buffer.clone();
let tiny_semaphore = AdjustableSemaphore::new(1, (0, 1));
let range = FileRange::new(file_len / 4, file_len * 3 / 4);
FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
.with_byte_range(range)
.with_config(url_refresh_test_config())
.with_buffer_semaphore(tiny_semaphore)
.reconstruct_to_writer(writer)
.await
.expect("Reconstruction should succeed");
let reconstructed = writer_buffer.lock().unwrap().clone();
let expected = &file_contents.data[range.start as usize..range.end as usize];
assert_eq!(reconstructed, expected);
}
#[test]
fn test_dynamic_buffer_scaling_noop_increment_preserves_total_permits() {
let mut runtime_config = xet_runtime::config::XetConfig::new();
runtime_config.reconstruction.download_buffer_size = xet_runtime::utils::ByteSize::from("1kb");
runtime_config.reconstruction.download_buffer_limit = xet_runtime::utils::ByteSize::from("4kb");
let expected_total = runtime_config.reconstruction.download_buffer_limit.as_u64();
let rt = XetRuntime::new_with_config(runtime_config).unwrap();
rt.bridge_sync(async move {
let (client, file_contents) = setup_test_file(&[(1, (0, 2)), (2, (0, 2)), (3, (0, 2))]).await;
let sem = XetRuntime::current().common().reconstruction_download_buffer.clone();
let p = sem.increment_total_permits(u64::MAX).unwrap();
drop(p);
assert_eq!(sem.total_permits(), expected_total);
let mut config = test_config();
config.download_buffer_perfile_size = xet_runtime::utils::ByteSize::from("8kb");
let reconstructed = reconstruct_to_vec(&client, file_contents.file_hash, None, &config, None)
.await
.unwrap();
assert_eq!(reconstructed, file_contents.data);
assert_eq!(sem.total_permits(), expected_total);
assert_eq!(XetRuntime::current().common().active_downloads.load(Ordering::Relaxed), 0);
})
.unwrap();
}
async fn reconstruct_range_to_file_path(
client: &Arc<LocalClient>,
file_hash: MerkleHash,
file_path: &std::path::Path,
range: FileRange,
config: ReconstructionConfig,
) -> Result<u64> {
FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_hash)
.with_byte_range(range)
.with_config(config)
.reconstruct_to_file(file_path, None, false)
.await
}
#[tokio::test]
async fn test_file_concurrent_non_overlapping_range_writes() {
const NUM_WRITERS: usize = 16;
const LARGE_CHUNK_SIZE: usize = 4096;
let term_spec: Vec<(u64, (u64, u64))> = (1..=16).map(|i| (i, (0, 16))).collect();
let client = LocalClient::temporary().await.unwrap();
let file_contents = client.upload_random_file(&term_spec, LARGE_CHUNK_SIZE).await.unwrap();
let file_len = file_contents.data.len() as u64;
let temp_dir = tempfile::tempdir().unwrap();
let file_path = temp_dir.path().join("output.bin");
std::fs::write(&file_path, vec![0u8; file_len as usize]).unwrap();
let mut config = ReconstructionConfig::default();
config.min_reconstruction_fetch_size = xet_runtime::utils::ByteSize::from("32kb");
config.max_reconstruction_fetch_size = xet_runtime::utils::ByteSize::from("128kb");
let chunk_size = file_len / NUM_WRITERS as u64;
let ranges: Vec<FileRange> = (0..NUM_WRITERS)
.map(|i| {
let start = i as u64 * chunk_size;
let end = if i == NUM_WRITERS - 1 {
file_len
} else {
(i as u64 + 1) * chunk_size
};
FileRange::new(start, end)
})
.collect();
let mut join_set = tokio::task::JoinSet::new();
for range in ranges {
let client = client.clone();
let file_hash = file_contents.file_hash;
let file_path = file_path.clone();
let config = config.clone();
join_set.spawn(async move {
FileReconstructor::new(&(client as Arc<dyn Client>), file_hash)
.with_byte_range(range)
.with_config(config)
.reconstruct_to_file(&file_path, None, false)
.await
});
}
while let Some(result) = join_set.join_next().await {
result.unwrap().unwrap();
}
let reconstructed = std::fs::read(&file_path).unwrap();
assert_eq!(reconstructed.len(), file_contents.data.len());
assert_eq!(reconstructed, file_contents.data);
}
#[tokio::test]
async fn test_file_writes_preserve_existing_content() {
let (client, file_contents) = setup_test_file(&[(1, (0, 10))]).await;
let file_len = file_contents.data.len() as u64;
let temp_dir = tempfile::tempdir().unwrap();
let file_path = temp_dir.path().join("output.bin");
let pattern: Vec<u8> = (0..file_len).map(|i| (i % 251) as u8).collect();
std::fs::write(&file_path, &pattern).unwrap();
let start = file_len / 3;
let end = 2 * file_len / 3;
let range = FileRange::new(start, end);
reconstruct_range_to_file_path(&client, file_contents.file_hash, &file_path, range, test_config())
.await
.unwrap();
let result = std::fs::read(&file_path).unwrap();
assert_eq!(&result[..start as usize], &pattern[..start as usize]);
assert_eq!(&result[end as usize..], &pattern[end as usize..]);
assert_eq!(&result[start as usize..end as usize], &file_contents.data[start as usize..end as usize]);
}
#[tokio::test]
async fn test_triple_disjoint_ranges_full() {
let (client, file_contents) = setup_test_file(&[(1, (0, 2)), (1, (4, 6)), (1, (8, 10))]).await;
reconstruct_and_verify_full(&client, &file_contents, test_config()).await;
}
#[tokio::test]
async fn test_triple_disjoint_ranges_partial() {
let (client, file_contents) = setup_test_file(&[(1, (0, 2)), (1, (4, 6)), (1, (8, 10))]).await;
let file_len = file_contents.data.len() as u64;
let range = FileRange::new(file_len / 4, file_len * 3 / 4);
reconstruct_and_verify_range(&client, &file_contents, range, test_config()).await;
}
#[tokio::test]
async fn test_multi_xorb_interleaved_disjoint() {
let term_spec = &[
(1, (0, 2)),
(2, (0, 2)),
(1, (4, 6)),
(2, (4, 6)),
(1, (8, 10)),
(2, (8, 10)),
];
let (client, file_contents) = setup_test_file(term_spec).await;
reconstruct_and_verify_full(&client, &file_contents, test_config()).await;
}
#[tokio::test]
async fn test_multi_xorb_interleaved_disjoint_partial() {
let term_spec = &[
(1, (0, 2)),
(2, (0, 2)),
(1, (4, 6)),
(2, (4, 6)),
(1, (8, 10)),
(2, (8, 10)),
];
let (client, file_contents) = setup_test_file(term_spec).await;
let file_len = file_contents.data.len() as u64;
let range = FileRange::new(file_len / 3, file_len * 2 / 3);
reconstruct_and_verify_range(&client, &file_contents, range, test_config()).await;
}
#[tokio::test]
async fn test_four_disjoint_ranges() {
let term_spec = &[(1, (0, 2)), (1, (4, 6)), (1, (8, 10)), (1, (12, 14))];
let (client, file_contents) = setup_test_file(term_spec).await;
reconstruct_and_verify_full(&client, &file_contents, test_config()).await;
}
#[tokio::test]
async fn test_mixed_contiguous_and_disjoint() {
let term_spec = &[
(1, (0, 3)), (1, (3, 5)), (1, (8, 10)), ];
let (client, file_contents) = setup_test_file(term_spec).await;
reconstruct_and_verify_full(&client, &file_contents, test_config()).await;
}
#[tokio::test]
async fn test_complex_three_xorb_disjoint() {
let term_spec = &[
(1, (0, 2)),
(2, (0, 3)),
(3, (2, 5)),
(1, (5, 8)),
(2, (6, 8)),
(3, (0, 2)),
];
let (client, file_contents) = setup_test_file(term_spec).await;
reconstruct_and_verify_full(&client, &file_contents, test_config()).await;
}
#[tokio::test]
async fn test_local_client_max_ranges_2_disjoint() {
let client = LocalClient::temporary().await.unwrap();
client.set_max_ranges_per_fetch(2);
let term_spec = &[(1, (0, 2)), (1, (4, 6)), (1, (8, 10)), (1, (12, 14))];
let file_contents = client.upload_random_file(term_spec, TEST_CHUNK_SIZE).await.unwrap();
let config = test_config();
let result = reconstruct_to_vec(&client, file_contents.file_hash, None, &config, None)
.await
.unwrap();
assert_eq!(result, file_contents.data.as_ref());
}
#[tokio::test]
async fn test_local_client_max_ranges_1_multi_xorb() {
let client = LocalClient::temporary().await.unwrap();
client.set_max_ranges_per_fetch(1);
let term_spec = &[(1, (0, 2)), (2, (0, 2)), (1, (4, 6)), (2, (4, 6))];
let file_contents = client.upload_random_file(term_spec, TEST_CHUNK_SIZE).await.unwrap();
let config = test_config();
let result = reconstruct_to_vec(&client, file_contents.file_hash, None, &config, None)
.await
.unwrap();
assert_eq!(result, file_contents.data.as_ref());
}
mod cancellation_tests {
use tokio_util::sync::CancellationToken;
use super::*;
#[tokio::test]
async fn test_cancellation_token_before_start() {
let (client, file_contents) = setup_test_file(&[(1, (0, 3))]).await;
let config = test_config();
let token = CancellationToken::new();
token.cancel();
let buffer = Arc::new(std::sync::Mutex::new(Cursor::new(Vec::new())));
let writer = StaticCursorWriter(buffer.clone());
let bytes_written = FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
.with_config(&config)
.with_cancellation_token(token)
.reconstruct_to_writer(writer)
.await
.unwrap();
assert_eq!(bytes_written, 0);
}
struct CancellingWriter {
buffer: Arc<std::sync::Mutex<Vec<u8>>>,
cancel_token: CancellationToken,
write_count: AtomicUsize,
cancel_after_writes: usize,
}
impl Write for CancellingWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let n = self.buffer.lock().unwrap().write(buf)?;
let count = self.write_count.fetch_add(1, Ordering::Relaxed) + 1;
if count >= self.cancel_after_writes {
self.cancel_token.cancel();
}
Ok(n)
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
#[tokio::test]
async fn test_cancellation_token_during_reconstruction() {
let term_spec: Vec<(u64, (u64, u64))> = (1..=10).map(|i| (i, (0, 5))).collect();
let (client, file_contents) = setup_test_file(&term_spec).await;
let config = test_config();
let token = CancellationToken::new();
let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
let writer = CancellingWriter {
buffer: buffer.clone(),
cancel_token: token.clone(),
write_count: AtomicUsize::new(0),
cancel_after_writes: 1,
};
let tiny_semaphore = AdjustableSemaphore::new(1, (1, 1));
let bytes_written = FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
.with_config(&config)
.with_cancellation_token(token)
.with_buffer_semaphore(tiny_semaphore)
.reconstruct_to_writer(writer)
.await
.unwrap();
assert_eq!(bytes_written, 0);
let written = buffer.lock().unwrap().len();
assert!(written < file_contents.data.len());
}
#[tokio::test]
async fn test_cancellation_token_not_set_completes_normally() {
let (client, file_contents) = setup_test_file(&[(1, (0, 3)), (2, (0, 2))]).await;
let config = test_config();
let token = CancellationToken::new();
let buffer = Arc::new(std::sync::Mutex::new(Cursor::new(Vec::new())));
let writer = StaticCursorWriter(buffer.clone());
let bytes_written = FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
.with_config(&config)
.with_cancellation_token(token)
.reconstruct_to_writer(writer)
.await
.unwrap();
assert_eq!(bytes_written, file_contents.data.len() as u64);
assert_eq!(buffer.lock().unwrap().get_ref().clone(), file_contents.data);
}
}
mod multirange_tests {
use super::*;
fn with_multirange_config(enable: bool) -> Arc<XetRuntime> {
let mut config = xet_runtime::config::XetConfig::new();
config.client.enable_multirange_fetching = enable;
XetRuntime::new_with_config(config).unwrap()
}
#[test]
fn test_multirange_local_client() {
for enable in [false, true] {
let rt = with_multirange_config(enable);
rt.bridge_sync(async move {
let scenarios: Vec<Vec<(u64, (u64, u64))>> = vec![
vec![(1, (0, 2)), (1, (4, 6)), (1, (8, 10))],
vec![
(1, (0, 2)),
(2, (0, 2)),
(1, (4, 6)),
(2, (4, 6)),
(1, (8, 10)),
(2, (8, 10)),
],
vec![
(1, (0, 2)),
(2, (0, 3)),
(3, (2, 5)),
(1, (5, 8)),
(2, (6, 8)),
(3, (0, 2)),
],
];
let config = test_config();
for term_spec in &scenarios {
let (client, fc) = setup_test_file(term_spec).await;
reconstruct_and_verify_full(&client, &fc, config.clone()).await;
let file_len = fc.data.len() as u64;
let range = FileRange::new(file_len / 4, file_len * 3 / 4);
reconstruct_and_verify_range(&client, &fc, range, config.clone()).await;
}
})
.unwrap();
}
}
#[test]
fn test_multirange_max_ranges() {
for enable in [false, true] {
let rt = with_multirange_config(enable);
rt.bridge_sync(async {
let client = LocalClient::temporary().await.unwrap();
client.set_max_ranges_per_fetch(2);
let term_spec = &[(1, (0, 2)), (1, (4, 6)), (1, (8, 10)), (1, (12, 14))];
let fc = client.upload_random_file(term_spec, TEST_CHUNK_SIZE).await.unwrap();
let config = test_config();
let result = reconstruct_to_vec(&client, fc.file_hash, None, &config, None).await.unwrap();
assert_eq!(result, fc.data.as_ref());
})
.unwrap();
}
}
}
#[cfg(feature = "simulation")]
mod server_tests {
use super::*;
async fn reconstruct_via_server(
server: &xet_client::cas_client::LocalTestServer,
file_hash: MerkleHash,
byte_range: Option<FileRange>,
config: &ReconstructionConfig,
) -> Result<Vec<u8>> {
let buffer = Arc::new(std::sync::Mutex::new(Cursor::new(Vec::new())));
let writer = StaticCursorWriter(buffer.clone());
let client: Arc<dyn Client> = server.remote_client().clone();
let mut reconstructor = FileReconstructor::new(&client, file_hash).with_config(config);
if let Some(range) = byte_range {
reconstructor = reconstructor.with_byte_range(range);
}
reconstructor.reconstruct_to_writer(writer).await?;
let data = buffer.lock().unwrap().get_ref().clone();
Ok(data)
}
#[tokio::test]
async fn test_v1_fallback_full_reconstruction() {
let server = xet_client::cas_client::LocalTestServerBuilder::new().start().await;
let file_contents = server
.remote_client()
.upload_random_file(&[(1, (0, 3)), (2, (0, 2))], TEST_CHUNK_SIZE)
.await
.unwrap();
server.disable_v2_reconstruction(404);
let config = test_config();
let result = reconstruct_via_server(&server, file_contents.file_hash, None, &config)
.await
.unwrap();
assert_eq!(result, file_contents.data.as_ref());
}
#[tokio::test]
async fn test_v1_fallback_partial_range() {
let server = xet_client::cas_client::LocalTestServerBuilder::new().start().await;
let file_contents = server
.remote_client()
.upload_random_file(&[(1, (0, 5)), (2, (0, 3))], TEST_CHUNK_SIZE)
.await
.unwrap();
server.disable_v2_reconstruction(404);
let file_len = file_contents.data.len() as u64;
let range = FileRange::new(file_len / 4, file_len * 3 / 4);
let config = test_config();
let result = reconstruct_via_server(&server, file_contents.file_hash, Some(range), &config)
.await
.unwrap();
assert_eq!(result, &file_contents.data[range.start as usize..range.end as usize]);
}
#[tokio::test]
async fn test_v1_fallback_non_contiguous_chunks() {
let server = xet_client::cas_client::LocalTestServerBuilder::new().start().await;
let file_contents = server
.remote_client()
.upload_random_file(&[(1, (0, 2)), (1, (4, 6))], TEST_CHUNK_SIZE)
.await
.unwrap();
server.disable_v2_reconstruction(404);
let config = test_config();
let result = reconstruct_via_server(&server, file_contents.file_hash, None, &config)
.await
.unwrap();
assert_eq!(result, file_contents.data.as_ref());
}
#[tokio::test]
async fn test_v1_fallback_multiple_xorbs() {
let server = xet_client::cas_client::LocalTestServerBuilder::new().start().await;
let file_contents = server
.remote_client()
.upload_random_file(&[(1, (0, 2)), (2, (0, 3)), (3, (0, 2)), (1, (2, 4))], TEST_CHUNK_SIZE)
.await
.unwrap();
server.disable_v2_reconstruction(404);
let config = test_config();
let result = reconstruct_via_server(&server, file_contents.file_hash, None, &config)
.await
.unwrap();
assert_eq!(result, file_contents.data.as_ref());
}
#[tokio::test]
async fn test_v1_fallback_triple_disjoint_ranges() {
let server = xet_client::cas_client::LocalTestServerBuilder::new().start().await;
let file_contents = server
.remote_client()
.upload_random_file(&[(1, (0, 2)), (1, (4, 6)), (1, (8, 10))], TEST_CHUNK_SIZE)
.await
.unwrap();
server.disable_v2_reconstruction(404);
let config = test_config();
let result = reconstruct_via_server(&server, file_contents.file_hash, None, &config)
.await
.unwrap();
assert_eq!(result, file_contents.data.as_ref());
}
async fn reconstruct_via_server_with_max_ranges(
term_spec: &[(u64, (u64, u64))],
max_ranges: usize,
byte_range: Option<FileRange>,
) -> (Vec<u8>, RandomFileContents) {
let server = xet_client::cas_client::LocalTestServerBuilder::new().start().await;
let file_contents = server
.remote_client()
.upload_random_file(term_spec, TEST_CHUNK_SIZE)
.await
.unwrap();
server.set_max_ranges_per_fetch(max_ranges);
let config = test_config();
let result = reconstruct_via_server(&server, file_contents.file_hash, byte_range, &config)
.await
.unwrap();
(result, file_contents)
}
#[tokio::test]
async fn test_max_ranges_simple() {
let (result, file_contents) =
reconstruct_via_server_with_max_ranges(&[(1, (0, 3)), (2, (0, 2))], 2, None).await;
assert_eq!(result, file_contents.data.as_ref());
}
#[tokio::test]
async fn test_max_ranges_1_disjoint() {
let (result, file_contents) =
reconstruct_via_server_with_max_ranges(&[(1, (0, 2)), (1, (4, 6))], 1, None).await;
assert_eq!(result, file_contents.data.as_ref());
}
#[tokio::test]
async fn test_max_ranges_2_triple_disjoint() {
let (result, file_contents) =
reconstruct_via_server_with_max_ranges(&[(1, (0, 2)), (1, (4, 6)), (1, (8, 10))], 2, None).await;
assert_eq!(result, file_contents.data.as_ref());
}
#[tokio::test]
async fn test_max_ranges_2_multi_xorb_disjoint() {
let term_spec = &[
(1, (0, 2)),
(2, (0, 2)),
(1, (4, 6)),
(2, (4, 6)),
(1, (8, 10)),
(2, (8, 10)),
];
let (result, file_contents) = reconstruct_via_server_with_max_ranges(term_spec, 2, None).await;
assert_eq!(result, file_contents.data.as_ref());
}
#[tokio::test]
async fn test_max_ranges_2_partial_range() {
let term_spec = &[
(1, (0, 3)),
(2, (0, 2)),
(1, (3, 5)),
(3, (1, 4)),
(2, (4, 6)),
(1, (0, 2)),
];
let server = xet_client::cas_client::LocalTestServerBuilder::new().start().await;
let file_contents = server
.remote_client()
.upload_random_file(term_spec, TEST_CHUNK_SIZE)
.await
.unwrap();
server.set_max_ranges_per_fetch(2);
let file_len = file_contents.data.len() as u64;
let range = FileRange::new(file_len / 4, file_len * 3 / 4);
let config = test_config();
let result = reconstruct_via_server(&server, file_contents.file_hash, Some(range), &config)
.await
.unwrap();
assert_eq!(result, &file_contents.data[range.start as usize..range.end as usize]);
}
fn with_multirange_config(enable: bool) -> Arc<XetRuntime> {
let mut config = xet_runtime::config::XetConfig::new();
config.client.enable_multirange_fetching = enable;
XetRuntime::new_with_config(config).unwrap()
}
#[test]
fn test_multirange_via_server() {
for enable in [false, true] {
let rt = with_multirange_config(enable);
rt.bridge_sync(async {
let config = test_config();
let server = xet_client::cas_client::LocalTestServerBuilder::new().start().await;
let fc = server
.remote_client()
.upload_random_file(&[(1, (0, 2)), (1, (4, 6)), (1, (8, 10))], TEST_CHUNK_SIZE)
.await
.unwrap();
let result = reconstruct_via_server(&server, fc.file_hash, None, &config).await.unwrap();
assert_eq!(result, fc.data.as_ref());
let server = xet_client::cas_client::LocalTestServerBuilder::new().start().await;
let fc = server
.remote_client()
.upload_random_file(
&[(1, (0, 2)), (2, (0, 2)), (1, (4, 6)), (2, (4, 6)), (1, (8, 10))],
TEST_CHUNK_SIZE,
)
.await
.unwrap();
server.set_max_ranges_per_fetch(2);
let result = reconstruct_via_server(&server, fc.file_hash, None, &config).await.unwrap();
assert_eq!(result, fc.data.as_ref());
let server = xet_client::cas_client::LocalTestServerBuilder::new().start().await;
let fc = server
.remote_client()
.upload_random_file(&[(1, (0, 3)), (2, (0, 2)), (1, (3, 5)), (2, (4, 6))], TEST_CHUNK_SIZE)
.await
.unwrap();
let file_len = fc.data.len() as u64;
let range = FileRange::new(file_len / 4, file_len * 3 / 4);
let result = reconstruct_via_server(&server, fc.file_hash, Some(range), &config)
.await
.unwrap();
assert_eq!(result, &fc.data[range.start as usize..range.end as usize]);
})
.unwrap();
}
}
} }