use std::collections::VecDeque;
use std::io::{IoSlice, Write};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use bytes::Bytes;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
use tokio::sync::oneshot;
use tokio::task::{JoinHandle, JoinSet};
use xet_client::cas_types::FileRange;
use xet_runtime::core::{XetRuntime, check_sigint_shutdown};
use xet_runtime::utils::adjustable_semaphore::AdjustableSemaphorePermit;
use super::super::data_writer::{DataFuture, DataWriter};
use super::super::run_state::RunState;
use super::super::{FileReconstructionError, Result};
use crate::progress_tracking::ItemProgressUpdater;
const WRITEV_MAX_SLICE: usize = 24;
pub(crate) enum SequentialRetrievalItem {
Data {
receiver: oneshot::Receiver<Bytes>,
permit: Option<AdjustableSemaphorePermit>,
},
Finish,
}
type PendingWrite = (Bytes, Option<AdjustableSemaphorePermit>);
struct SyncWriterThread {
rx: UnboundedReceiver<SequentialRetrievalItem>,
bytes_written: Arc<AtomicU64>,
progress_updater: Option<Arc<ItemProgressUpdater>>,
run_state: Arc<RunState>,
pending: Option<SequentialRetrievalItem>,
finished: bool,
}
impl SyncWriterThread {
fn new(
rx: UnboundedReceiver<SequentialRetrievalItem>,
bytes_written: Arc<AtomicU64>,
progress_updater: Option<Arc<ItemProgressUpdater>>,
run_state: Arc<RunState>,
) -> Self {
Self {
rx,
bytes_written,
progress_updater,
run_state,
pending: None,
finished: false,
}
}
#[inline]
fn next_write(&mut self, should_block: bool) -> Result<Option<PendingWrite>> {
if self.pending.is_none() {
self.pending = if should_block {
self.rx.blocking_recv()
} else {
self.rx.try_recv().ok()
};
}
match self.pending.take() {
Some(SequentialRetrievalItem::Data { mut receiver, permit }) => {
if should_block {
let data = match receiver.blocking_recv() {
Ok(data) => data,
Err(_) => {
self.run_state.check_error()?;
return Err(FileReconstructionError::InternalWriterError(
"Data sender was dropped before sending data.".to_string(),
));
},
};
Ok(Some((data, permit)))
} else {
match receiver.try_recv() {
Ok(data) => Ok(Some((data, permit))),
Err(oneshot::error::TryRecvError::Empty) => {
self.pending = Some(SequentialRetrievalItem::Data { receiver, permit });
Ok(None)
},
Err(oneshot::error::TryRecvError::Closed) => {
self.run_state.check_error()?;
Err(FileReconstructionError::InternalWriterError(
"Data sender was dropped before sending data.".to_string(),
))
},
}
}
},
Some(SequentialRetrievalItem::Finish) => {
self.finished = true;
Ok(None)
},
None => Ok(None),
}
}
fn run(mut self, mut writer: impl Write) -> Result<()> {
while let Some((data, permit)) = self.next_write(true)? {
let len = data.len() as u64;
writer.write_all(&data)?;
self.bytes_written.fetch_add(len, Ordering::Relaxed);
if let Some(ref updater) = self.progress_updater {
updater.report_bytes_written(len);
}
drop(permit);
if self.finished {
break;
}
check_sigint_shutdown()?;
}
debug_assert!(self.finished);
writer.flush()?;
Ok(())
}
fn run_vectorized(mut self, mut writer: impl Write) -> Result<()> {
let mut pending_writes: VecDeque<PendingWrite> = VecDeque::new();
while !self.finished || !pending_writes.is_empty() {
check_sigint_shutdown()?;
if pending_writes.is_empty() {
let Some(write) = self.next_write(true)? else {
break;
};
pending_writes.push_back(write);
}
while let Some(write) = self.next_write(false)? {
pending_writes.push_back(write);
}
let io_slices: Vec<IoSlice<'_>> = pending_writes
.iter()
.take(WRITEV_MAX_SLICE)
.map(|(data, _)| IoSlice::new(data))
.collect();
let written = match writer.write_vectored(&io_slices) {
Ok(0) if !io_slices.is_empty() => {
return Err(FileReconstructionError::IoError(Arc::new(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"write_vectored returned 0 with non-empty buffers",
))));
},
Ok(n) => n,
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
Err(e) => return Err(FileReconstructionError::IoError(Arc::new(e))),
};
self.bytes_written.fetch_add(written as u64, Ordering::Relaxed);
if let Some(ref updater) = self.progress_updater {
updater.report_bytes_written(written as u64);
}
let mut remaining = written;
while remaining > 0 && !pending_writes.is_empty() {
let front_len = pending_writes.front().unwrap().0.len();
if remaining >= front_len {
remaining -= front_len;
pending_writes.pop_front();
} else {
let front = pending_writes.front_mut().unwrap();
front.0 = front.0.slice(remaining..);
remaining = 0;
}
}
}
writer.flush()?;
Ok(())
}
}
pub struct SequentialWriter {
sender: UnboundedSender<SequentialRetrievalItem>,
next_position: u64,
background_handle: Option<JoinHandle<()>>,
run_state: Arc<RunState>,
bytes_written: Arc<AtomicU64>,
active_tasks: JoinSet<Result<()>>,
finished: bool,
}
impl Drop for SequentialWriter {
fn drop(&mut self) {
if !self.finished {
self.run_state.cancel();
}
}
}
#[async_trait::async_trait]
impl DataWriter for SequentialWriter {
async fn set_next_term_data_source(
&mut self,
byte_range: FileRange,
permit: Option<AdjustableSemaphorePermit>,
data_future: DataFuture,
) -> Result<()> {
self.run_state.check_error()?;
while let Some(result) = self.active_tasks.try_join_next() {
result.map_err(|e| FileReconstructionError::InternalError(format!("Task join error: {e}")))??;
}
if self.finished {
return Err(FileReconstructionError::InternalWriterError("Writer has already finished".to_string()));
}
if byte_range.start != self.next_position {
return Err(FileReconstructionError::InternalWriterError(format!(
"Byte range not sequential: expected start at {}, got {}",
self.next_position, byte_range.start
)));
}
let expected_size = byte_range.end - byte_range.start;
self.next_position = byte_range.end;
let (sender, receiver) = oneshot::channel();
if self.sender.send(SequentialRetrievalItem::Data { receiver, permit }).is_err() {
self.run_state.check_error()?;
return Err(FileReconstructionError::InternalWriterError("Background writer channel closed".to_string()));
}
let run_state = self.run_state.clone();
let task = async move {
let result = async {
run_state.check_error()?;
let data = data_future.await?;
if data.len() as u64 != expected_size {
return Err(FileReconstructionError::InternalWriterError(format!(
"Data size mismatch: expected {} bytes, got {} bytes",
expected_size,
data.len()
)));
}
if sender.send(data).is_err() {
run_state.check_error()?;
return Err(FileReconstructionError::InternalWriterError(
"Failed to send data: receiver dropped".to_string(),
));
}
Ok(())
}
.await;
if let Err(ref e) = result {
run_state.set_error(e.clone());
}
result
};
self.active_tasks.spawn(task);
Ok(())
}
async fn finish(mut self: Box<Self>) -> Result<u64> {
self.run_state.check_error()?;
if self.finished {
return Err(FileReconstructionError::InternalWriterError("Writer has already finished".to_string()));
}
self.finished = true;
if self.sender.send(SequentialRetrievalItem::Finish).is_err() {
self.run_state.check_error()?;
return Err(FileReconstructionError::InternalWriterError("Background writer channel closed".to_string()));
}
let expected_bytes = self.next_position;
while let Some(result) = self.active_tasks.join_next().await {
result.map_err(|e| FileReconstructionError::InternalError(format!("Task join error: {e}")))??;
}
match self.background_handle.take() {
Some(handle) => {
handle.await.map_err(|e| {
FileReconstructionError::InternalWriterError(format!("Background writer task failed: {e}"))
})?;
self.run_state.check_error()?;
let actual_bytes = self.bytes_written.load(Ordering::Relaxed);
if actual_bytes != expected_bytes {
return Err(FileReconstructionError::InternalWriterError(format!(
"Bytes written mismatch: expected {} bytes, but wrote {} bytes",
expected_bytes, actual_bytes
)));
}
Ok(actual_bytes)
},
None => {
Ok(expected_bytes)
},
}
}
}
impl SequentialWriter {
pub(crate) fn new_streaming(
run_state: Arc<RunState>,
) -> (Box<dyn DataWriter>, UnboundedReceiver<SequentialRetrievalItem>) {
let (tx, rx) = unbounded_channel::<SequentialRetrievalItem>();
let writer = Self {
sender: tx,
next_position: 0,
background_handle: None,
run_state,
bytes_written: Arc::new(AtomicU64::new(0)),
active_tasks: JoinSet::new(),
finished: false,
};
(Box::new(writer), rx)
}
#[allow(clippy::new_ret_no_self)]
pub(crate) fn new<W: Write + Send + 'static>(
writer: W,
use_vectorized: bool,
run_state: Arc<RunState>,
) -> Box<dyn DataWriter> {
let (tx, rx) = unbounded_channel::<SequentialRetrievalItem>();
let bytes_written = Arc::new(AtomicU64::new(0));
let run_state_clone = run_state.clone();
let run_state_thread = run_state.clone();
let bytes_written_clone = bytes_written.clone();
let progress_updater = run_state.progress_updater().cloned();
let handle = XetRuntime::current().spawn_blocking(move || {
let writer_thread = SyncWriterThread::new(rx, bytes_written_clone, progress_updater, run_state_thread);
let result = if use_vectorized {
writer_thread.run_vectorized(writer)
} else {
writer_thread.run(writer)
};
if let Err(err) = result {
run_state_clone.set_error(err);
}
});
Box::new(Self {
sender: tx,
next_position: 0,
background_handle: Some(handle),
run_state,
bytes_written,
active_tasks: JoinSet::new(),
finished: false,
})
}
}
#[cfg(test)]
mod tests {
use std::io;
use std::time::Duration;
use xet_runtime::utils::adjustable_semaphore::AdjustableSemaphore;
use super::*;
struct SharedBuffer(Arc<std::sync::Mutex<Vec<u8>>>);
impl Write for SharedBuffer {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.lock().unwrap().extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
#[derive(Clone, Default)]
struct TestWriterConfig {
max_write_size: Option<usize>,
max_vectored_write_size: Option<usize>,
hard_limit_vectored_write_slice: Option<usize>,
simulate_interrupts: bool,
interrupt_frequency: usize,
}
impl TestWriterConfig {
fn vectorized() -> Self {
Self::default()
}
fn vectorized_partial(max_size: usize) -> Self {
Self {
max_vectored_write_size: Some(max_size),
..Default::default()
}
}
fn vectorized_hard_limit(max_slice: usize) -> Self {
Self {
hard_limit_vectored_write_slice: Some(max_slice),
..Default::default()
}
}
fn partial(max_size: usize) -> Self {
Self {
max_write_size: Some(max_size),
..Default::default()
}
}
fn vectorized_with_interrupts() -> Self {
Self {
simulate_interrupts: true,
interrupt_frequency: 2,
..Default::default()
}
}
}
struct TestWriter {
buffer: Arc<std::sync::Mutex<Vec<u8>>>,
config: TestWriterConfig,
write_count: Arc<AtomicU64>,
vectored_write_count: Arc<AtomicU64>,
interrupt_counter: Arc<AtomicU64>,
}
impl TestWriter {
fn new(config: TestWriterConfig) -> Self {
Self {
buffer: Arc::new(std::sync::Mutex::new(Vec::new())),
config,
write_count: Arc::new(AtomicU64::new(0)),
vectored_write_count: Arc::new(AtomicU64::new(0)),
interrupt_counter: Arc::new(AtomicU64::new(0)),
}
}
fn should_interrupt(&self) -> bool {
if !self.config.simulate_interrupts {
return false;
}
let count = self.interrupt_counter.fetch_add(1, Ordering::Relaxed);
count % self.config.interrupt_frequency as u64 == 0
}
}
impl Write for TestWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.should_interrupt() {
return Err(io::Error::new(io::ErrorKind::Interrupted, "simulated interrupt"));
}
self.write_count.fetch_add(1, Ordering::Relaxed);
let bytes_to_write = match self.config.max_write_size {
Some(max) => buf.len().min(max),
None => buf.len(),
};
self.buffer.lock().unwrap().extend_from_slice(&buf[..bytes_to_write]);
Ok(bytes_to_write)
}
fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
if self.should_interrupt() {
return Err(io::Error::new(io::ErrorKind::Interrupted, "simulated interrupt"));
}
if let Some(max_slice) = self.config.hard_limit_vectored_write_slice
&& bufs.len() > max_slice
{
return Err(io::Error::new(io::ErrorKind::InvalidInput, "simulated iovcnt EINVAL"));
}
self.vectored_write_count.fetch_add(1, Ordering::Relaxed);
let total_len: usize = bufs.iter().map(|b| b.len()).sum();
let max_write = self.config.max_vectored_write_size.unwrap_or(total_len);
let bytes_to_write = total_len.min(max_write);
let mut remaining = bytes_to_write;
let mut buffer = self.buffer.lock().unwrap();
for buf in bufs {
if remaining == 0 {
break;
}
let to_write = buf.len().min(remaining);
buffer.extend_from_slice(&buf[..to_write]);
remaining -= to_write;
}
Ok(bytes_to_write)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
fn immediate_future(data: Bytes) -> DataFuture {
Box::pin(async move { Ok(data) })
}
#[tokio::test]
async fn test_sequential_writes() {
let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
let buffer_clone = buffer.clone();
let mut writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(5, 6), None, immediate_future(Bytes::from(" ")))
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(6, 11), None, immediate_future(Bytes::from("World")))
.await
.unwrap();
writer.finish().await.unwrap();
let result = buffer.lock().unwrap();
assert_eq!(&*result, b"Hello World");
}
#[tokio::test]
async fn test_delayed_future() {
let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
let buffer_clone = buffer.clone();
let mut writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
let f0: DataFuture = Box::pin(async {
tokio::time::sleep(Duration::from_millis(50)).await;
Ok(Bytes::from("Hello"))
});
let f1: DataFuture = Box::pin(async {
tokio::time::sleep(Duration::from_millis(10)).await;
Ok(Bytes::from(" "))
});
let f2: DataFuture = Box::pin(async { Ok(Bytes::from("World")) });
writer.set_next_term_data_source(FileRange::new(0, 5), None, f0).await.unwrap();
writer.set_next_term_data_source(FileRange::new(5, 6), None, f1).await.unwrap();
writer.set_next_term_data_source(FileRange::new(6, 11), None, f2).await.unwrap();
writer.finish().await.unwrap();
let result = buffer.lock().unwrap();
assert_eq!(&*result, b"Hello World");
}
#[tokio::test]
async fn test_size_mismatch_error() {
let buffer = std::io::Cursor::new(Vec::new());
let mut writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 10), None, immediate_future(Bytes::from("Hello")))
.await
.unwrap();
let result = writer.finish().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_background_writer_error_propagates() {
struct FailingWriter;
impl Write for FailingWriter {
fn write(&mut self, _buf: &[u8]) -> io::Result<usize> {
Err(io::Error::new(io::ErrorKind::Other, "Simulated write failure"))
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
let mut writer = SequentialWriter::new(Box::new(FailingWriter), false, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 4), None, immediate_future(Bytes::from("Test")))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(200)).await;
let result = writer
.set_next_term_data_source(FileRange::new(4, 8), None, immediate_future(Bytes::from("More")))
.await;
assert!(result.is_err());
assert!(matches!(result, Err(FileReconstructionError::IoError(_))));
}
#[tokio::test]
async fn test_flush_error_propagates() {
struct FlushFailingWriter;
impl Write for FlushFailingWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Err(io::Error::new(io::ErrorKind::Other, "Simulated flush failure"))
}
}
let writer = SequentialWriter::new(Box::new(FlushFailingWriter), false, RunState::new_for_test());
let result = writer.finish().await;
assert!(result.is_err());
assert!(matches!(result, Err(FileReconstructionError::IoError(_))));
}
#[tokio::test]
async fn test_future_error_propagates() {
let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
let buffer_clone = buffer.clone();
let mut writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
let failing_future: DataFuture =
Box::pin(async { Err(FileReconstructionError::InternalError("Simulated future error".to_string())) });
writer
.set_next_term_data_source(FileRange::new(0, 5), None, failing_future)
.await
.unwrap();
let result = writer.finish().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_size_mismatch_too_small() {
let buffer = std::io::Cursor::new(Vec::new());
let mut writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 10), None, immediate_future(Bytes::from("Hi")))
.await
.unwrap();
let result = writer.finish().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_size_mismatch_too_large() {
let buffer = std::io::Cursor::new(Vec::new());
let mut writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 2), None, immediate_future(Bytes::from("Hello World")))
.await
.unwrap();
let result = writer.finish().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_bytes_written_tracking() {
let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
let buffer_clone = buffer.clone();
let mut writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(5, 11), None, immediate_future(Bytes::from(" World")))
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(11, 16), None, immediate_future(Bytes::from("!!!!!")))
.await
.unwrap();
writer.finish().await.unwrap();
let result = buffer.lock().unwrap();
assert_eq!(&*result, b"Hello World!!!!!");
assert_eq!(result.len(), 16);
}
#[tokio::test]
async fn test_non_sequential_range_returns_error() {
let buffer = std::io::Cursor::new(Vec::new());
let mut writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
.await
.unwrap();
let result = writer
.set_next_term_data_source(FileRange::new(10, 15), None, immediate_future(Bytes::from("World")))
.await;
assert!(result.is_err());
assert!(matches!(result, Err(FileReconstructionError::InternalWriterError(_))));
}
#[tokio::test]
async fn test_first_range_must_start_at_zero() {
let buffer = std::io::Cursor::new(Vec::new());
let mut writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
let result = writer
.set_next_term_data_source(FileRange::new(5, 10), None, immediate_future(Bytes::from("Hello")))
.await;
assert!(result.is_err());
assert!(matches!(result, Err(FileReconstructionError::InternalWriterError(_))));
}
#[tokio::test]
async fn test_semaphore_permit_released_after_write() {
let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
let buffer_clone = buffer.clone();
let semaphore = AdjustableSemaphore::new(2, (0, 2));
let mut writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
let permit1 = semaphore.acquire().await.unwrap();
let permit2 = semaphore.acquire().await.unwrap();
assert_eq!(semaphore.available_permits(), 0);
writer
.set_next_term_data_source(FileRange::new(0, 5), Some(permit1), immediate_future(Bytes::from("Hello")))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(semaphore.available_permits(), 1);
writer
.set_next_term_data_source(FileRange::new(5, 6), Some(permit2), immediate_future(Bytes::from(" ")))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(semaphore.available_permits(), 2);
writer.finish().await.unwrap();
let result = buffer.lock().unwrap();
assert_eq!(&*result, b"Hello ");
}
#[tokio::test]
async fn test_vectorized_basic_writes() {
let test_writer = TestWriter::new(TestWriterConfig::vectorized());
let buffer = test_writer.buffer.clone();
let vectored_count = test_writer.vectored_write_count.clone();
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(5, 6), None, immediate_future(Bytes::from(" ")))
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(6, 11), None, immediate_future(Bytes::from("World")))
.await
.unwrap();
writer.finish().await.unwrap();
let result = buffer.lock().unwrap();
assert_eq!(&*result, b"Hello World");
assert!(vectored_count.load(Ordering::Relaxed) > 0);
}
#[tokio::test]
async fn test_vectorized_partial_writes() {
let test_writer = TestWriter::new(TestWriterConfig::vectorized_partial(3));
let buffer = test_writer.buffer.clone();
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(5, 6), None, immediate_future(Bytes::from(" ")))
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(6, 11), None, immediate_future(Bytes::from("World")))
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(11, 12), None, immediate_future(Bytes::from("!")))
.await
.unwrap();
writer.finish().await.unwrap();
let result = buffer.lock().unwrap();
assert_eq!(&*result, b"Hello World!");
}
#[tokio::test]
async fn test_vectorized_with_delays() {
let test_writer = TestWriter::new(TestWriterConfig::vectorized());
let buffer = test_writer.buffer.clone();
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
let f0: DataFuture = Box::pin(async {
tokio::time::sleep(Duration::from_millis(30)).await;
Ok(Bytes::from("A"))
});
let f1: DataFuture = Box::pin(async {
tokio::time::sleep(Duration::from_millis(10)).await;
Ok(Bytes::from("B"))
});
let f2: DataFuture = Box::pin(async { Ok(Bytes::from("C")) });
writer.set_next_term_data_source(FileRange::new(0, 1), None, f0).await.unwrap();
writer.set_next_term_data_source(FileRange::new(1, 2), None, f1).await.unwrap();
writer.set_next_term_data_source(FileRange::new(2, 3), None, f2).await.unwrap();
writer.finish().await.unwrap();
let result = buffer.lock().unwrap();
assert_eq!(&*result, b"ABC");
}
#[tokio::test]
async fn test_vectorized_many_small_writes() {
let expected: Vec<u8> = (0..100u8).collect();
let test_writer = TestWriter::new(TestWriterConfig::vectorized());
let buffer = test_writer.buffer.clone();
let vectored_count = test_writer.vectored_write_count.clone();
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
for i in 0..100u8 {
writer
.set_next_term_data_source(
FileRange::new(i as u64, i as u64 + 1),
None,
immediate_future(Bytes::from(vec![i])),
)
.await
.unwrap();
}
writer.finish().await.unwrap();
let result = buffer.lock().unwrap();
assert_eq!(&*result, &expected);
let vectored_calls = vectored_count.load(Ordering::Relaxed);
assert!(vectored_calls < 100);
}
#[tokio::test]
async fn test_vectorized_with_interrupts() {
let test_writer = TestWriter::new(TestWriterConfig::vectorized_with_interrupts());
let buffer = test_writer.buffer.clone();
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(5, 6), None, immediate_future(Bytes::from(" ")))
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(6, 11), None, immediate_future(Bytes::from("World")))
.await
.unwrap();
writer.finish().await.unwrap();
let result = buffer.lock().unwrap();
assert_eq!(&*result, b"Hello World");
}
#[tokio::test]
async fn test_vectorized_permit_release() {
let test_writer = TestWriter::new(TestWriterConfig::vectorized());
let buffer = test_writer.buffer.clone();
let semaphore = AdjustableSemaphore::new(2, (0, 2));
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
let permit1 = semaphore.acquire().await.unwrap();
let permit2 = semaphore.acquire().await.unwrap();
assert_eq!(semaphore.available_permits(), 0);
writer
.set_next_term_data_source(FileRange::new(0, 5), Some(permit1), immediate_future(Bytes::from("Hello")))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(semaphore.available_permits(), 1);
writer
.set_next_term_data_source(FileRange::new(5, 6), Some(permit2), immediate_future(Bytes::from(" ")))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(semaphore.available_permits(), 2);
writer.finish().await.unwrap();
let result = buffer.lock().unwrap();
assert_eq!(&*result, b"Hello ");
}
#[tokio::test]
async fn test_vectorized_partial_permit_release() {
let test_writer = TestWriter::new(TestWriterConfig::vectorized_partial(2));
let buffer = test_writer.buffer.clone();
let semaphore = AdjustableSemaphore::new(3, (0, 3));
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
let permit1 = semaphore.acquire().await.unwrap();
let permit2 = semaphore.acquire().await.unwrap();
let permit3 = semaphore.acquire().await.unwrap();
assert_eq!(semaphore.available_permits(), 0);
writer
.set_next_term_data_source(FileRange::new(0, 5), Some(permit1), immediate_future(Bytes::from("Hello")))
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(5, 11), Some(permit2), immediate_future(Bytes::from(" World")))
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(11, 12), Some(permit3), immediate_future(Bytes::from("!")))
.await
.unwrap();
writer.finish().await.unwrap();
assert_eq!(semaphore.available_permits(), 3);
let result = buffer.lock().unwrap();
assert_eq!(&*result, b"Hello World!");
}
#[tokio::test]
async fn test_non_vectorized_basic_writes() {
let test_writer = TestWriter::new(TestWriterConfig::default());
let buffer = test_writer.buffer.clone();
let write_count = test_writer.write_count.clone();
let vectored_count = test_writer.vectored_write_count.clone();
let mut writer = SequentialWriter::new(Box::new(test_writer), false, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(5, 6), None, immediate_future(Bytes::from(" ")))
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(6, 11), None, immediate_future(Bytes::from("World")))
.await
.unwrap();
writer.finish().await.unwrap();
let result = buffer.lock().unwrap();
assert_eq!(&*result, b"Hello World");
assert!(write_count.load(Ordering::Relaxed) > 0);
assert_eq!(vectored_count.load(Ordering::Relaxed), 0);
}
#[tokio::test]
async fn test_non_vectorized_partial_writes() {
let test_writer = TestWriter::new(TestWriterConfig::partial(3));
let buffer = test_writer.buffer.clone();
let mut writer = SequentialWriter::new(Box::new(test_writer), false, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(5, 6), None, immediate_future(Bytes::from(" ")))
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(6, 11), None, immediate_future(Bytes::from("World")))
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(11, 12), None, immediate_future(Bytes::from("!")))
.await
.unwrap();
writer.finish().await.unwrap();
let result = buffer.lock().unwrap();
assert_eq!(&*result, b"Hello World!");
}
#[tokio::test]
async fn test_vectorized_single_byte_partial() {
let test_writer = TestWriter::new(TestWriterConfig::vectorized_partial(1));
let buffer = test_writer.buffer.clone();
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("ABCDE")))
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(5, 10), None, immediate_future(Bytes::from("FGHIJ")))
.await
.unwrap();
writer.finish().await.unwrap();
let result = buffer.lock().unwrap();
assert_eq!(&*result, b"ABCDEFGHIJ");
}
#[tokio::test]
async fn test_vectorized_large_data() {
let expected: Vec<u8> = (0..10000).map(|i| (i % 256) as u8).collect();
let test_writer = TestWriter::new(TestWriterConfig::vectorized());
let buffer = test_writer.buffer.clone();
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
for i in 0..10 {
let start = i * 1000;
let end = start + 1000;
let chunk: Vec<u8> = (start..end).map(|j| (j % 256) as u8).collect();
writer
.set_next_term_data_source(
FileRange::new(start as u64, end as u64),
None,
immediate_future(Bytes::from(chunk)),
)
.await
.unwrap();
}
writer.finish().await.unwrap();
let result = buffer.lock().unwrap();
assert_eq!(&*result, &expected);
}
#[tokio::test]
async fn test_vectorized_large_data_partial() {
let expected: Vec<u8> = (0..5000).map(|i| (i % 256) as u8).collect();
let test_writer = TestWriter::new(TestWriterConfig::vectorized_partial(100));
let buffer = test_writer.buffer.clone();
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
for i in 0..10 {
let start = i * 500;
let end = start + 500;
let chunk: Vec<u8> = (start..end).map(|j| (j % 256) as u8).collect();
writer
.set_next_term_data_source(
FileRange::new(start as u64, end as u64),
None,
immediate_future(Bytes::from(chunk)),
)
.await
.unwrap();
}
writer.finish().await.unwrap();
let result = buffer.lock().unwrap();
assert_eq!(&*result, &expected);
}
#[tokio::test]
async fn test_vectorized_exceeded_max_slice() {
let test_writer = TestWriter::new(TestWriterConfig::vectorized_hard_limit(2));
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
for i in 0..1000 {
let start = i * 10;
let end = start + 10;
let chunk: Vec<u8> = (start..end).map(|j| (j % 256) as u8).collect();
if writer
.set_next_term_data_source(
FileRange::new(start as u64, end as u64),
None,
immediate_future(Bytes::from(chunk)),
)
.await
.is_err()
{
break;
}
}
let ret = writer.finish().await;
assert!(ret.is_err());
if let Err(FileReconstructionError::IoError(inner_err)) = ret {
assert_eq!(inner_err.kind(), std::io::ErrorKind::InvalidInput);
};
}
#[tokio::test]
async fn test_vectorized_controlled_max_slice() {
let expected: Vec<u8> = (0..10000).map(|i| (i % 256) as u8).collect();
let test_writer = TestWriter::new(TestWriterConfig::vectorized_hard_limit(40)); let buffer = test_writer.buffer.clone();
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
for i in 0..1000 {
let start = i * 10;
let end = start + 10;
let chunk: Vec<u8> = (start..end).map(|j| (j % 256) as u8).collect();
writer
.set_next_term_data_source(
FileRange::new(start as u64, end as u64),
None,
immediate_future(Bytes::from(chunk)),
)
.await
.unwrap();
}
writer.finish().await.unwrap();
let result = buffer.lock().unwrap();
assert_eq!(&*result, &expected);
}
}