use std::marker::PhantomData;
use std::sync::{Arc, Mutex};
use bytes::{BufMut, BytesMut};
use http::HeaderValue;
use http::header::{ACCEPT_RANGES, CONTENT_LENGTH, RANGE};
use crate::client::Client;
use crate::error::Error;
use crate::runtime::Runtime;
pub struct ChunkDownload<R: Runtime> {
client: Client<R>,
url: String,
chunks: usize,
_runtime: PhantomData<R>,
}
impl<R: Runtime> std::fmt::Debug for ChunkDownload<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChunkDownload")
.field("url", &self.url)
.finish()
}
}
#[derive(Debug)]
pub struct ChunkDownloadResult {
pub total_size: u64,
pub data: bytes::Bytes,
}
type ChunkResults = Arc<Mutex<Vec<Option<std::result::Result<bytes::Bytes, Error>>>>>;
impl<R: Runtime> ChunkDownload<R> {
pub(crate) fn new(client: Client<R>, url: String) -> Self {
Self {
client,
url,
chunks: 4,
_runtime: PhantomData,
}
}
pub fn chunks(mut self, n: usize) -> Self {
self.chunks = n.max(1);
self
}
pub async fn download(self) -> Result<ChunkDownloadResult, Error> {
let client = self.client;
let url = self.url;
let head_resp = client.head(&url)?.send().await?;
if !head_resp.status().is_success() {
return Err(Error::Other(
format!("HEAD request failed: {}", head_resp.status()).into(),
));
}
let accepts_ranges = head_resp
.headers()
.get(ACCEPT_RANGES)
.and_then(|v| v.to_str().ok())
.map(|v| v.contains("bytes"))
.unwrap_or(false);
let content_length = head_resp
.headers()
.get(CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok());
let total_size = match content_length {
Some(len) if accepts_ranges && len > 0 => len,
_ => {
let resp = client.get(&url)?.send().await?;
let data = resp.bytes().await?;
let len = data.len() as u64;
return Ok(ChunkDownloadResult {
total_size: len,
data,
});
}
};
let num_chunks = (self.chunks as u64).min(total_size) as usize;
let chunk_size = total_size / num_chunks as u64;
let results: ChunkResults = Arc::new(Mutex::new((0..num_chunks).map(|_| None).collect()));
let done_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
for i in 0..num_chunks {
let start = i as u64 * chunk_size;
let end = if i == num_chunks - 1 {
total_size - 1
} else {
(i as u64 + 1) * chunk_size - 1
};
let url = url.clone();
let range_value = format!("bytes={start}-{end}");
let client = client.clone();
let results = Arc::clone(&results);
let done_count = Arc::clone(&done_count);
R::spawn(async move {
let result: std::result::Result<bytes::Bytes, Error> = async {
let range_header = HeaderValue::from_str(&range_value)
.map_err(|e| Error::Other(Box::new(e)))?;
let resp = client.get(&url)?.header(RANGE, range_header).send().await?;
if resp.status() != http::StatusCode::PARTIAL_CONTENT
&& !resp.status().is_success()
{
return Err(Error::Other(
format!("chunk request failed: {}", resp.status()).into(),
));
}
resp.bytes().await
}
.await;
results.lock().unwrap()[i] = Some(result);
done_count.fetch_add(1, std::sync::atomic::Ordering::Release);
});
}
loop {
if done_count.load(std::sync::atomic::Ordering::Acquire) == num_chunks {
break;
}
R::sleep(std::time::Duration::from_millis(1)).await;
}
let chunk_data = Arc::try_unwrap(results)
.map_err(|_| Error::Other("failed to unwrap results".into()))?
.into_inner()
.map_err(|_| Error::Other("chunk result mutex poisoned".into()))?;
let mut buf = BytesMut::with_capacity(total_size as usize);
for result in chunk_data {
let data = result.ok_or_else(|| Error::Other("missing chunk".into()))??;
buf.put(data);
}
Ok(ChunkDownloadResult {
total_size,
data: buf.freeze(),
})
}
}
#[cfg(all(test, feature = "tokio"))]
mod tests {
use super::*;
use crate::runtime::TokioRuntime;
#[test]
fn chunks_clamps_to_one() {
let client = crate::Client::<TokioRuntime>::new();
let dl = client.chunk_download("http://example.com/file");
let dl = dl.chunks(0);
assert_eq!(dl.chunks, 1);
}
#[test]
fn chunks_accepts_large_value() {
let client = crate::Client::<TokioRuntime>::new();
let dl = client.chunk_download("http://example.com/file").chunks(100);
assert_eq!(dl.chunks, 100);
}
#[test]
fn debug_format_includes_url() {
let client = crate::Client::<TokioRuntime>::new();
let dl = client.chunk_download("http://example.com/large.bin");
let dbg = format!("{dl:?}");
assert!(dbg.contains("ChunkDownload"));
assert!(dbg.contains("large.bin"));
}
#[test]
fn range_splitting_single_chunk() {
let total_size: u64 = 100;
let start = 0u64;
let end = total_size - 1;
assert_eq!(start, 0);
assert_eq!(end, 99);
}
#[test]
fn range_splitting_even() {
let total_size: u64 = 100;
let num_chunks: usize = 4;
let chunk_size = total_size / num_chunks as u64;
let mut ranges = Vec::new();
for i in 0..num_chunks {
let start = i as u64 * chunk_size;
let end = if i == num_chunks - 1 {
total_size - 1
} else {
(i as u64 + 1) * chunk_size - 1
};
ranges.push((start, end));
}
assert_eq!(ranges, vec![(0, 24), (25, 49), (50, 74), (75, 99)]);
}
#[test]
fn range_splitting_uneven() {
let total_size: u64 = 10;
let num_chunks: usize = 3;
let chunk_size = total_size / num_chunks as u64;
let mut ranges = Vec::new();
for i in 0..num_chunks {
let start = i as u64 * chunk_size;
let end = if i == num_chunks - 1 {
total_size - 1
} else {
(i as u64 + 1) * chunk_size - 1
};
ranges.push((start, end));
}
assert_eq!(ranges[0], (0, 2));
assert_eq!(ranges[1], (3, 5));
assert_eq!(ranges[2], (6, 9));
}
#[test]
fn num_chunks_capped_at_total_size() {
let total_size: u64 = 3;
let requested_chunks: u64 = 10;
let num_chunks = requested_chunks.min(total_size) as usize;
assert_eq!(num_chunks, 3);
}
#[test]
fn chunk_download_result_debug() {
let result = ChunkDownloadResult {
total_size: 42,
data: bytes::Bytes::from("hello"),
};
let dbg = format!("{result:?}");
assert!(dbg.contains("42"));
}
#[test]
fn reassembly_order() {
let mut buf = BytesMut::with_capacity(9);
buf.put(&b"abc"[..]);
buf.put(&b"def"[..]);
buf.put(&b"ghi"[..]);
assert_eq!(&buf[..], b"abcdefghi");
}
}