use std::path::Path;
use std::sync::Arc;
use std::time::Instant;
use snapdir_core::manifest::{Manifest, ManifestEntry, PathType};
use snapdir_core::merkle::Blake3Hasher;
use snapdir_core::store::StoreError;
use snapdir_core::Meter;
use crate::adaptive::{
p95_object_size, AdaptiveGate, AdaptivePolicy as ControllerPolicy, ControllerDriver, OpResult,
OpSample,
};
use crate::transfer::{
classify_error, run_adaptive, run_concurrent, AdaptivePolicy, RateLimiter, TransferConfig,
};
use crate::util::file_present_and_verified;
fn strip_leading_dot_slash(path: &str) -> &str {
let trimmed = path.strip_prefix("./").unwrap_or(path);
trimmed.strip_suffix('/').unwrap_or(trimmed)
}
pub(crate) fn write_atomic(target: &Path, bytes: &[u8]) -> Result<(), StoreError> {
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
if let Some(parent) = target.parent() {
std::fs::create_dir_all(parent)?;
}
let n = COUNTER.fetch_add(1, Ordering::Relaxed);
let pid = std::process::id();
let file_name = target
.file_name()
.map(|s| s.to_string_lossy().into_owned())
.unwrap_or_default();
let tmp = match target.parent() {
Some(parent) => parent.join(format!("{file_name}.{pid}.{n}.tmp")),
None => std::path::PathBuf::from(format!("{file_name}.{pid}.{n}.tmp")),
};
std::fs::write(&tmp, bytes)?;
std::fs::rename(&tmp, target)?;
Ok(())
}
pub(crate) async fn fetch_files_concurrent<'a, F, Fut>(
manifest: &'a Manifest,
dest: &Path,
config: &TransferConfig,
rate_limiter: &RateLimiter,
meter: Option<&Meter>,
meter_arc: Option<Arc<Meter>>,
download: F,
) -> Result<(), StoreError>
where
F: Fn(&'a ManifestEntry) -> Fut,
Fut: std::future::Future<Output = Result<Vec<u8>, StoreError>>,
{
let hasher = Blake3Hasher::new();
let mut to_download: Vec<(&ManifestEntry, std::path::PathBuf)> = Vec::new();
for entry in manifest.entries() {
let rel = strip_leading_dot_slash(&entry.path);
let target = dest.join(rel);
match entry.path_type {
PathType::Directory => {
std::fs::create_dir_all(&target)?;
}
PathType::File => {
if file_present_and_verified(&target, &entry.checksum, &hasher) {
if let Some(m) = meter {
m.add_skipped(1);
}
continue;
}
if let Some(parent) = target.parent() {
std::fs::create_dir_all(parent)?;
}
to_download.push((entry, target));
}
}
}
if let Some(m) = meter {
let total: u64 = to_download.iter().map(|(entry, _)| entry.size).sum();
m.set_total(total);
}
let download_one = |entry: &'a ManifestEntry, target: std::path::PathBuf| {
let download = &download;
let rate_limiter = &rate_limiter;
async move {
rate_limiter.acquire(entry.size).await;
if let Some(m) = meter {
m.object_started();
}
let bytes = download(entry).await?;
if let Some(m) = meter {
m.add_in(bytes.len() as u64);
}
write_atomic(&target, &bytes)?;
if let Some(m) = meter {
m.add_out(bytes.len() as u64);
m.object_finished();
}
Ok::<(), StoreError>(())
}
};
match config.adaptive {
AdaptivePolicy::Off => {
run_concurrent(to_download, config.concurrency, |(entry, target)| {
download_one(entry, target)
})
.await?;
}
AdaptivePolicy::On { fraction, ceiling } => {
run_adaptive_downloads(
to_download,
config,
rate_limiter,
meter_arc,
fraction,
ceiling,
download_one,
)
.await?;
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn run_adaptive_downloads<'a, D, DFut>(
to_download: Vec<(&'a ManifestEntry, std::path::PathBuf)>,
config: &TransferConfig,
rate_limiter: &RateLimiter,
meter_arc: Option<Arc<Meter>>,
fraction: f64,
ceiling: usize,
download_one: D,
) -> Result<(), StoreError>
where
D: Fn(&'a ManifestEntry, std::path::PathBuf) -> DFut,
DFut: std::future::Future<Output = Result<(), StoreError>>,
{
let sizes: Vec<u64> = to_download.iter().map(|(e, _)| e.size).collect();
let p95 = p95_object_size(&sizes);
let total_ram = snapdir_core::resources::total_ram_bytes().unwrap_or(0);
let policy = ControllerPolicy::new(fraction, ceiling, total_ram, config.max_bytes_per_sec);
let gate = AdaptiveGate::new(config.concurrency.get(), ceiling);
let limiter = rate_limiter.clone();
let rate_applier: Arc<dyn Fn(Option<u64>) + Send + Sync> = Arc::new(move |rate| {
let limiter = limiter.clone();
tokio::spawn(async move {
limiter.set_rate(rate).await;
});
});
let driver = ControllerDriver::new(policy, gate.clone(), p95, Some(rate_applier), meter_arc);
let tick_driver = driver.clone();
let mut ticker = tokio::time::interval(std::time::Duration::from_millis(250));
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
let tick_handle = tokio::spawn(async move {
loop {
ticker.tick().await;
tick_driver.tick();
}
});
let result = run_adaptive(to_download, &gate, |(entry, target)| {
let download_one = &download_one;
let driver = &driver;
async move {
let started = Instant::now();
let outcome = download_one(entry, target).await;
let latency = started.elapsed();
let (bytes, op_result) = match &outcome {
Ok(()) => (entry.size, OpResult::Ok),
Err(err) => (0, classify_error(err)),
};
driver.record_op(OpSample {
bytes,
latency,
result: op_result,
});
outcome
}
})
.await;
tick_handle.abort();
result.map(|_| ())
}
#[cfg(test)]
mod tests {
use super::*;
use snapdir_core::merkle::Hasher;
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
fn runtime() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new_multi_thread()
.worker_threads(4)
.enable_time()
.build()
.expect("build tokio runtime")
}
struct TempDir {
path: std::path::PathBuf,
}
impl TempDir {
fn new() -> Self {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let n = COUNTER.fetch_add(1, Ordering::Relaxed);
let path =
std::env::temp_dir().join(format!("snapdir-fetch-test-{}-{n}", std::process::id()));
std::fs::create_dir_all(&path).expect("create temp dir");
Self { path }
}
fn path(&self) -> &Path {
&self.path
}
}
impl Drop for TempDir {
fn drop(&mut self) {
let _ = std::fs::remove_dir_all(&self.path);
}
}
fn checksum_of(bytes: &[u8]) -> String {
Blake3Hasher::new().hash_hex(bytes)
}
fn manifest_for(files: &[(&str, &[u8])]) -> Manifest {
let mut m = Manifest::new();
m.push(ManifestEntry::new(
PathType::Directory,
"700",
"0".repeat(64),
0,
"./",
));
for (path, contents) in files {
m.push(ManifestEntry::new(
PathType::File,
"600",
checksum_of(contents),
contents.len() as u64,
format!("./{path}"),
));
}
Manifest::from_entries(m.entries().to_vec())
}
struct FakeDownloader {
contents: HashMap<String, Vec<u8>>,
called: Mutex<Vec<String>>,
in_flight: AtomicUsize,
high_water: AtomicUsize,
}
impl FakeDownloader {
fn new(files: &[(&str, &[u8])]) -> Arc<Self> {
let contents = files
.iter()
.map(|(_, c)| (checksum_of(c), c.to_vec()))
.collect();
Arc::new(Self {
contents,
called: Mutex::new(Vec::new()),
in_flight: AtomicUsize::new(0),
high_water: AtomicUsize::new(0),
})
}
async fn download(&self, entry: &ManifestEntry) -> Result<Vec<u8>, StoreError> {
self.called.lock().unwrap().push(entry.checksum.clone());
let cur = self.in_flight.fetch_add(1, Ordering::SeqCst) + 1;
self.high_water.fetch_max(cur, Ordering::SeqCst);
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
self.in_flight.fetch_sub(1, Ordering::SeqCst);
self.contents
.get(&entry.checksum)
.cloned()
.ok_or_else(|| StoreError::ObjectNotFound {
checksum: entry.checksum.clone(),
})
}
}
#[test]
fn concurrent_download_orchestrator_materializes_all() {
let files: &[(&str, &[u8])] = &[
("a.txt", b"alpha" as &[u8]),
("nested/b.txt", b"bravo"),
("nested/deep/c.txt", b"charlie"),
("d.txt", b"delta"),
];
let manifest = manifest_for(files);
for concurrency in [1usize, 4] {
let dest = TempDir::new();
let fake = FakeDownloader::new(files);
let cfg = TransferConfig::new(concurrency, None);
let limiter = RateLimiter::new(None);
let rt = runtime();
let fake_ref = Arc::clone(&fake);
rt.block_on(async {
fetch_files_concurrent(
&manifest,
dest.path(),
&cfg,
&limiter,
None,
None,
|entry| {
let fake = Arc::clone(&fake_ref);
async move { fake.download(entry).await }
},
)
.await
})
.expect("orchestrator must succeed");
for (path, contents) in files {
let got = std::fs::read(dest.path().join(path))
.unwrap_or_else(|e| panic!("missing {path}: {e}"));
assert_eq!(&got, contents, "wrong bytes for {path}");
}
assert!(dest.path().join("nested/deep").is_dir());
let hw = fake.high_water.load(Ordering::SeqCst);
let expected = concurrency.min(files.len());
assert_eq!(
hw, expected,
"concurrency={concurrency}: peak in-flight {hw} != expected {expected}"
);
}
}
#[test]
fn concurrent_download_skips_present_and_verified() {
let files: &[(&str, &[u8])] = &[
("present.txt", b"already-here" as &[u8]),
("missing.txt", b"needs-download"),
];
let manifest = manifest_for(files);
let dest = TempDir::new();
std::fs::write(dest.path().join("present.txt"), b"already-here").unwrap();
let fake = FakeDownloader::new(files);
let cfg = TransferConfig::new(4, None);
let limiter = RateLimiter::new(None);
let rt = runtime();
let fake_ref = Arc::clone(&fake);
rt.block_on(async {
fetch_files_concurrent(
&manifest,
dest.path(),
&cfg,
&limiter,
None,
None,
|entry| {
let fake = Arc::clone(&fake_ref);
async move { fake.download(entry).await }
},
)
.await
})
.expect("orchestrator must succeed");
let called = fake.called.lock().unwrap().clone();
let present_sum = checksum_of(b"already-here");
let missing_sum = checksum_of(b"needs-download");
assert!(
!called.contains(&present_sum),
"present+verified file must not be downloaded"
);
assert_eq!(
called,
vec![missing_sum],
"only the missing file should be downloaded"
);
assert_eq!(
std::fs::read(dest.path().join("missing.txt")).unwrap(),
b"needs-download"
);
}
#[test]
fn adaptive_download_materializes_all_within_ceiling() {
use crate::transfer::AdaptivePolicy;
let files: &[(&str, &[u8])] = &[
("a.txt", b"alpha" as &[u8]),
("nested/b.txt", b"bravo"),
("nested/deep/c.txt", b"charlie"),
("d.txt", b"delta"),
("e.txt", b"echo"),
];
let manifest = manifest_for(files);
let dest = TempDir::new();
let fake = FakeDownloader::new(files);
let cfg = TransferConfig::new(4, None).with_adaptive(AdaptivePolicy::On {
fraction: 0.8,
ceiling: 2,
});
let limiter = RateLimiter::new(None);
let rt = runtime();
let fake_ref = Arc::clone(&fake);
rt.block_on(async {
fetch_files_concurrent(
&manifest,
dest.path(),
&cfg,
&limiter,
None,
None,
|entry| {
let fake = Arc::clone(&fake_ref);
async move { fake.download(entry).await }
},
)
.await
})
.expect("adaptive orchestrator must succeed");
for (path, contents) in files {
let got = std::fs::read(dest.path().join(path))
.unwrap_or_else(|e| panic!("missing {path}: {e}"));
assert_eq!(&got, contents, "wrong bytes for {path}");
}
let hw = fake.high_water.load(Ordering::SeqCst);
assert!(
hw <= 2,
"effective concurrency must not exceed the ceiling 2, got {hw}"
);
assert_eq!(fake.called.lock().unwrap().len(), files.len());
}
#[test]
fn concurrent_download_propagates_error() {
let files: &[(&str, &[u8])] = &[
("ok1.txt", b"one" as &[u8]),
("boom.txt", b"two"),
("ok2.txt", b"three"),
];
let manifest = manifest_for(files);
let boom_sum = checksum_of(b"two");
let dest = TempDir::new();
let cfg = TransferConfig::new(4, None);
let limiter = RateLimiter::new(None);
let rt = runtime();
let boom = boom_sum.clone();
let result = rt.block_on(async {
fetch_files_concurrent(
&manifest,
dest.path(),
&cfg,
&limiter,
None,
None,
|entry| {
let boom = boom.clone();
async move {
if entry.checksum == boom {
return Err(StoreError::Backend {
message: "download blew up".to_owned(),
source: None,
});
}
Ok(b"unused".to_vec())
}
},
)
.await
});
let err = result.expect_err("the failing download must surface");
assert!(
matches!(err, StoreError::Backend { ref message, .. } if message == "download blew up"),
"unexpected error: {err:?}"
);
}
}