Skip to main content

mold_core/
download.rs

1use std::path::{Path, PathBuf};
2use std::sync::{Arc, Mutex, OnceLock};
3use std::time::Instant;
4
5use console::Term;
6use hf_hub::api::tokio::{Api, ApiBuilder, ApiError, Progress};
7use hf_hub::{Cache, Repo, RepoType};
8use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle};
9use thiserror::Error;
10
11use crate::manifest::{paths_from_downloads, ModelComponent, ModelFile, ModelManifest};
12use crate::ModelPaths;
13
14/// Callback-based download progress event.
15#[derive(Debug, Clone)]
16pub enum DownloadProgressEvent {
17    /// A file download has started.
18    FileStart {
19        filename: String,
20        file_index: usize,
21        total_files: usize,
22        size_bytes: u64,
23        batch_bytes_downloaded: u64,
24        batch_bytes_total: u64,
25        batch_elapsed_ms: u64,
26    },
27    /// Bytes downloaded for the current file.
28    FileProgress {
29        filename: String,
30        file_index: usize,
31        bytes_downloaded: u64,
32        bytes_total: u64,
33        batch_bytes_downloaded: u64,
34        batch_bytes_total: u64,
35        batch_elapsed_ms: u64,
36    },
37    /// Status message (e.g. "Verifying cached files...").
38    Status { message: String },
39    /// A file download completed.
40    FileDone {
41        filename: String,
42        file_index: usize,
43        total_files: usize,
44        batch_bytes_downloaded: u64,
45        batch_bytes_total: u64,
46        batch_elapsed_ms: u64,
47    },
48}
49
50/// Callback type for download progress reporting.
51pub type DownloadProgressCallback = Arc<dyn Fn(DownloadProgressEvent) + Send + Sync>;
52
53/// Options controlling model pull behavior.
54#[derive(Debug, Clone, Default)]
55pub struct PullOptions {
56    /// Skip SHA-256 verification after download (use when HF updated a file).
57    pub skip_verify: bool,
58}
59
60#[derive(Debug, Error)]
61pub enum DownloadError {
62    #[error(
63        "Model requires access approval on HuggingFace.\n\n  1. Visit: https://huggingface.co/{repo}\n  2. Accept the license agreement\n  3. Create a token at: https://huggingface.co/settings/tokens\n  4. Set: export HF_TOKEN=hf_...\n  5. Retry: mold pull {model}"
64    )]
65    GatedModel { repo: String, model: String },
66
67    #[error(
68        "Authentication required for repository {repo}.\n\n  1. Create a token at: https://huggingface.co/settings/tokens\n     (select at least \"Read\" access)\n  2. Set: export HF_TOKEN=hf_...\n     Or run: huggingface-cli login\n  3. Retry: mold pull {model}\n\n  If HF_TOKEN is already set, it may be invalid or expired."
69    )]
70    Unauthorized { repo: String, model: String },
71
72    #[error("Download failed for {filename} from {repo}: {source}")]
73    DownloadFailed {
74        repo: String,
75        filename: String,
76        source: ApiError,
77    },
78
79    #[error("SHA-256 mismatch for {filename}\n  Expected: {expected}\n  Got:      {actual}\n\nThe corrupted file has been removed. Re-run: mold pull {model}\nIf the file was intentionally updated on HuggingFace, use: mold pull {model} --skip-verify")]
80    Sha256Mismatch {
81        filename: String,
82        expected: String,
83        actual: String,
84        model: String,
85    },
86
87    #[error("Failed to build HuggingFace API client: {0}")]
88    ApiSetup(#[from] ApiError),
89
90    #[error("Failed to build sync HuggingFace API client: {0}")]
91    SyncApiSetup(String),
92
93    #[error("Sync download failed for {filename} from {repo}: {message}")]
94    SyncDownloadFailed {
95        repo: String,
96        filename: String,
97        message: String,
98    },
99
100    #[error("Missing component after download — this is a bug")]
101    MissingComponent,
102
103    #[error("{0}")]
104    Other(String),
105
106    #[error("IO error during file placement: {0}")]
107    FilePlacement(String),
108
109    #[error("Unknown model '{model}'. No manifest found.")]
110    UnknownModel { model: String },
111
112    #[error("Failed to save config: {0}")]
113    ConfigSave(String),
114}
115
116/// Resolve HuggingFace token: `HF_TOKEN` env var takes precedence over
117/// the token file (`~/.cache/huggingface/token` from `huggingface-cli login`).
118fn resolve_hf_token() -> Option<String> {
119    if let Ok(token) = std::env::var("HF_TOKEN") {
120        let token = token.trim().to_string();
121        if !token.is_empty() {
122            return Some(token);
123        }
124    }
125    Cache::new(hf_cache_dir())
126        .token()
127        .or_else(|| Cache::from_env().token())
128}
129
130/// Resolve the mold models directory. Computed once from config on first access.
131/// Resolution order: `MOLD_MODELS_DIR` env var → config `models_dir` → `~/.mold/models`.
132///
133/// This is the clean model storage root. Actual model files live at clean paths like
134/// `models/flux-schnell-q8/transformer.gguf` and `models/shared/flux/ae.safetensors`.
135///
136/// **OnceLock caching**: The directory is resolved once on the first call and cached
137/// for the entire process lifetime. Changing `MOLD_MODELS_DIR` or the config file
138/// after the first call has no effect. This is by design — model paths recorded in
139/// config must remain stable within a single process run.
140fn models_dir() -> PathBuf {
141    static DIR: OnceLock<PathBuf> = OnceLock::new();
142    DIR.get_or_init(|| {
143        let dir = crate::Config::load_or_default().resolved_models_dir();
144        let _ = std::fs::create_dir_all(&dir);
145        dir
146    })
147    .clone()
148}
149
150/// Internal hf-hub cache directory: `<models_dir>/.hf-cache/`.
151/// Hidden from users; files get hardlinked to clean paths after download.
152fn hf_cache_dir() -> PathBuf {
153    static DIR: OnceLock<PathBuf> = OnceLock::new();
154    DIR.get_or_init(|| {
155        let dir = models_dir().join(".hf-cache");
156        let _ = std::fs::create_dir_all(&dir);
157        dir
158    })
159    .clone()
160}
161
162/// Hardlink `src` to `dst`, falling back to copy if hardlink fails (cross-filesystem).
163/// Idempotent: skips if `dst` already exists with the same size as `src`.
164///
165/// The source path is canonicalized to resolve hf-hub's symlink chain
166/// (`snapshots/<sha>/file → ../../blobs/<hash>`) before any filesystem ops.
167fn hardlink_or_copy(src: &std::path::Path, dst: &std::path::Path) -> Result<(), DownloadError> {
168    // Resolve symlinks — hf-hub cache returns symlink paths that can cause
169    // ENOENT on some filesystems when passed directly to hard_link or copy.
170    let real_src = src.canonicalize().map_err(|e| {
171        DownloadError::FilePlacement(format!(
172            "source file not found after download: {} ({e})",
173            src.display()
174        ))
175    })?;
176
177    // Check if dst already has the correct content (idempotent skip).
178    // Use metadata() which follows symlinks — only skip if the real target matches.
179    if dst.exists() {
180        if let (Ok(src_meta), Ok(dst_meta)) = (real_src.metadata(), dst.metadata()) {
181            if src_meta.len() == dst_meta.len() {
182                return Ok(());
183            }
184        }
185    }
186
187    // Remove stale destination before placement. A previous hard_link on an
188    // hf-hub symlink creates a relative symlink that dangles from the new
189    // location (e.g. shared/sd3/file → ../../blobs/hash, which doesn't exist
190    // relative to shared/sd3/). symlink_metadata() sees these even though
191    // exists() returns false for dangling symlinks.
192    if dst.symlink_metadata().is_ok() {
193        let _ = std::fs::remove_file(dst);
194    }
195
196    if let Some(parent) = dst.parent() {
197        std::fs::create_dir_all(parent).map_err(|e| {
198            DownloadError::FilePlacement(format!(
199                "failed to create directory {}: {e}",
200                parent.display()
201            ))
202        })?;
203    }
204    // Try hardlink first (zero extra disk space, instant)
205    match std::fs::hard_link(&real_src, dst) {
206        Ok(()) => return Ok(()),
207        Err(_e) => {
208            // Expected on cross-filesystem setups; fall through to copy
209        }
210    }
211    // Fall back to copy (cross-filesystem or hard_link unsupported)
212    std::fs::copy(&real_src, dst).map_err(|e| {
213        DownloadError::FilePlacement(format!(
214            "failed to copy {} → {}: {e}",
215            real_src.display(),
216            dst.display()
217        ))
218    })?;
219    Ok(())
220}
221
222/// Compute the SHA-256 hex digest of a file.
223pub fn compute_sha256(path: &std::path::Path) -> anyhow::Result<String> {
224    use sha2::{Digest, Sha256};
225
226    let mut file = std::fs::File::open(path)?;
227    let mut hasher = Sha256::new();
228    std::io::copy(&mut file, &mut hasher)?;
229    Ok(format!("{:x}", hasher.finalize()))
230}
231
232/// Verify the SHA-256 digest of a file against an expected hex string.
233///
234/// Returns `Ok(true)` when the digest matches, `Ok(false)` on mismatch.
235/// Errors only on I/O failures (e.g. file not found).
236pub fn verify_sha256(path: &std::path::Path, expected: &str) -> anyhow::Result<bool> {
237    Ok(compute_sha256(path)? == expected)
238}
239
240// ── Pull marker file (.pulling) ──────────────────────────────────────────────
241
242/// Relative path to a model's `.pulling` marker: `<sanitized-name>/.pulling`.
243pub fn pulling_marker_rel_path(model_name: &str) -> PathBuf {
244    let canonical = crate::manifest::resolve_model_name(model_name);
245    PathBuf::from(canonical.replace(':', "-")).join(".pulling")
246}
247
248/// Path to the `.pulling` marker for a model under an explicit models dir.
249pub fn pulling_marker_path_in(models_dir: &Path, model_name: &str) -> PathBuf {
250    models_dir.join(pulling_marker_rel_path(model_name))
251}
252
253/// Path to the `.pulling` marker for a model: `<models_dir>/<sanitized-name>/.pulling`.
254fn pulling_marker_path(model_name: &str) -> PathBuf {
255    pulling_marker_path_in(&models_dir(), model_name)
256}
257
258/// Write a `.pulling` marker to signal an in-progress download.
259fn write_pulling_marker(model_name: &str) -> Result<(), DownloadError> {
260    let path = pulling_marker_path(model_name);
261    if let Some(parent) = path.parent() {
262        std::fs::create_dir_all(parent).map_err(|e| {
263            DownloadError::FilePlacement(format!(
264                "failed to create directory for pull marker {}: {e}",
265                parent.display()
266            ))
267        })?;
268    }
269    std::fs::write(&path, model_name).map_err(|e| {
270        DownloadError::FilePlacement(format!(
271            "failed to write pull marker {}: {e}",
272            path.display()
273        ))
274    })
275}
276
277/// Remove the `.pulling` marker (best-effort, ignores errors).
278pub fn remove_pulling_marker(model_name: &str) {
279    let path = pulling_marker_path(model_name);
280    let _ = std::fs::remove_file(path);
281}
282
283/// Check whether a model has an active `.pulling` marker (incomplete download).
284pub fn has_pulling_marker(model_name: &str) -> bool {
285    let canonical = crate::manifest::resolve_model_name(model_name);
286    pulling_marker_path(&canonical).exists()
287}
288
289/// Verify SHA-256 integrity of a downloaded file. On mismatch, deletes the
290/// corrupted file and returns `Sha256Mismatch`. Respects `skip_verify`.
291fn verify_file_integrity(
292    clean_path: &std::path::Path,
293    file: &ModelFile,
294    model_name: &str,
295    skip_verify: bool,
296) -> Result<(), DownloadError> {
297    let expected = match file.sha256 {
298        Some(h) => h,
299        None => return Ok(()),
300    };
301    if skip_verify {
302        return Ok(());
303    }
304    match compute_sha256(clean_path) {
305        Ok(actual) if actual == expected => Ok(()),
306        Ok(actual) => {
307            let _ = std::fs::remove_file(clean_path);
308            Err(DownloadError::Sha256Mismatch {
309                filename: file.hf_filename.clone(),
310                expected: expected.to_string(),
311                actual,
312                model: model_name.to_string(),
313            })
314        }
315        Err(e) => {
316            eprintln!(
317                "warning: failed to verify SHA-256 for {}: {e}",
318                file.hf_filename
319            );
320            Ok(())
321        }
322    }
323}
324
325/// Truncate a string to fit within `max_len`, replacing the middle with "..." if needed.
326fn truncate_filename(name: &str, max_len: usize) -> String {
327    if name.len() <= max_len || max_len < 8 {
328        return name.to_string();
329    }
330    // Keep the end of the filename (the unique part) and trim the start
331    let suffix_len = max_len - 3; // "..." prefix
332    let start = name.len() - suffix_len;
333    format!("...{}", &name[start..])
334}
335
336/// Maximum characters for the filename column in progress bars.
337/// Derived from terminal width minus the fixed overhead of the bar template:
338/// 2 (indent) + 1 (space) + 1 ([) + 30 (bar) + 1 (]) + ~40 (bytes/speed/eta) = ~75 chars overhead.
339fn filename_column_width() -> usize {
340    let term_width = Term::stderr().size().1 as usize;
341    term_width.saturating_sub(75).max(12)
342}
343
344/// Progress adapter bridging hf-hub's `Progress` trait to an `indicatif::ProgressBar`.
345#[derive(Clone)]
346struct DownloadProgress {
347    bar: ProgressBar,
348    max_msg_len: usize,
349    filename: String,
350}
351
352impl DownloadProgress {
353    fn new(bar: ProgressBar, max_msg_len: usize) -> Self {
354        Self {
355            bar,
356            max_msg_len,
357            filename: String::new(),
358        }
359    }
360}
361
362impl Progress for DownloadProgress {
363    async fn init(&mut self, size: usize, filename: &str) {
364        self.bar.set_length(size as u64);
365        self.filename = truncate_filename(filename, self.max_msg_len);
366        self.bar.set_message(self.filename.clone());
367    }
368
369    async fn update(&mut self, size: usize) {
370        self.bar.inc(size as u64);
371    }
372
373    async fn finish(&mut self) {
374        self.bar.finish_with_message(self.filename.clone());
375    }
376}
377
378/// Progress adapter that dispatches to a callback instead of indicatif.
379/// Throttles `FileProgress` events to ~4/sec per file to avoid flooding SSE.
380#[derive(Clone)]
381struct CallbackProgress {
382    callback: DownloadProgressCallback,
383    file_index: usize,
384    total_files: usize,
385    batch_bytes_before_current: u64,
386    batch_bytes_total: u64,
387    batch_started_at: Instant,
388    shared: Arc<Mutex<CallbackProgressState>>,
389}
390
391struct CallbackProgressState {
392    accumulated: u64,
393    total: u64,
394    filename: String,
395    last_emit: Instant,
396}
397
398impl CallbackProgress {
399    fn new(
400        callback: DownloadProgressCallback,
401        file_index: usize,
402        total_files: usize,
403        batch_bytes_before_current: u64,
404        batch_bytes_total: u64,
405        batch_started_at: Instant,
406    ) -> Self {
407        Self {
408            callback,
409            file_index,
410            total_files,
411            batch_bytes_before_current,
412            batch_bytes_total,
413            batch_started_at,
414            shared: Arc::new(Mutex::new(CallbackProgressState {
415                accumulated: 0,
416                total: 0,
417                filename: String::new(),
418                last_emit: Instant::now(),
419            })),
420        }
421    }
422}
423
424impl Progress for CallbackProgress {
425    async fn init(&mut self, size: usize, filename: &str) {
426        let (fname, total) = {
427            let mut shared = self
428                .shared
429                .lock()
430                .expect("download progress mutex poisoned");
431            shared.total = size as u64;
432            shared.accumulated = 0;
433            shared.filename = filename.to_string();
434            shared.last_emit = Instant::now();
435            (shared.filename.clone(), shared.total)
436        };
437        (self.callback)(DownloadProgressEvent::FileStart {
438            filename: fname,
439            file_index: self.file_index,
440            total_files: self.total_files,
441            size_bytes: total,
442            batch_bytes_downloaded: self.batch_bytes_before_current,
443            batch_bytes_total: self.batch_bytes_total,
444            batch_elapsed_ms: self.batch_started_at.elapsed().as_millis() as u64,
445        });
446    }
447
448    async fn update(&mut self, size: usize) {
449        let mut shared = self
450            .shared
451            .lock()
452            .expect("download progress mutex poisoned");
453        shared.accumulated += size as u64;
454
455        let now = Instant::now();
456        let should_emit = now.duration_since(shared.last_emit).as_millis() >= 250
457            || shared.accumulated >= shared.total;
458        if !should_emit {
459            return;
460        }
461
462        shared.last_emit = now;
463        let filename = shared.filename.clone();
464        let accumulated = shared.accumulated;
465        let total = shared.total;
466        drop(shared);
467
468        (self.callback)(DownloadProgressEvent::FileProgress {
469            filename,
470            file_index: self.file_index,
471            bytes_downloaded: accumulated,
472            bytes_total: total,
473            batch_bytes_downloaded: self.batch_bytes_before_current + accumulated,
474            batch_bytes_total: self.batch_bytes_total,
475            batch_elapsed_ms: self.batch_started_at.elapsed().as_millis() as u64,
476        });
477    }
478
479    async fn finish(&mut self) {
480        let (fname, total) = {
481            let shared = self
482                .shared
483                .lock()
484                .expect("download progress mutex poisoned");
485            (shared.filename.clone(), shared.total)
486        };
487        (self.callback)(DownloadProgressEvent::FileDone {
488            filename: fname,
489            file_index: self.file_index,
490            total_files: self.total_files,
491            batch_bytes_downloaded: self.batch_bytes_before_current + total,
492            batch_bytes_total: self.batch_bytes_total,
493            batch_elapsed_ms: self.batch_started_at.elapsed().as_millis() as u64,
494        });
495    }
496}
497
498/// Sync progress adapter bridging hf-hub's sync `Progress` trait to our
499/// local `indicatif::ProgressBar`.
500struct SyncDownloadProgress {
501    bar: ProgressBar,
502    max_msg_len: usize,
503    filename: String,
504}
505
506impl SyncDownloadProgress {
507    fn new(bar: ProgressBar, max_msg_len: usize) -> Self {
508        Self {
509            bar,
510            max_msg_len,
511            filename: String::new(),
512        }
513    }
514}
515
516impl hf_hub::api::Progress for SyncDownloadProgress {
517    fn init(&mut self, size: usize, filename: &str) {
518        self.bar.set_length(size as u64);
519        self.filename = truncate_filename(filename, self.max_msg_len);
520        self.bar.set_message(self.filename.clone());
521    }
522
523    fn update(&mut self, size: usize) {
524        self.bar.inc(size as u64);
525    }
526
527    fn finish(&mut self) {
528        self.bar.finish_with_message(self.filename.clone());
529    }
530}
531
532/// Returns `true` if the file already exists at `clean_path` with the correct
533/// size and (if a SHA-256 is available) the correct digest.
534///
535/// **Side-effect**: if the file exists with matching size but failing integrity,
536/// `verify_file_integrity` will delete the corrupted file before returning `false`.
537fn is_already_placed(
538    clean_path: &std::path::Path,
539    file: &ModelFile,
540    model_name: &str,
541    skip_verify: bool,
542) -> bool {
543    let size_ok = clean_path
544        .metadata()
545        .map(|m| m.len() == file.size_bytes)
546        .unwrap_or(false);
547    if !size_ok {
548        return false;
549    }
550    // Verify integrity — a same-size but corrupted file must not be accepted
551    verify_file_integrity(clean_path, file, model_name, skip_verify).is_ok()
552}
553
554/// Return an existing valid clean path for a manifest file, migrating from a
555/// legacy location when needed.
556fn find_existing_placed_file(
557    models_dir: &std::path::Path,
558    manifest: &ModelManifest,
559    file: &ModelFile,
560    skip_verify: bool,
561) -> Result<Option<PathBuf>, DownloadError> {
562    let canonical_rel = crate::manifest::storage_path(manifest, file);
563    let canonical_path = models_dir.join(&canonical_rel);
564
565    for candidate_rel in crate::manifest::storage_path_candidates(manifest, file) {
566        let candidate_path = models_dir.join(candidate_rel);
567        if !is_already_placed(&candidate_path, file, &manifest.name, skip_verify) {
568            continue;
569        }
570        if candidate_path != canonical_path {
571            hardlink_or_copy(&candidate_path, &canonical_path)?;
572            verify_file_integrity(&canonical_path, file, &manifest.name, skip_verify)?;
573        }
574        return Ok(Some(canonical_path));
575    }
576
577    Ok(None)
578}
579
580/// Download all files for a model manifest, returning resolved paths.
581///
582/// Downloads go to a hidden hf-hub cache (`.hf-cache/`) for resume/dedup support,
583/// then files are hardlinked to clean paths:
584/// - Transformers → `<model-name>/<filename>`
585/// - Shared components → `shared/<family>/<filename>`
586///
587/// A `.pulling` marker file is written before downloads begin and removed on
588/// success. If the pull is interrupted, the marker signals an incomplete state.
589pub async fn pull_model(
590    manifest: &ModelManifest,
591    opts: &PullOptions,
592) -> Result<ModelPaths, DownloadError> {
593    write_pulling_marker(&manifest.name)?;
594
595    let mut builder = ApiBuilder::from_env().with_cache_dir(hf_cache_dir());
596    if let Some(token) = resolve_hf_token() {
597        builder = builder.with_token(Some(token));
598    }
599    let api = builder.build()?;
600
601    let multi = MultiProgress::with_draw_target(ProgressDrawTarget::stderr());
602    let msg_width = filename_column_width();
603    let bar_style = ProgressStyle::with_template(&format!(
604        "  {{msg:<{msg_width}}} [{{bar:30.cyan/dim}}] {{bytes}}/{{total_bytes}} ({{bytes_per_sec}}, {{eta}})"
605    ))
606    .unwrap()
607    .progress_chars("━╸─");
608
609    let mdir = models_dir();
610    let mut downloads: Vec<(ModelComponent, PathBuf)> = Vec::new();
611
612    for file in &manifest.files {
613        if let Some(clean_path) =
614            find_existing_placed_file(&mdir, manifest, file, opts.skip_verify)?
615        {
616            downloads.push((file.component, clean_path));
617            continue;
618        }
619
620        let clean_path = mdir.join(crate::manifest::storage_path(manifest, file));
621
622        let bar = multi.add(ProgressBar::new(file.size_bytes));
623        bar.set_style(bar_style.clone());
624        bar.set_message(truncate_filename(&file.hf_filename, msg_width));
625
626        let hf_path = download_file(
627            &api,
628            file,
629            DownloadProgress::new(bar, msg_width),
630            &manifest.name,
631        )
632        .await?;
633
634        // Place at clean path via hardlink (or copy as fallback)
635        hardlink_or_copy(&hf_path, &clean_path)?;
636
637        verify_file_integrity(&clean_path, file, &manifest.name, opts.skip_verify)?;
638
639        downloads.push((file.component, clean_path));
640    }
641
642    remove_pulling_marker(&manifest.name);
643    paths_from_downloads(&downloads, &manifest.family).ok_or(DownloadError::MissingComponent)
644}
645
646/// Download all files for a model manifest, reporting progress via callback.
647///
648/// Same as `pull_model` but uses a callback instead of indicatif progress bars.
649/// Suitable for server-side downloads where terminal bars are not appropriate.
650pub async fn pull_model_with_callback(
651    manifest: &ModelManifest,
652    callback: DownloadProgressCallback,
653    opts: &PullOptions,
654) -> Result<ModelPaths, DownloadError> {
655    write_pulling_marker(&manifest.name)?;
656
657    let mut builder = ApiBuilder::from_env().with_cache_dir(hf_cache_dir());
658    if let Some(token) = resolve_hf_token() {
659        builder = builder.with_token(Some(token));
660    }
661    let api = builder.build()?;
662
663    let mdir = models_dir();
664    let mut downloads: Vec<(ModelComponent, PathBuf)> = Vec::new();
665
666    // Pre-compute which files need downloading vs already cached.
667    // Run in spawn_blocking because SHA-256 verification of multi-GB cached
668    // files blocks the async runtime and prevents SSE event delivery.
669    let manifest_clone = manifest.clone();
670    let skip_verify = opts.skip_verify;
671    let mdir_clone = mdir.clone();
672    let cb = callback.clone();
673    let file_status: Vec<bool> = tokio::task::spawn_blocking(move || {
674        let total = manifest_clone.files.len();
675        manifest_clone
676            .files
677            .iter()
678            .enumerate()
679            .map(|(i, file)| {
680                cb(DownloadProgressEvent::Status {
681                    message: format!(
682                        "Verifying file [{}/{}] {}...",
683                        i + 1,
684                        total,
685                        file.hf_filename
686                    ),
687                });
688                find_existing_placed_file(&mdir_clone, &manifest_clone, file, skip_verify)
689                    .map(|p| p.is_some())
690                    .unwrap_or(false)
691            })
692            .collect()
693    })
694    .await
695    .map_err(|e| DownloadError::Other(format!("pre-scan task failed: {e}")))?;
696
697    let total_bytes_to_download: u64 = manifest
698        .files
699        .iter()
700        .zip(file_status.iter())
701        .filter(|(_, &placed)| !placed)
702        .map(|(file, _)| file.size_bytes)
703        .sum();
704    let total_files_count = manifest.files.len();
705    let mut completed_bytes = 0u64;
706    let batch_started_at = Instant::now();
707
708    for (file_pos, (file, &already_placed)) in
709        manifest.files.iter().zip(file_status.iter()).enumerate()
710    {
711        let clean_path = mdir.join(crate::manifest::storage_path(manifest, file));
712
713        if already_placed {
714            // Emit events for cached files so the TUI shows checkmarks.
715            let elapsed = batch_started_at.elapsed().as_millis() as u64;
716            (callback)(DownloadProgressEvent::FileStart {
717                filename: file.hf_filename.clone(),
718                file_index: file_pos,
719                total_files: total_files_count,
720                size_bytes: file.size_bytes,
721                batch_bytes_downloaded: completed_bytes,
722                batch_bytes_total: total_bytes_to_download,
723                batch_elapsed_ms: elapsed,
724            });
725            (callback)(DownloadProgressEvent::FileDone {
726                filename: file.hf_filename.clone(),
727                file_index: file_pos,
728                total_files: total_files_count,
729                batch_bytes_downloaded: completed_bytes,
730                batch_bytes_total: total_bytes_to_download,
731                batch_elapsed_ms: elapsed,
732            });
733            downloads.push((file.component, clean_path));
734            continue;
735        }
736
737        let progress = CallbackProgress::new(
738            callback.clone(),
739            file_pos,
740            total_files_count,
741            completed_bytes,
742            total_bytes_to_download,
743            batch_started_at,
744        );
745        let hf_path = download_file(&api, file, progress, &manifest.name).await?;
746
747        hardlink_or_copy(&hf_path, &clean_path)?;
748
749        verify_file_integrity(&clean_path, file, &manifest.name, opts.skip_verify)?;
750
751        downloads.push((file.component, clean_path));
752        completed_bytes += file.size_bytes;
753    }
754
755    remove_pulling_marker(&manifest.name);
756    paths_from_downloads(&downloads, &manifest.family).ok_or(DownloadError::MissingComponent)
757}
758
759/// Download all files for a utility model (no ModelPaths, no config writing).
760///
761/// Used for models like qwen3-expand that are not diffusion models and don't
762/// have a VAE. Files are downloaded and placed at their standard storage paths.
763async fn pull_model_files_only(
764    manifest: &ModelManifest,
765    opts: &PullOptions,
766) -> Result<(), DownloadError> {
767    write_pulling_marker(&manifest.name)?;
768
769    let mut builder = ApiBuilder::from_env().with_cache_dir(hf_cache_dir());
770    if let Some(token) = resolve_hf_token() {
771        builder = builder.with_token(Some(token));
772    }
773    let api = builder.build()?;
774
775    let multi = MultiProgress::with_draw_target(ProgressDrawTarget::stderr());
776    let msg_width = filename_column_width();
777    let bar_style = ProgressStyle::with_template(&format!(
778        "  {{msg:<{msg_width}}} [{{bar:30.cyan/dim}}] {{bytes}}/{{total_bytes}} ({{bytes_per_sec}}, {{eta}})"
779    ))
780    .unwrap()
781    .progress_chars("━╸─");
782
783    let mdir = models_dir();
784
785    for file in &manifest.files {
786        if find_existing_placed_file(&mdir, manifest, file, opts.skip_verify)?.is_some() {
787            continue;
788        }
789
790        let clean_path = mdir.join(crate::manifest::storage_path(manifest, file));
791
792        let bar = multi.add(ProgressBar::new(file.size_bytes));
793        bar.set_style(bar_style.clone());
794        bar.set_message(truncate_filename(&file.hf_filename, msg_width));
795
796        let hf_path = download_file(
797            &api,
798            file,
799            DownloadProgress::new(bar, msg_width),
800            &manifest.name,
801        )
802        .await?;
803
804        hardlink_or_copy(&hf_path, &clean_path)?;
805
806        verify_file_integrity(&clean_path, file, &manifest.name, opts.skip_verify)?;
807    }
808
809    remove_pulling_marker(&manifest.name);
810    Ok(())
811}
812
813/// Download all files for a utility model, reporting progress via callback.
814async fn pull_model_files_only_with_callback(
815    manifest: &ModelManifest,
816    callback: DownloadProgressCallback,
817    opts: &PullOptions,
818) -> Result<(), DownloadError> {
819    write_pulling_marker(&manifest.name)?;
820
821    let mut builder = ApiBuilder::from_env().with_cache_dir(hf_cache_dir());
822    if let Some(token) = resolve_hf_token() {
823        builder = builder.with_token(Some(token));
824    }
825    let api = builder.build()?;
826
827    let mdir = models_dir();
828
829    let manifest_clone = manifest.clone();
830    let skip_verify = opts.skip_verify;
831    let mdir_clone = mdir.clone();
832    let cb = callback.clone();
833    let file_status: Vec<bool> = tokio::task::spawn_blocking(move || {
834        let total = manifest_clone.files.len();
835        manifest_clone
836            .files
837            .iter()
838            .enumerate()
839            .map(|(i, file)| {
840                cb(DownloadProgressEvent::Status {
841                    message: format!(
842                        "Verifying file [{}/{}] {}...",
843                        i + 1,
844                        total,
845                        file.hf_filename
846                    ),
847                });
848                find_existing_placed_file(&mdir_clone, &manifest_clone, file, skip_verify)
849                    .map(|p| p.is_some())
850                    .unwrap_or(false)
851            })
852            .collect()
853    })
854    .await
855    .map_err(|e| DownloadError::Other(format!("pre-scan task failed: {e}")))?;
856    let total_bytes_to_download: u64 = manifest
857        .files
858        .iter()
859        .zip(file_status.iter())
860        .filter(|(_, &placed)| !placed)
861        .map(|(file, _)| file.size_bytes)
862        .sum();
863    let total_files_count = manifest.files.len();
864    let mut completed_bytes = 0u64;
865    let batch_started_at = Instant::now();
866
867    for (file_pos, (file, &already_placed)) in
868        manifest.files.iter().zip(file_status.iter()).enumerate()
869    {
870        let clean_path = mdir.join(crate::manifest::storage_path(manifest, file));
871
872        if already_placed {
873            let elapsed = batch_started_at.elapsed().as_millis() as u64;
874            (callback)(DownloadProgressEvent::FileStart {
875                filename: file.hf_filename.clone(),
876                file_index: file_pos,
877                total_files: total_files_count,
878                size_bytes: file.size_bytes,
879                batch_bytes_downloaded: completed_bytes,
880                batch_bytes_total: total_bytes_to_download,
881                batch_elapsed_ms: elapsed,
882            });
883            (callback)(DownloadProgressEvent::FileDone {
884                filename: file.hf_filename.clone(),
885                file_index: file_pos,
886                total_files: total_files_count,
887                batch_bytes_downloaded: completed_bytes,
888                batch_bytes_total: total_bytes_to_download,
889                batch_elapsed_ms: elapsed,
890            });
891            continue;
892        }
893
894        let progress = CallbackProgress::new(
895            callback.clone(),
896            file_pos,
897            total_files_count,
898            completed_bytes,
899            total_bytes_to_download,
900            batch_started_at,
901        );
902
903        let hf_path = download_file(&api, file, progress, &manifest.name).await?;
904
905        hardlink_or_copy(&hf_path, &clean_path)?;
906
907        verify_file_integrity(&clean_path, file, &manifest.name, opts.skip_verify)?;
908        completed_bytes += file.size_bytes;
909    }
910
911    remove_pulling_marker(&manifest.name);
912    Ok(())
913}
914
915/// Extract HTTP status code from an async `ApiError`, if available.
916fn extract_http_status(err: &ApiError) -> Option<u16> {
917    if let ApiError::RequestError(reqwest_err) = err {
918        reqwest_err.status().map(|s| s.as_u16())
919    } else {
920        None
921    }
922}
923
924async fn download_file<P: Progress + Clone + Send + Sync + 'static>(
925    api: &Api,
926    file: &ModelFile,
927    progress: P,
928    model_name: &str,
929) -> Result<PathBuf, DownloadError> {
930    let repo = api.repo(Repo::new(file.hf_repo.clone(), RepoType::Model));
931
932    match repo
933        .download_with_progress(&file.hf_filename, progress)
934        .await
935    {
936        Ok(path) => Ok(path),
937        Err(e) => {
938            let status = extract_http_status(&e);
939            let err_str = e.to_string();
940            if status == Some(401) || err_str.contains("401") || err_str.contains("Unauthorized") {
941                Err(DownloadError::Unauthorized {
942                    repo: file.hf_repo.clone(),
943                    model: model_name.to_string(),
944                })
945            } else if status == Some(403)
946                || err_str.contains("403")
947                || err_str.contains("Forbidden")
948                || err_str.contains("gated")
949                || err_str.contains("Access denied")
950            {
951                Err(DownloadError::GatedModel {
952                    repo: file.hf_repo.clone(),
953                    model: model_name.to_string(),
954                })
955            } else {
956                Err(DownloadError::DownloadFailed {
957                    repo: file.hf_repo.clone(),
958                    filename: file.hf_filename.clone(),
959                    source: e,
960                })
961            }
962        }
963    }
964}
965
966// ── Synchronous single-file download (for use from spawn_blocking) ───────────
967
968/// Download a single file from HuggingFace, returning its path.
969/// Uses the sync hf-hub API — safe to call from `spawn_blocking`.
970/// Returns immediately if already cached.
971///
972/// If `target_subdir` is provided (e.g., `"shared/t5-gguf"`), the file is hardlinked
973/// from the hf-cache to `<models_dir>/<target_subdir>/<leaf_filename>` and that clean
974/// path is returned. If `None`, the raw hf-cache path is returned.
975pub fn download_single_file_sync(
976    hf_repo: &str,
977    hf_filename: &str,
978    target_subdir: Option<&str>,
979) -> Result<PathBuf, DownloadError> {
980    use hf_hub::api::sync::ApiBuilder;
981
982    let mut builder = ApiBuilder::from_env()
983        .with_cache_dir(hf_cache_dir())
984        .with_progress(false);
985    if let Some(token) = resolve_hf_token() {
986        builder = builder.with_token(Some(token));
987    }
988    let api = builder
989        .build()
990        .map_err(|e| DownloadError::SyncApiSetup(e.to_string()))?;
991    let repo = api.repo(Repo::new(hf_repo.to_string(), RepoType::Model));
992    let msg_width = filename_column_width();
993    let bar_style = ProgressStyle::with_template(&format!(
994        "  {{msg:<{msg_width}}} [{{bar:30.cyan/dim}}] {{bytes}}/{{total_bytes}} ({{bytes_per_sec}}, {{eta}})"
995    ))
996    .unwrap()
997    .progress_chars("━╸─");
998    let bar = ProgressBar::new(0);
999    bar.set_style(bar_style);
1000    bar.set_message(truncate_filename(hf_filename, msg_width));
1001    let progress = SyncDownloadProgress::new(bar, msg_width);
1002    let hf_path = repo
1003        .download_with_progress(hf_filename, progress)
1004        .map_err(|e| {
1005            let err_str = e.to_string();
1006            if err_str.contains("401") || err_str.contains("Unauthorized") {
1007                DownloadError::Unauthorized {
1008                    repo: hf_repo.to_string(),
1009                    model: String::new(),
1010                }
1011            } else if err_str.contains("403")
1012                || err_str.contains("Forbidden")
1013                || err_str.contains("gated")
1014                || err_str.contains("Access denied")
1015            {
1016                DownloadError::GatedModel {
1017                    repo: hf_repo.to_string(),
1018                    model: String::new(),
1019                }
1020            } else {
1021                DownloadError::SyncDownloadFailed {
1022                    repo: hf_repo.to_string(),
1023                    filename: hf_filename.to_string(),
1024                    message: err_str,
1025                }
1026            }
1027        })?;
1028
1029    // Place at clean path if target_subdir specified
1030    if let Some(subdir) = target_subdir {
1031        let leaf = hf_filename.rsplit('/').next().unwrap_or(hf_filename);
1032        let clean_path = models_dir().join(subdir).join(leaf);
1033        hardlink_or_copy(&hf_path, &clean_path)?;
1034        Ok(clean_path)
1035    } else {
1036        Ok(hf_path)
1037    }
1038}
1039
1040/// Check if a file is already cached locally (no download).
1041///
1042/// If `target_subdir` is provided, checks the clean path first
1043/// (`<models_dir>/<target_subdir>/<leaf_filename>`). Then checks the hf-cache,
1044/// old mold models dir (backward compat), and default HF cache.
1045pub fn cached_file_path(
1046    hf_repo: &str,
1047    hf_filename: &str,
1048    target_subdir: Option<&str>,
1049) -> Option<PathBuf> {
1050    // 1. Check clean path (if target_subdir specified)
1051    if let Some(subdir) = target_subdir {
1052        let leaf = hf_filename.rsplit('/').next().unwrap_or(hf_filename);
1053        let clean_path = models_dir().join(subdir).join(leaf);
1054        if clean_path.exists() {
1055            return Some(clean_path);
1056        }
1057    }
1058
1059    // 2. Check new hf-cache location (~/.mold/models/.hf-cache/)
1060    let new_cache = Cache::new(hf_cache_dir());
1061    let new_repo = new_cache.repo(Repo::new(hf_repo.to_string(), RepoType::Model));
1062    if let Some(path) = new_repo.get(hf_filename) {
1063        return Some(path);
1064    }
1065
1066    // 3. Check old mold models dir (backward compat — HF cached here before .hf-cache/)
1067    let old_cache = Cache::new(models_dir());
1068    let old_repo = old_cache.repo(Repo::new(hf_repo.to_string(), RepoType::Model));
1069    if let Some(path) = old_repo.get(hf_filename) {
1070        return Some(path);
1071    }
1072
1073    // 4. Check default HF cache (~/.cache/huggingface/hub/)
1074    let default_cache = Cache::from_env();
1075    let default_repo = default_cache.repo(Repo::new(hf_repo.to_string(), RepoType::Model));
1076    default_repo.get(hf_filename)
1077}
1078
1079// ── Pull and configure (shared between CLI and server) ───────────────────────
1080
1081/// Download a model and save its paths to config. Returns the updated config
1082/// and resolved model paths. Used by both the CLI `pull` command and the
1083/// server's auto-pull logic.
1084pub async fn pull_and_configure(
1085    model: &str,
1086    opts: &PullOptions,
1087) -> Result<(crate::Config, Option<ModelPaths>), DownloadError> {
1088    use crate::config::Config;
1089    use crate::manifest::{find_manifest, resolve_model_name};
1090
1091    let canonical = resolve_model_name(model);
1092
1093    let manifest = find_manifest(&canonical).ok_or_else(|| DownloadError::UnknownModel {
1094        model: model.to_string(),
1095    })?;
1096
1097    // Utility models (e.g., qwen3-expand) have no VAE and don't need config entries.
1098    if manifest.is_utility() {
1099        pull_model_files_only(manifest, opts).await?;
1100        let config = Config::load_or_default();
1101        return Ok((config, None));
1102    }
1103
1104    // Upscaler models have a single weights file (no VAE, no encoders).
1105    // Download files and create a minimal config entry with the weights path.
1106    if manifest.is_upscaler() {
1107        pull_model_files_only(manifest, opts).await?;
1108
1109        // Resolve the weights path from the manifest storage path
1110        let mdir = models_dir();
1111        let weights_file = manifest
1112            .files
1113            .iter()
1114            .find(|f| f.component == crate::manifest::ModelComponent::Upscaler)
1115            .ok_or(DownloadError::MissingComponent)?;
1116        let weights_path = mdir.join(crate::manifest::storage_path(manifest, weights_file));
1117
1118        let mut config = Config::load_or_default();
1119        let model_config = crate::config::ModelConfig {
1120            transformer: Some(weights_path.to_string_lossy().to_string()),
1121            family: Some("upscaler".to_string()),
1122            ..Default::default()
1123        };
1124        config.upsert_model(manifest.name.clone(), model_config);
1125        config
1126            .save()
1127            .map_err(|e| DownloadError::ConfigSave(e.to_string()))?;
1128
1129        return Ok((config, None));
1130    }
1131
1132    let paths = pull_model(manifest, opts).await?;
1133
1134    let mut config = Config::load_or_default();
1135    let model_config = manifest.to_model_config(&paths);
1136
1137    // Auto-set default_model if no config existed before
1138    if !Config::exists_on_disk() {
1139        config.default_model = manifest.name.clone();
1140    }
1141
1142    config.upsert_model(manifest.name.clone(), model_config);
1143    config
1144        .save()
1145        .map_err(|e| DownloadError::ConfigSave(e.to_string()))?;
1146
1147    Ok((config, Some(paths)))
1148}
1149
1150/// Download a model and save its paths to config, reporting progress via callback.
1151/// Same as `pull_and_configure` but uses a callback instead of indicatif bars.
1152pub async fn pull_and_configure_with_callback(
1153    model: &str,
1154    callback: DownloadProgressCallback,
1155    opts: &PullOptions,
1156) -> Result<(crate::Config, Option<ModelPaths>), DownloadError> {
1157    use crate::config::Config;
1158    use crate::manifest::{find_manifest, resolve_model_name};
1159
1160    let canonical = resolve_model_name(model);
1161
1162    let manifest = find_manifest(&canonical).ok_or_else(|| DownloadError::UnknownModel {
1163        model: model.to_string(),
1164    })?;
1165
1166    // Utility models (e.g., qwen3-expand) have no VAE and don't need config entries.
1167    if manifest.is_utility() {
1168        pull_model_files_only_with_callback(manifest, callback, opts).await?;
1169        let config = Config::load_or_default();
1170        return Ok((config, None));
1171    }
1172
1173    // Upscaler models: download files, create minimal config with weights path.
1174    if manifest.is_upscaler() {
1175        pull_model_files_only_with_callback(manifest, callback, opts).await?;
1176
1177        let mdir = models_dir();
1178        let weights_file = manifest
1179            .files
1180            .iter()
1181            .find(|f| f.component == crate::manifest::ModelComponent::Upscaler)
1182            .ok_or(DownloadError::MissingComponent)?;
1183        let weights_path = mdir.join(crate::manifest::storage_path(manifest, weights_file));
1184
1185        let mut config = Config::load_or_default();
1186        let model_config = crate::config::ModelConfig {
1187            transformer: Some(weights_path.to_string_lossy().to_string()),
1188            family: Some("upscaler".to_string()),
1189            ..Default::default()
1190        };
1191        config.upsert_model(manifest.name.clone(), model_config);
1192        config
1193            .save()
1194            .map_err(|e| DownloadError::ConfigSave(e.to_string()))?;
1195
1196        return Ok((config, None));
1197    }
1198
1199    let paths = pull_model_with_callback(manifest, callback, opts).await?;
1200
1201    let mut config = Config::load_or_default();
1202    let model_config = manifest.to_model_config(&paths);
1203
1204    if !Config::exists_on_disk() {
1205        config.default_model = manifest.name.clone();
1206    }
1207
1208    config.upsert_model(manifest.name.clone(), model_config);
1209    config
1210        .save()
1211        .map_err(|e| DownloadError::ConfigSave(e.to_string()))?;
1212
1213    Ok((config, Some(paths)))
1214}
1215
1216#[cfg(test)]
1217mod tests {
1218    use super::*;
1219
1220    #[test]
1221    fn truncate_short_name_unchanged() {
1222        assert_eq!(truncate_filename("ae.safetensors", 45), "ae.safetensors");
1223    }
1224
1225    #[test]
1226    fn truncate_exact_fit_unchanged() {
1227        let name = "x".repeat(30);
1228        assert_eq!(truncate_filename(&name, 30), name);
1229    }
1230
1231    #[test]
1232    fn truncate_long_name_keeps_suffix() {
1233        let result = truncate_filename("unet/diffusion_pytorch_model.fp16.safetensors", 30);
1234        assert_eq!(result.len(), 30);
1235        assert!(result.starts_with("..."));
1236        assert!(result.ends_with(".fp16.safetensors"));
1237    }
1238
1239    #[test]
1240    fn truncate_very_small_max_returns_original() {
1241        // max_len < 8 returns unchanged to avoid degenerate "..." output
1242        let name = "something.safetensors";
1243        assert_eq!(truncate_filename(name, 5), name);
1244    }
1245
1246    #[tokio::test]
1247    async fn callback_progress_clones_share_accumulated_bytes() {
1248        let events = Arc::new(Mutex::new(Vec::new()));
1249        let events_for_cb = events.clone();
1250        let callback: DownloadProgressCallback = Arc::new(move |event| {
1251            events_for_cb
1252                .lock()
1253                .expect("events mutex poisoned")
1254                .push(event);
1255        });
1256
1257        let mut progress = CallbackProgress::new(callback, 1, 3, 1_000, 10_000, Instant::now());
1258        progress.init(1_024, "weights.safetensors").await;
1259
1260        let mut chunk_a = progress.clone();
1261        let mut chunk_b = progress.clone();
1262        chunk_a.update(512).await;
1263        chunk_b.update(512).await;
1264        progress.finish().await;
1265
1266        let events = events.lock().expect("events mutex poisoned");
1267        assert!(events.iter().any(|event| matches!(
1268            event,
1269            DownloadProgressEvent::FileProgress {
1270                bytes_downloaded: 1_024,
1271                bytes_total: 1_024,
1272                batch_bytes_downloaded: 2_024,
1273                ..
1274            }
1275        )));
1276    }
1277
1278    #[test]
1279    fn download_error_gated_message() {
1280        let err = DownloadError::GatedModel {
1281            repo: "black-forest-labs/FLUX.1-dev".to_string(),
1282            model: "flux-dev:q8".to_string(),
1283        };
1284        let msg = err.to_string();
1285        assert!(msg.contains("huggingface.co/black-forest-labs/FLUX.1-dev"));
1286        assert!(msg.contains("HF_TOKEN"));
1287        assert!(msg.contains("mold pull flux-dev:q8"));
1288    }
1289
1290    #[test]
1291    fn download_error_unauthorized_message() {
1292        let err = DownloadError::Unauthorized {
1293            repo: "black-forest-labs/FLUX.1-schnell".to_string(),
1294            model: "flux-schnell:q8".to_string(),
1295        };
1296        let msg = err.to_string();
1297        assert!(msg.contains("Authentication required"));
1298        assert!(msg.contains("black-forest-labs/FLUX.1-schnell"));
1299        assert!(msg.contains("HF_TOKEN"));
1300        assert!(msg.contains("huggingface-cli login"));
1301        assert!(msg.contains("mold pull flux-schnell:q8"));
1302    }
1303
1304    /// Mutex to serialize tests that mutate `HF_TOKEN` — `set_var`/`remove_var`
1305    /// are process-global and not thread-safe, so parallel tests race.
1306    static HF_TOKEN_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
1307
1308    #[test]
1309    fn resolve_hf_token_reads_env_var() {
1310        let _guard = HF_TOKEN_LOCK.lock().unwrap();
1311        let original = std::env::var("HF_TOKEN").ok();
1312        std::env::set_var("HF_TOKEN", "hf_test_token_123");
1313        let token = resolve_hf_token();
1314        // Restore before asserting so we don't leak on panic
1315        match &original {
1316            Some(v) => std::env::set_var("HF_TOKEN", v),
1317            None => std::env::remove_var("HF_TOKEN"),
1318        }
1319        assert_eq!(token, Some("hf_test_token_123".to_string()));
1320    }
1321
1322    #[test]
1323    fn resolve_hf_token_ignores_empty_env() {
1324        let _guard = HF_TOKEN_LOCK.lock().unwrap();
1325        let original = std::env::var("HF_TOKEN").ok();
1326        std::env::set_var("HF_TOKEN", "  ");
1327        let token = resolve_hf_token();
1328        // Restore before asserting
1329        match &original {
1330            Some(v) => std::env::set_var("HF_TOKEN", v),
1331            None => std::env::remove_var("HF_TOKEN"),
1332        }
1333        // Should fall through to file-based token (which may or may not exist)
1334        assert_ne!(token, Some("  ".to_string()));
1335    }
1336
1337    #[test]
1338    fn compute_sha256_correct_digest() {
1339        let dir = std::env::temp_dir().join("mold_test_sha256_compute");
1340        let _ = std::fs::create_dir_all(&dir);
1341        let path = dir.join("test_file.bin");
1342        std::fs::write(&path, b"hello world").unwrap();
1343        let digest = compute_sha256(&path).unwrap();
1344        assert_eq!(
1345            digest,
1346            "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"
1347        );
1348        let _ = std::fs::remove_dir_all(&dir);
1349    }
1350
1351    #[test]
1352    fn verify_sha256_matches() {
1353        let dir = std::env::temp_dir().join("mold_test_sha256_match");
1354        let _ = std::fs::create_dir_all(&dir);
1355        let path = dir.join("test_file.bin");
1356        std::fs::write(&path, b"hello world").unwrap();
1357        // SHA-256 of "hello world"
1358        let expected = "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9";
1359        assert!(verify_sha256(&path, expected).unwrap());
1360        let _ = std::fs::remove_dir_all(&dir);
1361    }
1362
1363    #[test]
1364    fn verify_sha256_mismatch() {
1365        let dir = std::env::temp_dir().join("mold_test_sha256_mismatch");
1366        let _ = std::fs::create_dir_all(&dir);
1367        let path = dir.join("test_file.bin");
1368        std::fs::write(&path, b"hello world").unwrap();
1369        let wrong = "0000000000000000000000000000000000000000000000000000000000000000";
1370        assert!(!verify_sha256(&path, wrong).unwrap());
1371        let _ = std::fs::remove_dir_all(&dir);
1372    }
1373
1374    #[test]
1375    fn verify_file_integrity_deletes_on_mismatch() {
1376        use crate::manifest::{ModelComponent, ModelFile};
1377        let dir = std::env::temp_dir().join("mold_test_integrity_mismatch");
1378        let _ = std::fs::create_dir_all(&dir);
1379        let path = dir.join("corrupted.bin");
1380        std::fs::write(&path, b"corrupted data").unwrap();
1381
1382        let file = ModelFile {
1383            hf_repo: "test/repo".to_string(),
1384            hf_filename: "corrupted.bin".to_string(),
1385            component: ModelComponent::Transformer,
1386            size_bytes: 14,
1387            gated: false,
1388            sha256: Some("0000000000000000000000000000000000000000000000000000000000000000"),
1389        };
1390
1391        let result = verify_file_integrity(&path, &file, "test-model:q8", false);
1392        assert!(result.is_err());
1393        assert!(matches!(
1394            result.unwrap_err(),
1395            DownloadError::Sha256Mismatch { .. }
1396        ),);
1397        // File should be deleted
1398        assert!(!path.exists());
1399        let _ = std::fs::remove_dir_all(&dir);
1400    }
1401
1402    #[test]
1403    fn verify_file_integrity_skip_verify_ignores_mismatch() {
1404        use crate::manifest::{ModelComponent, ModelFile};
1405        let dir = std::env::temp_dir().join("mold_test_integrity_skip");
1406        let _ = std::fs::create_dir_all(&dir);
1407        let path = dir.join("file.bin");
1408        std::fs::write(&path, b"some data").unwrap();
1409
1410        let file = ModelFile {
1411            hf_repo: "test/repo".to_string(),
1412            hf_filename: "file.bin".to_string(),
1413            component: ModelComponent::Transformer,
1414            size_bytes: 9,
1415            gated: false,
1416            sha256: Some("0000000000000000000000000000000000000000000000000000000000000000"),
1417        };
1418
1419        let result = verify_file_integrity(&path, &file, "test-model:q8", true);
1420        assert!(result.is_ok());
1421        // File should still exist
1422        assert!(path.exists());
1423        let _ = std::fs::remove_dir_all(&dir);
1424    }
1425
1426    #[test]
1427    fn verify_file_integrity_no_hash_is_ok() {
1428        use crate::manifest::{ModelComponent, ModelFile};
1429        let dir = std::env::temp_dir().join("mold_test_integrity_nohash");
1430        let _ = std::fs::create_dir_all(&dir);
1431        let path = dir.join("file.bin");
1432        std::fs::write(&path, b"data").unwrap();
1433
1434        let file = ModelFile {
1435            hf_repo: "test/repo".to_string(),
1436            hf_filename: "file.bin".to_string(),
1437            component: ModelComponent::Transformer,
1438            size_bytes: 4,
1439            gated: false,
1440            sha256: None,
1441        };
1442
1443        assert!(verify_file_integrity(&path, &file, "test:q8", false).is_ok());
1444        let _ = std::fs::remove_dir_all(&dir);
1445    }
1446
1447    #[test]
1448    fn pulling_marker_roundtrip() {
1449        let dir = std::env::temp_dir().join("mold_test_marker_roundtrip");
1450        let _ = std::fs::create_dir_all(&dir);
1451        let marker = dir.join(".pulling");
1452
1453        // Write
1454        std::fs::write(&marker, "test-model:q8").unwrap();
1455        assert!(marker.exists());
1456
1457        // Remove
1458        let _ = std::fs::remove_file(&marker);
1459        assert!(!marker.exists());
1460
1461        let _ = std::fs::remove_dir_all(&dir);
1462    }
1463
1464    #[test]
1465    fn sha256_mismatch_error_message() {
1466        let err = DownloadError::Sha256Mismatch {
1467            filename: "transformer.gguf".to_string(),
1468            expected: "aaa".to_string(),
1469            actual: "bbb".to_string(),
1470            model: "flux-dev:q8".to_string(),
1471        };
1472        let msg = err.to_string();
1473        assert!(msg.contains("SHA-256 mismatch"));
1474        assert!(msg.contains("transformer.gguf"));
1475        assert!(msg.contains("mold pull flux-dev:q8"));
1476        assert!(msg.contains("--skip-verify"));
1477    }
1478}