use std::collections::VecDeque;
use std::num::NonZeroU64;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use futures::io;
use prost::Message;
use reflink_copy::ReflinkBlockBuilder;
use crate::{
Download,
download_manager::checksum::check_final_file_checksum,
download_metadata::{DownloadMetadata, PartDetails},
error::OdlError,
fs_utils::{atomic_write, set_file_mtime_async},
progress::{
DownloadContext, Phase, ProgressEvent, SAMPLE_INTERVAL, speed_window_rate,
trim_speed_window,
},
};
const COPY_BUF_SIZE: usize = 1024 * 1024;
const ASSEMBLY_SPEED_WINDOW: Duration = Duration::from_millis(1500);
pub const ASSEMBLY_ULID: &str = "_assemble";
pub async fn remove_all_parts(download_dir: &Path) {
if let Ok(mut entries) = tokio::fs::read_dir(download_dir).await {
while let Ok(Some(entry)) = entries.next_entry().await {
let path = entry.path();
if let Some(ext) = path.extension()
&& ext == Download::PART_EXTENSION
{
let _ = tokio::fs::remove_file(&path).await;
}
}
}
}
pub async fn assemble_final_file(
metadata: &DownloadMetadata,
instruction: &Download,
ctx: &DownloadContext,
) -> Result<PathBuf, OdlError> {
let final_path = instruction.final_file_path();
let mut sorted_parts: Vec<&PartDetails> = metadata.parts.values().collect();
sorted_parts.sort_by_key(|p| p.offset);
let total: u64 = sorted_parts.iter().map(|p| p.size).sum();
ctx.emit(ProgressEvent::PhaseChanged(Phase::Assembling));
ctx.emit(ProgressEvent::Progress {
downloaded: 0,
total: Some(total),
});
ctx.emit(ProgressEvent::PartAdded {
ulid: ASSEMBLY_ULID.to_string(),
offset: 0,
size: total,
});
let parts: Vec<(PathBuf, u64, u64)> = sorted_parts
.iter()
.map(|p| (instruction.part_path(&p.ulid), p.offset, p.size))
.collect();
let final_end: u64 = sorted_parts.last().map(|p| p.offset + p.size).unwrap_or(0);
let final_path_for_blocking = final_path.clone();
let ctx_for_blocking = ctx.clone();
let done_counter = Arc::new(AtomicU64::new(0));
let done_for_blocking = Arc::clone(&done_counter);
let sampler_handle = spawn_assembly_sampler(ctx.clone(), Arc::clone(&done_counter), total);
let blocking_result = tokio::task::spawn_blocking(move || -> std::io::Result<()> {
assemble_blocking(
&final_path_for_blocking,
final_end,
parts,
done_for_blocking,
ctx_for_blocking,
)
})
.await;
sampler_handle.abort();
let blocking_ok = matches!(&blocking_result, Ok(Ok(())));
if blocking_ok {
ctx.emit(ProgressEvent::PartProgress {
ulid: ASSEMBLY_ULID.to_string(),
downloaded: total,
total,
});
ctx.emit(ProgressEvent::PartFinished {
ulid: ASSEMBLY_ULID.to_string(),
});
}
blocking_result??;
if metadata.use_server_time
&& let Some(last_modified) = metadata.last_modified
&& let Err(e) = set_file_mtime_async(&final_path, last_modified).await
{
tracing::error!(
"Failed to set file mtime for {}: {}",
final_path.display(),
e
);
}
ctx.emit(ProgressEvent::PhaseChanged(Phase::Verifying));
check_final_file_checksum(metadata, instruction, false).await?;
Ok(final_path)
}
fn spawn_assembly_sampler(
ctx: DownloadContext,
done: Arc<AtomicU64>,
total: u64,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let mut window: VecDeque<(Instant, u64)> = VecDeque::new();
window.push_back((Instant::now(), 0));
loop {
tokio::select! {
_ = ctx.cancel.cancelled() => return,
_ = tokio::time::sleep(SAMPLE_INTERVAL) => {}
}
let now = Instant::now();
let cur = done.load(Ordering::Relaxed);
window.push_back((now, cur));
trim_speed_window(&mut window, now, ASSEMBLY_SPEED_WINDOW);
if let Some(rate) = speed_window_rate(&window) {
ctx.emit(ProgressEvent::Speed {
bytes_per_second: rate,
});
ctx.emit(ProgressEvent::PartSpeed {
ulid: ASSEMBLY_ULID.to_string(),
bytes_per_second: rate,
});
}
ctx.emit(ProgressEvent::Progress {
downloaded: cur,
total: Some(total),
});
ctx.emit(ProgressEvent::PartProgress {
ulid: ASSEMBLY_ULID.to_string(),
downloaded: cur,
total,
});
}
})
}
fn assemble_blocking(
final_path: &Path,
final_end: u64,
parts: Vec<(PathBuf, u64, u64)>,
done: Arc<AtomicU64>,
ctx: DownloadContext,
) -> std::io::Result<()> {
use std::io::Read;
let final_file = std::fs::OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(final_path)?;
if final_end > 0 {
final_file.set_len(final_end)?;
}
let cluster_size =
NonZeroU64::new(Download::ASSEMBLY_CLUSTER_SIZE).expect("ASSEMBLY_CLUSTER_SIZE non-zero");
let cluster_mask = Download::ASSEMBLY_CLUSTER_SIZE - 1;
let mut reflink_disabled = false;
let mut buf = vec![0u8; COPY_BUF_SIZE];
let last_idx = parts.len().saturating_sub(1);
for (idx, (part_path, offset, size)) in parts.into_iter().enumerate() {
if size == 0 {
continue;
}
let mut part_file = std::fs::File::open(&part_path)?;
let is_last = idx == last_idx;
let aligned_offset = offset & cluster_mask == 0;
let aligned_size = size & cluster_mask == 0;
#[cfg(windows)]
let tail_reflinkable = false;
#[cfg(not(windows))]
let tail_reflinkable = is_last;
let reflinkable = !reflink_disabled && aligned_offset && (aligned_size || tail_reflinkable);
let reflinked = if reflinkable && let Some(len_nz) = NonZeroU64::new(size) {
let res = ReflinkBlockBuilder::new(&part_file, &final_file, len_nz)
.from_offset(0)
.to_offset(offset)
.cluster_size(cluster_size)
.reflink_block();
match res {
Ok(()) => true,
Err(e) => {
tracing::debug!(error = %e, "reflink failed, falling back to copy");
reflink_disabled = true;
false
}
}
} else {
false
};
if reflinked {
done.fetch_add(size, Ordering::Relaxed);
continue;
}
let mut write_offset = offset;
let mut remaining = size;
while remaining > 0 {
let want = remaining.min(buf.len() as u64) as usize;
let n = part_file.read(&mut buf[..want])?;
if n == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
format!(
"part file {} shorter than recorded size ({} bytes missing)",
part_path.display(),
remaining
),
));
}
pwrite_all(&final_file, &buf[..n], write_offset)?;
write_offset += n as u64;
remaining -= n as u64;
done.fetch_add(n as u64, Ordering::Relaxed);
}
}
ctx.emit(ProgressEvent::PhaseChanged(Phase::Flushing));
final_file.sync_data()?;
Ok(())
}
#[cfg(unix)]
fn pwrite_all(file: &std::fs::File, buf: &[u8], offset: u64) -> std::io::Result<()> {
use std::os::unix::fs::FileExt;
file.write_all_at(buf, offset)
}
#[cfg(windows)]
fn pwrite_all(file: &std::fs::File, buf: &[u8], offset: u64) -> std::io::Result<()> {
use std::os::windows::fs::FileExt;
let mut written = 0;
while written < buf.len() {
let n = file.seek_write(&buf[written..], offset + written as u64)?;
if n == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"failed to write whole buffer",
));
}
written += n;
}
Ok(())
}
#[cfg(not(any(unix, windows)))]
fn pwrite_all(file: &std::fs::File, buf: &[u8], offset: u64) -> std::io::Result<()> {
use std::io::{Seek, SeekFrom, Write};
let mut f = file;
f.seek(SeekFrom::Start(offset))?;
f.write_all(buf)
}
pub async fn sum_parts_on_disk(instruction: &Download, metadata: &DownloadMetadata) -> Option<u64> {
metadata.size?;
let part_futures = metadata.parts.values().map(|part| {
let part_path = instruction.part_path(&part.ulid);
async move {
match tokio::fs::metadata(&part_path).await {
Ok(meta) => meta.len(),
Err(_) => 0,
}
}
});
let sizes = futures::future::join_all(part_futures).await;
Some(sizes.into_iter().sum())
}
pub async fn persist_metadata(
metadata: &DownloadMetadata,
instruction: &Download,
) -> io::Result<()> {
let encoded = metadata.encode_length_delimited_to_vec();
persist_encoded_metadata(encoded.as_slice(), instruction).await
}
pub async fn persist_encoded_metadata(encoded: &[u8], instruction: &Download) -> io::Result<()> {
atomic_write(
instruction.metadata_path(),
instruction.metadata_temp_path(),
encoded,
)
.await
}