#![deny(missing_docs)]
#![forbid(clippy::unwrap_used)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![doc = include_str!("../README.md")]
use std::fmt::Debug;
use std::future::{self, Future};
use std::io::{self, Read, Seek, SeekFrom};
use educe::Educe;
pub use settings::*;
use source::handle::SourceHandle;
use source::{DecodeError, Source, SourceStream};
use storage::StorageProvider;
use tokio_util::sync::CancellationToken;
use tracing::{debug, instrument, trace};
#[cfg(feature = "async-read")]
pub mod async_read;
#[cfg(feature = "http")]
pub mod http;
#[cfg(feature = "process")]
pub mod process;
#[cfg(feature = "registry")]
pub mod registry;
mod settings;
pub mod source;
pub mod storage;
#[derive(Debug, Clone)]
pub struct StreamHandle {
finished: CancellationToken,
}
impl StreamHandle {
pub async fn wait_for_completion(self) {
self.finished.cancelled().await;
}
}
#[derive(Debug)]
pub struct StreamDownload<P: StorageProvider> {
output_reader: P::Reader,
handle: SourceHandle,
download_task_cancellation_token: CancellationToken,
cancel_on_drop: bool,
content_length: Option<u64>,
storage_capacity: Option<usize>,
}
impl<P: StorageProvider> StreamDownload<P> {
#[cfg(feature = "reqwest")]
pub async fn new_http(
url: ::reqwest::Url,
storage_provider: P,
settings: Settings<http::HttpStream<::reqwest::Client>>,
) -> Result<Self, StreamInitializationError<http::HttpStream<::reqwest::Client>>> {
Self::new(url, storage_provider, settings).await
}
#[cfg(feature = "reqwest-middleware")]
pub async fn new_http_with_middleware(
url: ::reqwest::Url,
storage_provider: P,
settings: Settings<http::HttpStream<::reqwest_middleware::ClientWithMiddleware>>,
) -> Result<
Self,
StreamInitializationError<http::HttpStream<::reqwest_middleware::ClientWithMiddleware>>,
> {
Self::new(url, storage_provider, settings).await
}
#[cfg(feature = "async-read")]
pub async fn new_async_read<T>(
params: async_read::AsyncReadStreamParams<T>,
storage_provider: P,
settings: Settings<async_read::AsyncReadStream<T>>,
) -> Result<Self, StreamInitializationError<async_read::AsyncReadStream<T>>>
where
T: tokio::io::AsyncRead + Send + Sync + Unpin + 'static,
{
Self::new(params, storage_provider, settings).await
}
#[cfg(feature = "process")]
pub async fn new_process(
params: process::ProcessStreamParams,
storage_provider: P,
settings: Settings<process::ProcessStream>,
) -> Result<Self, StreamInitializationError<process::ProcessStream>> {
Self::new(params, storage_provider, settings).await
}
pub async fn new<S>(
params: S::Params,
storage_provider: P,
settings: Settings<S>,
) -> Result<Self, StreamInitializationError<S>>
where
S: SourceStream,
S::Error: Debug + Send,
{
Self::from_create_stream(move || S::create(params), storage_provider, settings).await
}
pub async fn from_stream<S>(
stream: S,
storage_provider: P,
settings: Settings<S>,
) -> Result<Self, StreamInitializationError<S>>
where
S: SourceStream,
S::Error: Debug + Send,
{
Self::from_create_stream(
move || future::ready(Ok(stream)),
storage_provider,
settings,
)
.await
}
pub fn cancel_download(&self) {
self.download_task_cancellation_token.cancel();
}
pub fn cancellation_token(&self) -> CancellationToken {
self.download_task_cancellation_token.clone()
}
pub fn handle(&self) -> StreamHandle {
StreamHandle {
finished: self.download_task_cancellation_token.clone(),
}
}
pub fn content_length(&self) -> Option<u64> {
self.content_length
}
async fn from_create_stream<S, F, Fut>(
create_stream: F,
storage_provider: P,
settings: Settings<S>,
) -> Result<Self, StreamInitializationError<S>>
where
S: SourceStream<Error: Debug + Send>,
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = Result<S, S::StreamCreationError>> + Send,
{
let stream = create_stream()
.await
.map_err(StreamInitializationError::StreamCreationFailure)?;
let content_length = stream.content_length();
let storage_capacity = storage_provider.max_capacity();
let (reader, writer) = storage_provider
.into_reader_writer(content_length)
.map_err(StreamInitializationError::StorageCreationFailure)?;
let cancellation_token = CancellationToken::new();
let cancel_on_drop = settings.cancel_on_drop;
let mut source = Source::new(writer, content_length, settings, cancellation_token.clone());
let handle = source.source_handle();
tokio::spawn({
let cancellation_token = cancellation_token.clone();
async move {
source.download(stream).await;
cancellation_token.cancel();
debug!("download task finished");
}
});
Ok(Self {
output_reader: reader,
handle,
download_task_cancellation_token: cancellation_token,
cancel_on_drop,
content_length,
storage_capacity,
})
}
fn get_absolute_seek_position(&mut self, relative_position: SeekFrom) -> io::Result<u64> {
Ok(match relative_position {
SeekFrom::Start(position) => {
debug!(seek_position = position, "seeking from start");
position
}
SeekFrom::Current(position) => {
debug!(seek_position = position, "seeking from current position");
(self.output_reader.stream_position()? as i64 + position) as u64
}
SeekFrom::End(position) => {
debug!(seek_position = position, "seeking from end");
if let Some(length) = self.handle.content_length() {
(length as i64 + position) as u64
} else {
return Err(io::Error::new(
io::ErrorKind::Unsupported,
"cannot seek from end when content length is unknown",
));
}
}
})
}
fn handle_read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let res = self.output_reader.read(buf).inspect(|l| {
trace!(read_length = format!("{l:?}"), "returning read");
});
self.handle.notify_read();
res
}
fn normalize_requested_position(&self, requested_position: u64) -> u64 {
if let Some(content_length) = self.content_length {
requested_position.min(content_length)
} else {
requested_position
}
}
fn check_for_failure(&self) -> io::Result<()> {
if self.handle.is_failed() {
Err(io::Error::other("stream failed to download"))
} else {
Ok(())
}
}
fn check_for_excessive_read(&self, buf_len: usize) -> io::Result<()> {
let capacity = self.storage_capacity.unwrap_or(usize::MAX);
if buf_len > capacity {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("buffer size {buf_len} exceeds the max capacity of {capacity}",),
))
} else {
Ok(())
}
}
fn check_for_excessive_seek(&mut self, absolute_seek_position: u64) -> io::Result<()> {
if let Some(max_capacity) = self.storage_capacity {
let max_possible_seek_position = self
.output_reader
.stream_position()?
.saturating_add(max_capacity as u64);
if absolute_seek_position
> self
.output_reader
.stream_position()?
.saturating_add(max_capacity as u64)
{
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"seek position {absolute_seek_position} exceeds maximum of \
{max_possible_seek_position}"
),
));
}
}
Ok(())
}
}
#[derive(thiserror::Error, Educe)]
#[educe(Debug)]
pub enum StreamInitializationError<S: SourceStream> {
#[error("Storage creation failure: {0}")]
StorageCreationFailure(io::Error),
#[error("Stream creation failure: {0}")]
StreamCreationFailure(<S as SourceStream>::StreamCreationError),
}
impl<S: SourceStream> DecodeError for StreamInitializationError<S> {
async fn decode_error(self) -> String {
match self {
this @ Self::StorageCreationFailure(_) => this.to_string(),
Self::StreamCreationFailure(e) => e.decode_error().await,
}
}
}
impl<P: StorageProvider> Drop for StreamDownload<P> {
fn drop(&mut self) {
if self.cancel_on_drop {
self.cancel_download();
}
}
}
impl<P: StorageProvider> Read for StreamDownload<P> {
#[instrument(skip_all, fields(len=buf.len()))]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.check_for_failure()?;
self.check_for_excessive_read(buf.len())?;
trace!(buffer_length = buf.len(), "read requested");
let stream_position = self.output_reader.stream_position()?;
let requested_position =
self.normalize_requested_position(stream_position + buf.len() as u64);
trace!(
current_position = stream_position,
requested_position = requested_position
);
if let Some(closest_set) = self.handle.get_downloaded_at_position(stream_position) {
trace!(
downloaded_range = format!("{closest_set:?}"),
"current position already downloaded"
);
if closest_set.end >= requested_position {
trace!("requested position already downloaded");
return self.handle_read(buf);
}
debug!("requested position not yet downloaded");
} else {
debug!("stream position not yet downloaded");
}
self.handle.wait_for_position(requested_position);
self.check_for_failure()?;
debug!(
current_position = stream_position,
requested_position = requested_position,
output_stream_position = self.output_reader.stream_position()?,
"reached requested position"
);
self.handle_read(buf)
}
}
impl<P: StorageProvider> Seek for StreamDownload<P> {
#[instrument(skip(self))]
fn seek(&mut self, relative_position: SeekFrom) -> io::Result<u64> {
self.check_for_failure()?;
let absolute_seek_position = self.get_absolute_seek_position(relative_position)?;
let absolute_seek_position = self.normalize_requested_position(absolute_seek_position);
self.check_for_excessive_seek(absolute_seek_position)?;
debug!(absolute_seek_position, "absolute seek position");
if let Some(closest_set) = self
.handle
.get_downloaded_at_position(absolute_seek_position)
{
debug!(
downloaded_range = format!("{closest_set:?}"),
"seek position already downloaded"
);
return self
.output_reader
.seek(SeekFrom::Start(absolute_seek_position))
.inspect_err(|p| debug!(position = format!("{p:?}"), "returning seek position"));
}
self.handle.seek(absolute_seek_position);
self.check_for_failure()?;
debug!("reached seek position");
self.output_reader
.seek(SeekFrom::Start(absolute_seek_position))
.inspect_err(|p| debug!(position = format!("{p:?}"), "returning seek position"))
}
}
pub(crate) trait WrapIoResult {
fn wrap_err(self, msg: &str) -> Self;
}
impl<T> WrapIoResult for io::Result<T> {
fn wrap_err(self, msg: &str) -> Self {
if let Err(e) = self {
Err(io::Error::new(e.kind(), format!("{msg}: {e}")))
} else {
self
}
}
}