1use std::path::{Path, PathBuf};
2use std::sync::{Arc, Mutex, OnceLock};
3use std::time::Instant;
4
5use console::Term;
6use hf_hub::api::tokio::{Api, ApiBuilder, ApiError, Progress};
7use hf_hub::{Cache, Repo, RepoType};
8use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle};
9use thiserror::Error;
10
11use crate::manifest::{paths_from_downloads, ModelComponent, ModelFile, ModelManifest};
12use crate::ModelPaths;
13
14#[derive(Debug, Clone)]
16pub enum DownloadProgressEvent {
17 FileStart {
19 filename: String,
20 file_index: usize,
21 total_files: usize,
22 size_bytes: u64,
23 batch_bytes_downloaded: u64,
24 batch_bytes_total: u64,
25 batch_elapsed_ms: u64,
26 },
27 FileProgress {
29 filename: String,
30 file_index: usize,
31 bytes_downloaded: u64,
32 bytes_total: u64,
33 batch_bytes_downloaded: u64,
34 batch_bytes_total: u64,
35 batch_elapsed_ms: u64,
36 },
37 Status { message: String },
39 FileDone {
41 filename: String,
42 file_index: usize,
43 total_files: usize,
44 batch_bytes_downloaded: u64,
45 batch_bytes_total: u64,
46 batch_elapsed_ms: u64,
47 },
48}
49
50pub type DownloadProgressCallback = Arc<dyn Fn(DownloadProgressEvent) + Send + Sync>;
52
53#[derive(Debug, Clone, Default)]
55pub struct PullOptions {
56 pub skip_verify: bool,
58}
59
60#[derive(Debug, Error)]
61pub enum DownloadError {
62 #[error(
63 "Model requires access approval on HuggingFace.\n\n 1. Visit: https://huggingface.co/{repo}\n 2. Accept the license agreement\n 3. Create a token at: https://huggingface.co/settings/tokens\n 4. Set: export HF_TOKEN=hf_...\n 5. Retry: mold pull {model}"
64 )]
65 GatedModel { repo: String, model: String },
66
67 #[error(
68 "Authentication required for repository {repo}.\n\n 1. Create a token at: https://huggingface.co/settings/tokens\n (select at least \"Read\" access)\n 2. Set: export HF_TOKEN=hf_...\n Or run: huggingface-cli login\n 3. Retry: mold pull {model}\n\n If HF_TOKEN is already set, it may be invalid or expired."
69 )]
70 Unauthorized { repo: String, model: String },
71
72 #[error("Download failed for {filename} from {repo}: {source}")]
73 DownloadFailed {
74 repo: String,
75 filename: String,
76 source: ApiError,
77 },
78
79 #[error("SHA-256 mismatch for {filename}\n Expected: {expected}\n Got: {actual}\n\nThe corrupted file has been removed. Re-run: mold pull {model}\nIf the file was intentionally updated on HuggingFace, use: mold pull {model} --skip-verify")]
80 Sha256Mismatch {
81 filename: String,
82 expected: String,
83 actual: String,
84 model: String,
85 },
86
87 #[error("Failed to build HuggingFace API client: {0}")]
88 ApiSetup(#[from] ApiError),
89
90 #[error("Failed to build sync HuggingFace API client: {0}")]
91 SyncApiSetup(String),
92
93 #[error("Sync download failed for {filename} from {repo}: {message}")]
94 SyncDownloadFailed {
95 repo: String,
96 filename: String,
97 message: String,
98 },
99
100 #[error("Missing component after download — this is a bug")]
101 MissingComponent,
102
103 #[error("{0}")]
104 Other(String),
105
106 #[error("IO error during file placement: {0}")]
107 FilePlacement(String),
108
109 #[error("Unknown model '{model}'. No manifest found.")]
110 UnknownModel { model: String },
111
112 #[error("Failed to save config: {0}")]
113 ConfigSave(String),
114}
115
116fn resolve_hf_token() -> Option<String> {
119 if let Ok(token) = std::env::var("HF_TOKEN") {
120 let token = token.trim().to_string();
121 if !token.is_empty() {
122 return Some(token);
123 }
124 }
125 Cache::new(hf_cache_dir())
126 .token()
127 .or_else(|| Cache::from_env().token())
128}
129
130fn models_dir() -> PathBuf {
141 static DIR: OnceLock<PathBuf> = OnceLock::new();
142 DIR.get_or_init(|| {
143 let dir = crate::Config::load_or_default().resolved_models_dir();
144 let _ = std::fs::create_dir_all(&dir);
145 dir
146 })
147 .clone()
148}
149
150fn hf_cache_dir() -> PathBuf {
153 static DIR: OnceLock<PathBuf> = OnceLock::new();
154 DIR.get_or_init(|| {
155 let dir = models_dir().join(".hf-cache");
156 let _ = std::fs::create_dir_all(&dir);
157 dir
158 })
159 .clone()
160}
161
162fn hardlink_or_copy(src: &std::path::Path, dst: &std::path::Path) -> Result<(), DownloadError> {
168 let real_src = src.canonicalize().map_err(|e| {
171 DownloadError::FilePlacement(format!(
172 "source file not found after download: {} ({e})",
173 src.display()
174 ))
175 })?;
176
177 if dst.exists() {
180 if let (Ok(src_meta), Ok(dst_meta)) = (real_src.metadata(), dst.metadata()) {
181 if src_meta.len() == dst_meta.len() {
182 return Ok(());
183 }
184 }
185 }
186
187 if dst.symlink_metadata().is_ok() {
193 let _ = std::fs::remove_file(dst);
194 }
195
196 if let Some(parent) = dst.parent() {
197 std::fs::create_dir_all(parent).map_err(|e| {
198 DownloadError::FilePlacement(format!(
199 "failed to create directory {}: {e}",
200 parent.display()
201 ))
202 })?;
203 }
204 match std::fs::hard_link(&real_src, dst) {
206 Ok(()) => return Ok(()),
207 Err(_e) => {
208 }
210 }
211 std::fs::copy(&real_src, dst).map_err(|e| {
213 DownloadError::FilePlacement(format!(
214 "failed to copy {} → {}: {e}",
215 real_src.display(),
216 dst.display()
217 ))
218 })?;
219 Ok(())
220}
221
222pub fn compute_sha256(path: &std::path::Path) -> anyhow::Result<String> {
224 use sha2::{Digest, Sha256};
225
226 let mut file = std::fs::File::open(path)?;
227 let mut hasher = Sha256::new();
228 std::io::copy(&mut file, &mut hasher)?;
229 Ok(format!("{:x}", hasher.finalize()))
230}
231
232pub fn verify_sha256(path: &std::path::Path, expected: &str) -> anyhow::Result<bool> {
237 Ok(compute_sha256(path)? == expected)
238}
239
240pub fn pulling_marker_rel_path(model_name: &str) -> PathBuf {
244 let canonical = crate::manifest::resolve_model_name(model_name);
245 PathBuf::from(canonical.replace(':', "-")).join(".pulling")
246}
247
248pub fn pulling_marker_path_in(models_dir: &Path, model_name: &str) -> PathBuf {
250 models_dir.join(pulling_marker_rel_path(model_name))
251}
252
253fn pulling_marker_path(model_name: &str) -> PathBuf {
255 pulling_marker_path_in(&models_dir(), model_name)
256}
257
258fn write_pulling_marker(model_name: &str) -> Result<(), DownloadError> {
260 let path = pulling_marker_path(model_name);
261 if let Some(parent) = path.parent() {
262 std::fs::create_dir_all(parent).map_err(|e| {
263 DownloadError::FilePlacement(format!(
264 "failed to create directory for pull marker {}: {e}",
265 parent.display()
266 ))
267 })?;
268 }
269 std::fs::write(&path, model_name).map_err(|e| {
270 DownloadError::FilePlacement(format!(
271 "failed to write pull marker {}: {e}",
272 path.display()
273 ))
274 })
275}
276
277pub fn remove_pulling_marker(model_name: &str) {
279 let path = pulling_marker_path(model_name);
280 let _ = std::fs::remove_file(path);
281}
282
283pub fn has_pulling_marker(model_name: &str) -> bool {
285 let canonical = crate::manifest::resolve_model_name(model_name);
286 pulling_marker_path(&canonical).exists()
287}
288
289fn verify_file_integrity(
292 clean_path: &std::path::Path,
293 file: &ModelFile,
294 model_name: &str,
295 skip_verify: bool,
296) -> Result<(), DownloadError> {
297 let expected = match file.sha256 {
298 Some(h) => h,
299 None => return Ok(()),
300 };
301 if skip_verify {
302 return Ok(());
303 }
304 match compute_sha256(clean_path) {
305 Ok(actual) if actual == expected => Ok(()),
306 Ok(actual) => {
307 let _ = std::fs::remove_file(clean_path);
308 Err(DownloadError::Sha256Mismatch {
309 filename: file.hf_filename.clone(),
310 expected: expected.to_string(),
311 actual,
312 model: model_name.to_string(),
313 })
314 }
315 Err(e) => {
316 eprintln!(
317 "warning: failed to verify SHA-256 for {}: {e}",
318 file.hf_filename
319 );
320 Ok(())
321 }
322 }
323}
324
325fn truncate_filename(name: &str, max_len: usize) -> String {
327 if name.len() <= max_len || max_len < 8 {
328 return name.to_string();
329 }
330 let suffix_len = max_len - 3; let start = name.len() - suffix_len;
333 format!("...{}", &name[start..])
334}
335
336fn filename_column_width() -> usize {
340 let term_width = Term::stderr().size().1 as usize;
341 term_width.saturating_sub(75).max(12)
342}
343
344#[derive(Clone)]
346struct DownloadProgress {
347 bar: ProgressBar,
348 max_msg_len: usize,
349 filename: String,
350}
351
352impl DownloadProgress {
353 fn new(bar: ProgressBar, max_msg_len: usize) -> Self {
354 Self {
355 bar,
356 max_msg_len,
357 filename: String::new(),
358 }
359 }
360}
361
362impl Progress for DownloadProgress {
363 async fn init(&mut self, size: usize, filename: &str) {
364 self.bar.set_length(size as u64);
365 self.filename = truncate_filename(filename, self.max_msg_len);
366 self.bar.set_message(self.filename.clone());
367 }
368
369 async fn update(&mut self, size: usize) {
370 self.bar.inc(size as u64);
371 }
372
373 async fn finish(&mut self) {
374 self.bar.finish_with_message(self.filename.clone());
375 }
376}
377
378#[derive(Clone)]
381struct CallbackProgress {
382 callback: DownloadProgressCallback,
383 file_index: usize,
384 total_files: usize,
385 batch_bytes_before_current: u64,
386 batch_bytes_total: u64,
387 batch_started_at: Instant,
388 shared: Arc<Mutex<CallbackProgressState>>,
389}
390
391struct CallbackProgressState {
392 accumulated: u64,
393 total: u64,
394 filename: String,
395 last_emit: Instant,
396}
397
398impl CallbackProgress {
399 fn new(
400 callback: DownloadProgressCallback,
401 file_index: usize,
402 total_files: usize,
403 batch_bytes_before_current: u64,
404 batch_bytes_total: u64,
405 batch_started_at: Instant,
406 ) -> Self {
407 Self {
408 callback,
409 file_index,
410 total_files,
411 batch_bytes_before_current,
412 batch_bytes_total,
413 batch_started_at,
414 shared: Arc::new(Mutex::new(CallbackProgressState {
415 accumulated: 0,
416 total: 0,
417 filename: String::new(),
418 last_emit: Instant::now(),
419 })),
420 }
421 }
422}
423
424impl Progress for CallbackProgress {
425 async fn init(&mut self, size: usize, filename: &str) {
426 let (fname, total) = {
427 let mut shared = self
428 .shared
429 .lock()
430 .expect("download progress mutex poisoned");
431 shared.total = size as u64;
432 shared.accumulated = 0;
433 shared.filename = filename.to_string();
434 shared.last_emit = Instant::now();
435 (shared.filename.clone(), shared.total)
436 };
437 (self.callback)(DownloadProgressEvent::FileStart {
438 filename: fname,
439 file_index: self.file_index,
440 total_files: self.total_files,
441 size_bytes: total,
442 batch_bytes_downloaded: self.batch_bytes_before_current,
443 batch_bytes_total: self.batch_bytes_total,
444 batch_elapsed_ms: self.batch_started_at.elapsed().as_millis() as u64,
445 });
446 }
447
448 async fn update(&mut self, size: usize) {
449 let mut shared = self
450 .shared
451 .lock()
452 .expect("download progress mutex poisoned");
453 shared.accumulated += size as u64;
454
455 let now = Instant::now();
456 let should_emit = now.duration_since(shared.last_emit).as_millis() >= 250
457 || shared.accumulated >= shared.total;
458 if !should_emit {
459 return;
460 }
461
462 shared.last_emit = now;
463 let filename = shared.filename.clone();
464 let accumulated = shared.accumulated;
465 let total = shared.total;
466 drop(shared);
467
468 (self.callback)(DownloadProgressEvent::FileProgress {
469 filename,
470 file_index: self.file_index,
471 bytes_downloaded: accumulated,
472 bytes_total: total,
473 batch_bytes_downloaded: self.batch_bytes_before_current + accumulated,
474 batch_bytes_total: self.batch_bytes_total,
475 batch_elapsed_ms: self.batch_started_at.elapsed().as_millis() as u64,
476 });
477 }
478
479 async fn finish(&mut self) {
480 let (fname, total) = {
481 let shared = self
482 .shared
483 .lock()
484 .expect("download progress mutex poisoned");
485 (shared.filename.clone(), shared.total)
486 };
487 (self.callback)(DownloadProgressEvent::FileDone {
488 filename: fname,
489 file_index: self.file_index,
490 total_files: self.total_files,
491 batch_bytes_downloaded: self.batch_bytes_before_current + total,
492 batch_bytes_total: self.batch_bytes_total,
493 batch_elapsed_ms: self.batch_started_at.elapsed().as_millis() as u64,
494 });
495 }
496}
497
498struct SyncDownloadProgress {
501 bar: ProgressBar,
502 max_msg_len: usize,
503 filename: String,
504}
505
506impl SyncDownloadProgress {
507 fn new(bar: ProgressBar, max_msg_len: usize) -> Self {
508 Self {
509 bar,
510 max_msg_len,
511 filename: String::new(),
512 }
513 }
514}
515
516impl hf_hub::api::Progress for SyncDownloadProgress {
517 fn init(&mut self, size: usize, filename: &str) {
518 self.bar.set_length(size as u64);
519 self.filename = truncate_filename(filename, self.max_msg_len);
520 self.bar.set_message(self.filename.clone());
521 }
522
523 fn update(&mut self, size: usize) {
524 self.bar.inc(size as u64);
525 }
526
527 fn finish(&mut self) {
528 self.bar.finish_with_message(self.filename.clone());
529 }
530}
531
532fn is_already_placed(
538 clean_path: &std::path::Path,
539 file: &ModelFile,
540 model_name: &str,
541 skip_verify: bool,
542) -> bool {
543 let size_ok = clean_path
544 .metadata()
545 .map(|m| m.len() == file.size_bytes)
546 .unwrap_or(false);
547 if !size_ok {
548 return false;
549 }
550 verify_file_integrity(clean_path, file, model_name, skip_verify).is_ok()
552}
553
554fn find_existing_placed_file(
557 models_dir: &std::path::Path,
558 manifest: &ModelManifest,
559 file: &ModelFile,
560 skip_verify: bool,
561) -> Result<Option<PathBuf>, DownloadError> {
562 let canonical_rel = crate::manifest::storage_path(manifest, file);
563 let canonical_path = models_dir.join(&canonical_rel);
564
565 for candidate_rel in crate::manifest::storage_path_candidates(manifest, file) {
566 let candidate_path = models_dir.join(candidate_rel);
567 if !is_already_placed(&candidate_path, file, &manifest.name, skip_verify) {
568 continue;
569 }
570 if candidate_path != canonical_path {
571 hardlink_or_copy(&candidate_path, &canonical_path)?;
572 verify_file_integrity(&canonical_path, file, &manifest.name, skip_verify)?;
573 }
574 return Ok(Some(canonical_path));
575 }
576
577 Ok(None)
578}
579
580pub async fn pull_model(
590 manifest: &ModelManifest,
591 opts: &PullOptions,
592) -> Result<ModelPaths, DownloadError> {
593 write_pulling_marker(&manifest.name)?;
594
595 let mut builder = ApiBuilder::from_env().with_cache_dir(hf_cache_dir());
596 if let Some(token) = resolve_hf_token() {
597 builder = builder.with_token(Some(token));
598 }
599 let api = builder.build()?;
600
601 let multi = MultiProgress::with_draw_target(ProgressDrawTarget::stderr());
602 let msg_width = filename_column_width();
603 let bar_style = ProgressStyle::with_template(&format!(
604 " {{msg:<{msg_width}}} [{{bar:30.cyan/dim}}] {{bytes}}/{{total_bytes}} ({{bytes_per_sec}}, {{eta}})"
605 ))
606 .unwrap()
607 .progress_chars("━╸─");
608
609 let mdir = models_dir();
610 let mut downloads: Vec<(ModelComponent, PathBuf)> = Vec::new();
611
612 for file in &manifest.files {
613 if let Some(clean_path) =
614 find_existing_placed_file(&mdir, manifest, file, opts.skip_verify)?
615 {
616 downloads.push((file.component, clean_path));
617 continue;
618 }
619
620 let clean_path = mdir.join(crate::manifest::storage_path(manifest, file));
621
622 let bar = multi.add(ProgressBar::new(file.size_bytes));
623 bar.set_style(bar_style.clone());
624 bar.set_message(truncate_filename(&file.hf_filename, msg_width));
625
626 let hf_path = download_file(
627 &api,
628 file,
629 DownloadProgress::new(bar, msg_width),
630 &manifest.name,
631 )
632 .await?;
633
634 hardlink_or_copy(&hf_path, &clean_path)?;
636
637 verify_file_integrity(&clean_path, file, &manifest.name, opts.skip_verify)?;
638
639 downloads.push((file.component, clean_path));
640 }
641
642 remove_pulling_marker(&manifest.name);
643 paths_from_downloads(&downloads, &manifest.family).ok_or(DownloadError::MissingComponent)
644}
645
646pub async fn pull_model_with_callback(
651 manifest: &ModelManifest,
652 callback: DownloadProgressCallback,
653 opts: &PullOptions,
654) -> Result<ModelPaths, DownloadError> {
655 write_pulling_marker(&manifest.name)?;
656
657 let mut builder = ApiBuilder::from_env().with_cache_dir(hf_cache_dir());
658 if let Some(token) = resolve_hf_token() {
659 builder = builder.with_token(Some(token));
660 }
661 let api = builder.build()?;
662
663 let mdir = models_dir();
664 let mut downloads: Vec<(ModelComponent, PathBuf)> = Vec::new();
665
666 let manifest_clone = manifest.clone();
670 let skip_verify = opts.skip_verify;
671 let mdir_clone = mdir.clone();
672 let cb = callback.clone();
673 let file_status: Vec<bool> = tokio::task::spawn_blocking(move || {
674 let total = manifest_clone.files.len();
675 manifest_clone
676 .files
677 .iter()
678 .enumerate()
679 .map(|(i, file)| {
680 cb(DownloadProgressEvent::Status {
681 message: format!(
682 "Verifying file [{}/{}] {}...",
683 i + 1,
684 total,
685 file.hf_filename
686 ),
687 });
688 find_existing_placed_file(&mdir_clone, &manifest_clone, file, skip_verify)
689 .map(|p| p.is_some())
690 .unwrap_or(false)
691 })
692 .collect()
693 })
694 .await
695 .map_err(|e| DownloadError::Other(format!("pre-scan task failed: {e}")))?;
696
697 let total_bytes_to_download: u64 = manifest
698 .files
699 .iter()
700 .zip(file_status.iter())
701 .filter(|(_, &placed)| !placed)
702 .map(|(file, _)| file.size_bytes)
703 .sum();
704 let total_files_count = manifest.files.len();
705 let mut completed_bytes = 0u64;
706 let batch_started_at = Instant::now();
707
708 for (file_pos, (file, &already_placed)) in
709 manifest.files.iter().zip(file_status.iter()).enumerate()
710 {
711 let clean_path = mdir.join(crate::manifest::storage_path(manifest, file));
712
713 if already_placed {
714 let elapsed = batch_started_at.elapsed().as_millis() as u64;
716 (callback)(DownloadProgressEvent::FileStart {
717 filename: file.hf_filename.clone(),
718 file_index: file_pos,
719 total_files: total_files_count,
720 size_bytes: file.size_bytes,
721 batch_bytes_downloaded: completed_bytes,
722 batch_bytes_total: total_bytes_to_download,
723 batch_elapsed_ms: elapsed,
724 });
725 (callback)(DownloadProgressEvent::FileDone {
726 filename: file.hf_filename.clone(),
727 file_index: file_pos,
728 total_files: total_files_count,
729 batch_bytes_downloaded: completed_bytes,
730 batch_bytes_total: total_bytes_to_download,
731 batch_elapsed_ms: elapsed,
732 });
733 downloads.push((file.component, clean_path));
734 continue;
735 }
736
737 let progress = CallbackProgress::new(
738 callback.clone(),
739 file_pos,
740 total_files_count,
741 completed_bytes,
742 total_bytes_to_download,
743 batch_started_at,
744 );
745 let hf_path = download_file(&api, file, progress, &manifest.name).await?;
746
747 hardlink_or_copy(&hf_path, &clean_path)?;
748
749 verify_file_integrity(&clean_path, file, &manifest.name, opts.skip_verify)?;
750
751 downloads.push((file.component, clean_path));
752 completed_bytes += file.size_bytes;
753 }
754
755 remove_pulling_marker(&manifest.name);
756 paths_from_downloads(&downloads, &manifest.family).ok_or(DownloadError::MissingComponent)
757}
758
759async fn pull_model_files_only(
764 manifest: &ModelManifest,
765 opts: &PullOptions,
766) -> Result<(), DownloadError> {
767 write_pulling_marker(&manifest.name)?;
768
769 let mut builder = ApiBuilder::from_env().with_cache_dir(hf_cache_dir());
770 if let Some(token) = resolve_hf_token() {
771 builder = builder.with_token(Some(token));
772 }
773 let api = builder.build()?;
774
775 let multi = MultiProgress::with_draw_target(ProgressDrawTarget::stderr());
776 let msg_width = filename_column_width();
777 let bar_style = ProgressStyle::with_template(&format!(
778 " {{msg:<{msg_width}}} [{{bar:30.cyan/dim}}] {{bytes}}/{{total_bytes}} ({{bytes_per_sec}}, {{eta}})"
779 ))
780 .unwrap()
781 .progress_chars("━╸─");
782
783 let mdir = models_dir();
784
785 for file in &manifest.files {
786 if find_existing_placed_file(&mdir, manifest, file, opts.skip_verify)?.is_some() {
787 continue;
788 }
789
790 let clean_path = mdir.join(crate::manifest::storage_path(manifest, file));
791
792 let bar = multi.add(ProgressBar::new(file.size_bytes));
793 bar.set_style(bar_style.clone());
794 bar.set_message(truncate_filename(&file.hf_filename, msg_width));
795
796 let hf_path = download_file(
797 &api,
798 file,
799 DownloadProgress::new(bar, msg_width),
800 &manifest.name,
801 )
802 .await?;
803
804 hardlink_or_copy(&hf_path, &clean_path)?;
805
806 verify_file_integrity(&clean_path, file, &manifest.name, opts.skip_verify)?;
807 }
808
809 remove_pulling_marker(&manifest.name);
810 Ok(())
811}
812
813async fn pull_model_files_only_with_callback(
815 manifest: &ModelManifest,
816 callback: DownloadProgressCallback,
817 opts: &PullOptions,
818) -> Result<(), DownloadError> {
819 write_pulling_marker(&manifest.name)?;
820
821 let mut builder = ApiBuilder::from_env().with_cache_dir(hf_cache_dir());
822 if let Some(token) = resolve_hf_token() {
823 builder = builder.with_token(Some(token));
824 }
825 let api = builder.build()?;
826
827 let mdir = models_dir();
828
829 let manifest_clone = manifest.clone();
830 let skip_verify = opts.skip_verify;
831 let mdir_clone = mdir.clone();
832 let cb = callback.clone();
833 let file_status: Vec<bool> = tokio::task::spawn_blocking(move || {
834 let total = manifest_clone.files.len();
835 manifest_clone
836 .files
837 .iter()
838 .enumerate()
839 .map(|(i, file)| {
840 cb(DownloadProgressEvent::Status {
841 message: format!(
842 "Verifying file [{}/{}] {}...",
843 i + 1,
844 total,
845 file.hf_filename
846 ),
847 });
848 find_existing_placed_file(&mdir_clone, &manifest_clone, file, skip_verify)
849 .map(|p| p.is_some())
850 .unwrap_or(false)
851 })
852 .collect()
853 })
854 .await
855 .map_err(|e| DownloadError::Other(format!("pre-scan task failed: {e}")))?;
856 let total_bytes_to_download: u64 = manifest
857 .files
858 .iter()
859 .zip(file_status.iter())
860 .filter(|(_, &placed)| !placed)
861 .map(|(file, _)| file.size_bytes)
862 .sum();
863 let total_files_count = manifest.files.len();
864 let mut completed_bytes = 0u64;
865 let batch_started_at = Instant::now();
866
867 for (file_pos, (file, &already_placed)) in
868 manifest.files.iter().zip(file_status.iter()).enumerate()
869 {
870 let clean_path = mdir.join(crate::manifest::storage_path(manifest, file));
871
872 if already_placed {
873 let elapsed = batch_started_at.elapsed().as_millis() as u64;
874 (callback)(DownloadProgressEvent::FileStart {
875 filename: file.hf_filename.clone(),
876 file_index: file_pos,
877 total_files: total_files_count,
878 size_bytes: file.size_bytes,
879 batch_bytes_downloaded: completed_bytes,
880 batch_bytes_total: total_bytes_to_download,
881 batch_elapsed_ms: elapsed,
882 });
883 (callback)(DownloadProgressEvent::FileDone {
884 filename: file.hf_filename.clone(),
885 file_index: file_pos,
886 total_files: total_files_count,
887 batch_bytes_downloaded: completed_bytes,
888 batch_bytes_total: total_bytes_to_download,
889 batch_elapsed_ms: elapsed,
890 });
891 continue;
892 }
893
894 let progress = CallbackProgress::new(
895 callback.clone(),
896 file_pos,
897 total_files_count,
898 completed_bytes,
899 total_bytes_to_download,
900 batch_started_at,
901 );
902
903 let hf_path = download_file(&api, file, progress, &manifest.name).await?;
904
905 hardlink_or_copy(&hf_path, &clean_path)?;
906
907 verify_file_integrity(&clean_path, file, &manifest.name, opts.skip_verify)?;
908 completed_bytes += file.size_bytes;
909 }
910
911 remove_pulling_marker(&manifest.name);
912 Ok(())
913}
914
915fn extract_http_status(err: &ApiError) -> Option<u16> {
917 if let ApiError::RequestError(reqwest_err) = err {
918 reqwest_err.status().map(|s| s.as_u16())
919 } else {
920 None
921 }
922}
923
924async fn download_file<P: Progress + Clone + Send + Sync + 'static>(
925 api: &Api,
926 file: &ModelFile,
927 progress: P,
928 model_name: &str,
929) -> Result<PathBuf, DownloadError> {
930 let repo = api.repo(Repo::new(file.hf_repo.clone(), RepoType::Model));
931
932 match repo
933 .download_with_progress(&file.hf_filename, progress)
934 .await
935 {
936 Ok(path) => Ok(path),
937 Err(e) => {
938 let status = extract_http_status(&e);
939 let err_str = e.to_string();
940 if status == Some(401) || err_str.contains("401") || err_str.contains("Unauthorized") {
941 Err(DownloadError::Unauthorized {
942 repo: file.hf_repo.clone(),
943 model: model_name.to_string(),
944 })
945 } else if status == Some(403)
946 || err_str.contains("403")
947 || err_str.contains("Forbidden")
948 || err_str.contains("gated")
949 || err_str.contains("Access denied")
950 {
951 Err(DownloadError::GatedModel {
952 repo: file.hf_repo.clone(),
953 model: model_name.to_string(),
954 })
955 } else {
956 Err(DownloadError::DownloadFailed {
957 repo: file.hf_repo.clone(),
958 filename: file.hf_filename.clone(),
959 source: e,
960 })
961 }
962 }
963 }
964}
965
966pub fn download_single_file_sync(
976 hf_repo: &str,
977 hf_filename: &str,
978 target_subdir: Option<&str>,
979) -> Result<PathBuf, DownloadError> {
980 use hf_hub::api::sync::ApiBuilder;
981
982 let mut builder = ApiBuilder::from_env()
983 .with_cache_dir(hf_cache_dir())
984 .with_progress(false);
985 if let Some(token) = resolve_hf_token() {
986 builder = builder.with_token(Some(token));
987 }
988 let api = builder
989 .build()
990 .map_err(|e| DownloadError::SyncApiSetup(e.to_string()))?;
991 let repo = api.repo(Repo::new(hf_repo.to_string(), RepoType::Model));
992 let msg_width = filename_column_width();
993 let bar_style = ProgressStyle::with_template(&format!(
994 " {{msg:<{msg_width}}} [{{bar:30.cyan/dim}}] {{bytes}}/{{total_bytes}} ({{bytes_per_sec}}, {{eta}})"
995 ))
996 .unwrap()
997 .progress_chars("━╸─");
998 let bar = ProgressBar::new(0);
999 bar.set_style(bar_style);
1000 bar.set_message(truncate_filename(hf_filename, msg_width));
1001 let progress = SyncDownloadProgress::new(bar, msg_width);
1002 let hf_path = repo
1003 .download_with_progress(hf_filename, progress)
1004 .map_err(|e| {
1005 let err_str = e.to_string();
1006 if err_str.contains("401") || err_str.contains("Unauthorized") {
1007 DownloadError::Unauthorized {
1008 repo: hf_repo.to_string(),
1009 model: String::new(),
1010 }
1011 } else if err_str.contains("403")
1012 || err_str.contains("Forbidden")
1013 || err_str.contains("gated")
1014 || err_str.contains("Access denied")
1015 {
1016 DownloadError::GatedModel {
1017 repo: hf_repo.to_string(),
1018 model: String::new(),
1019 }
1020 } else {
1021 DownloadError::SyncDownloadFailed {
1022 repo: hf_repo.to_string(),
1023 filename: hf_filename.to_string(),
1024 message: err_str,
1025 }
1026 }
1027 })?;
1028
1029 if let Some(subdir) = target_subdir {
1031 let leaf = hf_filename.rsplit('/').next().unwrap_or(hf_filename);
1032 let clean_path = models_dir().join(subdir).join(leaf);
1033 hardlink_or_copy(&hf_path, &clean_path)?;
1034 Ok(clean_path)
1035 } else {
1036 Ok(hf_path)
1037 }
1038}
1039
1040pub fn cached_file_path(
1046 hf_repo: &str,
1047 hf_filename: &str,
1048 target_subdir: Option<&str>,
1049) -> Option<PathBuf> {
1050 if let Some(subdir) = target_subdir {
1052 let leaf = hf_filename.rsplit('/').next().unwrap_or(hf_filename);
1053 let clean_path = models_dir().join(subdir).join(leaf);
1054 if clean_path.exists() {
1055 return Some(clean_path);
1056 }
1057 }
1058
1059 let new_cache = Cache::new(hf_cache_dir());
1061 let new_repo = new_cache.repo(Repo::new(hf_repo.to_string(), RepoType::Model));
1062 if let Some(path) = new_repo.get(hf_filename) {
1063 return Some(path);
1064 }
1065
1066 let old_cache = Cache::new(models_dir());
1068 let old_repo = old_cache.repo(Repo::new(hf_repo.to_string(), RepoType::Model));
1069 if let Some(path) = old_repo.get(hf_filename) {
1070 return Some(path);
1071 }
1072
1073 let default_cache = Cache::from_env();
1075 let default_repo = default_cache.repo(Repo::new(hf_repo.to_string(), RepoType::Model));
1076 default_repo.get(hf_filename)
1077}
1078
1079pub async fn pull_and_configure(
1085 model: &str,
1086 opts: &PullOptions,
1087) -> Result<(crate::Config, Option<ModelPaths>), DownloadError> {
1088 use crate::config::Config;
1089 use crate::manifest::{find_manifest, resolve_model_name};
1090
1091 let canonical = resolve_model_name(model);
1092
1093 let manifest = find_manifest(&canonical).ok_or_else(|| DownloadError::UnknownModel {
1094 model: model.to_string(),
1095 })?;
1096
1097 if manifest.is_utility() {
1099 pull_model_files_only(manifest, opts).await?;
1100 let config = Config::load_or_default();
1101 return Ok((config, None));
1102 }
1103
1104 if manifest.is_upscaler() {
1107 pull_model_files_only(manifest, opts).await?;
1108
1109 let mdir = models_dir();
1111 let weights_file = manifest
1112 .files
1113 .iter()
1114 .find(|f| f.component == crate::manifest::ModelComponent::Upscaler)
1115 .ok_or(DownloadError::MissingComponent)?;
1116 let weights_path = mdir.join(crate::manifest::storage_path(manifest, weights_file));
1117
1118 let mut config = Config::load_or_default();
1119 let model_config = crate::config::ModelConfig {
1120 transformer: Some(weights_path.to_string_lossy().to_string()),
1121 family: Some("upscaler".to_string()),
1122 ..Default::default()
1123 };
1124 config.upsert_model(manifest.name.clone(), model_config);
1125 config
1126 .save()
1127 .map_err(|e| DownloadError::ConfigSave(e.to_string()))?;
1128
1129 return Ok((config, None));
1130 }
1131
1132 let paths = pull_model(manifest, opts).await?;
1133
1134 let mut config = Config::load_or_default();
1135 let model_config = manifest.to_model_config(&paths);
1136
1137 if !Config::exists_on_disk() {
1139 config.default_model = manifest.name.clone();
1140 }
1141
1142 config.upsert_model(manifest.name.clone(), model_config);
1143 config
1144 .save()
1145 .map_err(|e| DownloadError::ConfigSave(e.to_string()))?;
1146
1147 Ok((config, Some(paths)))
1148}
1149
1150pub async fn pull_and_configure_with_callback(
1153 model: &str,
1154 callback: DownloadProgressCallback,
1155 opts: &PullOptions,
1156) -> Result<(crate::Config, Option<ModelPaths>), DownloadError> {
1157 use crate::config::Config;
1158 use crate::manifest::{find_manifest, resolve_model_name};
1159
1160 let canonical = resolve_model_name(model);
1161
1162 let manifest = find_manifest(&canonical).ok_or_else(|| DownloadError::UnknownModel {
1163 model: model.to_string(),
1164 })?;
1165
1166 if manifest.is_utility() {
1168 pull_model_files_only_with_callback(manifest, callback, opts).await?;
1169 let config = Config::load_or_default();
1170 return Ok((config, None));
1171 }
1172
1173 if manifest.is_upscaler() {
1175 pull_model_files_only_with_callback(manifest, callback, opts).await?;
1176
1177 let mdir = models_dir();
1178 let weights_file = manifest
1179 .files
1180 .iter()
1181 .find(|f| f.component == crate::manifest::ModelComponent::Upscaler)
1182 .ok_or(DownloadError::MissingComponent)?;
1183 let weights_path = mdir.join(crate::manifest::storage_path(manifest, weights_file));
1184
1185 let mut config = Config::load_or_default();
1186 let model_config = crate::config::ModelConfig {
1187 transformer: Some(weights_path.to_string_lossy().to_string()),
1188 family: Some("upscaler".to_string()),
1189 ..Default::default()
1190 };
1191 config.upsert_model(manifest.name.clone(), model_config);
1192 config
1193 .save()
1194 .map_err(|e| DownloadError::ConfigSave(e.to_string()))?;
1195
1196 return Ok((config, None));
1197 }
1198
1199 let paths = pull_model_with_callback(manifest, callback, opts).await?;
1200
1201 let mut config = Config::load_or_default();
1202 let model_config = manifest.to_model_config(&paths);
1203
1204 if !Config::exists_on_disk() {
1205 config.default_model = manifest.name.clone();
1206 }
1207
1208 config.upsert_model(manifest.name.clone(), model_config);
1209 config
1210 .save()
1211 .map_err(|e| DownloadError::ConfigSave(e.to_string()))?;
1212
1213 Ok((config, Some(paths)))
1214}
1215
1216#[cfg(test)]
1217mod tests {
1218 use super::*;
1219
1220 #[test]
1221 fn truncate_short_name_unchanged() {
1222 assert_eq!(truncate_filename("ae.safetensors", 45), "ae.safetensors");
1223 }
1224
1225 #[test]
1226 fn truncate_exact_fit_unchanged() {
1227 let name = "x".repeat(30);
1228 assert_eq!(truncate_filename(&name, 30), name);
1229 }
1230
1231 #[test]
1232 fn truncate_long_name_keeps_suffix() {
1233 let result = truncate_filename("unet/diffusion_pytorch_model.fp16.safetensors", 30);
1234 assert_eq!(result.len(), 30);
1235 assert!(result.starts_with("..."));
1236 assert!(result.ends_with(".fp16.safetensors"));
1237 }
1238
1239 #[test]
1240 fn truncate_very_small_max_returns_original() {
1241 let name = "something.safetensors";
1243 assert_eq!(truncate_filename(name, 5), name);
1244 }
1245
1246 #[tokio::test]
1247 async fn callback_progress_clones_share_accumulated_bytes() {
1248 let events = Arc::new(Mutex::new(Vec::new()));
1249 let events_for_cb = events.clone();
1250 let callback: DownloadProgressCallback = Arc::new(move |event| {
1251 events_for_cb
1252 .lock()
1253 .expect("events mutex poisoned")
1254 .push(event);
1255 });
1256
1257 let mut progress = CallbackProgress::new(callback, 1, 3, 1_000, 10_000, Instant::now());
1258 progress.init(1_024, "weights.safetensors").await;
1259
1260 let mut chunk_a = progress.clone();
1261 let mut chunk_b = progress.clone();
1262 chunk_a.update(512).await;
1263 chunk_b.update(512).await;
1264 progress.finish().await;
1265
1266 let events = events.lock().expect("events mutex poisoned");
1267 assert!(events.iter().any(|event| matches!(
1268 event,
1269 DownloadProgressEvent::FileProgress {
1270 bytes_downloaded: 1_024,
1271 bytes_total: 1_024,
1272 batch_bytes_downloaded: 2_024,
1273 ..
1274 }
1275 )));
1276 }
1277
1278 #[test]
1279 fn download_error_gated_message() {
1280 let err = DownloadError::GatedModel {
1281 repo: "black-forest-labs/FLUX.1-dev".to_string(),
1282 model: "flux-dev:q8".to_string(),
1283 };
1284 let msg = err.to_string();
1285 assert!(msg.contains("huggingface.co/black-forest-labs/FLUX.1-dev"));
1286 assert!(msg.contains("HF_TOKEN"));
1287 assert!(msg.contains("mold pull flux-dev:q8"));
1288 }
1289
1290 #[test]
1291 fn download_error_unauthorized_message() {
1292 let err = DownloadError::Unauthorized {
1293 repo: "black-forest-labs/FLUX.1-schnell".to_string(),
1294 model: "flux-schnell:q8".to_string(),
1295 };
1296 let msg = err.to_string();
1297 assert!(msg.contains("Authentication required"));
1298 assert!(msg.contains("black-forest-labs/FLUX.1-schnell"));
1299 assert!(msg.contains("HF_TOKEN"));
1300 assert!(msg.contains("huggingface-cli login"));
1301 assert!(msg.contains("mold pull flux-schnell:q8"));
1302 }
1303
1304 static HF_TOKEN_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
1307
1308 #[test]
1309 fn resolve_hf_token_reads_env_var() {
1310 let _guard = HF_TOKEN_LOCK.lock().unwrap();
1311 let original = std::env::var("HF_TOKEN").ok();
1312 std::env::set_var("HF_TOKEN", "hf_test_token_123");
1313 let token = resolve_hf_token();
1314 match &original {
1316 Some(v) => std::env::set_var("HF_TOKEN", v),
1317 None => std::env::remove_var("HF_TOKEN"),
1318 }
1319 assert_eq!(token, Some("hf_test_token_123".to_string()));
1320 }
1321
1322 #[test]
1323 fn resolve_hf_token_ignores_empty_env() {
1324 let _guard = HF_TOKEN_LOCK.lock().unwrap();
1325 let original = std::env::var("HF_TOKEN").ok();
1326 std::env::set_var("HF_TOKEN", " ");
1327 let token = resolve_hf_token();
1328 match &original {
1330 Some(v) => std::env::set_var("HF_TOKEN", v),
1331 None => std::env::remove_var("HF_TOKEN"),
1332 }
1333 assert_ne!(token, Some(" ".to_string()));
1335 }
1336
1337 #[test]
1338 fn compute_sha256_correct_digest() {
1339 let dir = std::env::temp_dir().join("mold_test_sha256_compute");
1340 let _ = std::fs::create_dir_all(&dir);
1341 let path = dir.join("test_file.bin");
1342 std::fs::write(&path, b"hello world").unwrap();
1343 let digest = compute_sha256(&path).unwrap();
1344 assert_eq!(
1345 digest,
1346 "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"
1347 );
1348 let _ = std::fs::remove_dir_all(&dir);
1349 }
1350
1351 #[test]
1352 fn verify_sha256_matches() {
1353 let dir = std::env::temp_dir().join("mold_test_sha256_match");
1354 let _ = std::fs::create_dir_all(&dir);
1355 let path = dir.join("test_file.bin");
1356 std::fs::write(&path, b"hello world").unwrap();
1357 let expected = "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9";
1359 assert!(verify_sha256(&path, expected).unwrap());
1360 let _ = std::fs::remove_dir_all(&dir);
1361 }
1362
1363 #[test]
1364 fn verify_sha256_mismatch() {
1365 let dir = std::env::temp_dir().join("mold_test_sha256_mismatch");
1366 let _ = std::fs::create_dir_all(&dir);
1367 let path = dir.join("test_file.bin");
1368 std::fs::write(&path, b"hello world").unwrap();
1369 let wrong = "0000000000000000000000000000000000000000000000000000000000000000";
1370 assert!(!verify_sha256(&path, wrong).unwrap());
1371 let _ = std::fs::remove_dir_all(&dir);
1372 }
1373
1374 #[test]
1375 fn verify_file_integrity_deletes_on_mismatch() {
1376 use crate::manifest::{ModelComponent, ModelFile};
1377 let dir = std::env::temp_dir().join("mold_test_integrity_mismatch");
1378 let _ = std::fs::create_dir_all(&dir);
1379 let path = dir.join("corrupted.bin");
1380 std::fs::write(&path, b"corrupted data").unwrap();
1381
1382 let file = ModelFile {
1383 hf_repo: "test/repo".to_string(),
1384 hf_filename: "corrupted.bin".to_string(),
1385 component: ModelComponent::Transformer,
1386 size_bytes: 14,
1387 gated: false,
1388 sha256: Some("0000000000000000000000000000000000000000000000000000000000000000"),
1389 };
1390
1391 let result = verify_file_integrity(&path, &file, "test-model:q8", false);
1392 assert!(result.is_err());
1393 assert!(matches!(
1394 result.unwrap_err(),
1395 DownloadError::Sha256Mismatch { .. }
1396 ),);
1397 assert!(!path.exists());
1399 let _ = std::fs::remove_dir_all(&dir);
1400 }
1401
1402 #[test]
1403 fn verify_file_integrity_skip_verify_ignores_mismatch() {
1404 use crate::manifest::{ModelComponent, ModelFile};
1405 let dir = std::env::temp_dir().join("mold_test_integrity_skip");
1406 let _ = std::fs::create_dir_all(&dir);
1407 let path = dir.join("file.bin");
1408 std::fs::write(&path, b"some data").unwrap();
1409
1410 let file = ModelFile {
1411 hf_repo: "test/repo".to_string(),
1412 hf_filename: "file.bin".to_string(),
1413 component: ModelComponent::Transformer,
1414 size_bytes: 9,
1415 gated: false,
1416 sha256: Some("0000000000000000000000000000000000000000000000000000000000000000"),
1417 };
1418
1419 let result = verify_file_integrity(&path, &file, "test-model:q8", true);
1420 assert!(result.is_ok());
1421 assert!(path.exists());
1423 let _ = std::fs::remove_dir_all(&dir);
1424 }
1425
1426 #[test]
1427 fn verify_file_integrity_no_hash_is_ok() {
1428 use crate::manifest::{ModelComponent, ModelFile};
1429 let dir = std::env::temp_dir().join("mold_test_integrity_nohash");
1430 let _ = std::fs::create_dir_all(&dir);
1431 let path = dir.join("file.bin");
1432 std::fs::write(&path, b"data").unwrap();
1433
1434 let file = ModelFile {
1435 hf_repo: "test/repo".to_string(),
1436 hf_filename: "file.bin".to_string(),
1437 component: ModelComponent::Transformer,
1438 size_bytes: 4,
1439 gated: false,
1440 sha256: None,
1441 };
1442
1443 assert!(verify_file_integrity(&path, &file, "test:q8", false).is_ok());
1444 let _ = std::fs::remove_dir_all(&dir);
1445 }
1446
1447 #[test]
1448 fn pulling_marker_roundtrip() {
1449 let dir = std::env::temp_dir().join("mold_test_marker_roundtrip");
1450 let _ = std::fs::create_dir_all(&dir);
1451 let marker = dir.join(".pulling");
1452
1453 std::fs::write(&marker, "test-model:q8").unwrap();
1455 assert!(marker.exists());
1456
1457 let _ = std::fs::remove_file(&marker);
1459 assert!(!marker.exists());
1460
1461 let _ = std::fs::remove_dir_all(&dir);
1462 }
1463
1464 #[test]
1465 fn sha256_mismatch_error_message() {
1466 let err = DownloadError::Sha256Mismatch {
1467 filename: "transformer.gguf".to_string(),
1468 expected: "aaa".to_string(),
1469 actual: "bbb".to_string(),
1470 model: "flux-dev:q8".to_string(),
1471 };
1472 let msg = err.to_string();
1473 assert!(msg.contains("SHA-256 mismatch"));
1474 assert!(msg.contains("transformer.gguf"));
1475 assert!(msg.contains("mold pull flux-dev:q8"));
1476 assert!(msg.contains("--skip-verify"));
1477 }
1478}