use crate::models::verify::{SignatureError, verify_minisign};
use sha2::{Digest, Sha256};
use std::fs;
use std::io::{self, BufReader, Read, Write};
use std::path::{Path, PathBuf};
#[derive(Debug, thiserror::Error)]
pub enum DownloadError {
#[error("io error on {path}: {source}")]
Io {
path: PathBuf,
#[source]
source: io::Error,
},
#[error("network error fetching {url}: {source}")]
Network {
url: String,
#[source]
source: Box<ureq::Error>,
},
#[error("checksum mismatch for {path}: expected {expected:.16}…, computed {actual:.16}…")]
ChecksumMismatch {
path: PathBuf,
expected: String,
actual: String,
},
#[error("signature invalid for {path}: {source}")]
SignatureInvalid {
path: PathBuf,
#[source]
source: SignatureError,
},
#[error("refusing to fetch model over a non-https URL: {url}")]
InsecureScheme { url: String },
#[error("download for {path} exceeded the {max_bytes}-byte cap")]
TooLarge { path: PathBuf, max_bytes: u64 },
}
impl From<SignatureError> for DownloadError {
fn from(source: SignatureError) -> Self {
DownloadError::SignatureInvalid {
path: PathBuf::from("(unknown)"),
source,
}
}
}
pub fn download_with_checksum(
url: &str,
expected_sha256: &str,
dest: &Path,
) -> Result<bool, DownloadError> {
download_with_checksum_and_signature(url, expected_sha256, None, dest)
}
pub fn download_with_checksum_and_signature(
url: &str,
expected_sha256: &str,
signature: Option<&str>,
dest: &Path,
) -> Result<bool, DownloadError> {
download_with_checksum_signature_and_cap(
url,
expected_sha256,
signature,
dest,
DEFAULT_MAX_MODEL_BYTES,
)
}
pub(crate) const DEFAULT_MAX_MODEL_BYTES: u64 = 1024 * 1024 * 1024;
pub(crate) fn download_with_checksum_signature_and_cap(
url: &str,
expected_sha256: &str,
signature: Option<&str>,
dest: &Path,
max_bytes: u64,
) -> Result<bool, DownloadError> {
if dest.exists() && verify_sha256(dest, expected_sha256).is_ok() {
if let Some(sig) = signature {
verify_minisign(dest, sig).map_err(|e| DownloadError::SignatureInvalid {
path: dest.to_path_buf(),
source: e,
})?;
}
return Ok(false);
}
if !url
.get(..8)
.is_some_and(|s| s.eq_ignore_ascii_case("https://"))
{
return Err(DownloadError::InsecureScheme {
url: url.to_owned(),
});
}
if let Some(parent) = dest.parent() {
fs::create_dir_all(parent).map_err(|e| DownloadError::Io {
path: parent.to_path_buf(),
source: e,
})?;
}
let mut tmp = dest.to_path_buf();
let original_name = dest.file_name().and_then(|s| s.to_str()).unwrap_or("model");
tmp.set_file_name(format!(".{original_name}.partial"));
let public_key = if signature.is_some() {
Some(
minisign_verify::PublicKey::from_base64(crate::models::verify::SIGNING_PUBKEY_BASE64)
.map_err(|e| DownloadError::SignatureInvalid {
path: dest.to_path_buf(),
source: SignatureError::BadPublicKey(format!("{e:?}")),
})?,
)
} else {
None
};
let sig = if let Some(sig_text) = signature {
Some(minisign_verify::Signature::decode(sig_text).map_err(|e| {
DownloadError::SignatureInvalid {
path: dest.to_path_buf(),
source: SignatureError::BadSignature(format!("{e:?}")),
}
})?)
} else {
None
};
let mut verifier = if let (Some(pk), Some(s)) = (&public_key, &sig) {
Some(
pk.verify_stream(s)
.map_err(|e| DownloadError::SignatureInvalid {
path: dest.to_path_buf(),
source: SignatureError::VerificationFailed(format!("{e:?}")),
})?,
)
} else {
None
};
let resp = ureq::get(url).call().map_err(|e| DownloadError::Network {
url: url.to_owned(),
source: Box::new(e),
})?;
let reader = BufReader::new(resp.into_body().into_reader());
let mut file = fs::File::create(&tmp).map_err(|e| DownloadError::Io {
path: tmp.clone(),
source: e,
})?;
let mut hasher = Sha256::new();
{
let mut on_chunk = |chunk: &[u8]| {
hasher.update(chunk);
if let Some(v) = verifier.as_mut() {
v.update(chunk);
}
};
write_capped(reader, &mut file, &tmp, max_bytes, &mut on_chunk)?;
}
file.flush().map_err(|e| DownloadError::Io {
path: tmp.clone(),
source: e,
})?;
drop(file);
let actual = format!("{:x}", hasher.finalize());
if actual != expected_sha256 {
let _ = fs::remove_file(&tmp);
return Err(DownloadError::ChecksumMismatch {
path: dest.to_path_buf(),
expected: expected_sha256.to_owned(),
actual,
});
}
if let Some(mut v) = verifier {
v.finalize().map_err(|e| {
let _ = fs::remove_file(&tmp);
DownloadError::SignatureInvalid {
path: dest.to_path_buf(),
source: SignatureError::VerificationFailed(format!("{e:?}")),
}
})?;
}
fs::rename(&tmp, dest).map_err(|e| DownloadError::Io {
path: tmp.clone(),
source: e,
})?;
Ok(true)
}
fn write_capped<R: Read>(
mut reader: R,
file: &mut fs::File,
tmp: &Path,
max_bytes: u64,
on_chunk: &mut dyn FnMut(&[u8]),
) -> Result<(), DownloadError> {
let mut buf = [0u8; 64 * 1024];
let mut written: u64 = 0;
loop {
let n = reader.read(&mut buf).map_err(|e| DownloadError::Io {
path: tmp.to_path_buf(),
source: e,
})?;
if n == 0 {
break;
}
written += n as u64;
if written > max_bytes {
let _ = fs::remove_file(tmp);
return Err(DownloadError::TooLarge {
path: tmp.to_path_buf(),
max_bytes,
});
}
file.write_all(&buf[..n]).map_err(|e| DownloadError::Io {
path: tmp.to_path_buf(),
source: e,
})?;
on_chunk(&buf[..n]);
}
Ok(())
}
pub fn verify_sha256(path: &Path, expected: &str) -> Result<(), DownloadError> {
let f = fs::File::open(path).map_err(|e| DownloadError::Io {
path: path.to_path_buf(),
source: e,
})?;
let mut reader = BufReader::new(f);
let mut hasher = Sha256::new();
let mut buf = [0u8; 64 * 1024];
loop {
let n = reader.read(&mut buf).map_err(|e| DownloadError::Io {
path: path.to_path_buf(),
source: e,
})?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
}
let actual = format!("{:x}", hasher.finalize());
if actual == expected {
Ok(())
} else {
Err(DownloadError::ChecksumMismatch {
path: path.to_path_buf(),
expected: expected.to_owned(),
actual,
})
}
}
#[allow(clippy::unwrap_used)]
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use std::io::Write;
use tempfile::TempDir;
const TEST_BYTES: &[u8] = b"polyvoice";
fn test_bytes_sha256() -> String {
use sha2::{Digest, Sha256};
let mut h = Sha256::new();
h.update(TEST_BYTES);
format!("{:x}", h.finalize())
}
#[test]
fn verify_existing_file_passes_when_hash_matches() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("data.bin");
fs::write(&path, TEST_BYTES).unwrap();
verify_sha256(&path, &test_bytes_sha256()).expect("hash must match");
}
#[test]
fn verify_existing_file_fails_when_hash_differs() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("data.bin");
fs::write(&path, b"different content").unwrap();
let err = verify_sha256(&path, &test_bytes_sha256()).expect_err("must mismatch");
assert!(matches!(err, DownloadError::ChecksumMismatch { .. }));
}
#[test]
#[cfg_attr(miri, ignore)]
fn verify_streams_large_file_without_loading_into_ram() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("big.bin");
let mut f = fs::File::create(&path).unwrap();
for _ in 0..5 * 1024 {
f.write_all(&[0u8; 1024]).unwrap();
}
let expected = sha256_of_zeros_5mb();
verify_sha256(&path, &expected).expect("streaming hash should match");
}
fn sha256_of_zeros_5mb() -> String {
use sha2::{Digest, Sha256};
let mut h = Sha256::new();
for _ in 0..5 * 1024 {
h.update([0u8; 1024]);
}
format!("{:x}", h.finalize())
}
#[test]
#[cfg_attr(miri, ignore)]
fn download_with_checksum_no_signature_fallback() {
let dir = TempDir::new().unwrap();
let dest = dir.path().join("cached.bin");
fs::write(&dest, TEST_BYTES).unwrap();
let sha = test_bytes_sha256();
let result = download_with_checksum_and_signature(
"http://[invalid:definitely:not:a:real:url]",
&sha,
None,
&dest,
);
assert!(
result.is_ok(),
"fallback should succeed: {:?}",
result.err()
);
assert!(!result.unwrap(), "should be cached (no download)");
let result2 =
download_with_checksum("http://[invalid:definitely:not:a:real:url]", &sha, &dest);
assert!(
result2.is_ok(),
"wrapper should succeed: {:?}",
result2.err()
);
assert!(!result2.unwrap(), "wrapper should also be cached");
}
#[test]
fn rejects_non_https_url() {
let dir = TempDir::new().unwrap();
let dest = dir.path().join("model.bin");
let err = download_with_checksum_and_signature(
"http://unreachable.invalid/model.bin",
&test_bytes_sha256(),
None,
&dest,
)
.expect_err("non-https URL must be rejected");
assert!(matches!(err, DownloadError::InsecureScheme { .. }));
assert!(!dest.exists(), "no file should be created");
assert!(
!dir.path().join(".model.bin.partial").exists(),
"no .partial should be created"
);
}
#[test]
fn aborts_when_stream_exceeds_cap() {
let dir = TempDir::new().unwrap();
let tmp = dir.path().join(".big.partial");
let mut file = fs::File::create(&tmp).unwrap();
let mut noop = |_: &[u8]| {};
let err = write_capped(
std::io::Cursor::new(vec![0u8; 100]),
&mut file,
&tmp,
10,
&mut noop,
)
.expect_err("stream over the cap must abort");
assert!(matches!(err, DownloadError::TooLarge { max_bytes: 10, .. }));
drop(file);
assert!(!tmp.exists(), ".partial must be deleted on cap overflow");
}
}