use std::{
fs::{File, OpenOptions},
io::{Seek, Write},
path::Path,
};
pub async fn chuncked_download_to(
client: &reqwest::Client,
url: &str,
chunk_size: u64,
path: impl AsRef<Path>,
) -> anyhow::Result<File> {
let head = client.head(url).send().await?;
let accepts_ranges = head
.headers()
.get(reqwest::header::ACCEPT_RANGES)
.and_then(|v| v.to_str().ok())
.is_some_and(|v| v == "bytes");
let content_length = head
.headers()
.get(reqwest::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok());
match (accepts_ranges, content_length) {
(true, Some(total)) => download_chunked(client, url, total, chunk_size, path).await,
_ => simple_download_to(client, url, path).await,
}
}
async fn download_chunked(
client: &reqwest::Client,
url: &str,
total: u64,
chunk_size: u64,
path: impl AsRef<Path>,
) -> anyhow::Result<File> {
let (tx, mut rx) = tokio::sync::mpsc::channel(4);
let spawner_client = client.clone();
let spawner_url = url.to_owned();
let _spawner = tokio::spawn(async move {
let mut start = 0_u64;
while start < total {
let end = start.saturating_add(chunk_size).min(total);
let offset = start;
let chunk_client = spawner_client.clone();
let chunk_url = spawner_url.clone();
let chunk_tx = tx.clone();
let _handle = tokio::spawn(async move {
let result = async {
let chunk = chunk_client
.get(&chunk_url)
.header(
reqwest::header::RANGE,
format!("bytes={offset}-{}", end.saturating_sub(1)),
)
.send()
.await?
.bytes()
.await?;
anyhow::Ok((chunk, offset))
}
.await;
drop(chunk_tx.send(result).await);
});
start = end;
}
});
let mut file = OpenOptions::new()
.create(true)
.truncate(true)
.write(true)
.read(true)
.open(path)?;
file.set_len(total)?;
while let Some(result) = rx.recv().await {
let (chunk, offset) = result?;
file.seek(std::io::SeekFrom::Start(offset))?;
file.write_all(&chunk)?;
}
Ok(file)
}
pub async fn simple_download_to(
client: &reqwest::Client,
url: &str,
path: impl AsRef<Path>,
) -> anyhow::Result<File> {
let mut file = OpenOptions::new()
.create(true)
.truncate(true)
.write(true)
.read(true)
.open(path)?;
let data = client.get(url).send().await?.bytes().await?;
file.write_all(&data)?;
Ok(file)
}
#[cfg(test)]
mod tests {
use std::io::{Read, Seek, SeekFrom};
use tempfile::TempDir;
use super::*;
#[tokio::test]
async fn chuncked_download_to_test() {
const N: u64 = 1024;
let expected: Vec<u8> = (0..N).map(|i| b'a' + (i % 26) as u8).collect();
let tmp = TempDir::new().unwrap();
let path = tmp.path().join("output");
let client = reqwest::Client::new();
let mut file = chuncked_download_to(
&client,
&format!("https://httpbin.org/range/{N}"),
N / 2,
path,
)
.await
.unwrap();
file.seek(SeekFrom::Start(0)).unwrap();
let mut content = Vec::new();
file.read_to_end(&mut content).unwrap();
assert_eq!(content, expected);
}
#[tokio::test]
async fn simple_download_to_test() {
let tmp = TempDir::new().unwrap();
let path = tmp.path().join("output");
let client = reqwest::Client::new();
let mut file =
simple_download_to(&client, "https://httpbin.org/base64/SGVsbG8gV29ybGQ=", path)
.await
.unwrap();
file.seek(SeekFrom::Start(0)).unwrap();
let mut content = String::new();
file.read_to_string(&mut content).unwrap();
assert_eq!(content, "Hello World");
}
}