use std::fs::{self, File};
use std::io::{self, Write};
use std::path::{Path, PathBuf};
use std::sync::OnceLock;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use flate2::read::GzDecoder;
use sha2::{Digest, Sha256};
use tar::Archive;
use crate::classifier::manifest::Manifest;
use crate::error::{NlError, NlResult};
pub trait Downloader {
fn fetch(&self, url: &str, sink: &mut dyn Write) -> NlResult<u64>;
}
const UREQ_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
const UREQ_READ_TIMEOUT: Duration = Duration::from_secs(300);
fn ureq_agent() -> &'static ureq::Agent {
static AGENT: OnceLock<ureq::Agent> = OnceLock::new();
AGENT.get_or_init(|| {
ureq::AgentBuilder::new()
.timeout_connect(UREQ_CONNECT_TIMEOUT)
.timeout_read(UREQ_READ_TIMEOUT)
.build()
})
}
pub struct UreqDownloader;
impl Downloader for UreqDownloader {
fn fetch(&self, url: &str, sink: &mut dyn Write) -> NlResult<u64> {
let response = ureq_agent()
.get(url)
.call()
.map_err(|e| NlError::DownloadFailed(format!("ureq GET {url} failed: {e}")))?;
if response.status() < 200 || response.status() >= 300 {
return Err(NlError::DownloadFailed(format!(
"HTTP {} from {url}",
response.status()
)));
}
let mut reader = response.into_reader();
let written = io::copy(&mut reader, sink)
.map_err(|e| NlError::DownloadFailed(format!("body copy from {url} failed: {e}")))?;
Ok(written)
}
}
pub struct FileDownloader;
impl Downloader for FileDownloader {
fn fetch(&self, url: &str, sink: &mut dyn Write) -> NlResult<u64> {
let path = url.strip_prefix("file://").ok_or_else(|| {
NlError::DownloadFailed(format!("FileDownloader requires file:// URL, got {url:?}"))
})?;
let mut file =
File::open(path).map_err(|e| NlError::DownloadFailed(format!("open {path}: {e}")))?;
let written = io::copy(&mut file, sink)
.map_err(|e| NlError::DownloadFailed(format!("read {path}: {e}")))?;
Ok(written)
}
}
pub fn ensure_model_in_cache(
cache_dir: &Path,
expected_manifest: &Manifest,
allow_download: bool,
) -> NlResult<PathBuf> {
ensure_model_in_cache_with(
cache_dir,
expected_manifest,
allow_download,
&UreqDownloader,
)
}
pub fn ensure_model_in_cache_with(
cache_dir: &Path,
expected_manifest: &Manifest,
allow_download: bool,
downloader: &dyn Downloader,
) -> NlResult<PathBuf> {
if cache_dir.join("manifest.json").exists() {
return Ok(cache_dir.to_path_buf());
}
if !allow_download {
return Err(NlError::DownloadDisabled);
}
fs::create_dir_all(cache_dir)?;
let archive_path = cache_dir.join(&expected_manifest.archive);
let tmp_path = unique_tmp_path(&archive_path);
let actual_hash =
stream_to_file_with_hash(downloader, &expected_manifest.download_url, &tmp_path)
.inspect_err(|_| {
let _ = fs::remove_file(&tmp_path);
})?;
if !sha256_eq(&actual_hash, &expected_manifest.sha256) {
let _ = fs::remove_file(&tmp_path);
return Err(NlError::ManifestSha256Mismatch {
file: expected_manifest.archive.clone(),
expected: expected_manifest.sha256.clone(),
actual: actual_hash,
});
}
if let Err(e) = fs::rename(&tmp_path, &archive_path) {
match e.kind() {
io::ErrorKind::AlreadyExists => {
let _ = fs::remove_file(&tmp_path);
}
_ => {
let _ = fs::remove_file(&tmp_path);
return Err(NlError::DownloadFailed(format!(
"rename {} -> {}: {e}",
tmp_path.display(),
archive_path.display()
)));
}
}
}
let staging = unique_extract_dir(cache_dir);
if let Err(e) = extract_targz_into(&archive_path, &staging) {
let _ = fs::remove_dir_all(&staging);
return Err(e);
}
if let Err(e) = promote_extracted(&staging, cache_dir) {
let _ = fs::remove_dir_all(&staging);
return Err(e);
}
Ok(cache_dir.to_path_buf())
}
fn stream_to_file_with_hash(
downloader: &dyn Downloader,
url: &str,
path: &Path,
) -> NlResult<String> {
let file = File::create(path)
.map_err(|e| NlError::DownloadFailed(format!("create {}: {e}", path.display())))?;
let mut sink = HashingWriter {
inner: file,
hasher: Sha256::new(),
};
downloader.fetch(url, &mut sink)?;
sink.inner
.sync_all()
.map_err(|e| NlError::DownloadFailed(format!("sync_all {}: {e}", path.display())))?;
Ok(hex_lower(&sink.hasher.finalize()))
}
struct HashingWriter<W: Write> {
inner: W,
hasher: Sha256,
}
impl<W: Write> Write for HashingWriter<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let n = self.inner.write(buf)?;
if n > 0 {
self.hasher.update(&buf[..n]);
}
Ok(n)
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
fn sha256_eq(a: &str, b: &str) -> bool {
let a = a.as_bytes();
let b = b.as_bytes();
if a.len() != b.len() {
return false;
}
let mut diff: u8 = 0;
for (x, y) in a.iter().zip(b.iter()) {
diff |= if x.eq_ignore_ascii_case(y) { 0 } else { 1 };
}
diff == 0
}
fn hex_lower(bytes: &[u8]) -> String {
const HEX: &[u8; 16] = b"0123456789abcdef";
let mut out = String::with_capacity(bytes.len() * 2);
for byte in bytes {
out.push(HEX[(byte >> 4) as usize] as char);
out.push(HEX[(byte & 0x0f) as usize] as char);
}
out
}
fn unique_tmp_path(archive_path: &Path) -> PathBuf {
let mut buf = archive_path.as_os_str().to_owned();
buf.push(format!(
".tmp.{}.{}",
std::process::id(),
unique_suffix_nanos()
));
PathBuf::from(buf)
}
fn unique_extract_dir(cache_dir: &Path) -> PathBuf {
cache_dir.join(format!(
".extract.tmp.{}.{}",
std::process::id(),
unique_suffix_nanos()
))
}
fn unique_suffix_nanos() -> u128 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| d.as_nanos())
}
fn extract_targz_into(archive_path: &Path, into: &Path) -> NlResult<()> {
fs::create_dir_all(into).map_err(|e| {
NlError::DownloadFailed(format!("create staging dir {}: {e}", into.display()))
})?;
let file = File::open(archive_path)
.map_err(|e| NlError::DownloadFailed(format!("open {}: {e}", archive_path.display())))?;
let gz = GzDecoder::new(file);
let mut archive = Archive::new(gz);
archive.unpack(into).map_err(|e| {
NlError::DownloadFailed(format!(
"untar {} -> {}: {e}",
archive_path.display(),
into.display()
))
})
}
fn promote_extracted(staging: &Path, cache_dir: &Path) -> NlResult<()> {
let source = pick_extracted_root(staging)?;
for entry in fs::read_dir(&source)
.map_err(|e| NlError::DownloadFailed(format!("read_dir {}: {e}", source.display())))?
{
let entry = entry.map_err(|e| {
NlError::DownloadFailed(format!("dir entry under {}: {e}", source.display()))
})?;
let from = entry.path();
let to = cache_dir.join(entry.file_name());
match fs::rename(&from, &to) {
Ok(()) => {}
Err(e) if e.kind() == io::ErrorKind::AlreadyExists => {
let _ = if from.is_dir() {
fs::remove_dir_all(&from)
} else {
fs::remove_file(&from)
};
}
Err(e) => {
return Err(NlError::DownloadFailed(format!(
"promote {} -> {}: {e}",
from.display(),
to.display()
)));
}
}
}
let _ = fs::remove_dir_all(staging);
Ok(())
}
fn pick_extracted_root(staging: &Path) -> NlResult<PathBuf> {
if staging.join("manifest.json").is_file() {
return Ok(staging.to_path_buf());
}
let entries: Vec<_> = fs::read_dir(staging)
.map_err(|e| NlError::DownloadFailed(format!("read_dir {}: {e}", staging.display())))?
.collect::<Result<_, _>>()
.map_err(|e| {
NlError::DownloadFailed(format!("dir entry under {}: {e}", staging.display()))
})?;
if entries.len() == 1 {
let candidate = entries[0].path();
if candidate.is_dir() && candidate.join("manifest.json").is_file() {
return Ok(candidate);
}
}
Err(NlError::DownloadFailed(format!(
"extracted archive at {} does not contain a manifest.json",
staging.display()
)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sha256_eq_compares_lowercase_hex() {
assert!(sha256_eq("abc123", "ABC123"));
assert!(sha256_eq("abc123", "abc123"));
assert!(!sha256_eq("abc123", "abc124"));
assert!(!sha256_eq("abc", "abcd"));
}
#[test]
fn hex_lower_roundtrips_known_values() {
assert_eq!(hex_lower(&[0x00, 0xff, 0xab]), "00ffab");
assert_eq!(hex_lower(&[]), "");
}
}