1use std::collections::HashSet;
20use std::fs::{self, File};
21use std::io::{BufReader, Read, Write};
22use std::path::{Path, PathBuf};
23use std::sync::Arc;
24use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
25use std::time::{Duration, Instant};
26
27use serde::{Deserialize, Serialize};
28use sha2::{Digest, Sha256};
29use thiserror::Error;
30use url::Url;
31
32use crate::search::policy::{ModelDownloadPolicy, SemanticPolicy};
33
34#[derive(Debug, Clone, PartialEq)]
49pub enum ModelState {
50 NotInstalled,
52 NeedsConsent,
54 Downloading {
56 progress_pct: u8,
58 bytes_downloaded: u64,
60 total_bytes: u64,
62 },
63 Verifying,
65 Ready,
67 Disabled { reason: String },
69 VerificationFailed { reason: String },
71 UpdateAvailable {
73 current_revision: String,
75 latest_revision: String,
77 },
78 Cancelled,
80}
81
82impl ModelState {
83 pub fn is_ready(&self) -> bool {
85 matches!(self, ModelState::Ready)
86 }
87
88 pub fn is_downloading(&self) -> bool {
90 matches!(self, ModelState::Downloading { .. })
91 }
92
93 pub fn needs_consent(&self) -> bool {
95 matches!(self, ModelState::NeedsConsent)
96 }
97
98 pub fn summary(&self) -> String {
100 match self {
101 ModelState::NotInstalled => "not installed".into(),
102 ModelState::NeedsConsent => "needs consent".into(),
103 ModelState::Downloading { progress_pct, .. } => {
104 format!("downloading ({progress_pct}%)")
105 }
106 ModelState::Verifying => "verifying".into(),
107 ModelState::Ready => "ready".into(),
108 ModelState::Disabled { reason } => format!("disabled: {reason}"),
109 ModelState::VerificationFailed { reason } => format!("verification failed: {reason}"),
110 ModelState::UpdateAvailable {
111 current_revision,
112 latest_revision,
113 } => {
114 format!("update available: {current_revision} -> {latest_revision}")
115 }
116 ModelState::Cancelled => "cancelled".into(),
117 }
118 }
119}
120
121#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
127pub struct ModelAcquisitionPolicy {
128 pub downloads_enabled: bool,
130 pub requires_consent: bool,
132 pub offline: bool,
134 pub metered: bool,
136 pub allow_metered: bool,
138 pub max_model_bytes: Option<u64>,
140 pub mirror_base_url: Option<String>,
142 pub config_source: String,
144}
145
146impl Default for ModelAcquisitionPolicy {
147 fn default() -> Self {
148 Self {
149 downloads_enabled: true,
150 requires_consent: true,
151 offline: false,
152 metered: false,
153 allow_metered: false,
154 max_model_bytes: None,
155 mirror_base_url: None,
156 config_source: "compiled_default".to_string(),
157 }
158 }
159}
160
161impl ModelAcquisitionPolicy {
162 pub fn from_semantic_policy(policy: &SemanticPolicy) -> Self {
164 const MIB: u64 = 1_048_576;
165
166 Self {
167 downloads_enabled: policy.mode.should_build_semantic(),
168 requires_consent: matches!(policy.download_policy, ModelDownloadPolicy::OptIn),
169 max_model_bytes: Some(policy.max_model_size_mb.saturating_mul(MIB)),
170 config_source: "semantic_policy".to_string(),
171 ..Self::default()
172 }
173 }
174}
175
176#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
178#[serde(rename_all = "snake_case", tag = "state")]
179pub enum ModelCacheState {
180 NotAcquired {
182 missing_files: Vec<String>,
184 needs_consent: bool,
186 },
187 Acquiring {
189 staging_dir: PathBuf,
190 bytes_present: u64,
191 total_bytes: u64,
192 },
193 Acquired { model_dir: PathBuf },
195 ChecksumMismatch {
197 file: String,
198 expected: String,
199 actual: String,
200 },
201 IncompatibleVersion {
203 current_revision: String,
204 expected_revision: String,
205 },
206 DisabledByPolicy { reason: String },
208 BudgetBlocked { required_bytes: u64, max_bytes: u64 },
210 QuarantinedCorrupt {
212 marker_path: PathBuf,
213 reason: String,
214 },
215 PreseededLocal { model_dir: PathBuf },
217 MirrorSourced {
219 model_dir: PathBuf,
220 mirror_base_url: String,
221 },
222 OfflineBlocked { missing_files: Vec<String> },
224}
225
226impl ModelCacheState {
227 pub fn code(&self) -> &'static str {
229 match self {
230 Self::NotAcquired { .. } => "not_acquired",
231 Self::Acquiring { .. } => "acquiring",
232 Self::Acquired { .. } => "acquired",
233 Self::ChecksumMismatch { .. } => "checksum_mismatch",
234 Self::IncompatibleVersion { .. } => "incompatible_version",
235 Self::DisabledByPolicy { .. } => "disabled_by_policy",
236 Self::BudgetBlocked { .. } => "budget_blocked",
237 Self::QuarantinedCorrupt { .. } => "quarantined_corrupt",
238 Self::PreseededLocal { .. } => "preseeded_local",
239 Self::MirrorSourced { .. } => "mirror_sourced",
240 Self::OfflineBlocked { .. } => "offline_blocked",
241 }
242 }
243
244 pub fn summary(&self) -> String {
246 match self {
247 Self::NotAcquired {
248 missing_files,
249 needs_consent,
250 } => {
251 let action = if *needs_consent {
252 "user consent required"
253 } else {
254 "ready to acquire"
255 };
256 format!(
257 "model not acquired ({action}); missing {}",
258 missing_files.join(", ")
259 )
260 }
261 Self::Acquiring {
262 bytes_present,
263 total_bytes,
264 staging_dir,
265 } => format!(
266 "model acquisition in progress at {} ({bytes_present}/{total_bytes} bytes)",
267 staging_dir.display()
268 ),
269 Self::Acquired { .. } => "model cache acquired and verified".to_string(),
270 Self::ChecksumMismatch {
271 file,
272 expected,
273 actual,
274 } => format!("checksum mismatch for {file}: expected {expected}, got {actual}"),
275 Self::IncompatibleVersion {
276 current_revision,
277 expected_revision,
278 } => format!("model revision mismatch: {current_revision} != {expected_revision}"),
279 Self::DisabledByPolicy { reason } => format!("model acquisition disabled: {reason}"),
280 Self::BudgetBlocked {
281 required_bytes,
282 max_bytes,
283 } => {
284 format!("model requires {required_bytes} bytes but policy allows {max_bytes} bytes")
285 }
286 Self::QuarantinedCorrupt { reason, .. } => {
287 format!("model cache quarantined: {reason}")
288 }
289 Self::PreseededLocal { .. } => "preseeded local model cache verified".to_string(),
290 Self::MirrorSourced {
291 mirror_base_url, ..
292 } => {
293 format!("model cache verified from mirror {mirror_base_url}")
294 }
295 Self::OfflineBlocked { missing_files } => {
296 format!(
297 "offline and model is not acquired; missing {}",
298 missing_files.join(", ")
299 )
300 }
301 }
302 }
303
304 pub fn next_step(&self) -> Option<&'static str> {
306 match self {
307 Self::NotAcquired { .. } => {
308 Some("Run `cass models install`, or keep using lexical search.")
309 }
310 Self::Acquiring { .. } => {
311 Some("Wait for model acquisition to finish, or keep using lexical search.")
312 }
313 Self::Acquired { .. } | Self::PreseededLocal { .. } | Self::MirrorSourced { .. } => {
314 None
315 }
316 Self::ChecksumMismatch { .. } | Self::QuarantinedCorrupt { .. } => Some(
317 "Run `cass models verify --repair`, or reinstall with `cass models install -y`.",
318 ),
319 Self::IncompatibleVersion { .. } => {
320 Some("Run `cass models install -y` to refresh the model cache.")
321 }
322 Self::DisabledByPolicy { .. } => {
323 Some("Use lexical search or re-enable semantic model acquisition in policy.")
324 }
325 Self::BudgetBlocked { .. } => {
326 Some("Increase the semantic model budget or keep using lexical search.")
327 }
328 Self::OfflineBlocked { .. } => Some(
329 "Reconnect or install from local files with `cass models install --from-file`.",
330 ),
331 }
332 }
333
334 pub fn is_usable(&self) -> bool {
336 matches!(
337 self,
338 Self::Acquired { .. } | Self::PreseededLocal { .. } | Self::MirrorSourced { .. }
339 )
340 }
341}
342
343#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
345pub struct ModelCacheReport {
346 pub model_id: String,
347 pub model_dir: PathBuf,
348 pub state: ModelCacheState,
349 pub required_size_bytes: u64,
350 pub installed_size_bytes: u64,
351 pub policy_source: String,
352}
353
354impl ModelCacheReport {
355 pub fn state_code(&self) -> &'static str {
357 self.state.code()
358 }
359
360 pub fn is_usable(&self) -> bool {
362 self.state.is_usable()
363 }
364}
365
366#[derive(Debug, Clone)]
370pub struct ModelFile {
371 pub name: String,
373 pub sha256: String,
375 pub size: u64,
377}
378
379impl ModelFile {
380 pub fn local_name(&self) -> &str {
385 self.name.rsplit('/').next().unwrap_or(&self.name)
386 }
387}
388
389#[derive(Debug, Clone)]
395pub struct ModelManifest {
396 pub id: String,
398 pub repo: String,
400 pub revision: String,
402 pub files: Vec<ModelFile>,
404 pub license: String,
406}
407
408pub const PLACEHOLDER_CHECKSUM: &str = "PLACEHOLDER_VERIFY_AFTER_DOWNLOAD";
413
414pub fn normalize_mirror_base_url(base_url: &str) -> Result<String, DownloadError> {
419 let trimmed = base_url.trim();
420 if trimmed.is_empty() {
421 return Err(invalid_mirror_url(base_url, "mirror URL cannot be empty"));
422 }
423
424 let parsed = Url::parse(trimmed).map_err(|err| invalid_mirror_url(trimmed, err.to_string()))?;
425
426 match parsed.scheme() {
427 "http" | "https" => {}
428 scheme => {
429 return Err(invalid_mirror_url(
430 trimmed,
431 format!("unsupported URL scheme '{scheme}' (expected http or https)"),
432 ));
433 }
434 }
435
436 if parsed.host_str().is_none() {
437 return Err(invalid_mirror_url(
438 trimmed,
439 "mirror URL must include a host",
440 ));
441 }
442
443 if parsed.query().is_some() || parsed.fragment().is_some() {
444 return Err(invalid_mirror_url(
445 trimmed,
446 "mirror URL must not include query or fragment components",
447 ));
448 }
449
450 Ok(parsed.to_string().trim_end_matches('/').to_string())
451}
452
453fn invalid_mirror_url(url: impl Into<String>, reason: impl Into<String>) -> DownloadError {
454 DownloadError::InvalidMirrorUrl {
455 url: url.into(),
456 reason: reason.into(),
457 }
458}
459
460impl ModelManifest {
461 pub fn has_verified_checksums(&self) -> bool {
466 self.files.iter().all(|f| f.sha256 != PLACEHOLDER_CHECKSUM)
467 }
468
469 pub fn has_pinned_revision(&self) -> bool {
474 self.revision != "main"
475 }
476
477 pub fn is_production_ready(&self) -> bool {
483 self.has_verified_checksums() && self.has_pinned_revision()
484 }
485
486 pub fn minilm_v2() -> Self {
491 Self {
492 id: "all-minilm-l6-v2".into(),
493 repo: "sentence-transformers/all-MiniLM-L6-v2".into(),
494 revision: "c9745ed1d9f207416be6d2e6f8de32d1f16199bf".into(),
496 files: vec![
497 ModelFile {
498 name: "onnx/model.onnx".into(),
500 sha256: "6fd5d72fe4589f189f8ebc006442dbb529bb7ce38f8082112682524616046452"
501 .into(),
502 size: 90405214,
503 },
504 ModelFile {
505 name: "tokenizer.json".into(),
506 sha256: "be50c3628f2bf5bb5e3a7f17b1f74611b2561a3a27eeab05e5aa30f411572037"
507 .into(),
508 size: 466247,
509 },
510 ModelFile {
511 name: "config.json".into(),
512 sha256: "953f9c0d463486b10a6871cc2fd59f223b2c70184f49815e7efbcab5d8908b41"
513 .into(),
514 size: 612,
515 },
516 ModelFile {
520 name: "special_tokens_map.json".into(),
521 sha256: "303df45a03609e4ead04bc3dc1536d0ab19b5358db685b6f3da123d05ec200e3"
522 .into(),
523 size: 112,
524 },
525 ModelFile {
526 name: "tokenizer_config.json".into(),
527 sha256: "acb92769e8195aabd29b7b2137a9e6d6e25c476a4f15aa4355c233426c61576b"
528 .into(),
529 size: 350,
530 },
531 ],
532 license: "Apache-2.0".into(),
533 }
534 }
535
536 pub fn snowflake_arctic_s() -> Self {
550 Self {
551 id: "snowflake-arctic-embed-s".into(),
552 repo: "Snowflake/snowflake-arctic-embed-s".into(),
553 revision: "e596f507467533e48a2e17c007f0e1dacc837b33".into(),
554 files: vec![
555 ModelFile {
556 name: "onnx/model.onnx".into(),
557 sha256: "579c1f1778a0993eb0d2a1403340ffb491c769247fb46acc4f5cf8ac5b89c1e1"
558 .into(),
559 size: 133_093_492,
560 },
561 ModelFile {
562 name: "tokenizer.json".into(),
563 sha256: "91f1def9b9391fdabe028cd3f3fcc4efd34e5d1f08c3bf2de513ebb5911a1854"
564 .into(),
565 size: 711_649,
566 },
567 ModelFile {
568 name: "config.json".into(),
569 sha256: "4e519aa92ec40943356032afe458c8829d70c5766b109e4a57490b82f72dcfb7"
570 .into(),
571 size: 703,
572 },
573 ModelFile {
574 name: "special_tokens_map.json".into(),
575 sha256: "5d5b662e421ea9fac075174bb0688ee0d9431699900b90662acd44b2a350503a"
576 .into(),
577 size: 695,
578 },
579 ModelFile {
580 name: "tokenizer_config.json".into(),
581 sha256: "9ca59277519f6e3692c8685e26b94d4afca2d5438deff66483db495e48735810"
582 .into(),
583 size: 1_433,
584 },
585 ],
586 license: "Apache-2.0".into(),
587 }
588 }
589
590 pub fn nomic_embed() -> Self {
598 Self {
599 id: "nomic-embed-text-v1.5".into(),
600 repo: "nomic-ai/nomic-embed-text-v1.5".into(),
601 revision: "e5cf08aadaa33385f5990def41f7a23405aec398".into(),
602 files: vec![
603 ModelFile {
604 name: "onnx/model.onnx".into(),
605 sha256: "147d5aa88c2101237358e17796cf3a227cead1ec304ec34b465bb08e9d952965"
606 .into(),
607 size: 547_310_275,
608 },
609 ModelFile {
610 name: "tokenizer.json".into(),
611 sha256: "d241a60d5e8f04cc1b2b3e9ef7a4921b27bf526d9f6050ab90f9267a1f9e5c66"
612 .into(),
613 size: 711_396,
614 },
615 ModelFile {
616 name: "config.json".into(),
617 sha256: "0168e0883705b0bf8f2b381e10f45a9f3e1ef4b13869b43c160e4c8a70ddf442"
618 .into(),
619 size: 2_331,
620 },
621 ModelFile {
622 name: "special_tokens_map.json".into(),
623 sha256: "5d5b662e421ea9fac075174bb0688ee0d9431699900b90662acd44b2a350503a"
624 .into(),
625 size: 695,
626 },
627 ModelFile {
628 name: "tokenizer_config.json".into(),
629 sha256: "d7e0000bcc80134debd2222220427e6bf5fa20a669f40a0d0d1409cc18e0a9bc"
630 .into(),
631 size: 1_191,
632 },
633 ],
634 license: "Apache-2.0".into(),
635 }
636 }
637
638 pub fn msmarco_reranker() -> Self {
650 Self {
651 id: "ms-marco-MiniLM-L-6-v2".into(),
652 repo: "cross-encoder/ms-marco-MiniLM-L6-v2".into(),
653 revision: "c5ee24cb16019beea0893ab7796b1df96625c6b8".into(),
654 files: vec![
655 ModelFile {
656 name: "onnx/model.onnx".into(),
657 sha256: "5d3e70fd0c9ff14b9b5169a51e957b7a9c74897afd0a35ce4bd318150c1d4d4a"
658 .into(),
659 size: 91_011_230,
660 },
661 ModelFile {
662 name: "tokenizer.json".into(),
663 sha256: "d241a60d5e8f04cc1b2b3e9ef7a4921b27bf526d9f6050ab90f9267a1f9e5c66"
664 .into(),
665 size: 711_396,
666 },
667 ModelFile {
668 name: "config.json".into(),
669 sha256: "380e02c93f431831be65d99a4e7e5f67c133985bf2e77d9d4eba46847190bacc"
670 .into(),
671 size: 794,
672 },
673 ModelFile {
674 name: "special_tokens_map.json".into(),
675 sha256: "3c3507f36dff57bce437223db3b3081d1e2b52ec3e56ee55438193ecb2c94dd6"
676 .into(),
677 size: 132,
678 },
679 ModelFile {
680 name: "tokenizer_config.json".into(),
681 sha256: "a5c2e5a7b1a29a0702cd28c08a399b5ecc110c263009d17f7e3b415f25905fd8"
682 .into(),
683 size: 1_330,
684 },
685 ],
686 license: "Apache-2.0".into(),
687 }
688 }
689
690 pub fn jina_reranker_turbo() -> Self {
697 Self {
698 id: "jina-reranker-v1-turbo-en".into(),
699 repo: "jinaai/jina-reranker-v1-turbo-en".into(),
700 revision: "b8c14f4e723d9e0aab4732a7b7b93741eeeb77c2".into(),
701 files: vec![
702 ModelFile {
703 name: "onnx/model.onnx".into(),
704 sha256: "c1296c66c119de645fa9cdee536d8637740efe85224cfa270281e50f213aa565"
705 .into(),
706 size: 151_296_975,
707 },
708 ModelFile {
709 name: "tokenizer.json".into(),
710 sha256: "0046da43cc8c424b317f56b092b0512aaaa65c4f925d2f16af9d9eeb4d0ef902"
711 .into(),
712 size: 2_030_772,
713 },
714 ModelFile {
715 name: "config.json".into(),
716 sha256: "e050ff6a15ae9295e84882fa0e98051bd8754856cd5201395ebf00ce9f2d609b"
717 .into(),
718 size: 1_206,
719 },
720 ModelFile {
721 name: "special_tokens_map.json".into(),
722 sha256: "06e405a36dfe4b9604f484f6a1e619af1a7f7d09e34a8555eb0b77b66318067f"
723 .into(),
724 size: 280,
725 },
726 ModelFile {
727 name: "tokenizer_config.json".into(),
728 sha256: "d291c6652d96d56ffdbcf1ea19d9bae5ed79003f7648c627e725a619227ce8fa"
729 .into(),
730 size: 1_215,
731 },
732 ],
733 license: "Apache-2.0".into(),
734 }
735 }
736
737 pub fn for_embedder(name: &str) -> Option<Self> {
741 match name {
742 "minilm" => Some(Self::minilm_v2()),
743 "snowflake-arctic-s" => Some(Self::snowflake_arctic_s()),
744 "nomic-embed" => Some(Self::nomic_embed()),
745 _ => None,
746 }
747 }
748
749 pub fn for_reranker(name: &str) -> Option<Self> {
751 match name {
752 "ms-marco" => Some(Self::msmarco_reranker()),
753 "jina-reranker-turbo" => Some(Self::jina_reranker_turbo()),
754 _ => None,
755 }
756 }
757
758 pub fn bakeoff_embedder_candidates() -> Vec<Self> {
762 vec![Self::snowflake_arctic_s(), Self::nomic_embed()]
763 }
764
765 pub fn bakeoff_reranker_candidates() -> Vec<Self> {
769 vec![Self::jina_reranker_turbo()]
770 }
771
772 pub fn bakeoff_candidates() -> Vec<Self> {
776 let mut candidates = Self::bakeoff_embedder_candidates();
777 candidates.extend(Self::bakeoff_reranker_candidates());
778 candidates
779 }
780
781 pub fn total_size(&self) -> u64 {
783 self.files.iter().map(|f| f.size).sum()
784 }
785
786 pub fn download_url_with_base(&self, file: &ModelFile, base_url: Option<&str>) -> String {
788 let root = base_url.unwrap_or("https://huggingface.co");
789 format!(
790 "{}/{}/resolve/{}/{}",
791 root.trim_end_matches('/'),
792 self.repo.trim_start_matches('/'),
793 self.revision,
794 file.name.trim_start_matches('/')
795 )
796 }
797
798 pub fn download_url(&self, file: &ModelFile) -> String {
800 self.download_url_with_base(file, None)
801 }
802
803 pub fn air_gap_bash_script(&self, base_url: Option<&str>) -> String {
810 fn quote_url(url: &str) -> String {
816 debug_assert!(
817 !url.contains('\''),
818 "model download URL unexpectedly contains a single quote: {url}"
819 );
820 format!("'{url}'")
821 }
822
823 let mut out = String::new();
824 out.push_str("# Air-gap model install (bash / Git Bash / MSYS2)\n");
825 out.push_str(
826 "# Run these commands, then re-run `cass models install --from-file \"$DIR\"`.\n",
827 );
828 out.push_str("set -euo pipefail\n");
829 out.push_str(&format!("DIR=\"${{DIR:-./{}_files}}\"\n", self.id));
830 out.push_str("mkdir -p \"$DIR\"\n");
831 for file in &self.files {
832 let url = self.download_url_with_base(file, base_url);
837 out.push_str(&format!(
838 "curl -fL --retry 3 {} -o \"$DIR/{}\" # {} bytes\n",
839 quote_url(&url),
840 file.local_name(),
841 file.size,
842 ));
843 }
844 out.push_str(&format!(
845 "cass models install {} --from-file \"$DIR\" -y\n",
846 self.id
847 ));
848 out
849 }
850
851 pub fn air_gap_powershell_script(&self, base_url: Option<&str>) -> String {
854 fn quote_url_ps(url: &str) -> String {
856 debug_assert!(
857 !url.contains('\''),
858 "model download URL unexpectedly contains a single quote: {url}"
859 );
860 format!("'{url}'")
861 }
862
863 let mut out = String::new();
864 out.push_str("# Air-gap model install (PowerShell 5.1+ and 7+)\n");
865 out.push_str("$ErrorActionPreference = 'Stop'\n");
866 out.push_str(
869 "[System.Net.ServicePointManager]::SecurityProtocol = \
870 [System.Net.ServicePointManager]::SecurityProtocol -bor \
871 [System.Net.SecurityProtocolType]::Tls12\n",
872 );
873 out.push_str(&format!("$dir = \"{}_files\"\n", self.id));
874 out.push_str("New-Item -ItemType Directory -Force -Path $dir | Out-Null\n");
875 for file in &self.files {
876 let url = self.download_url_with_base(file, base_url);
877 out.push_str(&format!(
880 "Invoke-WebRequest -UseBasicParsing -Uri {} -OutFile (Join-Path $dir '{}') # {} bytes\n",
881 quote_url_ps(&url),
882 file.local_name(),
883 file.size,
884 ));
885 }
886 out.push_str(&format!(
887 "cass models install {} --from-file $dir -y\n",
888 self.id
889 ));
890 out
891 }
892}
893
894pub type ProgressCallback = Arc<dyn Fn(DownloadProgress) + Send + Sync>;
896
897#[derive(Debug, Clone)]
899pub struct DownloadProgress {
900 pub current_file: String,
902 pub file_index: usize,
904 pub total_files: usize,
906 pub file_bytes: u64,
908 pub file_total: u64,
910 pub total_bytes: u64,
912 pub grand_total: u64,
914 pub progress_pct: u8,
916}
917
918#[derive(Debug, Error)]
920pub enum DownloadError {
921 #[error("network error: {0}")]
923 NetworkError(String),
924 #[error("I/O error: {0}")]
926 IoError(#[from] std::io::Error),
927 #[error("verification failed for {file}: expected {expected}, got {actual}")]
929 VerificationFailed {
930 file: String,
931 expected: String,
932 actual: String,
933 },
934 #[error("download cancelled")]
936 Cancelled,
937 #[error("download timed out")]
939 Timeout,
940 #[error("HTTP error {status}: {message}")]
942 HttpError { status: u16, message: String },
943 #[error(
950 "model '{model_id}' is not production-ready: {} file(s) have placeholder checksums{}",
951 unverified_files.len(),
952 if *revision_unpinned {
953 " and revision is not pinned"
954 } else {
955 ""
956 }
957 )]
958 ManifestNotVerified {
959 model_id: String,
960 unverified_files: Vec<String>,
961 revision_unpinned: bool,
962 },
963 #[error("invalid mirror URL '{url}': {reason}")]
965 InvalidMirrorUrl { url: String, reason: String },
966}
967
968impl DownloadError {
969 fn is_retryable(&self) -> bool {
970 match self {
971 DownloadError::NetworkError(_) | DownloadError::IoError(_) | DownloadError::Timeout => {
972 true
973 }
974 DownloadError::HttpError { status, .. } => {
975 *status == 408 || *status == 429 || (500..=599).contains(status)
976 }
977 DownloadError::VerificationFailed { .. }
978 | DownloadError::Cancelled
979 | DownloadError::ManifestNotVerified { .. }
980 | DownloadError::InvalidMirrorUrl { .. } => false,
981 }
982 }
983
984 fn should_discard_temp(&self) -> bool {
985 matches!(self, DownloadError::VerificationFailed { .. })
986 }
987}
988
989const MODEL_HTTP_BODY_SIZE_FLOOR_BYTES: u64 = 500 * 1024 * 1024;
990const MODEL_HTTP_BODY_SIZE_MARGIN_BYTES: u64 = 32 * 1024 * 1024;
991
992fn map_reqwest_download_error(err: reqwest::Error) -> DownloadError {
993 if err.is_timeout() {
994 DownloadError::Timeout
995 } else {
996 DownloadError::NetworkError(err.to_string())
997 }
998}
999
1000fn map_model_response_read_error(err: std::io::Error) -> DownloadError {
1001 if err.kind() == std::io::ErrorKind::TimedOut {
1002 DownloadError::Timeout
1003 } else {
1004 DownloadError::NetworkError(format!("reading model artifact response: {err}"))
1005 }
1006}
1007
1008fn model_http_max_body_size(expected_size: u64) -> usize {
1009 let size_with_margin = expected_size.saturating_add(MODEL_HTTP_BODY_SIZE_MARGIN_BYTES);
1010 let desired = size_with_margin.max(MODEL_HTTP_BODY_SIZE_FLOOR_BYTES);
1011 let usize_max = u64::try_from(usize::MAX).unwrap_or(u64::MAX);
1012 let capped = desired.min(usize_max);
1013 usize::try_from(capped).unwrap_or(usize::MAX)
1014}
1015
1016pub struct ModelDownloader {
1018 target_dir: PathBuf,
1020 temp_dir: PathBuf,
1022 cancelled: Arc<AtomicBool>,
1024 connect_timeout: Duration,
1026 file_timeout: Duration,
1028 max_retries: u32,
1030}
1031
1032impl ModelDownloader {
1033 pub fn new(target_dir: PathBuf) -> Self {
1035 let temp_dir = if let Some(parent) = target_dir.parent() {
1038 let dir_name = target_dir
1039 .file_name()
1040 .and_then(|n| n.to_str())
1041 .unwrap_or("model");
1042 parent.join(format!("{}.downloading", dir_name))
1043 } else {
1044 target_dir.with_extension("downloading")
1046 };
1047 Self {
1048 target_dir,
1049 temp_dir,
1050 cancelled: Arc::new(AtomicBool::new(false)),
1051 connect_timeout: Duration::from_secs(30),
1052 file_timeout: Duration::from_secs(300), max_retries: 3,
1054 }
1055 }
1056
1057 pub fn cancellation_handle(&self) -> Arc<AtomicBool> {
1059 Arc::clone(&self.cancelled)
1060 }
1061
1062 pub fn cancel(&self) {
1064 self.cancelled.store(true, Ordering::SeqCst);
1065 }
1066
1067 pub fn is_cancelled(&self) -> bool {
1069 self.cancelled.load(Ordering::SeqCst)
1070 }
1071
1072 pub fn download(
1089 &self,
1090 manifest: &ModelManifest,
1091 on_progress: Option<ProgressCallback>,
1092 ) -> Result<(), DownloadError> {
1093 self.download_with_mirror(manifest, None, on_progress)
1094 }
1095
1096 pub fn download_with_mirror(
1098 &self,
1099 manifest: &ModelManifest,
1100 mirror_base_url: Option<&str>,
1101 on_progress: Option<ProgressCallback>,
1102 ) -> Result<(), DownloadError> {
1103 if !manifest.is_production_ready() {
1106 let unverified_files: Vec<String> = manifest
1107 .files
1108 .iter()
1109 .filter(|f| f.sha256 == PLACEHOLDER_CHECKSUM)
1110 .map(|f| f.name.clone())
1111 .collect();
1112 return Err(DownloadError::ManifestNotVerified {
1113 model_id: manifest.id.clone(),
1114 unverified_files,
1115 revision_unpinned: !manifest.has_pinned_revision(),
1116 });
1117 }
1118
1119 self.cancelled.store(false, Ordering::SeqCst);
1121
1122 self.prepare_temp_dir(manifest)?;
1125
1126 let grand_total = manifest.total_size();
1127 let total_files = manifest.files.len();
1128 let bytes_downloaded = Arc::new(AtomicU64::new(0));
1129
1130 for (idx, file) in manifest.files.iter().enumerate() {
1131 self.fail_if_cancelled()?;
1132
1133 let file_path = self.temp_dir.join(file.local_name());
1135 let url = manifest.download_url_with_base(file, mirror_base_url);
1136
1137 let bytes_before_file = bytes_downloaded.load(Ordering::SeqCst);
1139
1140 let mut last_error = None;
1142 for attempt in 0..self.max_retries {
1143 self.fail_if_cancelled()?;
1144
1145 if attempt > 0 {
1147 bytes_downloaded.store(bytes_before_file, Ordering::SeqCst);
1148 }
1149
1150 if attempt > 0 {
1152 let delay = Duration::from_secs(5 * (1 << (attempt - 1)));
1153 std::thread::sleep(delay);
1154 }
1155
1156 match self.download_file(
1157 &url,
1158 &file_path,
1159 file.size,
1160 idx,
1161 total_files,
1162 &bytes_downloaded,
1163 grand_total,
1164 on_progress.as_ref(),
1165 ) {
1166 Ok(()) => {
1167 last_error = None;
1168 break;
1169 }
1170 Err(DownloadError::Cancelled) => {
1171 return Err(DownloadError::Cancelled);
1172 }
1173 Err(e) => {
1174 if !e.is_retryable() {
1175 self.cleanup_temp_for_error(&e);
1176 return Err(e);
1177 }
1178 last_error = Some(e);
1179 }
1180 }
1181 }
1182
1183 if let Some(err) = last_error {
1184 self.cleanup_temp_for_error(&err);
1185 return Err(err);
1186 }
1187
1188 self.fail_if_cancelled()?;
1190
1191 let actual_hash = compute_sha256(&file_path)?;
1192 if actual_hash != file.sha256 {
1193 let err = DownloadError::VerificationFailed {
1194 file: file.name.clone(),
1195 expected: file.sha256.clone(),
1196 actual: actual_hash,
1197 };
1198 self.cleanup_temp_for_error(&err);
1199 return Err(err);
1200 }
1201 }
1202
1203 self.atomic_install()?;
1205
1206 self.write_verified_marker(manifest, mirror_base_url)?;
1208
1209 Ok(())
1210 }
1211
1212 fn prepare_temp_dir(&self, manifest: &ModelManifest) -> Result<(), DownloadError> {
1213 ensure_model_download_temp_dir(&self.temp_dir)?;
1214
1215 let expected_files: HashSet<String> = manifest
1216 .files
1217 .iter()
1218 .map(|file| file.local_name().to_string())
1219 .collect();
1220
1221 for entry in fs::read_dir(&self.temp_dir)? {
1222 let entry = entry?;
1223 let entry_type = entry.file_type()?;
1224 let entry_name = entry.file_name();
1225 let keep_entry = entry_type.is_file()
1226 && entry_name
1227 .to_str()
1228 .is_some_and(|name| expected_files.contains(name));
1229
1230 if keep_entry {
1231 continue;
1232 }
1233
1234 let entry_path = entry.path();
1235 if entry_type.is_dir() {
1236 fs::remove_dir_all(entry_path)?;
1237 } else {
1238 fs::remove_file(entry_path)?;
1239 }
1240 }
1241
1242 Ok(())
1243 }
1244
1245 #[allow(clippy::too_many_arguments)]
1247 fn download_file(
1248 &self,
1249 url: &str,
1250 path: &Path,
1251 expected_size: u64,
1252 file_idx: usize,
1253 total_files: usize,
1254 bytes_downloaded: &Arc<AtomicU64>,
1255 grand_total: u64,
1256 on_progress: Option<&ProgressCallback>,
1257 ) -> Result<(), DownloadError> {
1258 let mut existing_size = if path.exists() {
1260 fs::metadata(path).map(|m| m.len()).unwrap_or(0)
1261 } else {
1262 0
1263 };
1264
1265 if existing_size > expected_size {
1267 let _ = fs::remove_file(path);
1268 existing_size = 0;
1269 }
1270
1271 if existing_size == expected_size {
1273 bytes_downloaded.fetch_add(expected_size, Ordering::SeqCst);
1274 return Ok(());
1275 }
1276
1277 let url = url.to_string();
1278 let path = path.to_path_buf();
1279 let bytes_downloaded = Arc::clone(bytes_downloaded);
1280 let cancelled = Arc::clone(&self.cancelled);
1281 let progress_callback = on_progress.cloned();
1282 let max_body_size = model_http_max_body_size(expected_size);
1283
1284 let client = reqwest::blocking::Client::builder()
1285 .user_agent(concat!(
1286 "cass/",
1287 env!("CARGO_PKG_VERSION"),
1288 " (model-download)"
1289 ))
1290 .connect_timeout(self.connect_timeout)
1291 .timeout(self.file_timeout)
1292 .build()
1293 .map_err(map_reqwest_download_error)?;
1294
1295 let mut headers = reqwest::header::HeaderMap::new();
1296 headers.insert(
1297 reqwest::header::ACCEPT,
1298 reqwest::header::HeaderValue::from_static("application/octet-stream"),
1299 );
1300
1301 if existing_size > 0 {
1302 let range_header =
1303 reqwest::header::HeaderValue::from_str(&format!("bytes={existing_size}-"))
1304 .map_err(|err| {
1305 DownloadError::NetworkError(format!(
1306 "building model download Range header: {err}"
1307 ))
1308 })?;
1309 headers.insert(reqwest::header::RANGE, range_header);
1310 bytes_downloaded.fetch_add(existing_size, Ordering::SeqCst);
1311 }
1312
1313 let mut response = client
1314 .get(&url)
1315 .headers(headers)
1316 .send()
1317 .map_err(map_reqwest_download_error)?;
1318 let status = response.status();
1319 if !status.is_success() {
1320 return Err(DownloadError::HttpError {
1321 status: status.as_u16(),
1322 message: status
1323 .canonical_reason()
1324 .map_or_else(|| status.as_u16().to_string(), |reason| reason.to_string()),
1325 });
1326 }
1327
1328 let content_length = response.content_length().unwrap_or(0);
1329 if content_length > max_body_size as u64 {
1330 return Err(DownloadError::NetworkError(format!(
1331 "model artifact response exceeds transport cap: {content_length} > {max_body_size}"
1332 )));
1333 }
1334
1335 let actually_resuming = existing_size > 0 && status.as_u16() == 206;
1337 if existing_size > 0 && status.as_u16() == 200 {
1338 bytes_downloaded.fetch_sub(existing_size, Ordering::SeqCst);
1339 existing_size = 0;
1340 }
1341
1342 let mut file = fs::OpenOptions::new()
1343 .create(true)
1344 .append(actually_resuming)
1345 .write(true)
1346 .truncate(!actually_resuming)
1347 .open(&path)?;
1348
1349 let file_name = path
1350 .file_name()
1351 .and_then(|n| n.to_str())
1352 .unwrap_or("unknown")
1353 .to_string();
1354 let start = Instant::now();
1355 let mut file_bytes = if actually_resuming { existing_size } else { 0 };
1356 let mut buffer = [0_u8; 8192];
1357
1358 loop {
1359 if cancelled.load(Ordering::SeqCst) {
1360 return Err(DownloadError::Cancelled);
1361 }
1362 if start.elapsed() >= self.file_timeout {
1363 return Err(DownloadError::Timeout);
1364 }
1365
1366 let read = response
1367 .read(&mut buffer)
1368 .map_err(map_model_response_read_error)?;
1369 if read == 0 {
1370 break;
1371 }
1372
1373 file.write_all(&buffer[..read])?;
1374 file_bytes = file_bytes.saturating_add(read as u64);
1375 if file_bytes > max_body_size as u64 {
1376 return Err(DownloadError::NetworkError(format!(
1377 "model artifact body exceeds transport cap: {file_bytes} > {max_body_size}"
1378 )));
1379 }
1380 bytes_downloaded.fetch_add(read as u64, Ordering::SeqCst);
1381
1382 if let Some(callback) = progress_callback.as_ref() {
1383 let total_downloaded = bytes_downloaded.load(Ordering::SeqCst);
1384 let progress_pct = if grand_total > 0 {
1385 ((total_downloaded as f64 / grand_total as f64) * 100.0).min(100.0) as u8
1386 } else {
1387 0
1388 };
1389
1390 callback(DownloadProgress {
1391 current_file: file_name.clone(),
1392 file_index: file_idx + 1,
1393 total_files,
1394 file_bytes,
1395 file_total: expected_size,
1396 total_bytes: total_downloaded,
1397 grand_total,
1398 progress_pct,
1399 });
1400 }
1401 }
1402
1403 file.sync_all()?;
1404 Ok(())
1405 }
1406
1407 fn atomic_install(&self) -> Result<(), DownloadError> {
1414 let backup_dir = unique_model_backup_dir(&self.target_dir);
1415 sync_tree(&self.temp_dir)?;
1416
1417 let had_existing = if ensure_replaceable_model_dir(&self.target_dir)? {
1419 fs::rename(&self.target_dir, &backup_dir)?;
1420 true
1421 } else {
1422 false
1423 };
1424
1425 match fs::rename(&self.temp_dir, &self.target_dir) {
1427 Ok(()) => {
1428 sync_parent_directory(&self.target_dir)?;
1429 if had_existing {
1431 let _ = fs::remove_dir_all(&backup_dir);
1432 sync_parent_directory(&self.target_dir)?;
1433 }
1434 }
1435 Err(e) => {
1436 if had_existing && backup_dir.exists() {
1438 match fs::rename(&backup_dir, &self.target_dir) {
1439 Ok(()) => {
1440 sync_parent_directory(&self.target_dir)?;
1441 return Err(std::io::Error::other(format!(
1442 "failed installing {} from {}: {e}; restored original model",
1443 self.target_dir.display(),
1444 self.temp_dir.display()
1445 ))
1446 .into());
1447 }
1448 Err(restore_err) => {
1449 return Err(std::io::Error::other(format!(
1450 "failed installing {} from {}: {e}; restore error: {restore_err}; temp model retained at {}",
1451 self.target_dir.display(),
1452 self.temp_dir.display(),
1453 self.temp_dir.display()
1454 ))
1455 .into());
1456 }
1457 }
1458 }
1459 return Err(e.into());
1460 }
1461 }
1462
1463 Ok(())
1464 }
1465
1466 fn write_verified_marker(
1468 &self,
1469 manifest: &ModelManifest,
1470 mirror_base_url: Option<&str>,
1471 ) -> Result<(), DownloadError> {
1472 let marker_path = self.target_dir.join(".verified");
1473 let source = mirror_base_url
1474 .map(|url| format!("mirror:{url}"))
1475 .unwrap_or_else(|| "registry".to_string());
1476 let content = format!(
1477 "revision={}\nverified_at={}\nsource={}\n",
1478 manifest.revision,
1479 chrono::Utc::now().to_rfc3339(),
1480 source
1481 );
1482 let temp_path = unique_model_sidecar_path(&marker_path, "tmp", ".verified");
1483 let mut file = File::create(&temp_path)?;
1484 file.write_all(content.as_bytes())?;
1485 file.sync_all()?;
1486 replace_file_from_temp(&temp_path, &marker_path)?;
1487 sync_parent_directory(&marker_path)?;
1488 Ok(())
1489 }
1490
1491 fn cleanup_temp(&self) {
1493 if model_dir_is_real_directory(&self.temp_dir).unwrap_or(false) {
1494 let _ = fs::remove_dir_all(&self.temp_dir);
1495 }
1496 }
1497
1498 fn cleanup_temp_for_error(&self, err: &DownloadError) {
1499 if err.should_discard_temp() {
1500 self.cleanup_temp();
1501 }
1502 }
1503
1504 fn fail_if_cancelled(&self) -> Result<(), DownloadError> {
1505 if self.is_cancelled() {
1506 Err(DownloadError::Cancelled)
1507 } else {
1508 Ok(())
1509 }
1510 }
1511}
1512
1513pub fn compute_sha256(path: &Path) -> Result<String, DownloadError> {
1515 let file = File::open(path)?;
1516 let mut reader = BufReader::new(file);
1517 let mut hasher = Sha256::new();
1518
1519 let mut buffer = [0u8; 8192];
1520 loop {
1521 let n = reader.read(&mut buffer)?;
1522 if n == 0 {
1523 break;
1524 }
1525 hasher.update(&buffer[..n]);
1526 }
1527
1528 let hash = hasher.finalize();
1529 Ok(hex::encode(hash))
1530}
1531
1532pub fn classify_model_cache(
1538 model_dir: &Path,
1539 manifest: &ModelManifest,
1540 policy: &ModelAcquisitionPolicy,
1541) -> ModelCacheReport {
1542 classify_model_cache_with_integrity(model_dir, manifest, policy, ModelCacheIntegrity::Full)
1543}
1544
1545pub(crate) fn classify_model_cache_metadata(
1552 model_dir: &Path,
1553 manifest: &ModelManifest,
1554 policy: &ModelAcquisitionPolicy,
1555) -> ModelCacheReport {
1556 classify_model_cache_with_integrity(model_dir, manifest, policy, ModelCacheIntegrity::Metadata)
1557}
1558
1559#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1560enum ModelCacheIntegrity {
1561 Full,
1562 Metadata,
1563}
1564
1565fn classify_model_cache_with_integrity(
1566 model_dir: &Path,
1567 manifest: &ModelManifest,
1568 policy: &ModelAcquisitionPolicy,
1569 integrity: ModelCacheIntegrity,
1570) -> ModelCacheReport {
1571 let required_size_bytes = manifest.total_size();
1572 let installed_size_bytes = installed_manifest_size(model_dir, manifest);
1573 let missing_files = missing_manifest_files(model_dir, manifest);
1574 let state = classify_model_cache_state(model_dir, manifest, policy, &missing_files, integrity);
1575
1576 ModelCacheReport {
1577 model_id: manifest.id.clone(),
1578 model_dir: model_dir.to_path_buf(),
1579 state,
1580 required_size_bytes,
1581 installed_size_bytes,
1582 policy_source: policy.config_source.clone(),
1583 }
1584}
1585
1586fn classify_model_cache_state(
1587 model_dir: &Path,
1588 manifest: &ModelManifest,
1589 policy: &ModelAcquisitionPolicy,
1590 missing_files: &[String],
1591 integrity: ModelCacheIntegrity,
1592) -> ModelCacheState {
1593 if !policy.downloads_enabled {
1594 return ModelCacheState::DisabledByPolicy {
1595 reason: "semantic model downloads disabled by policy".to_string(),
1596 };
1597 }
1598
1599 let quarantine_marker = model_dir.join(".quarantined");
1600 if quarantine_marker.is_file() {
1601 let reason = fs::read_to_string(&quarantine_marker)
1602 .ok()
1603 .map(|s| s.trim().to_string())
1604 .filter(|s| !s.is_empty())
1605 .unwrap_or_else(|| "model cache quarantined after integrity failure".to_string());
1606 return ModelCacheState::QuarantinedCorrupt {
1607 marker_path: quarantine_marker,
1608 reason,
1609 };
1610 }
1611
1612 let staging_dir = model_download_temp_dir(model_dir);
1613 if staging_dir.is_dir() {
1614 return ModelCacheState::Acquiring {
1615 bytes_present: directory_size_bytes(&staging_dir),
1616 staging_dir,
1617 total_bytes: manifest.total_size(),
1618 };
1619 }
1620
1621 if !missing_files.is_empty() {
1622 if policy.offline {
1623 return ModelCacheState::OfflineBlocked {
1624 missing_files: missing_files.to_vec(),
1625 };
1626 }
1627
1628 if policy.metered && !policy.allow_metered {
1629 return ModelCacheState::DisabledByPolicy {
1630 reason: "metered network disallows model acquisition".to_string(),
1631 };
1632 }
1633
1634 if let Some(max_bytes) = policy.max_model_bytes
1635 && manifest.total_size() > max_bytes
1636 {
1637 return ModelCacheState::BudgetBlocked {
1638 required_bytes: manifest.total_size(),
1639 max_bytes,
1640 };
1641 }
1642
1643 return ModelCacheState::NotAcquired {
1644 missing_files: missing_files.to_vec(),
1645 needs_consent: policy.requires_consent,
1646 };
1647 }
1648
1649 if integrity == ModelCacheIntegrity::Full {
1650 for file in &manifest.files {
1651 let Some(path) = model_file_path(model_dir, file) else {
1652 continue;
1653 };
1654 match compute_sha256(&path) {
1655 Ok(actual) if actual == file.sha256 => {}
1656 Ok(actual) => {
1657 return ModelCacheState::ChecksumMismatch {
1658 file: file.local_name().to_string(),
1659 expected: file.sha256.clone(),
1660 actual,
1661 };
1662 }
1663 Err(err) => {
1664 return ModelCacheState::QuarantinedCorrupt {
1665 marker_path: path,
1666 reason: format!("unable to hash model file {}: {err}", file.local_name()),
1667 };
1668 }
1669 }
1670 }
1671 }
1672
1673 let verified_marker = model_dir.join(".verified");
1674 if !verified_marker.is_file() {
1675 return ModelCacheState::PreseededLocal {
1676 model_dir: model_dir.to_path_buf(),
1677 };
1678 }
1679
1680 let marker = match fs::read_to_string(&verified_marker) {
1681 Ok(marker) => marker,
1682 Err(err) => {
1683 return ModelCacheState::QuarantinedCorrupt {
1684 marker_path: verified_marker,
1685 reason: format!("unable to read verified marker: {err}"),
1686 };
1687 }
1688 };
1689
1690 let current_revision =
1691 marker_field(&marker, "revision").unwrap_or_else(|| "<unknown>".to_string());
1692 if current_revision != manifest.revision {
1693 return ModelCacheState::IncompatibleVersion {
1694 current_revision,
1695 expected_revision: manifest.revision.clone(),
1696 };
1697 }
1698
1699 match marker_field(&marker, "source") {
1700 Some(source) if source == "preseeded_local" => ModelCacheState::PreseededLocal {
1701 model_dir: model_dir.to_path_buf(),
1702 },
1703 Some(source) if source.starts_with("mirror:") => ModelCacheState::MirrorSourced {
1704 model_dir: model_dir.to_path_buf(),
1705 mirror_base_url: source.trim_start_matches("mirror:").to_string(),
1706 },
1707 _ => ModelCacheState::Acquired {
1708 model_dir: model_dir.to_path_buf(),
1709 },
1710 }
1711}
1712
1713pub fn check_model_installed(model_dir: &Path, manifest: &ModelManifest) -> ModelState {
1723 if !model_dir.is_dir() {
1724 return ModelState::NotInstalled;
1725 }
1726
1727 let verified_marker = model_dir.join(".verified");
1728 if !verified_marker.is_file() {
1729 return ModelState::NotInstalled;
1730 }
1731
1732 for file in &manifest.files {
1736 if model_file_path(model_dir, file).is_none() {
1737 return ModelState::NotInstalled;
1738 }
1739 }
1740
1741 ModelState::Ready
1742}
1743
1744pub fn check_version_mismatch(model_dir: &Path, manifest: &ModelManifest) -> Option<ModelState> {
1746 let verified_marker = model_dir.join(".verified");
1747 if !verified_marker.is_file() {
1748 return None;
1749 }
1750
1751 let content = fs::read_to_string(&verified_marker).ok()?;
1753 let installed_revision = content
1754 .lines()
1755 .find(|l| l.starts_with("revision="))
1756 .map(|l| l.trim_start_matches("revision=").to_string())?;
1757
1758 if installed_revision != manifest.revision {
1759 Some(ModelState::UpdateAvailable {
1760 current_revision: installed_revision,
1761 latest_revision: manifest.revision.clone(),
1762 })
1763 } else {
1764 None
1765 }
1766}
1767
1768fn ensure_replaceable_model_dir(path: &Path) -> Result<bool, DownloadError> {
1769 match fs::symlink_metadata(path) {
1770 Ok(metadata) => {
1771 ensure_real_model_directory_metadata(
1772 path,
1773 &metadata,
1774 "refusing to install model through symlink",
1775 "refusing to replace model target because it is not a directory",
1776 )?;
1777 Ok(true)
1778 }
1779 Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(false),
1780 Err(err) => Err(std::io::Error::new(
1781 err.kind(),
1782 format!(
1783 "failed inspecting model target before install {}: {err}",
1784 path.display()
1785 ),
1786 )
1787 .into()),
1788 }
1789}
1790
1791fn ensure_model_download_temp_dir(path: &Path) -> Result<(), DownloadError> {
1792 match fs::symlink_metadata(path) {
1793 Ok(metadata) => {
1794 ensure_real_model_directory_metadata(
1795 path,
1796 &metadata,
1797 "refusing to prepare model download temp dir through symlink",
1798 "refusing to prepare model download temp dir because it is not a directory",
1799 )?;
1800 }
1801 Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
1802 fs::create_dir_all(path)?;
1803 let metadata = fs::symlink_metadata(path).map_err(|err| {
1804 std::io::Error::new(
1805 err.kind(),
1806 format!(
1807 "failed inspecting model download temp dir after create {}: {err}",
1808 path.display()
1809 ),
1810 )
1811 })?;
1812 ensure_real_model_directory_metadata(
1813 path,
1814 &metadata,
1815 "refusing to prepare model download temp dir through symlink",
1816 "refusing to prepare model download temp dir because it is not a directory",
1817 )?;
1818 }
1819 Err(err) => {
1820 return Err(std::io::Error::new(
1821 err.kind(),
1822 format!(
1823 "failed inspecting model download temp dir before prepare {}: {err}",
1824 path.display()
1825 ),
1826 )
1827 .into());
1828 }
1829 }
1830 Ok(())
1831}
1832
1833fn model_dir_is_real_directory(path: &Path) -> Result<bool, DownloadError> {
1834 match fs::symlink_metadata(path) {
1835 Ok(metadata) => {
1836 let file_type = metadata.file_type();
1837 Ok(file_type.is_dir() && !file_type.is_symlink())
1838 }
1839 Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(false),
1840 Err(err) => Err(err.into()),
1841 }
1842}
1843
1844fn ensure_real_model_directory_metadata(
1845 path: &Path,
1846 metadata: &fs::Metadata,
1847 symlink_message: &str,
1848 non_dir_message: &str,
1849) -> Result<(), DownloadError> {
1850 let file_type = metadata.file_type();
1851 if file_type.is_symlink() {
1852 return Err(std::io::Error::other(format!("{symlink_message}: {}", path.display())).into());
1853 }
1854 if !file_type.is_dir() {
1855 return Err(std::io::Error::other(format!("{non_dir_message}: {}", path.display())).into());
1856 }
1857 Ok(())
1858}
1859
1860fn model_download_temp_dir(target_dir: &Path) -> PathBuf {
1861 if let Some(parent) = target_dir.parent() {
1862 let dir_name = target_dir
1863 .file_name()
1864 .and_then(|name| name.to_str())
1865 .unwrap_or("model");
1866 parent.join(format!("{dir_name}.downloading"))
1867 } else {
1868 target_dir.with_extension("downloading")
1869 }
1870}
1871
1872pub fn model_file_path(model_dir: &Path, file: &ModelFile) -> Option<PathBuf> {
1877 let canonical = model_dir.join(&file.name);
1878 if canonical.is_file() {
1879 return Some(canonical);
1880 }
1881
1882 let local = model_dir.join(file.local_name());
1883 if local.is_file() {
1884 return Some(local);
1885 }
1886
1887 None
1888}
1889
1890fn missing_manifest_files(model_dir: &Path, manifest: &ModelManifest) -> Vec<String> {
1891 manifest
1892 .files
1893 .iter()
1894 .filter(|file| model_file_path(model_dir, file).is_none())
1895 .map(|file| file.local_name().to_string())
1896 .collect()
1897}
1898
1899fn installed_manifest_size(model_dir: &Path, manifest: &ModelManifest) -> u64 {
1900 manifest
1901 .files
1902 .iter()
1903 .filter_map(|file| model_file_path(model_dir, file))
1904 .filter_map(|path| path.metadata().ok())
1905 .map(|metadata| metadata.len())
1906 .sum()
1907}
1908
1909fn directory_size_bytes(path: &Path) -> u64 {
1910 let Ok(entries) = fs::read_dir(path) else {
1911 return 0;
1912 };
1913
1914 entries
1915 .filter_map(Result::ok)
1916 .map(|entry| {
1917 let path = entry.path();
1918 match entry.file_type() {
1919 Ok(file_type) if file_type.is_file() => {
1920 entry.metadata().map(|metadata| metadata.len()).unwrap_or(0)
1921 }
1922 Ok(file_type) if file_type.is_dir() => directory_size_bytes(&path),
1923 _ => 0,
1924 }
1925 })
1926 .sum()
1927}
1928
1929fn marker_field(content: &str, field: &str) -> Option<String> {
1930 let prefix = format!("{field}=");
1931 content
1932 .lines()
1933 .find_map(|line| line.strip_prefix(&prefix))
1934 .map(|value| value.trim().to_string())
1935 .filter(|value| !value.is_empty())
1936}
1937
1938fn unique_model_backup_dir(path: &Path) -> PathBuf {
1939 unique_model_sidecar_path(path, "bak", "model")
1940}
1941
1942fn unique_model_sidecar_path(path: &Path, suffix: &str, fallback_name: &str) -> PathBuf {
1943 static NEXT_NONCE: AtomicU64 = AtomicU64::new(0);
1944
1945 let timestamp = std::time::SystemTime::now()
1946 .duration_since(std::time::UNIX_EPOCH)
1947 .unwrap_or_default()
1948 .as_nanos();
1949 let nonce = NEXT_NONCE.fetch_add(1, Ordering::Relaxed);
1950 let file_name = path
1951 .file_name()
1952 .and_then(|name| name.to_str())
1953 .unwrap_or(fallback_name);
1954
1955 path.with_file_name(format!(
1956 ".{file_name}.{suffix}.{}.{}.{}",
1957 std::process::id(),
1958 timestamp,
1959 nonce
1960 ))
1961}
1962
1963fn replace_file_from_temp(temp_path: &Path, final_path: &Path) -> Result<(), DownloadError> {
1964 #[cfg(windows)]
1965 {
1966 match fs::rename(temp_path, final_path) {
1967 Ok(()) => sync_parent_directory(final_path),
1968 Err(first_err)
1969 if final_path.exists()
1970 && matches!(
1971 first_err.kind(),
1972 std::io::ErrorKind::AlreadyExists | std::io::ErrorKind::PermissionDenied
1973 ) =>
1974 {
1975 let backup_path = unique_model_backup_dir(final_path);
1976 fs::rename(final_path, &backup_path).map_err(|backup_err| {
1977 let _ = fs::remove_file(temp_path);
1978 DownloadError::IoError(std::io::Error::other(format!(
1979 "failed preparing backup {} before replacing {}: first error: {first_err}; backup error: {backup_err}",
1980 backup_path.display(),
1981 final_path.display()
1982 )))
1983 })?;
1984 match fs::rename(temp_path, final_path) {
1985 Ok(()) => {
1986 let _ = fs::remove_file(&backup_path);
1987 sync_parent_directory(final_path)
1988 }
1989 Err(second_err) => match fs::rename(&backup_path, final_path) {
1990 Ok(()) => {
1991 let _ = fs::remove_file(temp_path);
1992 sync_parent_directory(final_path)?;
1993 Err(std::io::Error::other(format!(
1994 "failed replacing {} with {}: first error: {first_err}; second error: {second_err}; restored original file",
1995 final_path.display(),
1996 temp_path.display()
1997 ))
1998 .into())
1999 }
2000 Err(restore_err) => Err(std::io::Error::other(format!(
2001 "failed replacing {} with {}: first error: {first_err}; second error: {second_err}; restore error: {restore_err}; temp file retained at {}",
2002 final_path.display(),
2003 temp_path.display(),
2004 temp_path.display()
2005 ))
2006 .into()),
2007 },
2008 }
2009 }
2010 Err(rename_err) => Err(rename_err.into()),
2011 }
2012 }
2013
2014 #[cfg(not(windows))]
2015 {
2016 fs::rename(temp_path, final_path)?;
2017 sync_parent_directory(final_path)
2018 }
2019}
2020
2021#[cfg(not(windows))]
2022fn sync_tree(path: &Path) -> Result<(), DownloadError> {
2023 sync_tree_inner(path)?;
2024 sync_parent_directory(path)
2025}
2026
2027#[cfg(not(windows))]
2028fn sync_tree_inner(path: &Path) -> Result<(), DownloadError> {
2029 let metadata = fs::metadata(path)?;
2030 if metadata.is_dir() {
2031 for entry in fs::read_dir(path)? {
2032 let entry = entry?;
2033 sync_tree_inner(&entry.path())?;
2034 }
2035 File::open(path)?.sync_all()?;
2036 } else if metadata.is_file() {
2037 File::open(path)?.sync_all()?;
2038 }
2039 Ok(())
2040}
2041
2042#[cfg(windows)]
2043fn sync_tree(_path: &Path) -> Result<(), DownloadError> {
2044 Ok(())
2045}
2046
2047#[cfg(not(windows))]
2048fn sync_parent_directory(path: &Path) -> Result<(), DownloadError> {
2049 let Some(parent) = path.parent() else {
2050 return Ok(());
2051 };
2052 File::open(parent)?.sync_all()?;
2053 Ok(())
2054}
2055
2056#[cfg(windows)]
2057fn sync_parent_directory(_path: &Path) -> Result<(), DownloadError> {
2058 Ok(())
2059}
2060
2061#[cfg(test)]
2062mod tests {
2063 use super::*;
2064 use std::collections::BTreeMap;
2065 use std::error::Error as _;
2066 use std::io::{Read, Write};
2067 use std::net::{Shutdown, TcpListener, TcpStream};
2068 use std::sync::atomic::{AtomicBool, Ordering};
2069 use std::sync::{Arc, Mutex};
2070 use std::thread;
2071 use std::time::Duration;
2072
2073 fn copy_model_fixtures(target_dir: &Path) -> std::io::Result<()> {
2076 let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/models");
2077 fs::create_dir_all(target_dir)?;
2078
2079 fs::copy(
2081 fixture_dir.join("model.onnx"),
2082 target_dir.join("model.onnx"),
2083 )?;
2084
2085 for file in &[
2087 "tokenizer.json",
2088 "config.json",
2089 "special_tokens_map.json",
2090 "tokenizer_config.json",
2091 ] {
2092 fs::copy(fixture_dir.join(file), target_dir.join(file))?;
2093 }
2094
2095 Ok(())
2096 }
2097
2098 #[test]
2099 fn model_http_body_size_keeps_floor_for_small_models() {
2100 assert_eq!(
2101 model_http_max_body_size(86 * 1024 * 1024),
2102 MODEL_HTTP_BODY_SIZE_FLOOR_BYTES as usize
2103 );
2104 }
2105
2106 #[test]
2107 fn model_http_body_size_covers_nomic_onnx_manifest() {
2108 let manifest = ModelManifest::nomic_embed();
2109 let onnx_file = manifest
2110 .files
2111 .iter()
2112 .find(|file| file.name == "onnx/model.onnx")
2113 .expect("nomic manifest should include onnx/model.onnx");
2114 let max_body_size = model_http_max_body_size(onnx_file.size);
2115 let max_body_size_u64 = u64::try_from(max_body_size).unwrap_or(u64::MAX);
2116 assert!(
2117 max_body_size_u64 >= onnx_file.size + MODEL_HTTP_BODY_SIZE_MARGIN_BYTES,
2118 "transport cap {max_body_size} must cover nomic ONNX size {} plus margin",
2119 onnx_file.size
2120 );
2121 }
2122
2123 #[derive(Clone, Debug)]
2124 struct MirrorRequest {
2125 path: String,
2126 range_start: Option<u64>,
2127 }
2128
2129 #[derive(Clone)]
2130 struct MirrorRoute {
2131 body: Vec<u8>,
2132 content_type: &'static str,
2133 chunk_size: usize,
2134 chunk_delay: Duration,
2135 }
2136
2137 struct MirrorFixtureServer {
2138 base_url: String,
2139 stop: Arc<AtomicBool>,
2140 wake_addr: String,
2141 requests: Arc<Mutex<Vec<MirrorRequest>>>,
2142 handle: Option<std::thread::JoinHandle<()>>,
2143 }
2144
2145 impl MirrorFixtureServer {
2146 fn requests(&self) -> Vec<MirrorRequest> {
2147 self.requests.lock().unwrap().clone()
2148 }
2149 }
2150
2151 impl Drop for MirrorFixtureServer {
2152 fn drop(&mut self) {
2153 self.stop.store(true, Ordering::SeqCst);
2154 if let Ok(stream) = TcpStream::connect(&self.wake_addr) {
2155 let _ = stream.shutdown(Shutdown::Both);
2156 }
2157 if let Some(handle) = self.handle.take() {
2158 let _ = handle.join();
2159 }
2160 }
2161 }
2162
2163 fn start_mirror_fixture_server(routes: Vec<(String, MirrorRoute)>) -> MirrorFixtureServer {
2164 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test mirror server");
2165 listener
2166 .set_nonblocking(true)
2167 .expect("set test mirror server nonblocking");
2168 let addr = listener.local_addr().expect("read server address");
2169 let wake_addr = addr.to_string();
2170 let base_url = format!("http://{wake_addr}");
2171 let stop = Arc::new(AtomicBool::new(false));
2172 let stop_flag = Arc::clone(&stop);
2173 let requests = Arc::new(Mutex::new(Vec::new()));
2174 let request_log = Arc::clone(&requests);
2175 let route_map: BTreeMap<String, MirrorRoute> = routes.into_iter().collect();
2176 let handle = thread::spawn(move || {
2177 while !stop_flag.load(Ordering::SeqCst) {
2178 match listener.accept() {
2179 Ok((stream, _)) => {
2180 handle_mirror_request(stream, &route_map, &request_log);
2181 }
2182 Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {
2183 thread::sleep(Duration::from_millis(10));
2184 }
2185 Err(_) => break,
2186 }
2187 }
2188 });
2189 MirrorFixtureServer {
2190 base_url,
2191 stop,
2192 wake_addr,
2193 requests,
2194 handle: Some(handle),
2195 }
2196 }
2197
2198 fn handle_mirror_request(
2199 mut stream: TcpStream,
2200 routes: &BTreeMap<String, MirrorRoute>,
2201 request_log: &Arc<Mutex<Vec<MirrorRequest>>>,
2202 ) {
2203 let mut buffer = [0_u8; 8192];
2204 let read = match stream.read(&mut buffer) {
2205 Ok(read) => read,
2206 Err(_) => return,
2207 };
2208 let request = String::from_utf8_lossy(&buffer[..read]);
2209 let mut lines = request.lines();
2210 let target = lines
2211 .next()
2212 .and_then(|line| line.split_whitespace().nth(1))
2213 .unwrap_or("/");
2214 let path = target
2215 .split_once('?')
2216 .map(|(path, _)| path)
2217 .unwrap_or(target)
2218 .split_once('#')
2219 .map(|(path, _)| path)
2220 .unwrap_or(target)
2221 .to_string();
2222 let range_start = lines.find_map(parse_range_start_header);
2223 request_log.lock().unwrap().push(MirrorRequest {
2224 path: path.clone(),
2225 range_start,
2226 });
2227
2228 let Some(route) = routes.get(&path) else {
2229 let response = concat!(
2230 "HTTP/1.1 404 Not Found\r\n",
2231 "Content-Length: 9\r\n",
2232 "Content-Type: text/plain\r\n",
2233 "Connection: close\r\n\r\n",
2234 "not found"
2235 );
2236 let _ = stream.write_all(response.as_bytes());
2237 let _ = stream.flush();
2238 return;
2239 };
2240
2241 let start = range_start.unwrap_or(0) as usize;
2242 let mut status = "200 OK";
2243 let mut content_range = None;
2244 let body = if start >= route.body.len() {
2245 status = "416 Range Not Satisfiable";
2246 &[][..]
2247 } else if start > 0 {
2248 status = "206 Partial Content";
2249 content_range = Some(format!(
2250 "bytes {start}-{}/{}",
2251 route.body.len().saturating_sub(1),
2252 route.body.len()
2253 ));
2254 &route.body[start..]
2255 } else {
2256 route.body.as_slice()
2257 };
2258
2259 let mut response = format!(
2260 "HTTP/1.1 {status}\r\nContent-Length: {}\r\nContent-Type: {}\r\nConnection: close\r\n",
2261 body.len(),
2262 route.content_type
2263 );
2264 if let Some(content_range) = content_range {
2265 response.push_str(&format!("Content-Range: {content_range}\r\n"));
2266 }
2267 response.push_str("\r\n");
2268 let _ = stream.write_all(response.as_bytes());
2269 for chunk in body.chunks(route.chunk_size.max(1)) {
2270 if stream.write_all(chunk).is_err() {
2271 return;
2272 }
2273 let _ = stream.flush();
2274 if !route.chunk_delay.is_zero() {
2275 thread::sleep(route.chunk_delay);
2276 }
2277 }
2278 }
2279
2280 fn parse_range_start_header(line: &str) -> Option<u64> {
2281 let (name, value) = line.split_once(':')?;
2282 if !name.eq_ignore_ascii_case("range") {
2283 return None;
2284 }
2285 let value = value.trim();
2286 let value = value.strip_prefix("bytes=")?;
2287 let (start, _) = value.split_once('-')?;
2288 start.parse().ok()
2289 }
2290
2291 fn build_test_manifest(repo: &str, revision: &str, files: &[(&str, &[u8])]) -> ModelManifest {
2292 ModelManifest {
2293 id: "mirror-test-model".into(),
2294 repo: repo.into(),
2295 revision: revision.into(),
2296 files: files
2297 .iter()
2298 .map(|(name, body)| ModelFile {
2299 name: (*name).into(),
2300 sha256: hex::encode(Sha256::digest(body)),
2301 size: body.len() as u64,
2302 })
2303 .collect(),
2304 license: "Apache-2.0".into(),
2305 }
2306 }
2307
2308 fn mirror_route_path(prefix: &str, manifest: &ModelManifest, file: &ModelFile) -> String {
2309 format!(
2310 "{}/{}/resolve/{}/{}",
2311 prefix.trim_end_matches('/'),
2312 manifest.repo.trim_start_matches('/'),
2313 manifest.revision,
2314 file.name.trim_start_matches('/')
2315 )
2316 }
2317
2318 #[test]
2319 fn test_model_state_summary() {
2320 assert_eq!(ModelState::NotInstalled.summary(), "not installed");
2321 assert_eq!(ModelState::NeedsConsent.summary(), "needs consent");
2322 assert_eq!(ModelState::Ready.summary(), "ready");
2323 assert_eq!(
2324 ModelState::Downloading {
2325 progress_pct: 50,
2326 bytes_downloaded: 1000,
2327 total_bytes: 2000
2328 }
2329 .summary(),
2330 "downloading (50%)"
2331 );
2332 }
2333
2334 #[test]
2335 fn test_model_state_is_ready() {
2336 assert!(ModelState::Ready.is_ready());
2337 assert!(!ModelState::NotInstalled.is_ready());
2338 assert!(!ModelState::NeedsConsent.is_ready());
2339 assert!(
2340 !ModelState::Downloading {
2341 progress_pct: 0,
2342 bytes_downloaded: 0,
2343 total_bytes: 0
2344 }
2345 .is_ready()
2346 );
2347 }
2348
2349 #[test]
2350 fn test_model_manifest_total_size() {
2351 let manifest = ModelManifest::minilm_v2();
2352 assert!(manifest.total_size() > 20_000_000); }
2354
2355 #[test]
2356 fn test_model_manifest_download_url() {
2357 let manifest = ModelManifest::minilm_v2();
2358 let url = manifest.download_url(&manifest.files[0]);
2359 assert!(url.contains("huggingface.co"));
2360 assert!(url.contains("sentence-transformers/all-MiniLM-L6-v2"));
2361 assert!(url.contains("model.onnx"));
2362 }
2363
2364 #[test]
2365 fn test_model_manifest_download_url_with_mirror_base() {
2366 let manifest = ModelManifest::minilm_v2();
2367 let url = manifest
2368 .download_url_with_base(&manifest.files[0], Some("https://mirror.example/cache/"));
2369 assert_eq!(
2370 url,
2371 format!(
2372 "https://mirror.example/cache/{}/resolve/{}/{}",
2373 manifest.repo, manifest.revision, manifest.files[0].name
2374 )
2375 );
2376 }
2377
2378 #[test]
2379 fn air_gap_bash_script_uses_explicit_output_filenames() {
2380 let manifest = ModelManifest::minilm_v2();
2386 let script = manifest.air_gap_bash_script(None);
2387 assert!(script.contains("set -euo pipefail"));
2388 assert!(script.contains("DIR=\"${DIR:-./all-minilm-l6-v2_files}\""));
2389 for file in &manifest.files {
2390 let local = file.local_name();
2391 assert!(
2392 script.contains(&format!("-o \"$DIR/{local}\"")),
2393 "bash script must write {local} via explicit -o, got:\n{script}"
2394 );
2395 }
2396 assert!(
2397 script.contains("cass models install all-minilm-l6-v2 --from-file \"$DIR\" -y"),
2398 "bash script must invoke install with --from-file"
2399 );
2400 }
2401
2402 #[test]
2403 fn air_gap_bash_script_quotes_urls_with_single_quotes() {
2404 let manifest = ModelManifest::minilm_v2();
2406 let script = manifest.air_gap_bash_script(None);
2407 let sample_url = manifest.download_url(&manifest.files[0]);
2408 assert!(script.contains(&format!("'{sample_url}'")));
2409 }
2410
2411 #[test]
2412 fn air_gap_powershell_script_forces_tls12_and_basic_parsing() {
2413 let manifest = ModelManifest::minilm_v2();
2414 let script = manifest.air_gap_powershell_script(None);
2415 assert!(
2416 script.contains("SecurityProtocolType]::Tls12"),
2417 "PowerShell script must opt into TLS 1.2 for Windows PowerShell 5.1 compat"
2418 );
2419 assert!(
2420 script.contains("Invoke-WebRequest -UseBasicParsing"),
2421 "PowerShell script must use -UseBasicParsing for PS 5.1 compat"
2422 );
2423 for file in &manifest.files {
2424 let local = file.local_name();
2425 assert!(
2426 script.contains(&format!("(Join-Path $dir '{local}')")),
2427 "PowerShell script must materialize output path for {local}, got:\n{script}"
2428 );
2429 }
2430 assert!(
2431 script.contains("cass models install all-minilm-l6-v2 --from-file $dir -y"),
2432 "PowerShell script must invoke install with --from-file"
2433 );
2434 }
2435
2436 #[test]
2437 fn air_gap_scripts_honor_mirror_base_url() {
2438 let manifest = ModelManifest::minilm_v2();
2439 let mirror = Some("https://mirror.example/cache");
2440 let bash = manifest.air_gap_bash_script(mirror);
2441 let ps = manifest.air_gap_powershell_script(mirror);
2442 assert!(bash.contains("https://mirror.example/cache"));
2443 assert!(!bash.contains("huggingface.co"));
2444 assert!(ps.contains("https://mirror.example/cache"));
2445 assert!(!ps.contains("huggingface.co"));
2446 }
2447
2448 #[test]
2449 fn test_normalize_mirror_base_url_trims_trailing_slash() {
2450 let normalized = normalize_mirror_base_url("https://mirror.example/cache/").unwrap();
2451 assert_eq!(normalized, "https://mirror.example/cache");
2452 }
2453
2454 #[test]
2455 fn test_normalize_mirror_base_url_rejects_invalid_values() {
2456 let cases = [
2457 ("mirror.example", "invalid mirror URL"),
2458 ("file:///tmp/mirror", "unsupported URL scheme"),
2459 (
2460 "https://mirror.example/cache?trace=abc",
2461 "must not include query or fragment",
2462 ),
2463 ];
2464
2465 for (input, expected_fragment) in cases {
2466 let err = normalize_mirror_base_url(input).unwrap_err();
2467 let message = err.to_string();
2468 assert!(
2469 message.contains(expected_fragment),
2470 "expected error for {input:?} to contain {expected_fragment:?}, got {message:?}"
2471 );
2472 }
2473 }
2474
2475 #[test]
2476 fn test_invalid_mirror_url_helper_shape() {
2477 let err = invalid_mirror_url("ftp://mirror.example/model.onnx", "unsupported scheme");
2478
2479 assert!(matches!(
2480 &err,
2481 DownloadError::InvalidMirrorUrl { url, reason }
2482 if url == "ftp://mirror.example/model.onnx" && reason == "unsupported scheme"
2483 ));
2484 assert_eq!(
2485 err.to_string(),
2486 "invalid mirror URL 'ftp://mirror.example/model.onnx': unsupported scheme"
2487 );
2488 assert!(!err.is_retryable());
2489 }
2490
2491 #[test]
2492 fn test_check_model_installed_missing() {
2493 let tmp = tempfile::tempdir().unwrap();
2494 let model_dir = tmp.path().join("nonexistent");
2495 assert_eq!(
2496 check_model_installed(&model_dir, &ModelManifest::minilm_v2()),
2497 ModelState::NotInstalled
2498 );
2499 }
2500
2501 #[test]
2502 fn test_check_model_installed_no_marker() {
2503 let tmp = tempfile::tempdir().unwrap();
2504 let model_dir = tmp.path().join("model");
2505 let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/models");
2507 fs::create_dir_all(&model_dir).unwrap();
2508 fs::copy(fixture_dir.join("model.onnx"), model_dir.join("model.onnx")).unwrap();
2509 assert_eq!(
2510 check_model_installed(&model_dir, &ModelManifest::minilm_v2()),
2511 ModelState::NotInstalled
2512 );
2513 }
2514
2515 #[test]
2516 fn test_check_model_installed_ready() {
2517 let tmp = tempfile::tempdir().unwrap();
2518 let model_dir = tmp.path().join("model");
2519 copy_model_fixtures(&model_dir).unwrap();
2521 fs::write(model_dir.join(".verified"), "revision=test\n").unwrap();
2522 assert_eq!(
2523 check_model_installed(&model_dir, &ModelManifest::minilm_v2()),
2524 ModelState::Ready
2525 );
2526 }
2527
2528 #[test]
2529 fn classify_cache_policy_disabled_takes_precedence_over_missing() {
2530 let tmp = tempfile::tempdir().unwrap();
2531 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2532 let policy = ModelAcquisitionPolicy {
2533 downloads_enabled: false,
2534 offline: true,
2535 max_model_bytes: Some(1),
2536 ..ModelAcquisitionPolicy::default()
2537 };
2538
2539 let report = classify_model_cache(tmp.path(), &manifest, &policy);
2540 assert_eq!(report.state_code(), "disabled_by_policy");
2541 assert!(matches!(
2542 report.state,
2543 ModelCacheState::DisabledByPolicy { .. }
2544 ));
2545 }
2546
2547 #[test]
2548 fn classify_cache_detects_resume_stage_before_missing() {
2549 let tmp = tempfile::tempdir().unwrap();
2550 let model_dir = tmp.path().join("model");
2551 let staging_dir = tmp.path().join("model.downloading");
2552 fs::create_dir_all(&staging_dir).unwrap();
2553 fs::write(staging_dir.join("model.onnx"), b"partial").unwrap();
2554 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2555
2556 let report =
2557 classify_model_cache(&model_dir, &manifest, &ModelAcquisitionPolicy::default());
2558 assert_eq!(report.state_code(), "acquiring");
2559 assert!(matches!(
2560 report.state,
2561 ModelCacheState::Acquiring {
2562 bytes_present: 7,
2563 total_bytes: 5,
2564 ..
2565 }
2566 ));
2567 }
2568
2569 #[test]
2570 fn classify_cache_distinguishes_offline_and_budget_blocks() {
2571 let tmp = tempfile::tempdir().unwrap();
2572 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2573
2574 let offline = ModelAcquisitionPolicy {
2575 offline: true,
2576 ..ModelAcquisitionPolicy::default()
2577 };
2578 let report = classify_model_cache(tmp.path(), &manifest, &offline);
2579 assert_eq!(report.state_code(), "offline_blocked");
2580
2581 let budget = ModelAcquisitionPolicy {
2582 max_model_bytes: Some(1),
2583 ..ModelAcquisitionPolicy::default()
2584 };
2585 let report = classify_model_cache(tmp.path(), &manifest, &budget);
2586 assert_eq!(report.state_code(), "budget_blocked");
2587 }
2588
2589 #[test]
2590 fn classify_cache_accepts_preseeded_local_manifest_files() {
2591 let tmp = tempfile::tempdir().unwrap();
2592 let model_dir = tmp.path().join("model");
2593 fs::create_dir_all(model_dir.join("onnx")).unwrap();
2594 fs::write(model_dir.join("onnx/model.onnx"), b"model").unwrap();
2595 fs::write(model_dir.join("tokenizer.json"), b"tok").unwrap();
2596 let manifest = build_test_manifest(
2597 "repo/model",
2598 "rev1",
2599 &[("onnx/model.onnx", b"model"), ("tokenizer.json", b"tok")],
2600 );
2601
2602 let report =
2603 classify_model_cache(&model_dir, &manifest, &ModelAcquisitionPolicy::default());
2604 assert_eq!(report.state_code(), "preseeded_local");
2605 assert!(report.is_usable());
2606 }
2607
2608 #[test]
2609 fn classify_cache_detects_checksum_mismatch() {
2610 let tmp = tempfile::tempdir().unwrap();
2611 let model_dir = tmp.path().join("model");
2612 fs::create_dir_all(&model_dir).unwrap();
2613 fs::write(model_dir.join("model.onnx"), b"wrong").unwrap();
2614 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2615
2616 let report =
2617 classify_model_cache(&model_dir, &manifest, &ModelAcquisitionPolicy::default());
2618 assert_eq!(report.state_code(), "checksum_mismatch");
2619 assert!(matches!(
2620 report.state,
2621 ModelCacheState::ChecksumMismatch { .. }
2622 ));
2623 }
2624
2625 #[test]
2626 fn classify_cache_metadata_trusts_verified_marker_without_hashing_payload() {
2627 let tmp = tempfile::tempdir().unwrap();
2628 let model_dir = tmp.path().join("model");
2629 fs::create_dir_all(&model_dir).unwrap();
2630 fs::write(model_dir.join("model.onnx"), b"m0del").unwrap();
2631 fs::write(
2632 model_dir.join(".verified"),
2633 "revision=rev1\nsource=registry\n",
2634 )
2635 .unwrap();
2636 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2637
2638 let metadata_report = classify_model_cache_metadata(
2639 &model_dir,
2640 &manifest,
2641 &ModelAcquisitionPolicy::default(),
2642 );
2643 assert_eq!(metadata_report.state_code(), "acquired");
2644 assert!(metadata_report.is_usable());
2645
2646 let full_report =
2647 classify_model_cache(&model_dir, &manifest, &ModelAcquisitionPolicy::default());
2648 assert_eq!(full_report.state_code(), "checksum_mismatch");
2649 }
2650
2651 #[test]
2652 fn classify_cache_detects_incompatible_revision() {
2653 let tmp = tempfile::tempdir().unwrap();
2654 let model_dir = tmp.path().join("model");
2655 fs::create_dir_all(&model_dir).unwrap();
2656 fs::write(model_dir.join("model.onnx"), b"model").unwrap();
2657 fs::write(model_dir.join(".verified"), "revision=old\n").unwrap();
2658 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2659
2660 let report =
2661 classify_model_cache(&model_dir, &manifest, &ModelAcquisitionPolicy::default());
2662 assert_eq!(report.state_code(), "incompatible_version");
2663 assert!(matches!(
2664 report.state,
2665 ModelCacheState::IncompatibleVersion {
2666 current_revision,
2667 expected_revision
2668 } if current_revision == "old" && expected_revision == "rev1"
2669 ));
2670 }
2671
2672 #[test]
2673 fn classify_cache_reports_mirror_sourced_marker() {
2674 let tmp = tempfile::tempdir().unwrap();
2675 let model_dir = tmp.path().join("model");
2676 fs::create_dir_all(&model_dir).unwrap();
2677 fs::write(model_dir.join("model.onnx"), b"model").unwrap();
2678 fs::write(
2679 model_dir.join(".verified"),
2680 "revision=rev1\nsource=mirror:https://mirror.example/cache\n",
2681 )
2682 .unwrap();
2683 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2684
2685 let report =
2686 classify_model_cache(&model_dir, &manifest, &ModelAcquisitionPolicy::default());
2687 assert_eq!(report.state_code(), "mirror_sourced");
2688 assert!(matches!(
2689 report.state,
2690 ModelCacheState::MirrorSourced {
2691 mirror_base_url,
2692 ..
2693 } if mirror_base_url == "https://mirror.example/cache"
2694 ));
2695 }
2696
2697 #[test]
2698 fn classify_cache_reports_quarantine_marker() {
2699 let tmp = tempfile::tempdir().unwrap();
2700 let model_dir = tmp.path().join("model");
2701 fs::create_dir_all(&model_dir).unwrap();
2702 fs::write(model_dir.join(".quarantined"), "bad checksum\n").unwrap();
2703 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2704
2705 let report =
2706 classify_model_cache(&model_dir, &manifest, &ModelAcquisitionPolicy::default());
2707 assert_eq!(report.state_code(), "quarantined_corrupt");
2708 assert!(matches!(
2709 report.state,
2710 ModelCacheState::QuarantinedCorrupt { reason, .. } if reason == "bad checksum"
2711 ));
2712 }
2713
2714 #[test]
2715 fn test_compute_sha256() {
2716 let tmp = tempfile::tempdir().unwrap();
2717 let file_path = tmp.path().join("test.txt");
2718 fs::write(&file_path, b"hello world").unwrap();
2719 let hash = compute_sha256(&file_path).unwrap();
2720 assert_eq!(
2722 hash,
2723 "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"
2724 );
2725 }
2726
2727 #[test]
2728 fn test_check_version_mismatch_none() {
2729 let tmp = tempfile::tempdir().unwrap();
2730 let model_dir = tmp.path().join("model");
2731 fs::create_dir_all(&model_dir).unwrap();
2732 let manifest = ModelManifest::minilm_v2();
2734 fs::write(
2735 model_dir.join(".verified"),
2736 format!("revision={}\n", manifest.revision),
2737 )
2738 .unwrap();
2739
2740 let result = check_version_mismatch(&model_dir, &manifest);
2741 assert!(result.is_none());
2742 }
2743
2744 #[test]
2745 fn test_model_file_local_name() {
2746 let file = ModelFile {
2748 name: "onnx/model.onnx".into(),
2749 sha256: "abc123".into(),
2750 size: 1000,
2751 };
2752 assert_eq!(file.local_name(), "model.onnx");
2753
2754 let file2 = ModelFile {
2756 name: "tokenizer.json".into(),
2757 sha256: "def456".into(),
2758 size: 500,
2759 };
2760 assert_eq!(file2.local_name(), "tokenizer.json");
2761
2762 let file3 = ModelFile {
2764 name: "path/to/deep/model.bin".into(),
2765 sha256: "ghi789".into(),
2766 size: 2000,
2767 };
2768 assert_eq!(file3.local_name(), "model.bin");
2769 }
2770
2771 #[test]
2772 fn test_check_version_mismatch_found() {
2773 let tmp = tempfile::tempdir().unwrap();
2774 let model_dir = tmp.path().join("model");
2775 fs::create_dir_all(&model_dir).unwrap();
2776 fs::write(model_dir.join(".verified"), "revision=old_version\n").unwrap();
2777
2778 let manifest = ModelManifest::minilm_v2();
2779 let result = check_version_mismatch(&model_dir, &manifest);
2780 assert!(matches!(result, Some(ModelState::UpdateAvailable { .. })));
2781 }
2782
2783 #[test]
2784 fn test_atomic_install_preserves_preexisting_legacy_backup_dir() {
2785 let tmp = tempfile::tempdir().unwrap();
2786 let target_dir = tmp.path().join("model");
2787 copy_model_fixtures(&target_dir).unwrap();
2788 fs::write(target_dir.join(".verified"), "revision=old\n").unwrap();
2789
2790 let legacy_backup_dir = tmp.path().join("model.bak");
2791 fs::create_dir_all(&legacy_backup_dir).unwrap();
2792 fs::write(legacy_backup_dir.join("sentinel.txt"), "keep me").unwrap();
2793
2794 let downloader = ModelDownloader::new(target_dir.clone());
2795 copy_model_fixtures(&downloader.temp_dir).unwrap();
2796 fs::write(downloader.temp_dir.join(".verified"), "revision=new\n").unwrap();
2797
2798 downloader.atomic_install().unwrap();
2799
2800 assert_eq!(
2801 fs::read_to_string(legacy_backup_dir.join("sentinel.txt")).unwrap(),
2802 "keep me"
2803 );
2804 assert_eq!(
2805 fs::read_to_string(target_dir.join(".verified")).unwrap(),
2806 "revision=new\n"
2807 );
2808 }
2809
2810 #[test]
2811 fn test_atomic_install_rejects_file_target() {
2812 let tmp = tempfile::tempdir().unwrap();
2813 let target_dir = tmp.path().join("model");
2814 fs::write(&target_dir, "not a directory").unwrap();
2815
2816 let downloader = ModelDownloader::new(target_dir.clone());
2817 copy_model_fixtures(&downloader.temp_dir).unwrap();
2818
2819 let err = downloader.atomic_install().unwrap_err();
2820
2821 assert!(
2822 err.to_string().contains("not a directory"),
2823 "unexpected error: {err}"
2824 );
2825 assert!(downloader.temp_dir.exists());
2826 assert_eq!(fs::read_to_string(&target_dir).unwrap(), "not a directory");
2827 }
2828
2829 #[test]
2830 #[cfg(unix)]
2831 fn test_atomic_install_rejects_dangling_symlink_target() {
2832 use std::os::unix::fs::symlink;
2833
2834 let tmp = tempfile::tempdir().unwrap();
2835 let target_dir = tmp.path().join("model");
2836 let missing_target = tmp.path().join("missing-model");
2837 symlink(&missing_target, &target_dir).unwrap();
2838
2839 let downloader = ModelDownloader::new(target_dir.clone());
2840 copy_model_fixtures(&downloader.temp_dir).unwrap();
2841
2842 let err = downloader.atomic_install().unwrap_err();
2843
2844 assert!(
2845 err.to_string().contains("through symlink"),
2846 "unexpected error: {err}"
2847 );
2848 assert!(downloader.temp_dir.exists());
2849 assert!(
2850 fs::symlink_metadata(&target_dir)
2851 .unwrap()
2852 .file_type()
2853 .is_symlink()
2854 );
2855 assert!(!missing_target.exists());
2856 }
2857
2858 #[test]
2859 fn test_write_verified_marker_overwrites_existing_marker() {
2860 let tmp = tempfile::tempdir().unwrap();
2861 let target_dir = tmp.path().join("model");
2862 fs::create_dir_all(&target_dir).unwrap();
2863 fs::write(target_dir.join(".verified"), "revision=old\n").unwrap();
2864
2865 let downloader = ModelDownloader::new(target_dir.clone());
2866 let manifest = ModelManifest::minilm_v2();
2867 downloader.write_verified_marker(&manifest, None).unwrap();
2868
2869 let marker = fs::read_to_string(target_dir.join(".verified")).unwrap();
2870 assert!(marker.contains(&format!("revision={}", manifest.revision)));
2871 assert!(marker.contains("verified_at="));
2872 assert!(marker.contains("source=registry"));
2873 }
2874
2875 #[test]
2876 fn test_download_error_display() {
2877 let display_cases = [
2878 (
2879 DownloadError::NetworkError("connection refused".into()),
2880 "network error: connection refused",
2881 ),
2882 (
2883 DownloadError::VerificationFailed {
2884 file: "test.onnx".into(),
2885 expected: "abc".into(),
2886 actual: "def".into(),
2887 },
2888 "verification failed for test.onnx: expected abc, got def",
2889 ),
2890 (DownloadError::Cancelled, "download cancelled"),
2891 (DownloadError::Timeout, "download timed out"),
2892 (
2893 DownloadError::HttpError {
2894 status: 503,
2895 message: "service unavailable".into(),
2896 },
2897 "HTTP error 503: service unavailable",
2898 ),
2899 (
2900 DownloadError::ManifestNotVerified {
2901 model_id: "test-model".into(),
2902 unverified_files: vec!["model.onnx".into(), "config.json".into()],
2903 revision_unpinned: true,
2904 },
2905 "model 'test-model' is not production-ready: 2 file(s) have placeholder checksums and revision is not pinned",
2906 ),
2907 (
2908 DownloadError::ManifestNotVerified {
2909 model_id: "test-model".into(),
2910 unverified_files: vec!["model.onnx".into()],
2911 revision_unpinned: false,
2912 },
2913 "model 'test-model' is not production-ready: 1 file(s) have placeholder checksums",
2914 ),
2915 (
2916 DownloadError::InvalidMirrorUrl {
2917 url: "ftp://mirror.example/model.onnx".into(),
2918 reason: "unsupported scheme".into(),
2919 },
2920 "invalid mirror URL 'ftp://mirror.example/model.onnx': unsupported scheme",
2921 ),
2922 ];
2923
2924 for (err, expected) in display_cases {
2925 assert_eq!(err.to_string(), expected);
2926 }
2927
2928 let err: DownloadError = std::io::Error::other("disk full").into();
2929
2930 assert_eq!(err.to_string(), "I/O error: disk full");
2931 let source = err.source().expect("I/O errors expose their source");
2932 assert_eq!(source.to_string(), "disk full");
2933
2934 assert!(
2935 DownloadError::NetworkError("connection refused".into())
2936 .source()
2937 .is_none(),
2938 "non-source variants must not gain an error source"
2939 );
2940 }
2941
2942 #[test]
2943 fn test_manifest_production_ready_minilm() {
2944 let manifest = ModelManifest::minilm_v2();
2946 assert!(manifest.has_verified_checksums());
2947 assert!(manifest.has_pinned_revision());
2948 assert!(manifest.is_production_ready());
2949 }
2950
2951 #[test]
2952 fn test_all_bakeoff_candidates_production_ready() {
2953 let candidates = ModelManifest::bakeoff_candidates();
2955
2956 assert_eq!(candidates.len(), 3, "Expected 3 bake-off candidates");
2958
2959 for manifest in &candidates {
2961 assert!(
2962 manifest.is_production_ready(),
2963 "Model {} should be production-ready",
2964 manifest.id
2965 );
2966 assert!(
2967 manifest.has_verified_checksums(),
2968 "Model {} should have verified checksums",
2969 manifest.id
2970 );
2971 assert!(
2972 manifest.has_pinned_revision(),
2973 "Model {} should have pinned revision",
2974 manifest.id
2975 );
2976 }
2977
2978 assert!(
2980 candidates
2981 .iter()
2982 .any(|m| m.id == "snowflake-arctic-embed-s"),
2983 "Snowflake should be in candidates"
2984 );
2985 assert!(
2986 candidates.iter().any(|m| m.id == "nomic-embed-text-v1.5"),
2987 "Nomic should be in candidates"
2988 );
2989 assert!(
2990 candidates
2991 .iter()
2992 .any(|m| m.id == "jina-reranker-v1-turbo-en"),
2993 "Jina Turbo should be in candidates"
2994 );
2995 }
2996
2997 #[test]
2998 fn test_downloader_cancellation() {
2999 let tmp = tempfile::tempdir().unwrap();
3000 let downloader = ModelDownloader::new(tmp.path().join("model"));
3001
3002 assert!(!downloader.is_cancelled());
3003 downloader.cancel();
3004 assert!(downloader.is_cancelled());
3005 }
3006
3007 #[test]
3008 fn test_prepare_temp_dir_prunes_stale_entries() {
3009 let tmp = tempfile::tempdir().unwrap();
3010 let downloader = ModelDownloader::new(tmp.path().join("model"));
3011 fs::create_dir_all(&downloader.temp_dir).unwrap();
3012 fs::write(downloader.temp_dir.join("model.onnx"), b"partial").unwrap();
3013 fs::write(downloader.temp_dir.join("stale.bin"), b"stale").unwrap();
3014 fs::create_dir_all(downloader.temp_dir.join("nested")).unwrap();
3015 fs::write(
3016 downloader.temp_dir.join("nested").join("should-remove.txt"),
3017 b"stale",
3018 )
3019 .unwrap();
3020
3021 downloader
3022 .prepare_temp_dir(&ModelManifest::minilm_v2())
3023 .unwrap();
3024
3025 assert!(downloader.temp_dir.join("model.onnx").exists());
3026 assert!(!downloader.temp_dir.join("stale.bin").exists());
3027 assert!(!downloader.temp_dir.join("nested").exists());
3028 }
3029
3030 #[test]
3031 #[cfg(unix)]
3032 fn test_prepare_temp_dir_removes_symlink_entries() {
3033 use std::os::unix::fs::symlink;
3034
3035 let tmp = tempfile::tempdir().unwrap();
3036 let downloader = ModelDownloader::new(tmp.path().join("model"));
3037 fs::create_dir_all(&downloader.temp_dir).unwrap();
3038 let outside = tmp.path().join("outside.bin");
3039 fs::write(&outside, b"outside").unwrap();
3040 symlink(&outside, downloader.temp_dir.join("model.onnx")).unwrap();
3041
3042 downloader
3043 .prepare_temp_dir(&ModelManifest::minilm_v2())
3044 .unwrap();
3045
3046 let metadata = fs::symlink_metadata(downloader.temp_dir.join("model.onnx"));
3047 assert!(metadata.is_err(), "symlink should be removed before resume");
3048 assert!(
3049 outside.exists(),
3050 "cleanup must not touch the symlink target"
3051 );
3052 }
3053
3054 #[test]
3055 #[cfg(unix)]
3056 fn test_prepare_temp_dir_rejects_symlinked_temp_dir_without_pruning_target() {
3057 use std::os::unix::fs::symlink;
3058
3059 let tmp = tempfile::tempdir().unwrap();
3060 let downloader = ModelDownloader::new(tmp.path().join("model"));
3061 let outside = tmp.path().join("outside-download-cache");
3062 fs::create_dir_all(&outside).unwrap();
3063 fs::write(outside.join("stale.bin"), b"must remain").unwrap();
3064 symlink(&outside, &downloader.temp_dir).unwrap();
3065
3066 let err = downloader
3067 .prepare_temp_dir(&ModelManifest::minilm_v2())
3068 .expect_err("symlinked temp dir must be rejected before pruning");
3069
3070 assert!(
3071 err.to_string().contains("temp dir through symlink"),
3072 "unexpected symlink-temp-dir error: {err}"
3073 );
3074 assert_eq!(fs::read(outside.join("stale.bin")).unwrap(), b"must remain");
3075 assert!(
3076 fs::symlink_metadata(&downloader.temp_dir)
3077 .unwrap()
3078 .file_type()
3079 .is_symlink()
3080 );
3081 }
3082
3083 #[test]
3084 #[cfg(unix)]
3085 fn test_cleanup_temp_skips_symlinked_temp_dir() {
3086 use std::os::unix::fs::symlink;
3087
3088 let tmp = tempfile::tempdir().unwrap();
3089 let downloader = ModelDownloader::new(tmp.path().join("model"));
3090 let outside = tmp.path().join("outside-download-cache");
3091 fs::create_dir_all(&outside).unwrap();
3092 fs::write(outside.join("sentinel.bin"), b"must remain").unwrap();
3093 symlink(&outside, &downloader.temp_dir).unwrap();
3094
3095 downloader.cleanup_temp();
3096
3097 assert_eq!(
3098 fs::read(outside.join("sentinel.bin")).unwrap(),
3099 b"must remain"
3100 );
3101 assert!(
3102 fs::symlink_metadata(&downloader.temp_dir)
3103 .unwrap()
3104 .file_type()
3105 .is_symlink()
3106 );
3107 }
3108
3109 #[test]
3110 fn test_retryable_error_classification() {
3111 let cases = [
3112 (DownloadError::NetworkError("boom".into()), true),
3113 (DownloadError::Timeout, true),
3114 (
3115 DownloadError::HttpError {
3116 status: 503,
3117 message: "unavailable".into(),
3118 },
3119 true,
3120 ),
3121 (
3122 DownloadError::HttpError {
3123 status: 404,
3124 message: "missing".into(),
3125 },
3126 false,
3127 ),
3128 (DownloadError::Cancelled, false),
3129 (
3130 DownloadError::VerificationFailed {
3131 file: "model.onnx".into(),
3132 expected: "a".into(),
3133 actual: "b".into(),
3134 },
3135 false,
3136 ),
3137 ];
3138
3139 for (err, expected) in cases {
3140 assert_eq!(
3141 err.is_retryable(),
3142 expected,
3143 "retryability mismatch for {err}"
3144 );
3145 }
3146 }
3147
3148 #[test]
3149 fn test_cleanup_temp_for_error_preserves_partial_downloads_on_cancelled() {
3150 let tmp = tempfile::tempdir().unwrap();
3151 let downloader = ModelDownloader::new(tmp.path().join("model"));
3152 fs::create_dir_all(&downloader.temp_dir).unwrap();
3153 let partial = downloader.temp_dir.join("model.onnx");
3154 fs::write(&partial, b"partial").unwrap();
3155
3156 downloader.cleanup_temp_for_error(&DownloadError::Cancelled);
3157
3158 assert!(
3159 partial.exists(),
3160 "cancelled downloads should keep partial files for a resumable retry"
3161 );
3162 }
3163
3164 #[test]
3165 fn test_fail_if_cancelled_preserves_partial_downloads() {
3166 let tmp = tempfile::tempdir().unwrap();
3167 let downloader = ModelDownloader::new(tmp.path().join("model"));
3168 fs::create_dir_all(&downloader.temp_dir).unwrap();
3169 let partial = downloader.temp_dir.join("model.onnx");
3170 fs::write(&partial, b"partial").unwrap();
3171 downloader.cancel();
3172
3173 let result = downloader.fail_if_cancelled();
3174
3175 assert!(matches!(result, Err(DownloadError::Cancelled)));
3176 assert!(
3177 partial.exists(),
3178 "early cancellation checks should not discard resumable partial files"
3179 );
3180 }
3181
3182 #[test]
3183 fn test_cleanup_temp_for_error_discards_temp_after_verification_failure() {
3184 let tmp = tempfile::tempdir().unwrap();
3185 let downloader = ModelDownloader::new(tmp.path().join("model"));
3186 fs::create_dir_all(&downloader.temp_dir).unwrap();
3187 let partial = downloader.temp_dir.join("model.onnx");
3188 fs::write(&partial, b"partial").unwrap();
3189
3190 downloader.cleanup_temp_for_error(&DownloadError::VerificationFailed {
3191 file: "model.onnx".into(),
3192 expected: "good".into(),
3193 actual: "bad".into(),
3194 });
3195
3196 assert!(
3197 !downloader.temp_dir.exists(),
3198 "verification failures should discard the temp directory to avoid reusing corrupt data"
3199 );
3200 }
3201
3202 #[test]
3203 fn test_download_with_mirror_installs_verified_model_from_http_mirror() {
3204 let files = [
3205 ("onnx/model.onnx", b"mirror-model".as_slice()),
3206 ("tokenizer.json", br#"{"tokenizer":"ok"}"#.as_slice()),
3207 ];
3208 let manifest = build_test_manifest("mirror/test-model", "rev123", &files);
3209 let route_prefix = "/cache";
3210 let routes: Vec<(String, MirrorRoute)> = manifest
3211 .files
3212 .iter()
3213 .zip(files.iter())
3214 .map(|(file, (_, body))| {
3215 (
3216 mirror_route_path(route_prefix, &manifest, file),
3217 MirrorRoute {
3218 body: body.to_vec(),
3219 content_type: "application/octet-stream",
3220 chunk_size: 64,
3221 chunk_delay: Duration::ZERO,
3222 },
3223 )
3224 })
3225 .collect();
3226 let server = start_mirror_fixture_server(routes);
3227 let tmp = tempfile::tempdir().unwrap();
3228 let downloader = ModelDownloader::new(tmp.path().join("model"));
3229 let mirror_base = format!("{}/cache/", server.base_url);
3230
3231 downloader
3232 .download_with_mirror(&manifest, Some(&mirror_base), None)
3233 .unwrap();
3234
3235 for (name, body) in files {
3236 let installed = downloader.target_dir.join(
3237 Path::new(name)
3238 .file_name()
3239 .unwrap()
3240 .to_string_lossy()
3241 .as_ref(),
3242 );
3243 assert_eq!(
3244 fs::read(installed).unwrap(),
3245 body,
3246 "mirror install should persist the downloaded payload"
3247 );
3248 }
3249 let marker = fs::read_to_string(downloader.target_dir.join(".verified")).unwrap();
3250 assert!(
3251 marker.contains("revision=rev123"),
3252 "verified marker should preserve manifest identity after mirror install"
3253 );
3254 assert!(
3255 marker.contains("source=mirror:"),
3256 "verified marker should record mirror source"
3257 );
3258
3259 let requests = server.requests();
3260 assert_eq!(
3261 requests.len(),
3262 manifest.files.len(),
3263 "expected one request per manifest file"
3264 );
3265 assert!(
3266 requests
3267 .iter()
3268 .all(|request| request.path.starts_with("/cache/")),
3269 "mirror requests should stay under the configured mirror prefix: {requests:?}"
3270 );
3271 }
3272
3273 #[test]
3274 fn test_download_with_mirror_reports_missing_artifact_from_http_mirror() {
3275 let file_body = b"mirror-model".as_slice();
3276 let manifest = build_test_manifest(
3277 "mirror/test-model",
3278 "rev404",
3279 &[("onnx/model.onnx", file_body)],
3280 );
3281 let server = start_mirror_fixture_server(Vec::new());
3282 let tmp = tempfile::tempdir().unwrap();
3283 let downloader = ModelDownloader::new(tmp.path().join("model"));
3284 let mirror_base = format!("{}/cache", server.base_url);
3285
3286 let err = downloader
3287 .download_with_mirror(&manifest, Some(&mirror_base), None)
3288 .unwrap_err();
3289
3290 assert!(
3291 matches!(err, DownloadError::HttpError { status: 404, .. }),
3292 "missing mirror artifacts should surface as HTTP 404, got: {err}"
3293 );
3294 let requests = server.requests();
3295 assert_eq!(requests.len(), 1);
3296 assert!(
3297 requests[0].path.contains("/resolve/"),
3298 "mirror request should target the resolved artifact path: {requests:?}"
3299 );
3300 }
3301
3302 #[test]
3303 fn test_download_with_mirror_discards_corrupt_payload_from_http_mirror() {
3304 let manifest = build_test_manifest(
3305 "mirror/test-model",
3306 "revbad",
3307 &[("onnx/model.onnx", b"expected-bytes".as_slice())],
3308 );
3309 let route_prefix = "/cache";
3310 let server = start_mirror_fixture_server(vec![(
3311 mirror_route_path(route_prefix, &manifest, &manifest.files[0]),
3312 MirrorRoute {
3313 body: b"corrupt-bytes".to_vec(),
3314 content_type: "application/octet-stream",
3315 chunk_size: 64,
3316 chunk_delay: Duration::ZERO,
3317 },
3318 )]);
3319 let tmp = tempfile::tempdir().unwrap();
3320 let downloader = ModelDownloader::new(tmp.path().join("model"));
3321 let mirror_base = format!("{server_base}/cache", server_base = server.base_url);
3322
3323 let err = downloader
3324 .download_with_mirror(&manifest, Some(&mirror_base), None)
3325 .unwrap_err();
3326
3327 assert!(
3328 matches!(err, DownloadError::VerificationFailed { .. }),
3329 "corrupt mirror payloads must fail checksum verification, got: {err}"
3330 );
3331 assert!(
3332 !downloader.temp_dir.exists(),
3333 "verification failures should discard the temp directory so corrupt payloads are not reused"
3334 );
3335 assert!(
3336 !downloader.target_dir.exists(),
3337 "corrupt mirror payloads must not be promoted into the installed model directory"
3338 );
3339 }
3340
3341 #[test]
3342 fn test_download_with_mirror_resumes_after_cancelled_partial_download() {
3343 let large_payload = vec![b'x'; 128 * 1024];
3344 let manifest = build_test_manifest(
3345 "mirror/test-model",
3346 "revresume",
3347 &[("onnx/model.onnx", &large_payload)],
3348 );
3349 let route_prefix = "/cache";
3350 let server = start_mirror_fixture_server(vec![(
3351 mirror_route_path(route_prefix, &manifest, &manifest.files[0]),
3352 MirrorRoute {
3353 body: large_payload.clone(),
3354 content_type: "application/octet-stream",
3355 chunk_size: 1024,
3356 chunk_delay: Duration::from_millis(2),
3357 },
3358 )]);
3359 let tmp = tempfile::tempdir().unwrap();
3360 let downloader = Arc::new(ModelDownloader::new(tmp.path().join("model")));
3361 let mirror_base = format!("{server_base}/cache", server_base = server.base_url);
3362 let cancel_once = Arc::new(AtomicBool::new(false));
3363 let canceller = Arc::clone(&downloader);
3364 let cancel_flag = Arc::clone(&cancel_once);
3365
3366 let cancelled = downloader.download_with_mirror(
3367 &manifest,
3368 Some(&mirror_base),
3369 Some(Arc::new(move |progress| {
3370 if progress.total_bytes >= 16 * 1024
3371 && cancel_flag
3372 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
3373 .is_ok()
3374 {
3375 canceller.cancel();
3376 }
3377 })),
3378 );
3379
3380 assert!(
3381 matches!(cancelled, Err(DownloadError::Cancelled)),
3382 "first mirror attempt should stop with a cancellation so we can verify resumable recovery"
3383 );
3384 let partial_path = downloader.temp_dir.join("model.onnx");
3385 let partial_size = fs::metadata(&partial_path).unwrap().len();
3386 assert!(
3387 partial_size > 0 && partial_size < large_payload.len() as u64,
3388 "cancelled run should preserve a partial download for resume; got {partial_size} bytes"
3389 );
3390
3391 downloader
3392 .download_with_mirror(&manifest, Some(&mirror_base), None)
3393 .unwrap();
3394
3395 assert_eq!(
3396 fs::read(downloader.target_dir.join("model.onnx")).unwrap(),
3397 large_payload,
3398 "rerun after cancellation should finish the mirrored download and install the exact payload"
3399 );
3400 let requests = server.requests();
3401 assert!(
3402 requests
3403 .iter()
3404 .any(|request| request.range_start == Some(partial_size)),
3405 "rerun should resume from the preserved partial via Range requests; saw requests: {requests:?}"
3406 );
3407 }
3408}