use std::path::PathBuf;
use std::sync::{Arc, OnceLock};
use std::time::Instant;
use console::Term;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiError, Progress};
use hf_hub::{Cache, Repo, RepoType};
use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle};
use thiserror::Error;
use crate::manifest::{paths_from_downloads, ModelComponent, ModelFile, ModelManifest};
use crate::ModelPaths;
#[derive(Debug, Clone)]
pub enum DownloadProgressEvent {
FileStart {
filename: String,
file_index: usize,
total_files: usize,
size_bytes: u64,
},
FileProgress {
filename: String,
file_index: usize,
bytes_downloaded: u64,
bytes_total: u64,
},
FileDone {
filename: String,
file_index: usize,
total_files: usize,
},
}
pub type DownloadProgressCallback = Arc<dyn Fn(DownloadProgressEvent) + Send + Sync>;
#[derive(Debug, Clone, Default)]
pub struct PullOptions {
pub skip_verify: bool,
}
#[derive(Debug, Error)]
pub enum DownloadError {
#[error(
"Model requires access approval on HuggingFace.\n\n 1. Visit: https://huggingface.co/{repo}\n 2. Accept the license agreement\n 3. Create a token at: https://huggingface.co/settings/tokens\n 4. Set: export HF_TOKEN=hf_...\n 5. Retry: mold pull {model}"
)]
GatedModel { repo: String, model: String },
#[error(
"Authentication required for repository {repo}.\n\n 1. Create a token at: https://huggingface.co/settings/tokens\n (select at least \"Read\" access)\n 2. Set: export HF_TOKEN=hf_...\n Or run: huggingface-cli login\n 3. Retry: mold pull {model}\n\n If HF_TOKEN is already set, it may be invalid or expired."
)]
Unauthorized { repo: String, model: String },
#[error("Download failed for {filename} from {repo}: {source}")]
DownloadFailed {
repo: String,
filename: String,
source: ApiError,
},
#[error("SHA-256 mismatch for {filename}\n Expected: {expected}\n Got: {actual}\n\nThe corrupted file has been removed. Re-run: mold pull {model}\nIf the file was intentionally updated on HuggingFace, use: mold pull {model} --skip-verify")]
Sha256Mismatch {
filename: String,
expected: String,
actual: String,
model: String,
},
#[error("Failed to build HuggingFace API client: {0}")]
ApiSetup(#[from] ApiError),
#[error("Failed to build sync HuggingFace API client: {0}")]
SyncApiSetup(String),
#[error("Sync download failed for {filename} from {repo}: {message}")]
SyncDownloadFailed {
repo: String,
filename: String,
message: String,
},
#[error("Missing component after download — this is a bug")]
MissingComponent,
#[error("IO error during file placement: {0}")]
FilePlacement(String),
#[error("Unknown model '{model}'. No manifest found.")]
UnknownModel { model: String },
#[error("Failed to save config: {0}")]
ConfigSave(String),
}
fn resolve_hf_token() -> Option<String> {
if let Ok(token) = std::env::var("HF_TOKEN") {
let token = token.trim().to_string();
if !token.is_empty() {
return Some(token);
}
}
Cache::new(hf_cache_dir())
.token()
.or_else(|| Cache::from_env().token())
}
fn models_dir() -> PathBuf {
static DIR: OnceLock<PathBuf> = OnceLock::new();
DIR.get_or_init(|| {
let dir = crate::Config::load_or_default().resolved_models_dir();
let _ = std::fs::create_dir_all(&dir);
dir
})
.clone()
}
fn hf_cache_dir() -> PathBuf {
static DIR: OnceLock<PathBuf> = OnceLock::new();
DIR.get_or_init(|| {
let dir = models_dir().join(".hf-cache");
let _ = std::fs::create_dir_all(&dir);
dir
})
.clone()
}
fn hardlink_or_copy(src: &std::path::Path, dst: &std::path::Path) -> Result<(), DownloadError> {
let real_src = src.canonicalize().map_err(|e| {
DownloadError::FilePlacement(format!(
"source file not found after download: {} ({e})",
src.display()
))
})?;
if dst.exists() {
if let (Ok(src_meta), Ok(dst_meta)) = (real_src.metadata(), dst.metadata()) {
if src_meta.len() == dst_meta.len() {
return Ok(());
}
}
}
if dst.symlink_metadata().is_ok() {
let _ = std::fs::remove_file(dst);
}
if let Some(parent) = dst.parent() {
std::fs::create_dir_all(parent).map_err(|e| {
DownloadError::FilePlacement(format!(
"failed to create directory {}: {e}",
parent.display()
))
})?;
}
match std::fs::hard_link(&real_src, dst) {
Ok(()) => return Ok(()),
Err(_e) => {
}
}
std::fs::copy(&real_src, dst).map_err(|e| {
DownloadError::FilePlacement(format!(
"failed to copy {} → {}: {e}",
real_src.display(),
dst.display()
))
})?;
Ok(())
}
pub fn compute_sha256(path: &std::path::Path) -> anyhow::Result<String> {
use sha2::{Digest, Sha256};
let mut file = std::fs::File::open(path)?;
let mut hasher = Sha256::new();
std::io::copy(&mut file, &mut hasher)?;
Ok(format!("{:x}", hasher.finalize()))
}
pub fn verify_sha256(path: &std::path::Path, expected: &str) -> anyhow::Result<bool> {
Ok(compute_sha256(path)? == expected)
}
pub fn pulling_marker_rel_path(model_name: &str) -> PathBuf {
let canonical = crate::manifest::resolve_model_name(model_name);
PathBuf::from(canonical.replace(':', "-")).join(".pulling")
}
fn pulling_marker_path(model_name: &str) -> PathBuf {
models_dir().join(pulling_marker_rel_path(model_name))
}
fn write_pulling_marker(model_name: &str) -> Result<(), DownloadError> {
let path = pulling_marker_path(model_name);
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).map_err(|e| {
DownloadError::FilePlacement(format!(
"failed to create directory for pull marker {}: {e}",
parent.display()
))
})?;
}
std::fs::write(&path, model_name).map_err(|e| {
DownloadError::FilePlacement(format!(
"failed to write pull marker {}: {e}",
path.display()
))
})
}
pub fn remove_pulling_marker(model_name: &str) {
let path = pulling_marker_path(model_name);
let _ = std::fs::remove_file(path);
}
pub fn has_pulling_marker(model_name: &str) -> bool {
let canonical = crate::manifest::resolve_model_name(model_name);
pulling_marker_path(&canonical).exists()
}
fn verify_file_integrity(
clean_path: &std::path::Path,
file: &ModelFile,
model_name: &str,
skip_verify: bool,
) -> Result<(), DownloadError> {
let expected = match file.sha256 {
Some(h) => h,
None => return Ok(()),
};
if skip_verify {
return Ok(());
}
match compute_sha256(clean_path) {
Ok(actual) if actual == expected => Ok(()),
Ok(actual) => {
let _ = std::fs::remove_file(clean_path);
Err(DownloadError::Sha256Mismatch {
filename: file.hf_filename.clone(),
expected: expected.to_string(),
actual,
model: model_name.to_string(),
})
}
Err(e) => {
eprintln!(
"warning: failed to verify SHA-256 for {}: {e}",
file.hf_filename
);
Ok(())
}
}
}
fn truncate_filename(name: &str, max_len: usize) -> String {
if name.len() <= max_len || max_len < 8 {
return name.to_string();
}
let suffix_len = max_len - 3; let start = name.len() - suffix_len;
format!("...{}", &name[start..])
}
fn filename_column_width() -> usize {
let term_width = Term::stderr().size().1 as usize;
term_width.saturating_sub(75).max(12)
}
#[derive(Clone)]
struct DownloadProgress {
bar: ProgressBar,
max_msg_len: usize,
filename: String,
}
impl DownloadProgress {
fn new(bar: ProgressBar, max_msg_len: usize) -> Self {
Self {
bar,
max_msg_len,
filename: String::new(),
}
}
}
impl Progress for DownloadProgress {
async fn init(&mut self, size: usize, filename: &str) {
self.bar.set_length(size as u64);
self.filename = truncate_filename(filename, self.max_msg_len);
self.bar.set_message(self.filename.clone());
}
async fn update(&mut self, size: usize) {
self.bar.inc(size as u64);
}
async fn finish(&mut self) {
self.bar.finish_with_message(self.filename.clone());
}
}
#[derive(Clone)]
struct CallbackProgress {
callback: DownloadProgressCallback,
file_index: usize,
total_files: usize,
accumulated: u64,
total: u64,
filename: String,
last_emit: Instant,
}
impl CallbackProgress {
fn new(callback: DownloadProgressCallback, file_index: usize, total_files: usize) -> Self {
Self {
callback,
file_index,
total_files,
accumulated: 0,
total: 0,
filename: String::new(),
last_emit: Instant::now(),
}
}
}
impl Progress for CallbackProgress {
async fn init(&mut self, size: usize, filename: &str) {
self.total = size as u64;
self.accumulated = 0;
self.filename = filename.to_string();
(self.callback)(DownloadProgressEvent::FileStart {
filename: self.filename.clone(),
file_index: self.file_index,
total_files: self.total_files,
size_bytes: self.total,
});
}
async fn update(&mut self, size: usize) {
self.accumulated += size as u64;
let now = Instant::now();
if now.duration_since(self.last_emit).as_millis() >= 250 || self.accumulated >= self.total {
self.last_emit = now;
(self.callback)(DownloadProgressEvent::FileProgress {
filename: self.filename.clone(),
file_index: self.file_index,
bytes_downloaded: self.accumulated,
bytes_total: self.total,
});
}
}
async fn finish(&mut self) {
(self.callback)(DownloadProgressEvent::FileDone {
filename: self.filename.clone(),
file_index: self.file_index,
total_files: self.total_files,
});
}
}
fn is_already_placed(
clean_path: &std::path::Path,
file: &ModelFile,
model_name: &str,
skip_verify: bool,
) -> bool {
let size_ok = clean_path
.metadata()
.map(|m| m.len() == file.size_bytes)
.unwrap_or(false);
if !size_ok {
return false;
}
verify_file_integrity(clean_path, file, model_name, skip_verify).is_ok()
}
pub async fn pull_model(
manifest: &ModelManifest,
opts: &PullOptions,
) -> Result<ModelPaths, DownloadError> {
write_pulling_marker(&manifest.name)?;
let mut builder = ApiBuilder::from_env().with_cache_dir(hf_cache_dir());
if let Some(token) = resolve_hf_token() {
builder = builder.with_token(Some(token));
}
let api = builder.build()?;
let multi = MultiProgress::with_draw_target(ProgressDrawTarget::stderr());
let msg_width = filename_column_width();
let bar_style = ProgressStyle::with_template(&format!(
" {{msg:<{msg_width}}} [{{bar:30.cyan/dim}}] {{bytes}}/{{total_bytes}} ({{bytes_per_sec}}, {{eta}})"
))
.unwrap()
.progress_chars("━╸─");
let mdir = models_dir();
let mut downloads: Vec<(ModelComponent, PathBuf)> = Vec::new();
for file in &manifest.files {
let clean_rel = crate::manifest::storage_path(manifest, file);
let clean_path = mdir.join(&clean_rel);
if is_already_placed(&clean_path, file, &manifest.name, opts.skip_verify) {
downloads.push((file.component, clean_path));
continue;
}
let bar = multi.add(ProgressBar::new(file.size_bytes));
bar.set_style(bar_style.clone());
bar.set_message(truncate_filename(&file.hf_filename, msg_width));
let hf_path = download_file(
&api,
file,
DownloadProgress::new(bar, msg_width),
&manifest.name,
)
.await?;
hardlink_or_copy(&hf_path, &clean_path)?;
verify_file_integrity(&clean_path, file, &manifest.name, opts.skip_verify)?;
downloads.push((file.component, clean_path));
}
remove_pulling_marker(&manifest.name);
paths_from_downloads(&downloads).ok_or(DownloadError::MissingComponent)
}
pub async fn pull_model_with_callback(
manifest: &ModelManifest,
callback: DownloadProgressCallback,
opts: &PullOptions,
) -> Result<ModelPaths, DownloadError> {
write_pulling_marker(&manifest.name)?;
let mut builder = ApiBuilder::from_env().with_cache_dir(hf_cache_dir());
if let Some(token) = resolve_hf_token() {
builder = builder.with_token(Some(token));
}
let api = builder.build()?;
let mdir = models_dir();
let mut downloads: Vec<(ModelComponent, PathBuf)> = Vec::new();
let total_to_download = manifest
.files
.iter()
.filter(|file| {
let clean_path = mdir.join(crate::manifest::storage_path(manifest, file));
!is_already_placed(&clean_path, file, &manifest.name, opts.skip_verify)
})
.count();
let mut download_idx = 0;
for file in &manifest.files {
let clean_rel = crate::manifest::storage_path(manifest, file);
let clean_path = mdir.join(&clean_rel);
if is_already_placed(&clean_path, file, &manifest.name, opts.skip_verify) {
downloads.push((file.component, clean_path));
continue;
}
let progress = CallbackProgress::new(callback.clone(), download_idx, total_to_download);
download_idx += 1;
let hf_path = download_file(&api, file, progress, &manifest.name).await?;
hardlink_or_copy(&hf_path, &clean_path)?;
verify_file_integrity(&clean_path, file, &manifest.name, opts.skip_verify)?;
downloads.push((file.component, clean_path));
}
remove_pulling_marker(&manifest.name);
paths_from_downloads(&downloads).ok_or(DownloadError::MissingComponent)
}
async fn pull_model_files_only(
manifest: &ModelManifest,
opts: &PullOptions,
) -> Result<(), DownloadError> {
write_pulling_marker(&manifest.name)?;
let mut builder = ApiBuilder::from_env().with_cache_dir(hf_cache_dir());
if let Some(token) = resolve_hf_token() {
builder = builder.with_token(Some(token));
}
let api = builder.build()?;
let multi = MultiProgress::with_draw_target(ProgressDrawTarget::stderr());
let msg_width = filename_column_width();
let bar_style = ProgressStyle::with_template(&format!(
" {{msg:<{msg_width}}} [{{bar:30.cyan/dim}}] {{bytes}}/{{total_bytes}} ({{bytes_per_sec}}, {{eta}})"
))
.unwrap()
.progress_chars("━╸─");
let mdir = models_dir();
for file in &manifest.files {
let clean_rel = crate::manifest::storage_path(manifest, file);
let clean_path = mdir.join(&clean_rel);
if is_already_placed(&clean_path, file, &manifest.name, opts.skip_verify) {
continue;
}
let bar = multi.add(ProgressBar::new(file.size_bytes));
bar.set_style(bar_style.clone());
bar.set_message(truncate_filename(&file.hf_filename, msg_width));
let hf_path = download_file(
&api,
file,
DownloadProgress::new(bar, msg_width),
&manifest.name,
)
.await?;
hardlink_or_copy(&hf_path, &clean_path)?;
verify_file_integrity(&clean_path, file, &manifest.name, opts.skip_verify)?;
}
remove_pulling_marker(&manifest.name);
Ok(())
}
async fn pull_model_files_only_with_callback(
manifest: &ModelManifest,
callback: DownloadProgressCallback,
opts: &PullOptions,
) -> Result<(), DownloadError> {
write_pulling_marker(&manifest.name)?;
let mut builder = ApiBuilder::from_env().with_cache_dir(hf_cache_dir());
if let Some(token) = resolve_hf_token() {
builder = builder.with_token(Some(token));
}
let api = builder.build()?;
let mdir = models_dir();
let total_to_download = manifest
.files
.iter()
.filter(|file| {
let clean_path = mdir.join(crate::manifest::storage_path(manifest, file));
!is_already_placed(&clean_path, file, &manifest.name, opts.skip_verify)
})
.count();
let mut download_idx = 0;
for file in &manifest.files {
let clean_rel = crate::manifest::storage_path(manifest, file);
let clean_path = mdir.join(&clean_rel);
if is_already_placed(&clean_path, file, &manifest.name, opts.skip_verify) {
continue;
}
let progress = CallbackProgress::new(callback.clone(), download_idx, total_to_download);
download_idx += 1;
let hf_path = download_file(&api, file, progress, &manifest.name).await?;
hardlink_or_copy(&hf_path, &clean_path)?;
verify_file_integrity(&clean_path, file, &manifest.name, opts.skip_verify)?;
}
remove_pulling_marker(&manifest.name);
Ok(())
}
fn extract_http_status(err: &ApiError) -> Option<u16> {
if let ApiError::RequestError(reqwest_err) = err {
reqwest_err.status().map(|s| s.as_u16())
} else {
None
}
}
async fn download_file<P: Progress + Clone + Send + Sync + 'static>(
api: &Api,
file: &ModelFile,
progress: P,
model_name: &str,
) -> Result<PathBuf, DownloadError> {
let repo = api.repo(Repo::new(file.hf_repo.clone(), RepoType::Model));
match repo
.download_with_progress(&file.hf_filename, progress)
.await
{
Ok(path) => Ok(path),
Err(e) => {
let status = extract_http_status(&e);
let err_str = e.to_string();
if status == Some(401) || err_str.contains("401") || err_str.contains("Unauthorized") {
Err(DownloadError::Unauthorized {
repo: file.hf_repo.clone(),
model: model_name.to_string(),
})
} else if status == Some(403)
|| err_str.contains("403")
|| err_str.contains("Forbidden")
|| err_str.contains("gated")
|| err_str.contains("Access denied")
{
Err(DownloadError::GatedModel {
repo: file.hf_repo.clone(),
model: model_name.to_string(),
})
} else {
Err(DownloadError::DownloadFailed {
repo: file.hf_repo.clone(),
filename: file.hf_filename.clone(),
source: e,
})
}
}
}
}
pub fn download_single_file_sync(
hf_repo: &str,
hf_filename: &str,
target_subdir: Option<&str>,
) -> Result<PathBuf, DownloadError> {
use hf_hub::api::sync::ApiBuilder;
let mut builder = ApiBuilder::from_env().with_cache_dir(hf_cache_dir());
if let Some(token) = resolve_hf_token() {
builder = builder.with_token(Some(token));
}
let api = builder
.build()
.map_err(|e| DownloadError::SyncApiSetup(e.to_string()))?;
let repo = api.repo(Repo::new(hf_repo.to_string(), RepoType::Model));
let hf_path = repo.get(hf_filename).map_err(|e| {
let err_str = e.to_string();
if err_str.contains("401") || err_str.contains("Unauthorized") {
DownloadError::Unauthorized {
repo: hf_repo.to_string(),
model: String::new(),
}
} else if err_str.contains("403")
|| err_str.contains("Forbidden")
|| err_str.contains("gated")
|| err_str.contains("Access denied")
{
DownloadError::GatedModel {
repo: hf_repo.to_string(),
model: String::new(),
}
} else {
DownloadError::SyncDownloadFailed {
repo: hf_repo.to_string(),
filename: hf_filename.to_string(),
message: err_str,
}
}
})?;
if let Some(subdir) = target_subdir {
let leaf = hf_filename.rsplit('/').next().unwrap_or(hf_filename);
let clean_path = models_dir().join(subdir).join(leaf);
hardlink_or_copy(&hf_path, &clean_path)?;
Ok(clean_path)
} else {
Ok(hf_path)
}
}
pub fn cached_file_path(
hf_repo: &str,
hf_filename: &str,
target_subdir: Option<&str>,
) -> Option<PathBuf> {
if let Some(subdir) = target_subdir {
let leaf = hf_filename.rsplit('/').next().unwrap_or(hf_filename);
let clean_path = models_dir().join(subdir).join(leaf);
if clean_path.exists() {
return Some(clean_path);
}
}
let new_cache = Cache::new(hf_cache_dir());
let new_repo = new_cache.repo(Repo::new(hf_repo.to_string(), RepoType::Model));
if let Some(path) = new_repo.get(hf_filename) {
return Some(path);
}
let old_cache = Cache::new(models_dir());
let old_repo = old_cache.repo(Repo::new(hf_repo.to_string(), RepoType::Model));
if let Some(path) = old_repo.get(hf_filename) {
return Some(path);
}
let default_cache = Cache::from_env();
let default_repo = default_cache.repo(Repo::new(hf_repo.to_string(), RepoType::Model));
default_repo.get(hf_filename)
}
pub async fn pull_and_configure(
model: &str,
opts: &PullOptions,
) -> Result<(crate::Config, Option<ModelPaths>), DownloadError> {
use crate::config::Config;
use crate::manifest::{find_manifest, resolve_model_name};
let canonical = resolve_model_name(model);
let manifest = find_manifest(&canonical).ok_or_else(|| DownloadError::UnknownModel {
model: model.to_string(),
})?;
if manifest.is_utility() {
pull_model_files_only(manifest, opts).await?;
let config = Config::load_or_default();
return Ok((config, None));
}
let paths = pull_model(manifest, opts).await?;
let mut config = Config::load_or_default();
let model_config = manifest.to_model_config(&paths);
if !Config::exists_on_disk() {
config.default_model = manifest.name.clone();
}
config.upsert_model(manifest.name.clone(), model_config);
config
.save()
.map_err(|e| DownloadError::ConfigSave(e.to_string()))?;
Ok((config, Some(paths)))
}
pub async fn pull_and_configure_with_callback(
model: &str,
callback: DownloadProgressCallback,
opts: &PullOptions,
) -> Result<(crate::Config, Option<ModelPaths>), DownloadError> {
use crate::config::Config;
use crate::manifest::{find_manifest, resolve_model_name};
let canonical = resolve_model_name(model);
let manifest = find_manifest(&canonical).ok_or_else(|| DownloadError::UnknownModel {
model: model.to_string(),
})?;
if manifest.is_utility() {
pull_model_files_only_with_callback(manifest, callback, opts).await?;
let config = Config::load_or_default();
return Ok((config, None));
}
let paths = pull_model_with_callback(manifest, callback, opts).await?;
let mut config = Config::load_or_default();
let model_config = manifest.to_model_config(&paths);
if !Config::exists_on_disk() {
config.default_model = manifest.name.clone();
}
config.upsert_model(manifest.name.clone(), model_config);
config
.save()
.map_err(|e| DownloadError::ConfigSave(e.to_string()))?;
Ok((config, Some(paths)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn truncate_short_name_unchanged() {
assert_eq!(truncate_filename("ae.safetensors", 45), "ae.safetensors");
}
#[test]
fn truncate_exact_fit_unchanged() {
let name = "x".repeat(30);
assert_eq!(truncate_filename(&name, 30), name);
}
#[test]
fn truncate_long_name_keeps_suffix() {
let result = truncate_filename("unet/diffusion_pytorch_model.fp16.safetensors", 30);
assert_eq!(result.len(), 30);
assert!(result.starts_with("..."));
assert!(result.ends_with(".fp16.safetensors"));
}
#[test]
fn truncate_very_small_max_returns_original() {
let name = "something.safetensors";
assert_eq!(truncate_filename(name, 5), name);
}
#[test]
fn download_error_gated_message() {
let err = DownloadError::GatedModel {
repo: "black-forest-labs/FLUX.1-dev".to_string(),
model: "flux-dev:q8".to_string(),
};
let msg = err.to_string();
assert!(msg.contains("huggingface.co/black-forest-labs/FLUX.1-dev"));
assert!(msg.contains("HF_TOKEN"));
assert!(msg.contains("mold pull flux-dev:q8"));
}
#[test]
fn download_error_unauthorized_message() {
let err = DownloadError::Unauthorized {
repo: "black-forest-labs/FLUX.1-schnell".to_string(),
model: "flux-schnell:q8".to_string(),
};
let msg = err.to_string();
assert!(msg.contains("Authentication required"));
assert!(msg.contains("black-forest-labs/FLUX.1-schnell"));
assert!(msg.contains("HF_TOKEN"));
assert!(msg.contains("huggingface-cli login"));
assert!(msg.contains("mold pull flux-schnell:q8"));
}
static HF_TOKEN_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
#[test]
fn resolve_hf_token_reads_env_var() {
let _guard = HF_TOKEN_LOCK.lock().unwrap();
let original = std::env::var("HF_TOKEN").ok();
std::env::set_var("HF_TOKEN", "hf_test_token_123");
let token = resolve_hf_token();
match &original {
Some(v) => std::env::set_var("HF_TOKEN", v),
None => std::env::remove_var("HF_TOKEN"),
}
assert_eq!(token, Some("hf_test_token_123".to_string()));
}
#[test]
fn resolve_hf_token_ignores_empty_env() {
let _guard = HF_TOKEN_LOCK.lock().unwrap();
let original = std::env::var("HF_TOKEN").ok();
std::env::set_var("HF_TOKEN", " ");
let token = resolve_hf_token();
match &original {
Some(v) => std::env::set_var("HF_TOKEN", v),
None => std::env::remove_var("HF_TOKEN"),
}
assert_ne!(token, Some(" ".to_string()));
}
#[test]
fn compute_sha256_correct_digest() {
let dir = std::env::temp_dir().join("mold_test_sha256_compute");
let _ = std::fs::create_dir_all(&dir);
let path = dir.join("test_file.bin");
std::fs::write(&path, b"hello world").unwrap();
let digest = compute_sha256(&path).unwrap();
assert_eq!(
digest,
"b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn verify_sha256_matches() {
let dir = std::env::temp_dir().join("mold_test_sha256_match");
let _ = std::fs::create_dir_all(&dir);
let path = dir.join("test_file.bin");
std::fs::write(&path, b"hello world").unwrap();
let expected = "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9";
assert!(verify_sha256(&path, expected).unwrap());
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn verify_sha256_mismatch() {
let dir = std::env::temp_dir().join("mold_test_sha256_mismatch");
let _ = std::fs::create_dir_all(&dir);
let path = dir.join("test_file.bin");
std::fs::write(&path, b"hello world").unwrap();
let wrong = "0000000000000000000000000000000000000000000000000000000000000000";
assert!(!verify_sha256(&path, wrong).unwrap());
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn verify_file_integrity_deletes_on_mismatch() {
use crate::manifest::{ModelComponent, ModelFile};
let dir = std::env::temp_dir().join("mold_test_integrity_mismatch");
let _ = std::fs::create_dir_all(&dir);
let path = dir.join("corrupted.bin");
std::fs::write(&path, b"corrupted data").unwrap();
let file = ModelFile {
hf_repo: "test/repo".to_string(),
hf_filename: "corrupted.bin".to_string(),
component: ModelComponent::Transformer,
size_bytes: 14,
gated: false,
sha256: Some("0000000000000000000000000000000000000000000000000000000000000000"),
};
let result = verify_file_integrity(&path, &file, "test-model:q8", false);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
DownloadError::Sha256Mismatch { .. }
),);
assert!(!path.exists());
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn verify_file_integrity_skip_verify_ignores_mismatch() {
use crate::manifest::{ModelComponent, ModelFile};
let dir = std::env::temp_dir().join("mold_test_integrity_skip");
let _ = std::fs::create_dir_all(&dir);
let path = dir.join("file.bin");
std::fs::write(&path, b"some data").unwrap();
let file = ModelFile {
hf_repo: "test/repo".to_string(),
hf_filename: "file.bin".to_string(),
component: ModelComponent::Transformer,
size_bytes: 9,
gated: false,
sha256: Some("0000000000000000000000000000000000000000000000000000000000000000"),
};
let result = verify_file_integrity(&path, &file, "test-model:q8", true);
assert!(result.is_ok());
assert!(path.exists());
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn verify_file_integrity_no_hash_is_ok() {
use crate::manifest::{ModelComponent, ModelFile};
let dir = std::env::temp_dir().join("mold_test_integrity_nohash");
let _ = std::fs::create_dir_all(&dir);
let path = dir.join("file.bin");
std::fs::write(&path, b"data").unwrap();
let file = ModelFile {
hf_repo: "test/repo".to_string(),
hf_filename: "file.bin".to_string(),
component: ModelComponent::Transformer,
size_bytes: 4,
gated: false,
sha256: None,
};
assert!(verify_file_integrity(&path, &file, "test:q8", false).is_ok());
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn pulling_marker_roundtrip() {
let dir = std::env::temp_dir().join("mold_test_marker_roundtrip");
let _ = std::fs::create_dir_all(&dir);
let marker = dir.join(".pulling");
std::fs::write(&marker, "test-model:q8").unwrap();
assert!(marker.exists());
let _ = std::fs::remove_file(&marker);
assert!(!marker.exists());
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn sha256_mismatch_error_message() {
let err = DownloadError::Sha256Mismatch {
filename: "transformer.gguf".to_string(),
expected: "aaa".to_string(),
actual: "bbb".to_string(),
model: "flux-dev:q8".to_string(),
};
let msg = err.to_string();
assert!(msg.contains("SHA-256 mismatch"));
assert!(msg.contains("transformer.gguf"));
assert!(msg.contains("mold pull flux-dev:q8"));
assert!(msg.contains("--skip-verify"));
}
}