//! This library provides the [`RequestFile`] type that
//! provides an asynchronous file-like interface to a web resource.
//!
//! # Examples
//!
//! ```
//! # tokio_test::block_on(async {
//! use reqwest_file::RequestFile;
//! use tokio::io::{AsyncReadExt, AsyncSeekExt};
//!
//! let client = reqwest::Client::new();
//! let request = client.get("http://httpbin.org/base64/aGVsbG8gd29ybGQ=");
//! let mut file = RequestFile::new(request);
//!
//! let mut buffer = [0; 5];
//! assert_eq!(file.read(&mut buffer).await.unwrap(), 5);
//! assert_eq!(&buffer, b"hello");
//!
//! //let mut buffer = [0; 5];
//! //assert_eq!(file.seek(std::io::SeekFrom::Current(1)).await.unwrap(), 6);
//! //assert_eq!(file.read(&mut buffer).await.unwrap(), 5);
//! //assert_eq!(&buffer, b"world");
//! # })
//! ```
#![feature(generic_associated_types, type_alias_impl_trait, mixed_integer_ops, io_error_more)]
use std::pin::Pin;
use std::io::{SeekFrom, Error as IoError, ErrorKind};
use std::task::{Poll, Context};
use std::future::Future;
use bytes::Bytes;
use pin_project::pin_project;
use futures_util::{FutureExt, Stream, StreamExt};
use tokio::io::{ReadBuf, AsyncRead, AsyncSeek};
use tokio_util::io::StreamReader;
use reqwest::{RequestBuilder, StatusCode};
fn to_io_error(e: impl std::error::Error + Send + Sync + 'static) -> IoError {
IoError::new(ErrorKind::Other, e)
}
#[derive(Debug)]
struct HttpResponseStatusError(StatusCode);
impl std::fmt::Display for HttpResponseStatusError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Http status error: {}", self.0)
}
}
impl std::error::Error for HttpResponseStatusError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
None
}
}
type RequestStream = impl Stream<Item=Result<Bytes, IoError>>;
type RequestStreamFuture = impl Future<Output=Result<(Option<u64>, RequestStream), IoError>>;
/// Send a request with a `Range` header,
/// returning a future for the response stream.
fn send_request(request: &RequestBuilder, offset: u64) -> RequestStreamFuture {
let request = request.try_clone().expect("request contains streaming body");
request
.header(reqwest::header::RANGE, format!("bytes={offset}-"))
.send()
.map(move |result| result
.and_then(|response| response.error_for_status())
.map_err(to_io_error)
.and_then(|response| {
if response.status() == StatusCode::OK {
if offset != 0 {
return Err(ErrorKind::NotSeekable.into())
}
} else if response.status() != StatusCode::PARTIAL_CONTENT {
return Err(to_io_error(HttpResponseStatusError(response.status())))
}
let size = response.headers().get(reqwest::header::CONTENT_LENGTH);
let size = size.and_then(|size|
size.to_str().ok().and_then(|size|
size.parse::<u64>().ok()));
let stream = response.bytes_stream().map(|result|
result.map_err(to_io_error));
Ok((size, stream))
})
)
}
/*
macro_rules! take_if {
(
$pattern:pat => $value:expr -> $replace:expr {
$($body:tt)*
} else {
$($alt:tt)*
}
) => {
if let Some($pattern) = $value {
if let Some($pattern) = std::mem::replace($value, $replace) {
$($body)*
} else {
unreachable!()
}
} else {
$($alt)*
}
}
}
*/
const FF_SEEK_MAX_DELTA: usize = 128 * 1024;
const FF_SEEK_BUFFER_SIZE: usize = 4096;
/// State of the request file.
enum State<P, R> {
/// The initial state, before a request is sent.
Initial,
/// The request is being sent.
Pending(Pin<Box<P>>),
/// The request is finished, the response stream is ready.
Ready(Pin<Box<R>>),
/// The response body is seeking forward.
Seeking(Pin<Box<R>>, u64),
Transient,
}
/// Type that provides an asynchronous file-like interface to a web resource.
///
/// This type implements [`AsyncRead`] and [`AsyncSeek`],
/// so it can be used like an asynchronous file.
///
/// Seeking is implemented as sending out a new request
/// with a `Range` header.
/// All http requests made by this type include this header.
///
/// If the webserver does not support range requests,
/// seeking to anything other than the start of the file
/// will return a [`NotSeekable`] error.
///
/// If the webserver does not provide the `Content-Length` header,
/// and no size was given to the [`RequestFile`] constructor,
/// then seeking relative to the file end
/// will return an [`Unsupported`] error.
///
/// If the http request fails during a `read` or `seek`
/// operation, it returns an [`Other`] error that wraps the
/// original http error.
/// If the webserver returns anything other than status code
/// `206 Partial Content`, or `200 Ok` for responses
/// starting at the first byte, that is also considered a failure.
///
/// For transient errors that require a new HTTP request,
/// the [`reset()`](RequestFile::reset) method can be used.
///
/// # Assumptions
///
/// This type assumes that the http resource
/// is of constant size, and thus that the separate requests
/// performed while seeking are all consistent.
///
/// # Reading
///
/// Reads are implemented by wrapping the response body
/// stream in [`StreamReader`].
///
/// # Seeking
///
/// Seeking a position before the first byte will
/// return an [`InvalidInput`] error.
///
/// This type performs no special handling of seeking
/// beyond the file EOF, so it what happens in this case
/// depends on the webserver.
///
/// [`NotSeekable`]: std::io::ErrorKind::NotSeekable
/// [`Unsupported`]: std::io::ErrorKind::Unsupported
/// [`InvalidInput`]: std::io::ErrorKind::InvalidInput
/// [`Other`]: std::io::ErrorKind::Other
#[pin_project(project = RequestFileProjection)]
pub struct RequestFile {
request: RequestBuilder,
state: State<RequestStreamFuture, StreamReader<RequestStream, Bytes>>,
/// Track the size of the response body.
size: Option<u64>,
/// Track the current position in the response body.
position: u64,
}
impl RequestFile {
/// Create a new file-like object for a web resource.
///
/// # Panics
///
/// This function will panic if the request:
/// * already includes a `Range` header
/// * contains a streaming body
/// * building the request fails.
pub fn new(request: RequestBuilder) -> Self {
Self::with_size(request, None)
}
/// Create a new file-like object for a web resource.
///
/// This function allows you to set the response body size
/// if you happen to know it.
/// If the response specifies the size using the `Content-Length`
/// header, the value given here is overwritten.
///
/// Tip: To limit the response size, use
/// [`tokio::io::AsyncReadExt::take()`].
///
/// # Panics
///
/// This function will panic if the request:
/// * already includes a `Range` header
/// * contains a streaming body
/// * building the request fails.
pub fn with_size(request: RequestBuilder, size: impl Into<Option<u64>>) -> Self {
request
.try_clone().expect("request contains streaming body")
.build().expect("invalid request")
.headers().contains_key(reqwest::header::RANGE).then(||
panic!("request already has range header set"));
let stream_future = Box::pin(send_request(&request, 0));
Self {
request,
state: State::Pending(stream_future),
size: size.into(),
position: 0,
}
}
/// Force a new HTTP request to begin,
/// without changing the current position.
///
/// This method can be used if the request is broken
/// in a way that can be fixed by restarting it,
/// for example due to a transient network issue.
pub fn reset(&mut self) {
let stream_future = Box::pin(send_request(&self.request, self.position));
self.state = State::Pending(stream_future);
}
}
impl RequestFileProjection<'_> {
/// Drive the state from `State::Initial` to to `State::Pending`.
fn drive_initial(&mut self) {
if let State::Initial = self.state {
let stream_future = Box::pin(send_request(self.request, self.position));
*self.state = State::Pending(stream_future);
}
}
/// Drive the state from `State::Pending` to to `State::Ready`.
fn poll_drive_pending(
&mut self,
context: &mut Context<'_>,
) -> Poll<Result<(), IoError>> {
if let State::Pending(future) = self.state {
future.as_mut().poll(context).map_ok(move |(size, stream)| {
*self.size = self.size.or(size);
*self.state = State::Ready(Box::pin(StreamReader::new(stream)));
()
})
} else {
Poll::Ready(Ok(()))
}
}
/// Drive the state from `State::Seeking` to to `State::Ready`.
fn poll_drive_seeking(
&mut self,
context: &mut Context<'_>,
) -> Poll<Result<(), IoError>> {
match std::mem::replace(self.state, State::Transient) {
State::Seeking(mut reader, mut delta) => {
let mut array = [std::mem::MaybeUninit::uninit(); FF_SEEK_BUFFER_SIZE];
let mut buffer = ReadBuf::uninit(&mut array[0..(delta as usize)]);
loop {
assert!(delta > 0);
buffer.clear();
match reader.as_mut().poll_read(context, &mut buffer) {
Poll::Ready(Ok(())) => {
let read = buffer.filled().len() as u64;
delta = delta.checked_sub(read).unwrap();
*self.position = self.position.saturating_add(read);
if delta == 0 {
*self.state = State::Ready(reader);
return Poll::Ready(Ok(()))
} else {
continue
}
},
other => {
*self.state = State::Seeking(reader, delta);
return other
},
}
}
}
state => {
*self.state = state;
Poll::Ready(Ok(()))
},
}
}
fn poll_reader<'a>(
&'a mut self,
context: &mut Context<'_>,
) -> Poll<Result<Pin<&'a mut StreamReader<RequestStream, Bytes>>, IoError>> {
self.drive_initial();
match self.poll_drive_pending(context) {
Poll::Ready(Ok(())) => (),
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
};
match self.poll_drive_seeking(context) {
Poll::Ready(Ok(())) => (),
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
};
match self.state {
State::Ready(reader) => Poll::Ready(Ok(reader.as_mut())),
// _ => unreachable!(),
ref state => { // TODO rad
let statename = match state {
State::Pending(_) => "Pending",
State::Ready(_) => "Ready",
State::Seeking(_, _) => "Seeking",
State::Transient => "Transient",
};
unreachable!("unreachable state: {statename}")
},
}
/*
let stream_future = match self.state {
State::Pending(future) => future,
State::Ready(reader) => return Poll::Ready(Ok(reader.as_mut())),
State::Seeking(reader, delta) => {
assert!(delta > 0);
let array = [0; SEEK_READ_BUFFER_SIZE];
let buffer = ReadBuf::new(&mut array[0..delta]);
match reader.as_mut().poll_read(context, &mut buffer) {
Poll::Ready(Ok(())) => {
let read = buffer.filled().len() as u64;
let delta = delta.checked_sub(read).unwrap();
self.position.saturating_add(read);
if delta == 0 {
self.state = State::Ready(reader);
reader
} else {
self.state = State::Seeking(reader, delta);
}
},
Poll::Ready(Err(e)) => return Poll::Read(Err(e)),
Poll::Pending => return Poll::Pending,
}
},
};
stream_future.as_mut().poll(context)
.map_ok(move |(size, stream)| {
*self.size = self.size.or(size);
*self.state = State::Ready(Box::pin(StreamReader::new(stream)));
match self.state {
State::Ready(reader) => reader.as_mut(),
_ => unreachable!(),
}
})
*/
}
}
impl AsyncRead for RequestFile {
fn poll_read(
self: Pin<&mut Self>,
context: &mut Context<'_>,
buffer: &mut ReadBuf<'_>
) -> Poll<Result<(), IoError>> {
let mut this = self.project();
let reader = match this.poll_reader(context) {
Poll::Ready(Ok(reader)) => reader,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
};
let initial_size = buffer.filled().len();
reader.poll_read(context, buffer).map_ok(|_| {
let delta = buffer.filled().len().checked_sub(initial_size).unwrap();
*this.position += this.position.saturating_add(delta as u64);
/*
if delta == 0 {
// The end of the file has been reached,
// so we can update our knowledge of the size.
// In case a seek just set the position beyond
// the actual size, the current position may
// actually be larger than the true size,
// so the size should not be increased.
if let Some(size) = this.size {
*this.size = Some(*this.position.min(size));
} else {
*this.size = Some(*this.position);
}
}
*/
})
}
}
impl AsyncSeek for RequestFile {
fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> Result<(), IoError> {
let this = self.project();
let initial_position = *this.position;
match position {
SeekFrom::Start(position) => *this.position = position,
SeekFrom::End(delta) => if let Some(size) = this.size {
if let Some(position) = size.checked_add_signed(delta) {
*this.position = position;
} else if delta > 0 {
// seek overflow; wrap to maximum
*this.position = u64::MAX;
} else {
// seek to negative position
return Err(ErrorKind::InvalidInput.into())
}
} else {
// size not known
return Err(ErrorKind::Unsupported.into())
},
SeekFrom::Current(delta) => {
if let Some(position) = this.position.checked_add_signed(delta) {
*this.position = position;
} else if delta > 0 {
// seek overflow; wrap to maximum
*this.position = u64::MAX;
} else {
// seek to negative position
return Err(ErrorKind::InvalidInput.into())
}
},
};
if initial_position != *this.position {
let delta_forward = this.position.saturating_sub(initial_position);
if delta_forward <= FF_SEEK_MAX_DELTA {
// seeking forwards by a small leap
match this.state {
State::Pending(_) => {
// TODO
let stream_future = Box::pin(send_request(this.request, *this.position));
*this.state = State::Pending(stream_future);
},
State::Ready(reader) => {
*this.state = State::Seeking(reader, delta_forward);
},
State::Seeking(reader, delta) => {
// Don't perform a fast-forward seek twice consecutively,
// assuming there might be another.
let stream_future = Box::pin(send_request(this.request, *this.position));
*this.state = State::Pending(stream_future);
},
State::Transient => unreachable!(),
}
let stream_future = Box::pin(send_request(this.request, *this.position));
*this.state = State::Pending(stream_future);
} else {
// seeking backwards or a large leap forwards
let stream_future = Box::pin(send_request(this.request, *this.position));
*this.state = State::Pending(stream_future);
}
}
Ok(())
}
fn poll_complete(
self: Pin<&mut Self>,
context: &mut Context<'_>
) -> Poll<Result<u64, IoError>> {
let mut this = self.project();
let position = *this.position;
this.poll_reader(context).map_ok(|_| position)
}
}
#[cfg(test)]
mod tests {
use std::io::SeekFrom;
use tokio::io::{AsyncReadExt, AsyncSeekExt};
use super::RequestFile;
#[derive(serde::Deserialize)]
pub struct QueryParams {
data: String,
content_length: Option<u64>,
}
async fn index(
range_header: axum::extract::TypedHeader<axum::headers::Range>,
axum::extract::Query(query): axum::extract::Query<QueryParams>,
) -> (axum::http::StatusCode, String) {
dbg!(&range_header);
let range = range_header.iter().next().expect("missing range");
if let std::ops::Bound::Included(offset) = range.0 {
dbg!(&offset);
let response = query.data.chars().skip(offset as usize).collect();
(axum::http::StatusCode::PARTIAL_CONTENT, response)
} else {
panic!("missing range start") // TODO
}
}
fn start_server() -> String {
let app = axum::Router::new().route("/", axum::routing::get(index));
let server = axum::Server::bind(&"0.0.0.0:0".parse().unwrap())
.serve(app.into_make_service());
let address = server.local_addr();
tokio::spawn(server);
format!("http://{address}/")
}
macro_rules! test {
(
$name:ident
[ $( SeekFrom::$seek_from:ident($offset:literal) => $tell:literal );* $(;)? ]
//$( Content-Length = $size:literal )?
$data:literal => $result:literal
) => {
#[tokio::test]
async fn $name() {
let url = start_server();
let data = $data;
let client = reqwest::Client::new();
let size = "";
//$( let size = format!("&size={}", $size); )?
let request = client.get(format!("{url}?data={data}{size}"));
let mut file = RequestFile::new(request);
let mut response_data = String::new();
$(
let pos = file.seek(SeekFrom::$seek_from($offset)).await.unwrap();
assert_eq!(pos, $tell);
)*
file.read_to_string(&mut response_data).await.unwrap();
assert_eq!(response_data, $result);
}
}
}
/*
#[tokio::test]
async fn test_seek() {
use std::io::Seek;
let mut f = std::fs::File::open("/tmp/test.txt").unwrap();
let pos = f.seek(SeekFrom::Start(999)).unwrap();
dbg!(pos);
todo!()
}
*/
test! {
test_read
[]
"abc" => "abc"
}
test! {
test_from_start_first
[ SeekFrom::Start(0) => 0 ]
"abc" => "abc"
}
test! {
test_from_start_middle
[ SeekFrom::Start(2) => 2 ]
"abcd" => "cd"
}
test! {
test_from_start_last
[ SeekFrom::Start(4) => 4 ]
"abcd" => ""
}
test! {
test_from_start_beyond
[ SeekFrom::Start(9) => 9 ]
"abcd" => ""
}
/*
test! {
test_from_current_before
[ SeekFrom::Start(2); SeekFrom::Current(-4) ]
"abcd" => Err
}
*/
test! {
test_from_current_first
[ SeekFrom::Start(2) => 2; SeekFrom::Current(-2) => 0 ]
"abcd" => "abcd"
}
test! {
test_from_current_middle_backward
[ SeekFrom::Start(4) => 4; SeekFrom::Current(-2) => 2 ]
"abcdef" => "cdef"
}
test! {
test_from_current_middle_forward
[ SeekFrom::Start(2) => 2; SeekFrom::Current(2) => 4 ]
"abcdef" => "ef"
}
test! {
test_from_current_last
[ SeekFrom::Start(2) => 2; SeekFrom::Current(2) => 4 ]
"abcd" => ""
}
test! {
test_from_current_beyond
[ SeekFrom::Start(2) => 2; SeekFrom::Current(4) => 6 ]
"abcd" => ""
}
/*
test! {
test_from_end_before
[ SeekFrom::End(-6) => Err ]
Content-Length: true
"abcd" => Err
}
test! {
test_from_end_first
[ SeekFrom::End(-4) => 0 ]
Content-Length: true
"abcd" => "abcd"
}
test! {
test_from_end_middle
[ SeekFrom::End(-2) => 2 ]
Content-Length: true
"abcd" => "cd"
}
test! {
test_from_end_last
[ SeekFrom::End(0) => 4 ]
Content-Length: true
"abcd" => ""
}
*/
/// Test that the file supports reading the stream
/// when it is already at EOF.
#[tokio::test]
async fn test_read_at_end() {
let url = start_server();
let client = reqwest::Client::new();
let data = "abc";
let request = client.get(format!("{url}?data={data}"));
let mut file = RequestFile::new(request);
let mut response_data = String::new();
file.read_to_string(&mut response_data).await.unwrap();
assert_eq!(response_data, data);
response_data.clear();
file.read_to_string(&mut response_data).await.unwrap();
assert_eq!(response_data, "");
}
}