use bytes::Bytes;
use futures_core::Stream;
use ripcurl::protocol::{
DestinationWriter, ReadOffset, SourceProtocol, SourceReader, TransferError,
};
use std::collections::VecDeque;
use std::future::Future;
use std::time::Duration;
use url::Url;
pub struct MockSource {
results: VecDeque<MockReaderResult>,
pub get_reader_calls: Vec<(Url, u64)>,
}
pub enum MockReaderResult {
Ok {
offset: u64,
total_size: Option<u64>,
supports_random_access: bool,
chunks: Vec<Result<Bytes, TransferError>>,
},
Err(TransferError),
}
impl MockSource {
pub fn new(results: Vec<MockReaderResult>) -> Self {
Self {
results: VecDeque::from(results),
get_reader_calls: Vec::new(),
}
}
}
impl SourceProtocol for MockSource {
type Reader = MockReader;
fn get_reader(
&mut self,
url: Url,
start_byte_offset: u64,
) -> impl Future<Output = Result<(Self::Reader, ReadOffset), TransferError>> {
self.get_reader_calls.push((url, start_byte_offset));
let result = self
.results
.pop_front()
.expect("MockSource: no more scripted results for get_reader");
async move {
match result {
MockReaderResult::Ok {
offset,
total_size,
supports_random_access,
chunks,
} => Ok((
MockReader { chunks },
ReadOffset {
offset,
total_size,
supports_random_access,
},
)),
MockReaderResult::Err(e) => Err(e),
}
}
}
}
pub struct MockReader {
chunks: Vec<Result<Bytes, TransferError>>,
}
impl SourceReader for MockReader {
fn stream_bytes(self) -> impl Stream<Item = Result<Bytes, TransferError>> {
futures_util::stream::iter(self.chunks)
}
}
pub enum MockErrorKind {
Transient { reason: String },
Permanent { reason: String },
}
pub struct MockWriter {
pub written: Vec<u8>,
error_at: Option<(usize, MockErrorKind)>,
pub finalized: bool,
pub truncate_count: u32,
}
impl MockWriter {
pub fn new() -> Self {
Self {
written: Vec::new(),
error_at: None,
finalized: false,
truncate_count: 0,
}
}
pub fn fail_transiently_at(mut self, threshold: usize, reason: &str) -> Self {
self.error_at = Some((
threshold,
MockErrorKind::Transient {
reason: reason.into(),
},
));
self
}
pub fn fail_permanently_at(mut self, threshold: usize, reason: &str) -> Self {
self.error_at = Some((
threshold,
MockErrorKind::Permanent {
reason: reason.into(),
},
));
self
}
}
impl DestinationWriter for MockWriter {
async fn write(&mut self, bytes: &[u8]) -> Result<(), TransferError> {
if let Some((threshold, _)) = &self.error_at {
let threshold = *threshold;
let bytes_to_write = threshold.saturating_sub(self.written.len());
if bytes_to_write < bytes.len() {
self.written.extend_from_slice(&bytes[..bytes_to_write]);
let (_, kind) = self.error_at.take().unwrap();
let consumed = self.written.len() as u64;
return Err(match kind {
MockErrorKind::Transient { reason } => TransferError::Transient {
consumed_byte_count: consumed,
minimum_retry_delay: Duration::from_millis(1),
reason,
},
MockErrorKind::Permanent { reason } => TransferError::Permanent { reason },
});
}
}
self.written.extend_from_slice(bytes);
Ok(())
}
async fn finalize(mut self) -> Result<(), TransferError> {
self.finalized = true;
Ok(())
}
async fn truncate_and_reset(&mut self) -> Result<(), TransferError> {
self.written.clear();
self.truncate_count += 1;
Ok(())
}
}