use crate::models::verify::{verify_minisign, SignatureError};
use sha2::{Digest, Sha256};
use std::fs;
use std::io::{self, BufReader, Read, Write};
use std::path::{Path, PathBuf};
#[derive(Debug, thiserror::Error)]
pub enum DownloadError {
#[error("io error on {path}: {source}")]
Io {
path: PathBuf,
#[source]
source: io::Error,
},
#[error("network error fetching {url}: {source}")]
Network {
url: String,
#[source]
source: Box<ureq::Error>,
},
#[error("checksum mismatch for {path}: expected {expected:.16}…, computed {actual:.16}…")]
ChecksumMismatch {
path: PathBuf,
expected: String,
actual: String,
},
#[error("signature invalid for {path}: {source}")]
SignatureInvalid {
path: PathBuf,
#[source]
source: SignatureError,
},
}
impl From<SignatureError> for DownloadError {
fn from(source: SignatureError) -> Self {
DownloadError::SignatureInvalid {
path: PathBuf::from("(unknown)"),
source,
}
}
}
pub fn download_with_checksum(
url: &str,
expected_sha256: &str,
dest: &Path,
) -> Result<bool, DownloadError> {
download_with_checksum_and_signature(url, expected_sha256, None, dest)
}
pub fn download_with_checksum_and_signature(
url: &str,
expected_sha256: &str,
signature: Option<&str>,
dest: &Path,
) -> Result<bool, DownloadError> {
if dest.exists() && verify_sha256(dest, expected_sha256).is_ok() {
if let Some(sig) = signature {
verify_minisign(dest, sig).map_err(|e| DownloadError::SignatureInvalid {
path: dest.to_path_buf(),
source: e,
})?;
}
return Ok(false);
}
if let Some(parent) = dest.parent() {
fs::create_dir_all(parent).map_err(|e| DownloadError::Io {
path: parent.to_path_buf(),
source: e,
})?;
}
let mut tmp = dest.to_path_buf();
let original_name = dest.file_name().and_then(|s| s.to_str()).unwrap_or("model");
tmp.set_file_name(format!(".{original_name}.partial"));
let public_key = if signature.is_some() {
Some(
minisign_verify::PublicKey::from_base64(crate::models::verify::SIGNING_PUBKEY_BASE64)
.map_err(|e| DownloadError::SignatureInvalid {
path: dest.to_path_buf(),
source: SignatureError::BadPublicKey(format!("{e:?}")),
})?,
)
} else {
None
};
let sig = if let Some(sig_text) = signature {
Some(
minisign_verify::Signature::decode(sig_text).map_err(|e| {
DownloadError::SignatureInvalid {
path: dest.to_path_buf(),
source: SignatureError::BadSignature(format!("{e:?}")),
}
})?,
)
} else {
None
};
let mut verifier =
if let (Some(pk), Some(s)) = (&public_key, &sig) {
Some(pk.verify_stream(s).map_err(|e| DownloadError::SignatureInvalid {
path: dest.to_path_buf(),
source: SignatureError::VerificationFailed(format!("{e:?}")),
})?)
} else {
None
};
let resp = ureq::get(url).call().map_err(|e| DownloadError::Network {
url: url.to_owned(),
source: Box::new(e),
})?;
let reader = resp.into_body().into_reader();
let mut reader = BufReader::new(reader);
let mut file = fs::File::create(&tmp).map_err(|e| DownloadError::Io {
path: tmp.clone(),
source: e,
})?;
let mut hasher = Sha256::new();
let mut buf = [0u8; 64 * 1024];
loop {
let n = reader.read(&mut buf).map_err(|e| DownloadError::Io {
path: tmp.clone(),
source: e,
})?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
file.write_all(&buf[..n]).map_err(|e| DownloadError::Io {
path: tmp.clone(),
source: e,
})?;
if let Some(ref mut v) = verifier {
v.update(&buf[..n]);
}
}
file.flush().map_err(|e| DownloadError::Io {
path: tmp.clone(),
source: e,
})?;
drop(file);
let actual = format!("{:x}", hasher.finalize());
if actual != expected_sha256 {
let _ = fs::remove_file(&tmp);
return Err(DownloadError::ChecksumMismatch {
path: dest.to_path_buf(),
expected: expected_sha256.to_owned(),
actual,
});
}
if let Some(mut v) = verifier {
v.finalize().map_err(|e| {
let _ = fs::remove_file(&tmp);
DownloadError::SignatureInvalid {
path: dest.to_path_buf(),
source: SignatureError::VerificationFailed(format!("{e:?}")),
}
})?;
}
fs::rename(&tmp, dest).map_err(|e| DownloadError::Io {
path: tmp.clone(),
source: e,
})?;
Ok(true)
}
pub fn verify_sha256(path: &Path, expected: &str) -> Result<(), DownloadError> {
let f = fs::File::open(path).map_err(|e| DownloadError::Io {
path: path.to_path_buf(),
source: e,
})?;
let mut reader = BufReader::new(f);
let mut hasher = Sha256::new();
let mut buf = [0u8; 64 * 1024];
loop {
let n = reader.read(&mut buf).map_err(|e| DownloadError::Io {
path: path.to_path_buf(),
source: e,
})?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
}
let actual = format!("{:x}", hasher.finalize());
if actual == expected {
Ok(())
} else {
Err(DownloadError::ChecksumMismatch {
path: path.to_path_buf(),
expected: expected.to_owned(),
actual,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use std::io::Write;
use tempfile::TempDir;
const TEST_BYTES: &[u8] = b"polyvoice";
fn test_bytes_sha256() -> String {
use sha2::{Digest, Sha256};
let mut h = Sha256::new();
h.update(TEST_BYTES);
format!("{:x}", h.finalize())
}
#[test]
fn verify_existing_file_passes_when_hash_matches() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("data.bin");
fs::write(&path, TEST_BYTES).unwrap();
verify_sha256(&path, &test_bytes_sha256()).expect("hash must match");
}
#[test]
fn verify_existing_file_fails_when_hash_differs() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("data.bin");
fs::write(&path, b"different content").unwrap();
let err = verify_sha256(&path, &test_bytes_sha256()).expect_err("must mismatch");
assert!(matches!(err, DownloadError::ChecksumMismatch { .. }));
}
#[test]
fn verify_streams_large_file_without_loading_into_ram() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("big.bin");
let mut f = fs::File::create(&path).unwrap();
for _ in 0..5 * 1024 {
f.write_all(&[0u8; 1024]).unwrap();
}
let expected = sha256_of_zeros_5mb();
verify_sha256(&path, &expected).expect("streaming hash should match");
}
fn sha256_of_zeros_5mb() -> String {
use sha2::{Digest, Sha256};
let mut h = Sha256::new();
for _ in 0..5 * 1024 {
h.update([0u8; 1024]);
}
format!("{:x}", h.finalize())
}
#[test]
fn download_with_checksum_no_signature_fallback() {
let dir = TempDir::new().unwrap();
let dest = dir.path().join("cached.bin");
fs::write(&dest, TEST_BYTES).unwrap();
let sha = test_bytes_sha256();
let result = download_with_checksum_and_signature(
"http://[invalid:definitely:not:a:real:url]",
&sha,
None,
&dest,
);
assert!(
result.is_ok(),
"fallback should succeed: {:?}",
result.err()
);
assert!(!result.unwrap(), "should be cached (no download)");
let result2 = download_with_checksum(
"http://[invalid:definitely:not:a:real:url]",
&sha,
&dest,
);
assert!(
result2.is_ok(),
"wrapper should succeed: {:?}",
result2.err()
);
assert!(!result2.unwrap(), "wrapper should also be cached");
}
}