use crate::error::{Error, Result};
use bytes::Bytes;
use futures_util::{StreamExt, TryStreamExt};
use std::error::Error as StdError;
use std::future::Future;
use std::path::Path;
use std::sync::{
Arc,
atomic::{AtomicU64, Ordering},
};
use url::Url;
pub const STREAM_LIMIT_ERROR: &str = "download size exceeds maximum";
pub type BodyError = Box<dyn StdError + Send + Sync>;
#[derive(Copy, Clone, Debug, PartialEq)]
enum DownloadStrategy {
Buffer,
Stream,
}
pub struct DownloadResult<S> {
pub content_type: String,
pub extension: String,
pub body: DownloadBody<S>,
}
pub enum DownloadBody<S> {
Buffered { bytes: Bytes },
Streamed {
stream: S,
size_counter: Arc<AtomicU64>,
},
}
pub trait Downloader {
type Stream: futures_util::Stream<Item = std::result::Result<Bytes, BodyError>>
+ Send
+ Sync
+ 'static;
type Future<'a>: Future<Output = Result<DownloadResult<Self::Stream>>> + Send
where
Self: 'a;
fn fetch<'a>(
&'a self,
url: Url,
max_buffered_bytes: u64,
max_file_size: u64,
) -> Self::Future<'a>;
}
#[derive(Clone, Debug)]
pub struct ReqwestDownloader {
client: reqwest::Client,
}
impl ReqwestDownloader {
pub fn new(client: reqwest::Client) -> Self {
Self { client }
}
}
impl Downloader for ReqwestDownloader {
type Stream =
SyncStream<futures_util::stream::BoxStream<'static, std::result::Result<Bytes, BodyError>>>;
type Future<'a> =
std::pin::Pin<Box<dyn Future<Output = Result<DownloadResult<Self::Stream>>> + Send + 'a>>;
fn fetch<'a>(
&'a self,
url: Url,
max_buffered_bytes: u64,
max_file_size: u64,
) -> Self::Future<'a> {
Box::pin(async move {
let response =
self.client
.get(url.clone())
.send()
.await
.map_err(|e| Error::Download {
url: url.to_string(),
reason: e.to_string(),
})?;
if !response.status().is_success() {
return Err(Error::Download {
url: url.to_string(),
reason: format!("HTTP {}", response.status()),
});
}
let strategy = select_download_strategy(
response.content_length(),
max_buffered_bytes,
max_file_size,
)?;
let content_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|s| s.split(';').next().unwrap_or(s).trim().to_string())
.unwrap_or_else(|| "application/octet-stream".to_string());
let ext = url
.path_segments()
.and_then(|mut segments| segments.next_back())
.and_then(|name| Path::new(name).extension().and_then(|e| e.to_str()))
.or_else(|| {
mime_guess::get_mime_extensions_str(&content_type)
.and_then(|arr| arr.first().copied())
})
.unwrap_or("bin")
.to_string();
match strategy {
DownloadStrategy::Buffer => {
let bytes = response.bytes().await.map_err(|e| Error::Download {
url: url.to_string(),
reason: format!("failed to read response body: {}", e),
})?;
if bytes.len() as u64 > max_file_size {
return Err(Error::Config {
message: format!(
"downloaded content size {} exceeds maximum {}",
bytes.len(),
max_file_size
),
});
}
Ok(DownloadResult {
content_type,
extension: ext,
body: DownloadBody::Buffered { bytes },
})
}
DownloadStrategy::Stream => {
let stream = response
.bytes_stream()
.map_err(|e| -> BodyError { Box::new(e) });
let (limited_stream, counter) = limit_stream(stream, max_file_size);
let boxed_stream = limited_stream.boxed();
let sync_stream = SyncStream::new(boxed_stream);
Ok(DownloadResult {
content_type,
extension: ext,
body: DownloadBody::Streamed {
stream: sync_stream,
size_counter: counter,
},
})
}
}
})
}
}
pub struct SyncStream<S> {
inner: std::sync::Mutex<std::pin::Pin<Box<S>>>,
}
impl<S> SyncStream<S> {
fn new(stream: S) -> Self {
Self {
inner: std::sync::Mutex::new(Box::pin(stream)),
}
}
}
impl<S> futures_util::Stream for SyncStream<S>
where
S: futures_util::Stream<Item = std::result::Result<Bytes, BodyError>>,
{
type Item = std::result::Result<Bytes, BodyError>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let mut guard = match self.inner.lock() {
Ok(guard) => guard,
Err(_) => {
return std::task::Poll::Ready(Some(Err(Box::new(std::io::Error::other(
"stream lock poisoned",
)))));
}
};
guard.as_mut().poll_next(cx)
}
}
fn select_download_strategy(
content_length: Option<u64>,
max_buffered_bytes: u64,
max_file_size: u64,
) -> Result<DownloadStrategy> {
if max_buffered_bytes == 0 {
return Err(Error::Config {
message: "max_buffered_bytes must be greater than zero".into(),
});
}
if max_buffered_bytes > max_file_size {
return Err(Error::Config {
message: "max_buffered_bytes must not exceed max_file_size".into(),
});
}
match content_length {
Some(size) if size > max_file_size => Err(Error::Config {
message: format!("content size {} exceeds maximum {}", size, max_file_size),
}),
Some(size) if size <= max_buffered_bytes => Ok(DownloadStrategy::Buffer),
_ => Ok(DownloadStrategy::Stream),
}
}
fn limit_stream<S>(
stream: S,
max: u64,
) -> (
impl futures_util::Stream<Item = std::result::Result<Bytes, BodyError>>,
Arc<AtomicU64>,
)
where
S: futures_util::Stream<Item = std::result::Result<Bytes, BodyError>>,
{
let counter = Arc::new(AtomicU64::new(0));
let counter_inner = Arc::clone(&counter);
let limited = stream.map(move |result| match result {
Ok(bytes) => {
let read = bytes.len() as u64;
let total = counter_inner.fetch_add(read, Ordering::Relaxed) + read;
if total > max {
return Err(Box::new(std::io::Error::other(STREAM_LIMIT_ERROR)) as BodyError);
}
Ok(bytes)
}
Err(err) => Err(err),
});
(limited, counter)
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::TryStreamExt;
#[test]
fn test_select_download_strategy_buffers_small_content() {
let strategy = select_download_strategy(Some(1024), 8 * 1024, 64 * 1024)
.expect("expected a valid strategy");
assert_eq!(strategy, DownloadStrategy::Buffer);
}
#[test]
fn test_select_download_strategy_streams_large_content() {
let strategy = select_download_strategy(Some(16 * 1024), 8 * 1024, 64 * 1024)
.expect("expected a valid strategy");
assert_eq!(strategy, DownloadStrategy::Stream);
}
#[test]
fn test_select_download_strategy_streams_when_length_unknown() {
let strategy =
select_download_strategy(None, 8 * 1024, 64 * 1024).expect("expected a valid strategy");
assert_eq!(strategy, DownloadStrategy::Stream);
}
#[test]
fn test_select_download_strategy_rejects_oversize() {
let result = select_download_strategy(Some(128 * 1024), 8 * 1024, 64 * 1024);
assert!(result.is_err());
}
#[tokio::test]
async fn test_limit_stream_enforces_max_size() {
let data = vec![0_u8; 16];
let stream = futures_util::stream::iter(vec![Ok::<Bytes, BodyError>(Bytes::from(data))]);
let (limited, _) = limit_stream(stream, 8);
let result: std::result::Result<Vec<Bytes>, BodyError> = limited.try_collect().await;
let err = result.expect_err("expected size limit error");
assert!(err.to_string().contains(STREAM_LIMIT_ERROR));
}
#[tokio::test]
async fn test_limit_stream_tracks_size() {
let data = vec![1_u8; 5];
let stream =
futures_util::stream::iter(vec![Ok::<Bytes, BodyError>(Bytes::from(data.clone()))]);
let (limited, counter) = limit_stream(stream, 8);
let result: std::result::Result<Vec<Bytes>, BodyError> = limited.try_collect().await;
let chunks = result.expect("expected stream to succeed");
let total: usize = chunks.iter().map(|chunk| chunk.len()).sum();
assert_eq!(total, data.len());
assert_eq!(counter.load(Ordering::Relaxed), data.len() as u64);
}
}