#![deny(missing_docs, rustdoc::broken_intra_doc_links)]
use bytes::Bytes;
use futures::{future::Either, Future};
use std::io;
#[allow(clippy::len_without_is_empty)]
pub trait AsyncSliceReader {
type ReadAtFuture<'a>: Future<Output = io::Result<Bytes>> + 'a
where
Self: 'a;
#[must_use = "io futures must be polled to completion"]
fn read_at(&mut self, offset: u64, len: usize) -> Self::ReadAtFuture<'_>;
type LenFuture<'a>: Future<Output = io::Result<u64>> + 'a
where
Self: 'a;
#[must_use = "io futures must be polled to completion"]
fn len(&mut self) -> Self::LenFuture<'_>;
}
impl<'b, T: AsyncSliceReader> AsyncSliceReader for &'b mut T {
type ReadAtFuture<'a> = T::ReadAtFuture<'a> where T: 'a, 'b: 'a;
type LenFuture<'a> = T::LenFuture<'a> where T: 'a, 'b: 'a;
fn read_at(&mut self, offset: u64, len: usize) -> Self::ReadAtFuture<'_> {
(**self).read_at(offset, len)
}
fn len(&mut self) -> Self::LenFuture<'_> {
(**self).len()
}
}
impl<T: AsyncSliceReader> AsyncSliceReader for Box<T> {
type ReadAtFuture<'a> = T::ReadAtFuture<'a> where T: 'a;
type LenFuture<'a> = T::LenFuture<'a> where T: 'a;
fn read_at(&mut self, offset: u64, len: usize) -> Self::ReadAtFuture<'_> {
(**self).read_at(offset, len)
}
fn len(&mut self) -> Self::LenFuture<'_> {
(**self).len()
}
}
pub trait AsyncSliceReaderExt: AsyncSliceReader {
fn read_to_end(&mut self) -> Self::ReadAtFuture<'_> {
self.read_at(0, usize::MAX)
}
}
impl<T: AsyncSliceReader> AsyncSliceReaderExt for T {}
pub trait AsyncSliceWriter: Sized {
type WriteAtFuture<'a>: Future<Output = io::Result<()>> + 'a
where
Self: 'a;
#[must_use = "io futures must be polled to completion"]
fn write_at(&mut self, offset: u64, data: &[u8]) -> Self::WriteAtFuture<'_>;
type WriteBytesAtFuture<'a>: Future<Output = io::Result<()>> + 'a
where
Self: 'a;
#[must_use = "io futures must be polled to completion"]
fn write_bytes_at(&mut self, offset: u64, data: Bytes) -> Self::WriteBytesAtFuture<'_>;
type SetLenFuture<'a>: Future<Output = io::Result<()>> + 'a
where
Self: 'a;
#[must_use = "io futures must be polled to completion"]
fn set_len(&mut self, len: u64) -> Self::SetLenFuture<'_>;
type SyncFuture<'a>: Future<Output = io::Result<()>> + 'a
where
Self: 'a;
#[must_use = "io futures must be polled to completion"]
fn sync(&mut self) -> Self::SyncFuture<'_>;
}
impl<'b, T: AsyncSliceWriter> AsyncSliceWriter for &'b mut T {
type WriteAtFuture<'a> = T::WriteAtFuture<'a> where T: 'a, 'b: 'a;
type WriteBytesAtFuture<'a> = T::WriteBytesAtFuture<'a> where T: 'a, 'b: 'a;
type SetLenFuture<'a> = T::SetLenFuture<'a> where T: 'a, 'b: 'a;
type SyncFuture<'a> = T::SyncFuture<'a> where T: 'a, 'b: 'a;
fn write_at(&mut self, offset: u64, data: &[u8]) -> Self::WriteAtFuture<'_> {
(**self).write_at(offset, data)
}
fn write_bytes_at(&mut self, offset: u64, data: Bytes) -> Self::WriteBytesAtFuture<'_> {
(**self).write_bytes_at(offset, data)
}
fn set_len(&mut self, len: u64) -> Self::SetLenFuture<'_> {
(**self).set_len(len)
}
fn sync(&mut self) -> Self::SyncFuture<'_> {
(**self).sync()
}
}
impl<T: AsyncSliceWriter> AsyncSliceWriter for Box<T> {
type WriteAtFuture<'a> = T::WriteAtFuture<'a> where T: 'a;
type WriteBytesAtFuture<'a> = T::WriteBytesAtFuture<'a> where T: 'a;
type SetLenFuture<'a> = T::SetLenFuture<'a> where T: 'a;
type SyncFuture<'a> = T::SyncFuture<'a> where T: 'a;
fn write_at(&mut self, offset: u64, data: &[u8]) -> Self::WriteAtFuture<'_> {
(**self).write_at(offset, data)
}
fn write_bytes_at(&mut self, offset: u64, data: Bytes) -> Self::WriteBytesAtFuture<'_> {
(**self).write_bytes_at(offset, data)
}
fn set_len(&mut self, len: u64) -> Self::SetLenFuture<'_> {
(**self).set_len(len)
}
fn sync(&mut self) -> Self::SyncFuture<'_> {
(**self).sync()
}
}
#[cfg(any(feature = "tokio-io", feature = "http"))]
macro_rules! newtype_future {
($(#[$outer:meta])* $name:ident, $inner:ty, $output:ty) => {
#[repr(transparent)]
#[pin_project::pin_project]
#[must_use]
$(#[$outer])*
pub struct $name<'a>(#[pin] $inner);
impl<'a> futures::Future for $name<'a> {
type Output = $output;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
self.project().0.poll(cx)
}
}
};
}
#[cfg(feature = "tokio-io")]
mod tokio_io;
#[cfg(feature = "tokio-io")]
pub use tokio_io::*;
#[cfg(feature = "http")]
mod http;
#[cfg(feature = "http")]
pub use http::*;
mod mem;
impl<L, R> AsyncSliceReader for futures::future::Either<L, R>
where
L: AsyncSliceReader + 'static,
R: AsyncSliceReader + 'static,
{
type ReadAtFuture<'a> = Either<L::ReadAtFuture<'a>, R::ReadAtFuture<'a>>;
fn read_at(&mut self, offset: u64, len: usize) -> Self::ReadAtFuture<'_> {
match self {
futures::future::Either::Left(l) => Either::Left(l.read_at(offset, len)),
futures::future::Either::Right(r) => Either::Right(r.read_at(offset, len)),
}
}
type LenFuture<'a> = Either<L::LenFuture<'a>, R::LenFuture<'a>>;
fn len(&mut self) -> Self::LenFuture<'_> {
match self {
futures::future::Either::Left(l) => Either::Left(l.len()),
futures::future::Either::Right(r) => Either::Right(r.len()),
}
}
}
impl<L, R> AsyncSliceWriter for futures::future::Either<L, R>
where
L: AsyncSliceWriter + 'static,
R: AsyncSliceWriter + 'static,
{
type WriteBytesAtFuture<'a> = Either<L::WriteBytesAtFuture<'a>, R::WriteBytesAtFuture<'a>>;
fn write_bytes_at(&mut self, offset: u64, data: Bytes) -> Self::WriteBytesAtFuture<'_> {
match self {
futures::future::Either::Left(l) => Either::Left(l.write_bytes_at(offset, data)),
futures::future::Either::Right(r) => Either::Right(r.write_bytes_at(offset, data)),
}
}
type WriteAtFuture<'a> = Either<L::WriteAtFuture<'a>, R::WriteAtFuture<'a>>;
fn write_at(&mut self, offset: u64, data: &[u8]) -> Self::WriteAtFuture<'_> {
match self {
futures::future::Either::Left(l) => Either::Left(l.write_at(offset, data)),
futures::future::Either::Right(r) => Either::Right(r.write_at(offset, data)),
}
}
type SyncFuture<'a> = Either<L::SyncFuture<'a>, R::SyncFuture<'a>>;
fn sync(&mut self) -> Self::SyncFuture<'_> {
match self {
futures::future::Either::Left(l) => Either::Left(l.sync()),
futures::future::Either::Right(r) => Either::Right(r.sync()),
}
}
type SetLenFuture<'a> = Either<L::SetLenFuture<'a>, R::SetLenFuture<'a>>;
fn set_len(&mut self, len: u64) -> Self::SetLenFuture<'_> {
match self {
futures::future::Either::Left(l) => Either::Left(l.set_len(len)),
futures::future::Either::Right(r) => Either::Right(r.set_len(len)),
}
}
}
#[cfg(feature = "tokio-util")]
impl<L, R> AsyncSliceReader for tokio_util::either::Either<L, R>
where
L: AsyncSliceReader + 'static,
R: AsyncSliceReader + 'static,
{
type ReadAtFuture<'a> = Either<L::ReadAtFuture<'a>, R::ReadAtFuture<'a>>;
fn read_at(&mut self, offset: u64, len: usize) -> Self::ReadAtFuture<'_> {
match self {
tokio_util::either::Either::Left(l) => Either::Left(l.read_at(offset, len)),
tokio_util::either::Either::Right(r) => Either::Right(r.read_at(offset, len)),
}
}
type LenFuture<'a> = Either<L::LenFuture<'a>, R::LenFuture<'a>>;
fn len(&mut self) -> Self::LenFuture<'_> {
match self {
tokio_util::either::Either::Left(l) => Either::Left(l.len()),
tokio_util::either::Either::Right(r) => Either::Right(r.len()),
}
}
}
#[cfg(feature = "tokio-util")]
impl<L, R> AsyncSliceWriter for tokio_util::either::Either<L, R>
where
L: AsyncSliceWriter + 'static,
R: AsyncSliceWriter + 'static,
{
type WriteBytesAtFuture<'a> = Either<L::WriteBytesAtFuture<'a>, R::WriteBytesAtFuture<'a>>;
fn write_bytes_at(&mut self, offset: u64, data: Bytes) -> Self::WriteBytesAtFuture<'_> {
match self {
tokio_util::either::Either::Left(l) => Either::Left(l.write_bytes_at(offset, data)),
tokio_util::either::Either::Right(r) => Either::Right(r.write_bytes_at(offset, data)),
}
}
type WriteAtFuture<'a> = Either<L::WriteAtFuture<'a>, R::WriteAtFuture<'a>>;
fn write_at(&mut self, offset: u64, data: &[u8]) -> Self::WriteAtFuture<'_> {
match self {
tokio_util::either::Either::Left(l) => Either::Left(l.write_at(offset, data)),
tokio_util::either::Either::Right(r) => Either::Right(r.write_at(offset, data)),
}
}
type SyncFuture<'a> = Either<L::SyncFuture<'a>, R::SyncFuture<'a>>;
fn sync(&mut self) -> Self::SyncFuture<'_> {
match self {
tokio_util::either::Either::Left(l) => Either::Left(l.sync()),
tokio_util::either::Either::Right(r) => Either::Right(r.sync()),
}
}
type SetLenFuture<'a> = Either<L::SetLenFuture<'a>, R::SetLenFuture<'a>>;
fn set_len(&mut self, len: u64) -> Self::SetLenFuture<'_> {
match self {
tokio_util::either::Either::Left(l) => Either::Left(l.set_len(len)),
tokio_util::either::Either::Right(r) => Either::Right(r.set_len(len)),
}
}
}
#[cfg(any(feature = "tokio-io", feature = "http"))]
fn make_io_error<E>(e: E) -> io::Error
where
E: Into<Box<dyn std::error::Error + Send + Sync>>,
{
io::Error::new(io::ErrorKind::Other, e)
}
#[cfg(test)]
mod tests {
use crate::mem::limited_range;
use super::*;
use bytes::BytesMut;
use proptest::prelude::*;
use std::io;
#[cfg(feature = "tokio-io")]
use std::io::Write;
#[cfg(feature = "http")]
mod test_server {
use super::*;
use axum::{routing::get, Extension, Router};
use hyper::{Body, Request, Response, StatusCode};
use std::{net::SocketAddr, ops::Range, sync::Arc};
pub fn serve(data: Vec<u8>) -> (SocketAddr, impl Future<Output = hyper::Result<()>>) {
let app = Router::new()
.route("/", get(handler))
.layer(Extension(Arc::new(data)));
let addr: SocketAddr = SocketAddr::from(([0, 0, 0, 0], 0));
let fut = axum::Server::bind(&addr).serve(app.into_make_service());
(fut.local_addr(), fut)
}
async fn handler(state: Extension<Arc<Vec<u8>>>, req: Request<Body>) -> Response<Body> {
let data = state.0.as_ref();
if let Some(range_header) = req.headers().get("Range") {
if let Ok(range) = parse_range_header(range_header.to_str().unwrap()) {
let start = range.start;
let end = range.end.min(data.len());
let sliced_data = &data[start..end];
return Response::builder()
.status(StatusCode::PARTIAL_CONTENT)
.header("Content-Type", "application/octet-stream")
.header("Content-Length", sliced_data.len())
.header(
"Content-Range",
format!("bytes {}-{}/{}", start, end - 1, data.len()),
)
.body(Body::from(sliced_data.to_vec()))
.unwrap();
}
}
Response::new(data.to_owned().into())
}
fn parse_range_header(
header_value: &str,
) -> std::result::Result<Range<usize>, &'static str> {
let prefix = "bytes=";
if header_value.starts_with(prefix) {
let range_str = header_value.strip_prefix(prefix).unwrap();
if let Some(index) = range_str.find('-') {
let start = range_str[..index]
.parse()
.map_err(|_| "Failed to parse range start")?;
let end: usize = range_str[index + 1..]
.parse()
.map_err(|_| "Failed to parse range end")?;
return Ok(start..end + 1);
}
}
Err("Invalid Range header format")
}
}
async fn read_mut_smoke(mut file: impl AsyncSliceReader) -> io::Result<()> {
let expected = (0..100u8).collect::<Vec<_>>();
let res = file.read_at(0, usize::MAX).await?;
assert_eq!(res, expected);
let res = file.len().await?;
assert_eq!(res, 100);
let res = file.read_at(0, 3).await?;
assert_eq!(res, vec![0, 1, 2]);
let res = file.read_at(95, 10).await?;
assert_eq!(res, vec![95, 96, 97, 98, 99]);
let res = file.read_at(110, 10).await?;
assert_eq!(res, vec![]);
Ok(())
}
async fn write_mut_smoke<F: AsyncSliceWriter, C: Fn(&F) -> Vec<u8>>(
mut file: F,
contents: C,
) -> io::Result<()> {
file.write_bytes_at(0, vec![0, 1, 2].into()).await?;
assert_eq!(contents(&file), &[0, 1, 2]);
file.write_bytes_at(5, vec![0, 1, 2].into()).await?;
assert_eq!(contents(&file), &[0, 1, 2, 0, 0, 0, 1, 2]);
file.write_at(8, &1u16.to_le_bytes()).await?;
assert_eq!(contents(&file), &[0, 1, 2, 0, 0, 0, 1, 2, 1, 0]);
file.set_len(0).await?;
assert_eq!(contents(&file).len(), 0);
Ok(())
}
#[cfg(feature = "tokio-io")]
#[tokio::test]
async fn file_reading_smoke() -> io::Result<()> {
let mut file = tempfile::tempfile().unwrap();
file.write_all(&(0..100u8).collect::<Vec<_>>()).unwrap();
read_mut_smoke(File::from_std(file)).await?;
Ok(())
}
#[tokio::test]
async fn bytes_reading_smoke() -> io::Result<()> {
let bytes: Bytes = (0..100u8).collect::<Vec<_>>().into();
read_mut_smoke(bytes).await?;
Ok(())
}
#[tokio::test]
async fn bytes_mut_reading_smoke() -> io::Result<()> {
let mut bytes: BytesMut = BytesMut::new();
bytes.extend(0..100u8);
read_mut_smoke(bytes).await?;
Ok(())
}
fn bytes_mut_contents(bytes: &BytesMut) -> Vec<u8> {
bytes.to_vec()
}
#[cfg(feature = "tokio-io")]
#[tokio::test]
async fn async_slice_writer_smoke() -> io::Result<()> {
let file = tempfile::tempfile().unwrap();
write_mut_smoke(File::from_std(file), |x| x.read_contents()).await?;
Ok(())
}
#[tokio::test]
async fn bytes_mut_writing_smoke() -> io::Result<()> {
let bytes: BytesMut = BytesMut::new();
write_mut_smoke(bytes, |x| x.as_ref().to_vec()).await?;
Ok(())
}
fn random_slice(offset: u64, size: usize) -> impl Strategy<Value = (u64, Vec<u8>)> {
(0..offset, 0..size).prop_map(|(offset, size)| {
let data = (0..size).map(|x| x as u8).collect::<Vec<_>>();
(offset, data)
})
}
fn random_write_op(offset: u64, size: usize) -> impl Strategy<Value = WriteOp> {
prop_oneof![
20 => random_slice(offset, size).prop_map(|(offset, data)| WriteOp::Write(offset, data)),
1 => (0..(offset + size as u64)).prop_map(WriteOp::SetLen),
1 => Just(WriteOp::Sync),
]
}
fn random_write_ops(offset: u64, size: usize, n: usize) -> impl Strategy<Value = Vec<WriteOp>> {
prop::collection::vec(random_write_op(offset, size), n)
}
fn random_read_ops(offset: u64, size: usize, n: usize) -> impl Strategy<Value = Vec<ReadOp>> {
prop::collection::vec(random_read_op(offset, size), n)
}
fn sequential_offset(mag: usize) -> impl Strategy<Value = isize> {
prop_oneof![
20 => Just(0),
1 => (0..mag).prop_map(|x| x as isize),
1 => (0..mag).prop_map(|x| -(x as isize)),
]
}
fn random_read_op(offset: u64, size: usize) -> impl Strategy<Value = ReadOp> {
prop_oneof![
20 => (0..offset, 0..size).prop_map(|(offset, len)| ReadOp::ReadAt(offset, len)),
1 => (sequential_offset(1024), 0..size).prop_map(|(offset, len)| ReadOp::ReadSequential(offset, len)),
1 => Just(ReadOp::Len),
]
}
#[derive(Debug, Clone)]
enum ReadOp {
ReadAt(u64, usize),
ReadSequential(isize, usize),
Len,
}
#[derive(Debug, Clone)]
enum WriteOp {
Write(u64, Vec<u8>),
SetLen(u64),
Sync,
}
fn apply_op(file: &mut Vec<u8>, op: &WriteOp) {
match op {
WriteOp::Write(offset, data) => {
if data.is_empty() {
return;
}
let end = offset.saturating_add(data.len() as u64);
let start = usize::try_from(*offset).unwrap();
let end = usize::try_from(end).unwrap();
if end > file.len() {
file.resize(end, 0);
}
file[start..end].copy_from_slice(data);
}
WriteOp::SetLen(offset) => {
let offset = usize::try_from(*offset).unwrap_or(usize::MAX);
file.resize(offset, 0);
}
WriteOp::Sync => {}
}
}
fn async_test<F: Future>(f: F) -> F::Output {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(f)
}
async fn write_op_test<W: AsyncSliceWriter, C: Fn(&W) -> Vec<u8>>(
ops: Vec<WriteOp>,
mut bytes: W,
content: C,
) -> io::Result<()> {
let mut reference = Vec::new();
for op in ops {
apply_op(&mut reference, &op);
match op {
WriteOp::Write(offset, data) => {
AsyncSliceWriter::write_bytes_at(&mut bytes, offset, data.into()).await?;
}
WriteOp::SetLen(offset) => {
AsyncSliceWriter::set_len(&mut bytes, offset).await?;
}
WriteOp::Sync => {
AsyncSliceWriter::sync(&mut bytes).await?;
}
}
assert_eq!(content(&bytes), reference.as_slice());
}
io::Result::Ok(())
}
async fn read_op_test<R: AsyncSliceReader>(
ops: Vec<ReadOp>,
mut file: R,
actual: &[u8],
) -> io::Result<()> {
let mut current = 0u64;
for op in ops {
match op {
ReadOp::ReadAt(offset, len) => {
let data = AsyncSliceReader::read_at(&mut file, offset, len).await?;
assert_eq!(&data, &actual[limited_range(offset, len, actual.len())]);
current = offset.checked_add(len as u64).unwrap();
}
ReadOp::ReadSequential(offset, len) => {
let offset = if offset >= 0 {
current.saturating_add(offset as u64)
} else {
current.saturating_sub((-offset) as u64)
};
let data = AsyncSliceReader::read_at(&mut file, offset, len).await?;
assert_eq!(&data, &actual[limited_range(offset, len, actual.len())]);
current = offset.checked_add(len as u64).unwrap();
}
ReadOp::Len => {
let len = AsyncSliceReader::len(&mut file).await?;
assert_eq!(len, actual.len() as u64);
}
}
}
io::Result::Ok(())
}
#[cfg(feature = "http")]
#[tokio::test]
#[cfg_attr(target_os = "windows", ignore)]
async fn http_smoke() {
let (addr, server) = test_server::serve(b"hello world".to_vec());
let url = format!("http://{}", addr);
println!("serving from {}", url);
let url = reqwest::Url::parse(&url).unwrap();
let server = tokio::spawn(server);
let mut reader = HttpAdapter::new(url).await.unwrap();
let len = reader.len().await.unwrap();
assert_eq!(len, 11);
println!("len: {:?}", reader);
let part = reader.read_at(0, 11).await.unwrap();
assert_eq!(part.as_ref(), b"hello world");
let part = reader.read_at(6, 5).await.unwrap();
assert_eq!(part.as_ref(), b"world");
let part = reader.read_at(6, 10).await.unwrap();
assert_eq!(part.as_ref(), b"world");
let part = reader.read_at(100, 10).await.unwrap();
assert_eq!(part.as_ref(), b"");
server.abort();
}
proptest! {
#[test]
fn bytes_write(ops in random_write_ops(1024, 1024, 10)) {
async_test(write_op_test(ops, BytesMut::new(), bytes_mut_contents)).unwrap();
}
#[cfg(feature = "tokio-io")]
#[test]
fn file_write(ops in random_write_ops(1024, 1024, 10)) {
let file = tempfile::tempfile().unwrap();
async_test(write_op_test(ops, File::from_std(file), |x| x.read_contents())).unwrap();
}
#[test]
fn bytes_read(data in proptest::collection::vec(any::<u8>(), 0..1024), ops in random_read_ops(1024, 1024, 2)) {
async_test(read_op_test(ops, Bytes::from(data.clone()), &data)).unwrap();
}
#[cfg(feature = "tokio-io")]
#[test]
fn file_read(data in proptest::collection::vec(any::<u8>(), 0..1024), ops in random_read_ops(1024, 1024, 2)) {
let mut file = tempfile::tempfile().unwrap();
file.write_all(&data).unwrap();
async_test(read_op_test(ops, File::from_std(file), &data)).unwrap();
}
#[cfg(feature = "http")]
#[cfg_attr(target_os = "windows", ignore)]
#[test]
fn http_read(data in proptest::collection::vec(any::<u8>(), 0..10), ops in random_read_ops(10, 10, 2)) {
async_test(async move {
let (addr, server) = test_server::serve(data.clone());
let server = tokio::spawn(server);
let url = reqwest::Url::parse(&format!("http://{}", addr)).unwrap();
let file = HttpAdapter::new(url).await.unwrap();
read_op_test(ops, file, &data).await.unwrap();
server.abort();
});
}
}
}