use std::error::Error;
use std::fmt::{Debug, Display, Formatter};
use std::io;
use std::pin::Pin;
use std::task::{self, Poll};
use std::time::Instant;
use bytes::Bytes;
use educe::Educe;
use futures_util::{Future, Stream};
use mediatype::MediaTypeBuf;
#[cfg(feature = "reqwest")]
pub use reqwest;
use tracing::{debug, instrument, warn};
use crate::WrapIoResult;
use crate::source::{DecodeError, SourceStream};
#[cfg(feature = "reqwest")]
pub mod reqwest_client;
#[cfg(feature = "reqwest-middleware")]
pub(crate) mod reqwest_middleware_client;
pub trait Client: Send + Sync + Unpin + 'static {
type Url: Display + Send + Sync + Unpin;
type Headers: ResponseHeaders;
type Response: ClientResponse<Headers = Self::Headers>;
type Error: Error + Send + Sync;
fn create() -> Self;
fn get(
&self,
url: &Self::Url,
) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send;
fn get_range(
&self,
url: &Self::Url,
start: u64,
end: Option<u64>,
) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send;
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ContentType {
pub r#type: String,
pub subtype: String,
}
pub trait ResponseHeaders: Send + Sync + Unpin {
fn header(&self, name: &str) -> Option<&str>;
}
pub trait ClientResponse: Send + Sync + Sized {
type ResponseError: DecodeError + Send;
type StreamError: Error + Send + Sync;
type Headers: ResponseHeaders;
fn content_length(&self) -> Option<u64>;
fn content_type(&self) -> Option<&str>;
fn headers(&self) -> Self::Headers;
fn into_result(self) -> Result<Self, Self::ResponseError>;
fn stream(
self,
) -> Box<dyn Stream<Item = Result<Bytes, Self::StreamError>> + Unpin + Send + Sync>;
}
fn fmt<T>(val: &T, fmt: &mut Formatter<'_>) -> Result<(), std::fmt::Error>
where
T: Display,
{
write!(fmt, "{val}")
}
#[derive(thiserror::Error, Educe)]
#[educe(Debug)]
pub enum HttpStreamError<C: Client> {
#[error("Failed to fetch: {0}")]
FetchFailure(C::Error),
#[error("Failed to get response: {0}")]
ResponseFailure(<<C as Client>::Response as ClientResponse>::ResponseError),
}
impl<C: Client> DecodeError for HttpStreamError<C> {
async fn decode_error(self) -> String {
match self {
Self::ResponseFailure(e) => e.decode_error().await,
this @ Self::FetchFailure(_) => this.to_string(),
}
}
}
#[derive(Educe)]
#[educe(Debug)]
pub struct HttpStream<C: Client> {
#[educe(Debug = false)]
stream: Box<
dyn Stream<Item = Result<Bytes, <<C as Client>::Response as ClientResponse>::StreamError>>
+ Unpin
+ Send
+ Sync,
>,
client: C,
content_length: Option<u64>,
content_type: Option<ContentType>,
#[educe(Debug(method = "fmt"))]
url: C::Url,
#[educe(Debug = false)]
headers: C::Headers,
}
impl<C: Client> HttpStream<C> {
#[instrument(skip(client, url), fields(url = url.to_string()))]
pub async fn new(
client: C,
url: <Self as SourceStream>::Params,
) -> Result<Self, HttpStreamError<C>> {
debug!("requesting stream content");
let request_start = Instant::now();
let response = client
.get(&url)
.await
.map_err(HttpStreamError::FetchFailure)?;
debug!(
duration = format!("{:?}", request_start.elapsed()),
"request finished"
);
let response = response
.into_result()
.map_err(HttpStreamError::ResponseFailure)?;
let content_length = response.content_length().map_or_else(
|| {
warn!("content length header missing");
None
},
|content_length| {
debug!(content_length, "received content length");
Some(content_length)
},
);
let content_type = response.content_type().map_or_else(
|| {
warn!("content type header missing");
None
},
|content_type| {
debug!(content_type, "received content type");
match content_type.parse::<MediaTypeBuf>() {
Ok(content_type) => Some(ContentType {
r#type: content_type.ty().to_string(),
subtype: content_type.subty().to_string(),
}),
Err(e) => {
warn!("error parsing content type: {e:?}");
None
}
}
},
);
let headers = response.headers();
let stream = response.stream();
Ok(Self {
stream: Box::new(stream),
client,
content_length,
content_type,
headers,
url,
})
}
pub fn content_type(&self) -> &Option<ContentType> {
&self.content_type
}
pub fn header(&self, name: &str) -> Option<&str> {
self.headers.header(name)
}
pub fn headers(&self) -> &C::Headers {
&self.headers
}
fn supports_range_request(&self) -> bool {
match self.header("Accept-Ranges") {
Some(val) => val != "none",
None => false,
}
}
}
impl<C: Client> Stream for HttpStream<C> {
type Item = Result<Bytes, <<C as Client>::Response as ClientResponse>::StreamError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.stream).poll_next(cx)
}
}
impl<C: Client> SourceStream for HttpStream<C> {
type Params = C::Url;
type StreamCreationError = HttpStreamError<C>;
async fn create(params: Self::Params) -> Result<Self, Self::StreamCreationError> {
Self::new(C::create(), params).await
}
fn content_length(&self) -> Option<u64> {
self.content_length
}
#[instrument(skip(self))]
async fn seek_range(&mut self, start: u64, end: Option<u64>) -> io::Result<()> {
if Some(start) == self.content_length {
debug!(
"attempting to seek where start is the length of the stream, returning empty \
stream"
);
self.stream = Box::new(futures_util::stream::empty());
return Ok(());
}
if !self.supports_range_request() {
warn!("Accept-Ranges header not present. Attempting seek anyway.");
}
debug!("sending HTTP range request");
let request_start = Instant::now();
let response = self
.client
.get_range(&self.url, start, end.map(|e| e - 1))
.await
.map_err(|e| io::Error::other(e.to_string()))
.wrap_err(&format!("error sending HTTP range request to {}", self.url))?;
debug!(
duration = format!("{:?}", request_start.elapsed()),
"HTTP request finished"
);
let response = match response.into_result() {
Ok(response) => Ok(response),
Err(e) => {
let error = e.decode_error().await;
Err(io::Error::other(error)).wrap_err(&format!(
"error getting HTTP range response from {}",
self.url
))
}
}?;
self.stream = Box::new(response.stream());
debug!("done seeking");
Ok(())
}
async fn reconnect(&mut self, current_position: u64) -> Result<(), io::Error> {
if self.supports_range_request() {
self.seek_range(current_position, None).await
} else {
let response = self
.client
.get(&self.url)
.await
.map_err(|e| io::Error::other(e.to_string()))
.wrap_err(&format!("error sending HTTP request to {}", self.url))?;
self.stream = Box::new(response.stream());
Ok(())
}
}
fn supports_seek(&self) -> bool {
true
}
}
pub const RANGE_HEADER_KEY: &str = "Range";
pub fn format_range_header_bytes(start: u64, end: Option<u64>) -> String {
format!(
"bytes={start}-{}",
end.map(|e| e.to_string()).unwrap_or_default()
)
}