use std::{os::unix::fs::FileExt as _, path::Path, sync::Arc};
use bytes::Bytes;
use futures_util::StreamExt as _;
use http::{Method, Uri};
use tokio::{
fs,
io::AsyncWriteExt as _,
sync::mpsc,
task::{self, JoinSet},
};
use crate::{
auth::Credentials,
config::Config,
error::{Error, Result},
http::{ObjectKey, request::build_signed, retry::send_with_retry},
trace::{maybe_debug, maybe_info},
transfer::part::Part,
};
pub(crate) async fn download_single(
http: &reqwest::Client, config: &Config, creds: &Credentials, bucket: &str, key: &ObjectKey,
dest: &Path,
) -> Result<u64> {
let uri: Uri = format!("{}/{bucket}/{}", config.endpoint_url(), key.encoded()).parse()?;
let req = build_signed(Method::GET, uri, Bytes::new(), creds, &config.region)?;
let resp = send_with_retry(http, req, &config.retry).await?;
if let Some(parent) = dest.parent() {
fs::create_dir_all(parent).await?;
}
let tmp = task::spawn_blocking({
let dir = dest.parent().unwrap_or_else(|| Path::new(".")).to_owned();
move || tempfile::NamedTempFile::new_in(dir)
})
.await
.map_err(|e| Error::Internal(e.to_string()))??;
let std_file = tmp.reopen()?;
let mut file = fs::File::from_std(std_file);
let mut size = 0_u64;
let mut stream = resp.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
#[expect(
clippy::arithmetic_side_effects,
reason = "download size bounded by S3 object size"
)]
{
size += u64::try_from(chunk.len()).unwrap_or(u64::MAX);
}
file.write_all(&chunk).await?;
}
file.flush().await?;
drop(file);
let dest_owned = dest.to_owned();
task::spawn_blocking(move || tmp.persist(dest_owned).map_err(|e| Error::Io(e.error)))
.await
.map_err(|e| Error::Internal(e.to_string()))??;
Ok(size)
}
struct WriteCmd {
buf: Vec<u8>,
offset: u64,
}
struct DiskWriter {
tx: std::sync::mpsc::SyncSender<WriteCmd>,
handle: Option<std::thread::JoinHandle<std::io::Result<()>>>,
}
impl DiskWriter {
fn spawn(fd: Arc<std::fs::File>, capacity: usize) -> Self {
let (tx, rx) = std::sync::mpsc::sync_channel::<WriteCmd>(capacity);
let handle = std::thread::spawn(move || {
while let Ok(cmd) = rx.recv() {
fd.write_all_at(&cmd.buf, cmd.offset)?;
}
fd.sync_all()
});
Self {
tx,
handle: Some(handle),
}
}
fn finish(self) -> Result<()> {
drop(self.tx);
if let Some(h) = self.handle {
h.join()
.map_err(|_| Error::Internal("writer thread panicked".into()))?
.map_err(Error::Io)?;
}
Ok(())
}
}
struct DownloadCtx {
bucket: String,
config: Config,
creds: Credentials,
http: reqwest::Client,
key: ObjectKey,
writer_tx: std::sync::mpsc::SyncSender<WriteCmd>,
}
struct DownloadPartJob {
#[cfg_attr(not(feature = "tracing"), expect(dead_code, reason = "read by maybe_debug!"))]
number: u32,
offset: u64,
size: u64,
}
#[expect(clippy::too_many_arguments, reason = "internal fn, context struct would add indirection")]
pub(crate) async fn download_multipart(
http: &reqwest::Client, config: &Config, creds: &Credentials, bucket: &str, key: &ObjectKey,
parts: &[Part], dest: &Path, total_size: u64, concurrency: usize,
) -> Result<u64> {
assert!(concurrency > 0, "concurrency must be at least 1");
if let Some(parent) = dest.parent() {
fs::create_dir_all(parent).await?;
}
let tmp = task::spawn_blocking({
let dir = dest.parent().unwrap_or_else(|| Path::new(".")).to_owned();
move || tempfile::NamedTempFile::new_in(dir)
})
.await
.map_err(|e| Error::Internal(e.to_string()))??;
let std_file = tmp.reopen()?;
std_file.set_len(total_size)?;
let writer = DiskWriter::spawn(Arc::new(std_file), concurrency.saturating_mul(2).max(2));
maybe_info!(
key = %key, parts = parts.len(), concurrency, size = total_size,
"multipart download started"
);
let ctx = Arc::new(DownloadCtx {
bucket: bucket.to_owned(),
config: config.clone(),
creds: creds.clone(),
http: http.clone(),
key: key.clone(),
writer_tx: writer.tx.clone(),
});
let result = download_all_parts(&ctx, parts, concurrency).await;
drop(ctx);
match result {
Ok(()) => {
writer.finish()?;
let dest = dest.to_owned();
task::spawn_blocking(move || tmp.persist(dest).map_err(|e| Error::Io(e.error)))
.await
.map_err(|e| Error::Internal(e.to_string()))??;
maybe_info!(key = %key, size = total_size, "multipart download complete");
Ok(total_size)
},
Err(e) => {
drop(tmp);
Err(e)
},
}
}
async fn download_all_parts(
ctx: &Arc<DownloadCtx>, parts: &[Part], concurrency: usize,
) -> Result<()> {
let (tx, rx) = mpsc::channel::<DownloadPartJob>(concurrency);
let parts_owned: Vec<Part> = parts.to_vec();
let scheduler_handle = task::spawn(async move {
for part in &parts_owned {
let job = DownloadPartJob {
number: part.number,
offset: part.offset,
size: part.size,
};
if tx.send(job).await.is_err() {
break;
}
}
});
let download_result = run_download_workers(ctx, rx, concurrency).await;
drop(scheduler_handle.await);
download_result
}
async fn run_download_workers(
ctx: &Arc<DownloadCtx>, mut rx: mpsc::Receiver<DownloadPartJob>, concurrency: usize,
) -> Result<()> {
let mut set = JoinSet::new();
let mut channel_open = true;
loop {
if set.is_empty() && !channel_open {
break;
}
let has_capacity = channel_open && set.len() < concurrency;
tokio::select! {
Some(handle) = set.join_next() => {
match handle.map_err(|e| Error::Internal(e.to_string()))? {
Ok(()) => {},
Err(e) => {
rx.close();
set.abort_all();
return Err(e);
},
}
}
job = rx.recv(), if has_capacity => {
match job {
Some(job) => {
let c = Arc::clone(ctx);
set.spawn(async move { download_part(&c, job).await });
},
None => { channel_open = false; },
}
}
else => break,
}
}
Ok(())
}
const WRITE_BUF_SIZE: usize = 512 * 1024;
async fn download_part(ctx: &DownloadCtx, job: DownloadPartJob) -> Result<()> {
debug_assert!(job.size > 0, "zero-size part would underflow range calculation");
#[expect(
clippy::arithmetic_side_effects,
reason = "offset + size bounded by file size; size > 0 guaranteed by plan_parts"
)]
let range_end = job.offset + job.size - 1;
let range_header = format!("bytes={}-{range_end}", job.offset);
let uri: Uri =
format!("{}/{}/{}", ctx.config.endpoint_url(), ctx.bucket, ctx.key.encoded()).parse()?;
maybe_debug!(
key = %ctx.key, part = job.number, offset = job.offset, size = job.size,
"downloading part"
);
let mut signed = build_signed(Method::GET, uri, Bytes::new(), &ctx.creds, &ctx.config.region)?;
signed.headers_mut().insert(
"range",
range_header
.parse()
.map_err(|e: http::header::InvalidHeaderValue| Error::Internal(e.to_string()))?,
);
let resp = send_with_retry(&ctx.http, signed, &ctx.config.retry).await?;
let tx = ctx.writer_tx.clone();
let mut write_offset = job.offset;
let mut buf = Vec::with_capacity(WRITE_BUF_SIZE);
let mut stream = resp.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
buf.extend_from_slice(&chunk);
if buf.len() >= WRITE_BUF_SIZE {
let data = std::mem::replace(&mut buf, Vec::with_capacity(WRITE_BUF_SIZE));
let data_len = data.len();
tx.send(WriteCmd {
buf: data,
offset: write_offset,
})
.map_err(|_| Error::Internal("writer thread exited early".into()))?;
#[expect(clippy::arithmetic_side_effects, reason = "write_offset bounded by file size")]
{
write_offset += u64::try_from(data_len).unwrap_or(u64::MAX);
}
}
}
if !buf.is_empty() {
tx.send(WriteCmd {
buf,
offset: write_offset,
})
.map_err(|_| Error::Internal("writer thread exited early".into()))?;
}
Ok(())
}
#[cfg(test)]
mod tests {
#[test]
fn tempfile_creates_in_target_dir() {
let dir = std::env::temp_dir();
let tmp = tempfile::NamedTempFile::new_in(&dir).unwrap();
assert!(tmp.path().starts_with(&dir));
}
}