1use std::collections::HashSet;
20use std::fs::{self, File};
21use std::future::{Future, poll_fn};
22use std::io::{BufReader, Read, Write};
23use std::path::{Path, PathBuf};
24use std::pin::Pin;
25use std::sync::Arc;
26use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
27use std::sync::mpsc::TryRecvError;
28use std::time::{Duration, Instant};
29
30use asupersync::bytes::Buf;
31use asupersync::http::Body;
32use serde::{Deserialize, Serialize};
33use sha2::{Digest, Sha256};
34use thiserror::Error;
35use url::Url;
36
37use crate::search::policy::{ModelDownloadPolicy, SemanticPolicy};
38
39#[derive(Debug, Clone, PartialEq)]
54pub enum ModelState {
55 NotInstalled,
57 NeedsConsent,
59 Downloading {
61 progress_pct: u8,
63 bytes_downloaded: u64,
65 total_bytes: u64,
67 },
68 Verifying,
70 Ready,
72 Disabled { reason: String },
74 VerificationFailed { reason: String },
76 UpdateAvailable {
78 current_revision: String,
80 latest_revision: String,
82 },
83 Cancelled,
85}
86
87impl ModelState {
88 pub fn is_ready(&self) -> bool {
90 matches!(self, ModelState::Ready)
91 }
92
93 pub fn is_downloading(&self) -> bool {
95 matches!(self, ModelState::Downloading { .. })
96 }
97
98 pub fn needs_consent(&self) -> bool {
100 matches!(self, ModelState::NeedsConsent)
101 }
102
103 pub fn summary(&self) -> String {
105 match self {
106 ModelState::NotInstalled => "not installed".into(),
107 ModelState::NeedsConsent => "needs consent".into(),
108 ModelState::Downloading { progress_pct, .. } => {
109 format!("downloading ({progress_pct}%)")
110 }
111 ModelState::Verifying => "verifying".into(),
112 ModelState::Ready => "ready".into(),
113 ModelState::Disabled { reason } => format!("disabled: {reason}"),
114 ModelState::VerificationFailed { reason } => format!("verification failed: {reason}"),
115 ModelState::UpdateAvailable {
116 current_revision,
117 latest_revision,
118 } => {
119 format!("update available: {current_revision} -> {latest_revision}")
120 }
121 ModelState::Cancelled => "cancelled".into(),
122 }
123 }
124}
125
126#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
132pub struct ModelAcquisitionPolicy {
133 pub downloads_enabled: bool,
135 pub requires_consent: bool,
137 pub offline: bool,
139 pub metered: bool,
141 pub allow_metered: bool,
143 pub max_model_bytes: Option<u64>,
145 pub mirror_base_url: Option<String>,
147 pub config_source: String,
149}
150
151impl Default for ModelAcquisitionPolicy {
152 fn default() -> Self {
153 Self {
154 downloads_enabled: true,
155 requires_consent: true,
156 offline: false,
157 metered: false,
158 allow_metered: false,
159 max_model_bytes: None,
160 mirror_base_url: None,
161 config_source: "compiled_default".to_string(),
162 }
163 }
164}
165
166impl ModelAcquisitionPolicy {
167 pub fn from_semantic_policy(policy: &SemanticPolicy) -> Self {
169 const MIB: u64 = 1_048_576;
170
171 Self {
172 downloads_enabled: policy.mode.should_build_semantic(),
173 requires_consent: matches!(policy.download_policy, ModelDownloadPolicy::OptIn),
174 max_model_bytes: Some(policy.max_model_size_mb.saturating_mul(MIB)),
175 config_source: "semantic_policy".to_string(),
176 ..Self::default()
177 }
178 }
179}
180
181#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
183#[serde(rename_all = "snake_case", tag = "state")]
184pub enum ModelCacheState {
185 NotAcquired {
187 missing_files: Vec<String>,
189 needs_consent: bool,
191 },
192 Acquiring {
194 staging_dir: PathBuf,
195 bytes_present: u64,
196 total_bytes: u64,
197 },
198 Acquired { model_dir: PathBuf },
200 ChecksumMismatch {
202 file: String,
203 expected: String,
204 actual: String,
205 },
206 IncompatibleVersion {
208 current_revision: String,
209 expected_revision: String,
210 },
211 DisabledByPolicy { reason: String },
213 BudgetBlocked { required_bytes: u64, max_bytes: u64 },
215 QuarantinedCorrupt {
217 marker_path: PathBuf,
218 reason: String,
219 },
220 PreseededLocal { model_dir: PathBuf },
222 MirrorSourced {
224 model_dir: PathBuf,
225 mirror_base_url: String,
226 },
227 OfflineBlocked { missing_files: Vec<String> },
229}
230
231impl ModelCacheState {
232 pub fn code(&self) -> &'static str {
234 match self {
235 Self::NotAcquired { .. } => "not_acquired",
236 Self::Acquiring { .. } => "acquiring",
237 Self::Acquired { .. } => "acquired",
238 Self::ChecksumMismatch { .. } => "checksum_mismatch",
239 Self::IncompatibleVersion { .. } => "incompatible_version",
240 Self::DisabledByPolicy { .. } => "disabled_by_policy",
241 Self::BudgetBlocked { .. } => "budget_blocked",
242 Self::QuarantinedCorrupt { .. } => "quarantined_corrupt",
243 Self::PreseededLocal { .. } => "preseeded_local",
244 Self::MirrorSourced { .. } => "mirror_sourced",
245 Self::OfflineBlocked { .. } => "offline_blocked",
246 }
247 }
248
249 pub fn summary(&self) -> String {
251 match self {
252 Self::NotAcquired {
253 missing_files,
254 needs_consent,
255 } => {
256 let action = if *needs_consent {
257 "user consent required"
258 } else {
259 "ready to acquire"
260 };
261 format!(
262 "model not acquired ({action}); missing {}",
263 missing_files.join(", ")
264 )
265 }
266 Self::Acquiring {
267 bytes_present,
268 total_bytes,
269 staging_dir,
270 } => format!(
271 "model acquisition in progress at {} ({bytes_present}/{total_bytes} bytes)",
272 staging_dir.display()
273 ),
274 Self::Acquired { .. } => "model cache acquired and verified".to_string(),
275 Self::ChecksumMismatch {
276 file,
277 expected,
278 actual,
279 } => format!("checksum mismatch for {file}: expected {expected}, got {actual}"),
280 Self::IncompatibleVersion {
281 current_revision,
282 expected_revision,
283 } => format!("model revision mismatch: {current_revision} != {expected_revision}"),
284 Self::DisabledByPolicy { reason } => format!("model acquisition disabled: {reason}"),
285 Self::BudgetBlocked {
286 required_bytes,
287 max_bytes,
288 } => {
289 format!("model requires {required_bytes} bytes but policy allows {max_bytes} bytes")
290 }
291 Self::QuarantinedCorrupt { reason, .. } => {
292 format!("model cache quarantined: {reason}")
293 }
294 Self::PreseededLocal { .. } => "preseeded local model cache verified".to_string(),
295 Self::MirrorSourced {
296 mirror_base_url, ..
297 } => {
298 format!("model cache verified from mirror {mirror_base_url}")
299 }
300 Self::OfflineBlocked { missing_files } => {
301 format!(
302 "offline and model is not acquired; missing {}",
303 missing_files.join(", ")
304 )
305 }
306 }
307 }
308
309 pub fn next_step(&self) -> Option<&'static str> {
311 match self {
312 Self::NotAcquired { .. } => {
313 Some("Run `cass models install`, or keep using lexical search.")
314 }
315 Self::Acquiring { .. } => {
316 Some("Wait for model acquisition to finish, or keep using lexical search.")
317 }
318 Self::Acquired { .. } | Self::PreseededLocal { .. } | Self::MirrorSourced { .. } => {
319 None
320 }
321 Self::ChecksumMismatch { .. } | Self::QuarantinedCorrupt { .. } => Some(
322 "Run `cass models verify --repair`, or reinstall with `cass models install -y`.",
323 ),
324 Self::IncompatibleVersion { .. } => {
325 Some("Run `cass models install -y` to refresh the model cache.")
326 }
327 Self::DisabledByPolicy { .. } => {
328 Some("Use lexical search or re-enable semantic model acquisition in policy.")
329 }
330 Self::BudgetBlocked { .. } => {
331 Some("Increase the semantic model budget or keep using lexical search.")
332 }
333 Self::OfflineBlocked { .. } => Some(
334 "Reconnect or install from local files with `cass models install --from-file`.",
335 ),
336 }
337 }
338
339 pub fn is_usable(&self) -> bool {
341 matches!(
342 self,
343 Self::Acquired { .. } | Self::PreseededLocal { .. } | Self::MirrorSourced { .. }
344 )
345 }
346}
347
348#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
350pub struct ModelCacheReport {
351 pub model_id: String,
352 pub model_dir: PathBuf,
353 pub state: ModelCacheState,
354 pub required_size_bytes: u64,
355 pub installed_size_bytes: u64,
356 pub policy_source: String,
357}
358
359impl ModelCacheReport {
360 pub fn state_code(&self) -> &'static str {
362 self.state.code()
363 }
364
365 pub fn is_usable(&self) -> bool {
367 self.state.is_usable()
368 }
369}
370
371#[derive(Debug, Clone)]
375pub struct ModelFile {
376 pub name: String,
378 pub sha256: String,
380 pub size: u64,
382}
383
384impl ModelFile {
385 pub fn local_name(&self) -> &str {
390 self.name.rsplit('/').next().unwrap_or(&self.name)
391 }
392}
393
394#[derive(Debug, Clone)]
400pub struct ModelManifest {
401 pub id: String,
403 pub repo: String,
405 pub revision: String,
407 pub files: Vec<ModelFile>,
409 pub license: String,
411}
412
413pub const PLACEHOLDER_CHECKSUM: &str = "PLACEHOLDER_VERIFY_AFTER_DOWNLOAD";
418
419pub fn normalize_mirror_base_url(base_url: &str) -> Result<String, DownloadError> {
424 let trimmed = base_url.trim();
425 if trimmed.is_empty() {
426 return Err(invalid_mirror_url(base_url, "mirror URL cannot be empty"));
427 }
428
429 let parsed = Url::parse(trimmed).map_err(|err| invalid_mirror_url(trimmed, err.to_string()))?;
430
431 match parsed.scheme() {
432 "http" | "https" => {}
433 scheme => {
434 return Err(invalid_mirror_url(
435 trimmed,
436 format!("unsupported URL scheme '{scheme}' (expected http or https)"),
437 ));
438 }
439 }
440
441 if parsed.host_str().is_none() {
442 return Err(invalid_mirror_url(
443 trimmed,
444 "mirror URL must include a host",
445 ));
446 }
447
448 if parsed.query().is_some() || parsed.fragment().is_some() {
449 return Err(invalid_mirror_url(
450 trimmed,
451 "mirror URL must not include query or fragment components",
452 ));
453 }
454
455 Ok(parsed.to_string().trim_end_matches('/').to_string())
456}
457
458fn invalid_mirror_url(url: impl Into<String>, reason: impl Into<String>) -> DownloadError {
459 DownloadError::InvalidMirrorUrl {
460 url: url.into(),
461 reason: reason.into(),
462 }
463}
464
465impl ModelManifest {
466 pub fn has_verified_checksums(&self) -> bool {
471 self.files.iter().all(|f| f.sha256 != PLACEHOLDER_CHECKSUM)
472 }
473
474 pub fn has_pinned_revision(&self) -> bool {
479 self.revision != "main"
480 }
481
482 pub fn is_production_ready(&self) -> bool {
488 self.has_verified_checksums() && self.has_pinned_revision()
489 }
490
491 pub fn minilm_v2() -> Self {
496 Self {
497 id: "all-minilm-l6-v2".into(),
498 repo: "sentence-transformers/all-MiniLM-L6-v2".into(),
499 revision: "c9745ed1d9f207416be6d2e6f8de32d1f16199bf".into(),
501 files: vec![
502 ModelFile {
503 name: "onnx/model.onnx".into(),
505 sha256: "6fd5d72fe4589f189f8ebc006442dbb529bb7ce38f8082112682524616046452"
506 .into(),
507 size: 90405214,
508 },
509 ModelFile {
510 name: "tokenizer.json".into(),
511 sha256: "be50c3628f2bf5bb5e3a7f17b1f74611b2561a3a27eeab05e5aa30f411572037"
512 .into(),
513 size: 466247,
514 },
515 ModelFile {
516 name: "config.json".into(),
517 sha256: "953f9c0d463486b10a6871cc2fd59f223b2c70184f49815e7efbcab5d8908b41"
518 .into(),
519 size: 612,
520 },
521 ModelFile {
525 name: "special_tokens_map.json".into(),
526 sha256: "303df45a03609e4ead04bc3dc1536d0ab19b5358db685b6f3da123d05ec200e3"
527 .into(),
528 size: 112,
529 },
530 ModelFile {
531 name: "tokenizer_config.json".into(),
532 sha256: "acb92769e8195aabd29b7b2137a9e6d6e25c476a4f15aa4355c233426c61576b"
533 .into(),
534 size: 350,
535 },
536 ],
537 license: "Apache-2.0".into(),
538 }
539 }
540
541 pub fn snowflake_arctic_s() -> Self {
555 Self {
556 id: "snowflake-arctic-embed-s".into(),
557 repo: "Snowflake/snowflake-arctic-embed-s".into(),
558 revision: "e596f507467533e48a2e17c007f0e1dacc837b33".into(),
559 files: vec![
560 ModelFile {
561 name: "onnx/model.onnx".into(),
562 sha256: "579c1f1778a0993eb0d2a1403340ffb491c769247fb46acc4f5cf8ac5b89c1e1"
563 .into(),
564 size: 133_093_492,
565 },
566 ModelFile {
567 name: "tokenizer.json".into(),
568 sha256: "91f1def9b9391fdabe028cd3f3fcc4efd34e5d1f08c3bf2de513ebb5911a1854"
569 .into(),
570 size: 711_649,
571 },
572 ModelFile {
573 name: "config.json".into(),
574 sha256: "4e519aa92ec40943356032afe458c8829d70c5766b109e4a57490b82f72dcfb7"
575 .into(),
576 size: 703,
577 },
578 ModelFile {
579 name: "special_tokens_map.json".into(),
580 sha256: "5d5b662e421ea9fac075174bb0688ee0d9431699900b90662acd44b2a350503a"
581 .into(),
582 size: 695,
583 },
584 ModelFile {
585 name: "tokenizer_config.json".into(),
586 sha256: "9ca59277519f6e3692c8685e26b94d4afca2d5438deff66483db495e48735810"
587 .into(),
588 size: 1_433,
589 },
590 ],
591 license: "Apache-2.0".into(),
592 }
593 }
594
595 pub fn nomic_embed() -> Self {
603 Self {
604 id: "nomic-embed-text-v1.5".into(),
605 repo: "nomic-ai/nomic-embed-text-v1.5".into(),
606 revision: "e5cf08aadaa33385f5990def41f7a23405aec398".into(),
607 files: vec![
608 ModelFile {
609 name: "onnx/model.onnx".into(),
610 sha256: "147d5aa88c2101237358e17796cf3a227cead1ec304ec34b465bb08e9d952965"
611 .into(),
612 size: 547_310_275,
613 },
614 ModelFile {
615 name: "tokenizer.json".into(),
616 sha256: "d241a60d5e8f04cc1b2b3e9ef7a4921b27bf526d9f6050ab90f9267a1f9e5c66"
617 .into(),
618 size: 711_396,
619 },
620 ModelFile {
621 name: "config.json".into(),
622 sha256: "0168e0883705b0bf8f2b381e10f45a9f3e1ef4b13869b43c160e4c8a70ddf442"
623 .into(),
624 size: 2_331,
625 },
626 ModelFile {
627 name: "special_tokens_map.json".into(),
628 sha256: "5d5b662e421ea9fac075174bb0688ee0d9431699900b90662acd44b2a350503a"
629 .into(),
630 size: 695,
631 },
632 ModelFile {
633 name: "tokenizer_config.json".into(),
634 sha256: "d7e0000bcc80134debd2222220427e6bf5fa20a669f40a0d0d1409cc18e0a9bc"
635 .into(),
636 size: 1_191,
637 },
638 ],
639 license: "Apache-2.0".into(),
640 }
641 }
642
643 pub fn msmarco_reranker() -> Self {
655 Self {
656 id: "ms-marco-MiniLM-L-6-v2".into(),
657 repo: "cross-encoder/ms-marco-MiniLM-L6-v2".into(),
658 revision: "c5ee24cb16019beea0893ab7796b1df96625c6b8".into(),
659 files: vec![
660 ModelFile {
661 name: "onnx/model.onnx".into(),
662 sha256: "5d3e70fd0c9ff14b9b5169a51e957b7a9c74897afd0a35ce4bd318150c1d4d4a"
663 .into(),
664 size: 91_011_230,
665 },
666 ModelFile {
667 name: "tokenizer.json".into(),
668 sha256: "d241a60d5e8f04cc1b2b3e9ef7a4921b27bf526d9f6050ab90f9267a1f9e5c66"
669 .into(),
670 size: 711_396,
671 },
672 ModelFile {
673 name: "config.json".into(),
674 sha256: "380e02c93f431831be65d99a4e7e5f67c133985bf2e77d9d4eba46847190bacc"
675 .into(),
676 size: 794,
677 },
678 ModelFile {
679 name: "special_tokens_map.json".into(),
680 sha256: "3c3507f36dff57bce437223db3b3081d1e2b52ec3e56ee55438193ecb2c94dd6"
681 .into(),
682 size: 132,
683 },
684 ModelFile {
685 name: "tokenizer_config.json".into(),
686 sha256: "a5c2e5a7b1a29a0702cd28c08a399b5ecc110c263009d17f7e3b415f25905fd8"
687 .into(),
688 size: 1_330,
689 },
690 ],
691 license: "Apache-2.0".into(),
692 }
693 }
694
695 pub fn jina_reranker_turbo() -> Self {
702 Self {
703 id: "jina-reranker-v1-turbo-en".into(),
704 repo: "jinaai/jina-reranker-v1-turbo-en".into(),
705 revision: "b8c14f4e723d9e0aab4732a7b7b93741eeeb77c2".into(),
706 files: vec![
707 ModelFile {
708 name: "onnx/model.onnx".into(),
709 sha256: "c1296c66c119de645fa9cdee536d8637740efe85224cfa270281e50f213aa565"
710 .into(),
711 size: 151_296_975,
712 },
713 ModelFile {
714 name: "tokenizer.json".into(),
715 sha256: "0046da43cc8c424b317f56b092b0512aaaa65c4f925d2f16af9d9eeb4d0ef902"
716 .into(),
717 size: 2_030_772,
718 },
719 ModelFile {
720 name: "config.json".into(),
721 sha256: "e050ff6a15ae9295e84882fa0e98051bd8754856cd5201395ebf00ce9f2d609b"
722 .into(),
723 size: 1_206,
724 },
725 ModelFile {
726 name: "special_tokens_map.json".into(),
727 sha256: "06e405a36dfe4b9604f484f6a1e619af1a7f7d09e34a8555eb0b77b66318067f"
728 .into(),
729 size: 280,
730 },
731 ModelFile {
732 name: "tokenizer_config.json".into(),
733 sha256: "d291c6652d96d56ffdbcf1ea19d9bae5ed79003f7648c627e725a619227ce8fa"
734 .into(),
735 size: 1_215,
736 },
737 ],
738 license: "Apache-2.0".into(),
739 }
740 }
741
742 pub fn for_embedder(name: &str) -> Option<Self> {
746 match name {
747 "minilm" => Some(Self::minilm_v2()),
748 "snowflake-arctic-s" => Some(Self::snowflake_arctic_s()),
749 "nomic-embed" => Some(Self::nomic_embed()),
750 _ => None,
751 }
752 }
753
754 pub fn for_reranker(name: &str) -> Option<Self> {
756 match name {
757 "ms-marco" => Some(Self::msmarco_reranker()),
758 "jina-reranker-turbo" => Some(Self::jina_reranker_turbo()),
759 _ => None,
760 }
761 }
762
763 pub fn bakeoff_embedder_candidates() -> Vec<Self> {
767 vec![Self::snowflake_arctic_s(), Self::nomic_embed()]
768 }
769
770 pub fn bakeoff_reranker_candidates() -> Vec<Self> {
774 vec![Self::jina_reranker_turbo()]
775 }
776
777 pub fn bakeoff_candidates() -> Vec<Self> {
781 let mut candidates = Self::bakeoff_embedder_candidates();
782 candidates.extend(Self::bakeoff_reranker_candidates());
783 candidates
784 }
785
786 pub fn total_size(&self) -> u64 {
788 self.files.iter().map(|f| f.size).sum()
789 }
790
791 pub fn download_url_with_base(&self, file: &ModelFile, base_url: Option<&str>) -> String {
793 let root = base_url.unwrap_or("https://huggingface.co");
794 format!(
795 "{}/{}/resolve/{}/{}",
796 root.trim_end_matches('/'),
797 self.repo.trim_start_matches('/'),
798 self.revision,
799 file.name.trim_start_matches('/')
800 )
801 }
802
803 pub fn download_url(&self, file: &ModelFile) -> String {
805 self.download_url_with_base(file, None)
806 }
807
808 pub fn air_gap_bash_script(&self, base_url: Option<&str>) -> String {
815 fn quote_url(url: &str) -> String {
821 debug_assert!(
822 !url.contains('\''),
823 "model download URL unexpectedly contains a single quote: {url}"
824 );
825 format!("'{url}'")
826 }
827
828 let mut out = String::new();
829 out.push_str("# Air-gap model install (bash / Git Bash / MSYS2)\n");
830 out.push_str(
831 "# Run these commands, then re-run `cass models install --from-file \"$DIR\"`.\n",
832 );
833 out.push_str("set -euo pipefail\n");
834 out.push_str(&format!("DIR=\"${{DIR:-./{}_files}}\"\n", self.id));
835 out.push_str("mkdir -p \"$DIR\"\n");
836 for file in &self.files {
837 let url = self.download_url_with_base(file, base_url);
842 out.push_str(&format!(
843 "curl -fL --retry 3 {} -o \"$DIR/{}\" # {} bytes\n",
844 quote_url(&url),
845 file.local_name(),
846 file.size,
847 ));
848 }
849 out.push_str(&format!(
850 "cass models install {} --from-file \"$DIR\" -y\n",
851 self.id
852 ));
853 out
854 }
855
856 pub fn air_gap_powershell_script(&self, base_url: Option<&str>) -> String {
859 fn quote_url_ps(url: &str) -> String {
861 debug_assert!(
862 !url.contains('\''),
863 "model download URL unexpectedly contains a single quote: {url}"
864 );
865 format!("'{url}'")
866 }
867
868 let mut out = String::new();
869 out.push_str("# Air-gap model install (PowerShell 5.1+ and 7+)\n");
870 out.push_str("$ErrorActionPreference = 'Stop'\n");
871 out.push_str(
874 "[System.Net.ServicePointManager]::SecurityProtocol = \
875 [System.Net.ServicePointManager]::SecurityProtocol -bor \
876 [System.Net.SecurityProtocolType]::Tls12\n",
877 );
878 out.push_str(&format!("$dir = \"{}_files\"\n", self.id));
879 out.push_str("New-Item -ItemType Directory -Force -Path $dir | Out-Null\n");
880 for file in &self.files {
881 let url = self.download_url_with_base(file, base_url);
882 out.push_str(&format!(
885 "Invoke-WebRequest -UseBasicParsing -Uri {} -OutFile (Join-Path $dir '{}') # {} bytes\n",
886 quote_url_ps(&url),
887 file.local_name(),
888 file.size,
889 ));
890 }
891 out.push_str(&format!(
892 "cass models install {} --from-file $dir -y\n",
893 self.id
894 ));
895 out
896 }
897}
898
899pub type ProgressCallback = Arc<dyn Fn(DownloadProgress) + Send + Sync>;
901
902#[derive(Debug, Clone)]
904pub struct DownloadProgress {
905 pub current_file: String,
907 pub file_index: usize,
909 pub total_files: usize,
911 pub file_bytes: u64,
913 pub file_total: u64,
915 pub total_bytes: u64,
917 pub grand_total: u64,
919 pub progress_pct: u8,
921}
922
923#[derive(Debug, Error)]
925pub enum DownloadError {
926 #[error("network error: {0}")]
928 NetworkError(String),
929 #[error("I/O error: {0}")]
931 IoError(#[from] std::io::Error),
932 #[error("verification failed for {file}: expected {expected}, got {actual}")]
934 VerificationFailed {
935 file: String,
936 expected: String,
937 actual: String,
938 },
939 #[error("download cancelled")]
941 Cancelled,
942 #[error("download timed out")]
944 Timeout,
945 #[error("HTTP error {status}: {message}")]
947 HttpError { status: u16, message: String },
948 #[error(
955 "model '{model_id}' is not production-ready: {} file(s) have placeholder checksums{}",
956 unverified_files.len(),
957 if *revision_unpinned {
958 " and revision is not pinned"
959 } else {
960 ""
961 }
962 )]
963 ManifestNotVerified {
964 model_id: String,
965 unverified_files: Vec<String>,
966 revision_unpinned: bool,
967 },
968 #[error("invalid mirror URL '{url}': {reason}")]
970 InvalidMirrorUrl { url: String, reason: String },
971}
972
973impl DownloadError {
974 fn is_retryable(&self) -> bool {
975 match self {
976 DownloadError::NetworkError(_) | DownloadError::IoError(_) | DownloadError::Timeout => {
977 true
978 }
979 DownloadError::HttpError { status, .. } => {
980 *status == 408 || *status == 429 || (500..=599).contains(status)
981 }
982 DownloadError::VerificationFailed { .. }
983 | DownloadError::Cancelled
984 | DownloadError::ManifestNotVerified { .. }
985 | DownloadError::InvalidMirrorUrl { .. } => false,
986 }
987 }
988
989 fn should_discard_temp(&self) -> bool {
990 matches!(self, DownloadError::VerificationFailed { .. })
991 }
992}
993
994fn run_download_with_cx<T, F, Fut>(f: F) -> Result<T, DownloadError>
995where
996 T: Send + 'static,
997 F: FnOnce(asupersync::Cx) -> Fut + Send + 'static,
998 Fut: Future<Output = Result<T, DownloadError>> + Send + 'static,
999{
1000 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
1001 .build()
1002 .map_err(|e| {
1003 DownloadError::NetworkError(format!("failed to build download runtime: {e}"))
1004 })?;
1005
1006 runtime.block_on(async move {
1007 let handle = asupersync::runtime::Runtime::current_handle().ok_or_else(|| {
1008 DownloadError::NetworkError("download runtime handle unavailable".into())
1009 })?;
1010 let (tx, rx) = std::sync::mpsc::channel();
1011 handle
1012 .try_spawn_with_cx(move |cx| async move {
1013 let _ = tx.send(f(cx).await);
1014 })
1015 .map_err(|e| {
1016 DownloadError::NetworkError(format!("failed to spawn download task: {e}"))
1017 })?;
1018
1019 loop {
1020 match rx.try_recv() {
1021 Ok(result) => return result,
1022 Err(TryRecvError::Empty) => asupersync::runtime::yield_now().await,
1023 Err(TryRecvError::Disconnected) => {
1024 return Err(DownloadError::NetworkError(
1025 "download task exited before returning a result".into(),
1026 ));
1027 }
1028 }
1029 }
1030 })
1031}
1032
1033pub struct ModelDownloader {
1035 target_dir: PathBuf,
1037 temp_dir: PathBuf,
1039 cancelled: Arc<AtomicBool>,
1041 connect_timeout: Duration,
1043 file_timeout: Duration,
1045 max_retries: u32,
1047}
1048
1049impl ModelDownloader {
1050 pub fn new(target_dir: PathBuf) -> Self {
1052 let temp_dir = if let Some(parent) = target_dir.parent() {
1055 let dir_name = target_dir
1056 .file_name()
1057 .and_then(|n| n.to_str())
1058 .unwrap_or("model");
1059 parent.join(format!("{}.downloading", dir_name))
1060 } else {
1061 target_dir.with_extension("downloading")
1063 };
1064 Self {
1065 target_dir,
1066 temp_dir,
1067 cancelled: Arc::new(AtomicBool::new(false)),
1068 connect_timeout: Duration::from_secs(30),
1069 file_timeout: Duration::from_secs(300), max_retries: 3,
1071 }
1072 }
1073
1074 pub fn cancellation_handle(&self) -> Arc<AtomicBool> {
1076 Arc::clone(&self.cancelled)
1077 }
1078
1079 pub fn cancel(&self) {
1081 self.cancelled.store(true, Ordering::SeqCst);
1082 }
1083
1084 pub fn is_cancelled(&self) -> bool {
1086 self.cancelled.load(Ordering::SeqCst)
1087 }
1088
1089 pub fn download(
1106 &self,
1107 manifest: &ModelManifest,
1108 on_progress: Option<ProgressCallback>,
1109 ) -> Result<(), DownloadError> {
1110 self.download_with_mirror(manifest, None, on_progress)
1111 }
1112
1113 pub fn download_with_mirror(
1115 &self,
1116 manifest: &ModelManifest,
1117 mirror_base_url: Option<&str>,
1118 on_progress: Option<ProgressCallback>,
1119 ) -> Result<(), DownloadError> {
1120 if !manifest.is_production_ready() {
1123 let unverified_files: Vec<String> = manifest
1124 .files
1125 .iter()
1126 .filter(|f| f.sha256 == PLACEHOLDER_CHECKSUM)
1127 .map(|f| f.name.clone())
1128 .collect();
1129 return Err(DownloadError::ManifestNotVerified {
1130 model_id: manifest.id.clone(),
1131 unverified_files,
1132 revision_unpinned: !manifest.has_pinned_revision(),
1133 });
1134 }
1135
1136 self.cancelled.store(false, Ordering::SeqCst);
1138
1139 self.prepare_temp_dir(manifest)?;
1142
1143 let grand_total = manifest.total_size();
1144 let total_files = manifest.files.len();
1145 let bytes_downloaded = Arc::new(AtomicU64::new(0));
1146
1147 for (idx, file) in manifest.files.iter().enumerate() {
1148 self.fail_if_cancelled()?;
1149
1150 let file_path = self.temp_dir.join(file.local_name());
1152 let url = manifest.download_url_with_base(file, mirror_base_url);
1153
1154 let bytes_before_file = bytes_downloaded.load(Ordering::SeqCst);
1156
1157 let mut last_error = None;
1159 for attempt in 0..self.max_retries {
1160 self.fail_if_cancelled()?;
1161
1162 if attempt > 0 {
1164 bytes_downloaded.store(bytes_before_file, Ordering::SeqCst);
1165 }
1166
1167 if attempt > 0 {
1169 let delay = Duration::from_secs(5 * (1 << (attempt - 1)));
1170 std::thread::sleep(delay);
1171 }
1172
1173 match self.download_file(
1174 &url,
1175 &file_path,
1176 file.size,
1177 idx,
1178 total_files,
1179 &bytes_downloaded,
1180 grand_total,
1181 on_progress.as_ref(),
1182 ) {
1183 Ok(()) => {
1184 last_error = None;
1185 break;
1186 }
1187 Err(DownloadError::Cancelled) => {
1188 return Err(DownloadError::Cancelled);
1189 }
1190 Err(e) => {
1191 if !e.is_retryable() {
1192 self.cleanup_temp_for_error(&e);
1193 return Err(e);
1194 }
1195 last_error = Some(e);
1196 }
1197 }
1198 }
1199
1200 if let Some(err) = last_error {
1201 self.cleanup_temp_for_error(&err);
1202 return Err(err);
1203 }
1204
1205 self.fail_if_cancelled()?;
1207
1208 let actual_hash = compute_sha256(&file_path)?;
1209 if actual_hash != file.sha256 {
1210 let err = DownloadError::VerificationFailed {
1211 file: file.name.clone(),
1212 expected: file.sha256.clone(),
1213 actual: actual_hash,
1214 };
1215 self.cleanup_temp_for_error(&err);
1216 return Err(err);
1217 }
1218 }
1219
1220 self.atomic_install()?;
1222
1223 self.write_verified_marker(manifest, mirror_base_url)?;
1225
1226 Ok(())
1227 }
1228
1229 fn prepare_temp_dir(&self, manifest: &ModelManifest) -> Result<(), DownloadError> {
1230 ensure_model_download_temp_dir(&self.temp_dir)?;
1231
1232 let expected_files: HashSet<String> = manifest
1233 .files
1234 .iter()
1235 .map(|file| file.local_name().to_string())
1236 .collect();
1237
1238 for entry in fs::read_dir(&self.temp_dir)? {
1239 let entry = entry?;
1240 let entry_type = entry.file_type()?;
1241 let entry_name = entry.file_name();
1242 let keep_entry = entry_type.is_file()
1243 && entry_name
1244 .to_str()
1245 .is_some_and(|name| expected_files.contains(name));
1246
1247 if keep_entry {
1248 continue;
1249 }
1250
1251 let entry_path = entry.path();
1252 if entry_type.is_dir() {
1253 fs::remove_dir_all(entry_path)?;
1254 } else {
1255 fs::remove_file(entry_path)?;
1256 }
1257 }
1258
1259 Ok(())
1260 }
1261
1262 #[allow(clippy::too_many_arguments)]
1264 fn download_file(
1265 &self,
1266 url: &str,
1267 path: &Path,
1268 expected_size: u64,
1269 file_idx: usize,
1270 total_files: usize,
1271 bytes_downloaded: &Arc<AtomicU64>,
1272 grand_total: u64,
1273 on_progress: Option<&ProgressCallback>,
1274 ) -> Result<(), DownloadError> {
1275 let mut existing_size = if path.exists() {
1277 fs::metadata(path).map(|m| m.len()).unwrap_or(0)
1278 } else {
1279 0
1280 };
1281
1282 if existing_size > expected_size {
1284 let _ = fs::remove_file(path);
1285 existing_size = 0;
1286 }
1287
1288 if existing_size == expected_size {
1290 bytes_downloaded.fetch_add(expected_size, Ordering::SeqCst);
1291 return Ok(());
1292 }
1293
1294 let url = url.to_string();
1295 let path = path.to_path_buf();
1296 let bytes_downloaded = Arc::clone(bytes_downloaded);
1297 let cancelled = Arc::clone(&self.cancelled);
1298 let progress_callback = on_progress.cloned();
1299 let connect_timeout = self.connect_timeout;
1300 let file_timeout = self.file_timeout;
1301
1302 run_download_with_cx(move |cx| async move {
1303 const MODEL_MAX_BODY_SIZE: usize = 500 * 1024 * 1024;
1307
1308 let client = asupersync::http::h1::HttpClient::builder()
1309 .user_agent(concat!(
1310 "cass/",
1311 env!("CARGO_PKG_VERSION"),
1312 " (model-download)"
1313 ))
1314 .max_body_size(MODEL_MAX_BODY_SIZE)
1315 .build();
1316 let mut headers = vec![("Accept".to_string(), "application/octet-stream".to_string())];
1317
1318 if existing_size > 0 {
1319 headers.push(("Range".to_string(), format!("bytes={existing_size}-")));
1320 bytes_downloaded.fetch_add(existing_size, Ordering::SeqCst);
1321 }
1322
1323 let mut response = asupersync::time::timeout(
1324 cx.now(),
1325 connect_timeout,
1326 client.request_streaming(
1327 &cx,
1328 asupersync::http::h1::Method::Get,
1329 &url,
1330 headers,
1331 Vec::new(),
1332 ),
1333 )
1334 .await
1335 .map_err(|_| DownloadError::Timeout)?
1336 .map_err(|e| DownloadError::NetworkError(e.to_string()))?;
1337
1338 let status = response.head.status;
1339 if status >= 400 {
1340 return Err(DownloadError::HttpError {
1341 status,
1342 message: if response.head.reason.is_empty() {
1343 status.to_string()
1344 } else {
1345 format!("{} {}", status, response.head.reason)
1346 },
1347 });
1348 }
1349
1350 let actually_resuming = existing_size > 0 && status == 206;
1352 if existing_size > 0 && status == 200 {
1353 bytes_downloaded.fetch_sub(existing_size, Ordering::SeqCst);
1354 existing_size = 0;
1355 }
1356
1357 let mut file = fs::OpenOptions::new()
1358 .create(true)
1359 .append(actually_resuming)
1360 .write(true)
1361 .truncate(!actually_resuming)
1362 .open(&path)?;
1363
1364 let file_name = path
1365 .file_name()
1366 .and_then(|n| n.to_str())
1367 .unwrap_or("unknown")
1368 .to_string();
1369 let start = Instant::now();
1370 let mut file_bytes = if actually_resuming { existing_size } else { 0 };
1371
1372 loop {
1373 if cancelled.load(Ordering::SeqCst) {
1374 return Err(DownloadError::Cancelled);
1375 }
1376
1377 let remaining = file_timeout.saturating_sub(start.elapsed());
1378 if remaining.is_zero() {
1379 return Err(DownloadError::Timeout);
1380 }
1381
1382 let frame = asupersync::time::timeout(
1383 cx.now(),
1384 remaining,
1385 poll_fn(|task_cx| Pin::new(&mut response.body).poll_frame(task_cx)),
1386 )
1387 .await
1388 .map_err(|_| DownloadError::Timeout)?;
1389
1390 let Some(frame) = frame else {
1391 break;
1392 };
1393
1394 match frame.map_err(|e| DownloadError::NetworkError(e.to_string()))? {
1395 asupersync::http::body::Frame::Data(mut buf) => {
1396 while buf.has_remaining() {
1397 let chunk = buf.chunk();
1398 if chunk.is_empty() {
1399 break;
1400 }
1401 file.write_all(chunk)?;
1402 let chunk_len = chunk.len();
1403 buf.advance(chunk_len);
1404 file_bytes = file_bytes.saturating_add(chunk_len as u64);
1405 bytes_downloaded.fetch_add(chunk_len as u64, Ordering::SeqCst);
1406
1407 if let Some(callback) = progress_callback.as_ref() {
1408 let total_downloaded = bytes_downloaded.load(Ordering::SeqCst);
1409 let progress_pct = if grand_total > 0 {
1410 ((total_downloaded as f64 / grand_total as f64) * 100.0)
1411 .min(100.0) as u8
1412 } else {
1413 0
1414 };
1415
1416 callback(DownloadProgress {
1417 current_file: file_name.clone(),
1418 file_index: file_idx + 1,
1419 total_files,
1420 file_bytes,
1421 file_total: expected_size,
1422 total_bytes: total_downloaded,
1423 grand_total,
1424 progress_pct,
1425 });
1426 }
1427 }
1428 }
1429 asupersync::http::body::Frame::Trailers(_) => {}
1430 }
1431 }
1432
1433 file.sync_all()?;
1434 Ok(())
1435 })
1436 }
1437
1438 fn atomic_install(&self) -> Result<(), DownloadError> {
1445 let backup_dir = unique_model_backup_dir(&self.target_dir);
1446 sync_tree(&self.temp_dir)?;
1447
1448 let had_existing = if ensure_replaceable_model_dir(&self.target_dir)? {
1450 fs::rename(&self.target_dir, &backup_dir)?;
1451 true
1452 } else {
1453 false
1454 };
1455
1456 match fs::rename(&self.temp_dir, &self.target_dir) {
1458 Ok(()) => {
1459 sync_parent_directory(&self.target_dir)?;
1460 if had_existing {
1462 let _ = fs::remove_dir_all(&backup_dir);
1463 sync_parent_directory(&self.target_dir)?;
1464 }
1465 }
1466 Err(e) => {
1467 if had_existing && backup_dir.exists() {
1469 match fs::rename(&backup_dir, &self.target_dir) {
1470 Ok(()) => {
1471 sync_parent_directory(&self.target_dir)?;
1472 return Err(std::io::Error::other(format!(
1473 "failed installing {} from {}: {e}; restored original model",
1474 self.target_dir.display(),
1475 self.temp_dir.display()
1476 ))
1477 .into());
1478 }
1479 Err(restore_err) => {
1480 return Err(std::io::Error::other(format!(
1481 "failed installing {} from {}: {e}; restore error: {restore_err}; temp model retained at {}",
1482 self.target_dir.display(),
1483 self.temp_dir.display(),
1484 self.temp_dir.display()
1485 ))
1486 .into());
1487 }
1488 }
1489 }
1490 return Err(e.into());
1491 }
1492 }
1493
1494 Ok(())
1495 }
1496
1497 fn write_verified_marker(
1499 &self,
1500 manifest: &ModelManifest,
1501 mirror_base_url: Option<&str>,
1502 ) -> Result<(), DownloadError> {
1503 let marker_path = self.target_dir.join(".verified");
1504 let source = mirror_base_url
1505 .map(|url| format!("mirror:{url}"))
1506 .unwrap_or_else(|| "registry".to_string());
1507 let content = format!(
1508 "revision={}\nverified_at={}\nsource={}\n",
1509 manifest.revision,
1510 chrono::Utc::now().to_rfc3339(),
1511 source
1512 );
1513 let temp_path = unique_model_sidecar_path(&marker_path, "tmp", ".verified");
1514 let mut file = File::create(&temp_path)?;
1515 file.write_all(content.as_bytes())?;
1516 file.sync_all()?;
1517 replace_file_from_temp(&temp_path, &marker_path)?;
1518 sync_parent_directory(&marker_path)?;
1519 Ok(())
1520 }
1521
1522 fn cleanup_temp(&self) {
1524 if model_dir_is_real_directory(&self.temp_dir).unwrap_or(false) {
1525 let _ = fs::remove_dir_all(&self.temp_dir);
1526 }
1527 }
1528
1529 fn cleanup_temp_for_error(&self, err: &DownloadError) {
1530 if err.should_discard_temp() {
1531 self.cleanup_temp();
1532 }
1533 }
1534
1535 fn fail_if_cancelled(&self) -> Result<(), DownloadError> {
1536 if self.is_cancelled() {
1537 Err(DownloadError::Cancelled)
1538 } else {
1539 Ok(())
1540 }
1541 }
1542}
1543
1544pub fn compute_sha256(path: &Path) -> Result<String, DownloadError> {
1546 let file = File::open(path)?;
1547 let mut reader = BufReader::new(file);
1548 let mut hasher = Sha256::new();
1549
1550 let mut buffer = [0u8; 8192];
1551 loop {
1552 let n = reader.read(&mut buffer)?;
1553 if n == 0 {
1554 break;
1555 }
1556 hasher.update(&buffer[..n]);
1557 }
1558
1559 let hash = hasher.finalize();
1560 Ok(hex::encode(hash))
1561}
1562
1563pub fn classify_model_cache(
1569 model_dir: &Path,
1570 manifest: &ModelManifest,
1571 policy: &ModelAcquisitionPolicy,
1572) -> ModelCacheReport {
1573 classify_model_cache_with_integrity(model_dir, manifest, policy, ModelCacheIntegrity::Full)
1574}
1575
1576pub(crate) fn classify_model_cache_metadata(
1583 model_dir: &Path,
1584 manifest: &ModelManifest,
1585 policy: &ModelAcquisitionPolicy,
1586) -> ModelCacheReport {
1587 classify_model_cache_with_integrity(model_dir, manifest, policy, ModelCacheIntegrity::Metadata)
1588}
1589
1590#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1591enum ModelCacheIntegrity {
1592 Full,
1593 Metadata,
1594}
1595
1596fn classify_model_cache_with_integrity(
1597 model_dir: &Path,
1598 manifest: &ModelManifest,
1599 policy: &ModelAcquisitionPolicy,
1600 integrity: ModelCacheIntegrity,
1601) -> ModelCacheReport {
1602 let required_size_bytes = manifest.total_size();
1603 let installed_size_bytes = installed_manifest_size(model_dir, manifest);
1604 let missing_files = missing_manifest_files(model_dir, manifest);
1605 let state = classify_model_cache_state(model_dir, manifest, policy, &missing_files, integrity);
1606
1607 ModelCacheReport {
1608 model_id: manifest.id.clone(),
1609 model_dir: model_dir.to_path_buf(),
1610 state,
1611 required_size_bytes,
1612 installed_size_bytes,
1613 policy_source: policy.config_source.clone(),
1614 }
1615}
1616
1617fn classify_model_cache_state(
1618 model_dir: &Path,
1619 manifest: &ModelManifest,
1620 policy: &ModelAcquisitionPolicy,
1621 missing_files: &[String],
1622 integrity: ModelCacheIntegrity,
1623) -> ModelCacheState {
1624 if !policy.downloads_enabled {
1625 return ModelCacheState::DisabledByPolicy {
1626 reason: "semantic model downloads disabled by policy".to_string(),
1627 };
1628 }
1629
1630 let quarantine_marker = model_dir.join(".quarantined");
1631 if quarantine_marker.is_file() {
1632 let reason = fs::read_to_string(&quarantine_marker)
1633 .ok()
1634 .map(|s| s.trim().to_string())
1635 .filter(|s| !s.is_empty())
1636 .unwrap_or_else(|| "model cache quarantined after integrity failure".to_string());
1637 return ModelCacheState::QuarantinedCorrupt {
1638 marker_path: quarantine_marker,
1639 reason,
1640 };
1641 }
1642
1643 let staging_dir = model_download_temp_dir(model_dir);
1644 if staging_dir.is_dir() {
1645 return ModelCacheState::Acquiring {
1646 bytes_present: directory_size_bytes(&staging_dir),
1647 staging_dir,
1648 total_bytes: manifest.total_size(),
1649 };
1650 }
1651
1652 if !missing_files.is_empty() {
1653 if policy.offline {
1654 return ModelCacheState::OfflineBlocked {
1655 missing_files: missing_files.to_vec(),
1656 };
1657 }
1658
1659 if policy.metered && !policy.allow_metered {
1660 return ModelCacheState::DisabledByPolicy {
1661 reason: "metered network disallows model acquisition".to_string(),
1662 };
1663 }
1664
1665 if let Some(max_bytes) = policy.max_model_bytes
1666 && manifest.total_size() > max_bytes
1667 {
1668 return ModelCacheState::BudgetBlocked {
1669 required_bytes: manifest.total_size(),
1670 max_bytes,
1671 };
1672 }
1673
1674 return ModelCacheState::NotAcquired {
1675 missing_files: missing_files.to_vec(),
1676 needs_consent: policy.requires_consent,
1677 };
1678 }
1679
1680 if integrity == ModelCacheIntegrity::Full {
1681 for file in &manifest.files {
1682 let Some(path) = model_file_path(model_dir, file) else {
1683 continue;
1684 };
1685 match compute_sha256(&path) {
1686 Ok(actual) if actual == file.sha256 => {}
1687 Ok(actual) => {
1688 return ModelCacheState::ChecksumMismatch {
1689 file: file.local_name().to_string(),
1690 expected: file.sha256.clone(),
1691 actual,
1692 };
1693 }
1694 Err(err) => {
1695 return ModelCacheState::QuarantinedCorrupt {
1696 marker_path: path,
1697 reason: format!("unable to hash model file {}: {err}", file.local_name()),
1698 };
1699 }
1700 }
1701 }
1702 }
1703
1704 let verified_marker = model_dir.join(".verified");
1705 if !verified_marker.is_file() {
1706 return ModelCacheState::PreseededLocal {
1707 model_dir: model_dir.to_path_buf(),
1708 };
1709 }
1710
1711 let marker = match fs::read_to_string(&verified_marker) {
1712 Ok(marker) => marker,
1713 Err(err) => {
1714 return ModelCacheState::QuarantinedCorrupt {
1715 marker_path: verified_marker,
1716 reason: format!("unable to read verified marker: {err}"),
1717 };
1718 }
1719 };
1720
1721 let current_revision =
1722 marker_field(&marker, "revision").unwrap_or_else(|| "<unknown>".to_string());
1723 if current_revision != manifest.revision {
1724 return ModelCacheState::IncompatibleVersion {
1725 current_revision,
1726 expected_revision: manifest.revision.clone(),
1727 };
1728 }
1729
1730 match marker_field(&marker, "source") {
1731 Some(source) if source == "preseeded_local" => ModelCacheState::PreseededLocal {
1732 model_dir: model_dir.to_path_buf(),
1733 },
1734 Some(source) if source.starts_with("mirror:") => ModelCacheState::MirrorSourced {
1735 model_dir: model_dir.to_path_buf(),
1736 mirror_base_url: source.trim_start_matches("mirror:").to_string(),
1737 },
1738 _ => ModelCacheState::Acquired {
1739 model_dir: model_dir.to_path_buf(),
1740 },
1741 }
1742}
1743
1744pub fn check_model_installed(model_dir: &Path, manifest: &ModelManifest) -> ModelState {
1754 if !model_dir.is_dir() {
1755 return ModelState::NotInstalled;
1756 }
1757
1758 let verified_marker = model_dir.join(".verified");
1759 if !verified_marker.is_file() {
1760 return ModelState::NotInstalled;
1761 }
1762
1763 for file in &manifest.files {
1767 if model_file_path(model_dir, file).is_none() {
1768 return ModelState::NotInstalled;
1769 }
1770 }
1771
1772 ModelState::Ready
1773}
1774
1775pub fn check_version_mismatch(model_dir: &Path, manifest: &ModelManifest) -> Option<ModelState> {
1777 let verified_marker = model_dir.join(".verified");
1778 if !verified_marker.is_file() {
1779 return None;
1780 }
1781
1782 let content = fs::read_to_string(&verified_marker).ok()?;
1784 let installed_revision = content
1785 .lines()
1786 .find(|l| l.starts_with("revision="))
1787 .map(|l| l.trim_start_matches("revision=").to_string())?;
1788
1789 if installed_revision != manifest.revision {
1790 Some(ModelState::UpdateAvailable {
1791 current_revision: installed_revision,
1792 latest_revision: manifest.revision.clone(),
1793 })
1794 } else {
1795 None
1796 }
1797}
1798
1799fn ensure_replaceable_model_dir(path: &Path) -> Result<bool, DownloadError> {
1800 match fs::symlink_metadata(path) {
1801 Ok(metadata) => {
1802 ensure_real_model_directory_metadata(
1803 path,
1804 &metadata,
1805 "refusing to install model through symlink",
1806 "refusing to replace model target because it is not a directory",
1807 )?;
1808 Ok(true)
1809 }
1810 Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(false),
1811 Err(err) => Err(std::io::Error::new(
1812 err.kind(),
1813 format!(
1814 "failed inspecting model target before install {}: {err}",
1815 path.display()
1816 ),
1817 )
1818 .into()),
1819 }
1820}
1821
1822fn ensure_model_download_temp_dir(path: &Path) -> Result<(), DownloadError> {
1823 match fs::symlink_metadata(path) {
1824 Ok(metadata) => {
1825 ensure_real_model_directory_metadata(
1826 path,
1827 &metadata,
1828 "refusing to prepare model download temp dir through symlink",
1829 "refusing to prepare model download temp dir because it is not a directory",
1830 )?;
1831 }
1832 Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
1833 fs::create_dir_all(path)?;
1834 let metadata = fs::symlink_metadata(path).map_err(|err| {
1835 std::io::Error::new(
1836 err.kind(),
1837 format!(
1838 "failed inspecting model download temp dir after create {}: {err}",
1839 path.display()
1840 ),
1841 )
1842 })?;
1843 ensure_real_model_directory_metadata(
1844 path,
1845 &metadata,
1846 "refusing to prepare model download temp dir through symlink",
1847 "refusing to prepare model download temp dir because it is not a directory",
1848 )?;
1849 }
1850 Err(err) => {
1851 return Err(std::io::Error::new(
1852 err.kind(),
1853 format!(
1854 "failed inspecting model download temp dir before prepare {}: {err}",
1855 path.display()
1856 ),
1857 )
1858 .into());
1859 }
1860 }
1861 Ok(())
1862}
1863
1864fn model_dir_is_real_directory(path: &Path) -> Result<bool, DownloadError> {
1865 match fs::symlink_metadata(path) {
1866 Ok(metadata) => {
1867 let file_type = metadata.file_type();
1868 Ok(file_type.is_dir() && !file_type.is_symlink())
1869 }
1870 Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(false),
1871 Err(err) => Err(err.into()),
1872 }
1873}
1874
1875fn ensure_real_model_directory_metadata(
1876 path: &Path,
1877 metadata: &fs::Metadata,
1878 symlink_message: &str,
1879 non_dir_message: &str,
1880) -> Result<(), DownloadError> {
1881 let file_type = metadata.file_type();
1882 if file_type.is_symlink() {
1883 return Err(std::io::Error::other(format!("{symlink_message}: {}", path.display())).into());
1884 }
1885 if !file_type.is_dir() {
1886 return Err(std::io::Error::other(format!("{non_dir_message}: {}", path.display())).into());
1887 }
1888 Ok(())
1889}
1890
1891fn model_download_temp_dir(target_dir: &Path) -> PathBuf {
1892 if let Some(parent) = target_dir.parent() {
1893 let dir_name = target_dir
1894 .file_name()
1895 .and_then(|name| name.to_str())
1896 .unwrap_or("model");
1897 parent.join(format!("{dir_name}.downloading"))
1898 } else {
1899 target_dir.with_extension("downloading")
1900 }
1901}
1902
1903pub fn model_file_path(model_dir: &Path, file: &ModelFile) -> Option<PathBuf> {
1908 let canonical = model_dir.join(&file.name);
1909 if canonical.is_file() {
1910 return Some(canonical);
1911 }
1912
1913 let local = model_dir.join(file.local_name());
1914 if local.is_file() {
1915 return Some(local);
1916 }
1917
1918 None
1919}
1920
1921fn missing_manifest_files(model_dir: &Path, manifest: &ModelManifest) -> Vec<String> {
1922 manifest
1923 .files
1924 .iter()
1925 .filter(|file| model_file_path(model_dir, file).is_none())
1926 .map(|file| file.local_name().to_string())
1927 .collect()
1928}
1929
1930fn installed_manifest_size(model_dir: &Path, manifest: &ModelManifest) -> u64 {
1931 manifest
1932 .files
1933 .iter()
1934 .filter_map(|file| model_file_path(model_dir, file))
1935 .filter_map(|path| path.metadata().ok())
1936 .map(|metadata| metadata.len())
1937 .sum()
1938}
1939
1940fn directory_size_bytes(path: &Path) -> u64 {
1941 let Ok(entries) = fs::read_dir(path) else {
1942 return 0;
1943 };
1944
1945 entries
1946 .filter_map(Result::ok)
1947 .map(|entry| {
1948 let path = entry.path();
1949 match entry.file_type() {
1950 Ok(file_type) if file_type.is_file() => {
1951 entry.metadata().map(|metadata| metadata.len()).unwrap_or(0)
1952 }
1953 Ok(file_type) if file_type.is_dir() => directory_size_bytes(&path),
1954 _ => 0,
1955 }
1956 })
1957 .sum()
1958}
1959
1960fn marker_field(content: &str, field: &str) -> Option<String> {
1961 let prefix = format!("{field}=");
1962 content
1963 .lines()
1964 .find_map(|line| line.strip_prefix(&prefix))
1965 .map(|value| value.trim().to_string())
1966 .filter(|value| !value.is_empty())
1967}
1968
1969fn unique_model_backup_dir(path: &Path) -> PathBuf {
1970 unique_model_sidecar_path(path, "bak", "model")
1971}
1972
1973fn unique_model_sidecar_path(path: &Path, suffix: &str, fallback_name: &str) -> PathBuf {
1974 static NEXT_NONCE: AtomicU64 = AtomicU64::new(0);
1975
1976 let timestamp = std::time::SystemTime::now()
1977 .duration_since(std::time::UNIX_EPOCH)
1978 .unwrap_or_default()
1979 .as_nanos();
1980 let nonce = NEXT_NONCE.fetch_add(1, Ordering::Relaxed);
1981 let file_name = path
1982 .file_name()
1983 .and_then(|name| name.to_str())
1984 .unwrap_or(fallback_name);
1985
1986 path.with_file_name(format!(
1987 ".{file_name}.{suffix}.{}.{}.{}",
1988 std::process::id(),
1989 timestamp,
1990 nonce
1991 ))
1992}
1993
1994fn replace_file_from_temp(temp_path: &Path, final_path: &Path) -> Result<(), DownloadError> {
1995 #[cfg(windows)]
1996 {
1997 match fs::rename(temp_path, final_path) {
1998 Ok(()) => sync_parent_directory(final_path),
1999 Err(first_err)
2000 if final_path.exists()
2001 && matches!(
2002 first_err.kind(),
2003 std::io::ErrorKind::AlreadyExists | std::io::ErrorKind::PermissionDenied
2004 ) =>
2005 {
2006 let backup_path = unique_model_backup_dir(final_path);
2007 fs::rename(final_path, &backup_path).map_err(|backup_err| {
2008 let _ = fs::remove_file(temp_path);
2009 DownloadError::IoError(std::io::Error::other(format!(
2010 "failed preparing backup {} before replacing {}: first error: {first_err}; backup error: {backup_err}",
2011 backup_path.display(),
2012 final_path.display()
2013 )))
2014 })?;
2015 match fs::rename(temp_path, final_path) {
2016 Ok(()) => {
2017 let _ = fs::remove_file(&backup_path);
2018 sync_parent_directory(final_path)
2019 }
2020 Err(second_err) => match fs::rename(&backup_path, final_path) {
2021 Ok(()) => {
2022 let _ = fs::remove_file(temp_path);
2023 sync_parent_directory(final_path)?;
2024 Err(std::io::Error::other(format!(
2025 "failed replacing {} with {}: first error: {first_err}; second error: {second_err}; restored original file",
2026 final_path.display(),
2027 temp_path.display()
2028 ))
2029 .into())
2030 }
2031 Err(restore_err) => Err(std::io::Error::other(format!(
2032 "failed replacing {} with {}: first error: {first_err}; second error: {second_err}; restore error: {restore_err}; temp file retained at {}",
2033 final_path.display(),
2034 temp_path.display(),
2035 temp_path.display()
2036 ))
2037 .into()),
2038 },
2039 }
2040 }
2041 Err(rename_err) => Err(rename_err.into()),
2042 }
2043 }
2044
2045 #[cfg(not(windows))]
2046 {
2047 fs::rename(temp_path, final_path)?;
2048 sync_parent_directory(final_path)
2049 }
2050}
2051
2052#[cfg(not(windows))]
2053fn sync_tree(path: &Path) -> Result<(), DownloadError> {
2054 sync_tree_inner(path)?;
2055 sync_parent_directory(path)
2056}
2057
2058#[cfg(not(windows))]
2059fn sync_tree_inner(path: &Path) -> Result<(), DownloadError> {
2060 let metadata = fs::metadata(path)?;
2061 if metadata.is_dir() {
2062 for entry in fs::read_dir(path)? {
2063 let entry = entry?;
2064 sync_tree_inner(&entry.path())?;
2065 }
2066 File::open(path)?.sync_all()?;
2067 } else if metadata.is_file() {
2068 File::open(path)?.sync_all()?;
2069 }
2070 Ok(())
2071}
2072
2073#[cfg(windows)]
2074fn sync_tree(_path: &Path) -> Result<(), DownloadError> {
2075 Ok(())
2076}
2077
2078#[cfg(not(windows))]
2079fn sync_parent_directory(path: &Path) -> Result<(), DownloadError> {
2080 let Some(parent) = path.parent() else {
2081 return Ok(());
2082 };
2083 File::open(parent)?.sync_all()?;
2084 Ok(())
2085}
2086
2087#[cfg(windows)]
2088fn sync_parent_directory(_path: &Path) -> Result<(), DownloadError> {
2089 Ok(())
2090}
2091
2092#[cfg(test)]
2093mod tests {
2094 use super::*;
2095 use std::collections::BTreeMap;
2096 use std::error::Error as _;
2097 use std::io::{Read, Write};
2098 use std::net::{Shutdown, TcpListener, TcpStream};
2099 use std::sync::atomic::{AtomicBool, Ordering};
2100 use std::sync::{Arc, Mutex};
2101 use std::thread;
2102 use std::time::Duration;
2103
2104 fn copy_model_fixtures(target_dir: &Path) -> std::io::Result<()> {
2107 let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/models");
2108 fs::create_dir_all(target_dir)?;
2109
2110 fs::copy(
2112 fixture_dir.join("model.onnx"),
2113 target_dir.join("model.onnx"),
2114 )?;
2115
2116 for file in &[
2118 "tokenizer.json",
2119 "config.json",
2120 "special_tokens_map.json",
2121 "tokenizer_config.json",
2122 ] {
2123 fs::copy(fixture_dir.join(file), target_dir.join(file))?;
2124 }
2125
2126 Ok(())
2127 }
2128
2129 #[derive(Clone, Debug)]
2130 struct MirrorRequest {
2131 path: String,
2132 range_start: Option<u64>,
2133 }
2134
2135 #[derive(Clone)]
2136 struct MirrorRoute {
2137 body: Vec<u8>,
2138 content_type: &'static str,
2139 chunk_size: usize,
2140 chunk_delay: Duration,
2141 }
2142
2143 struct MirrorFixtureServer {
2144 base_url: String,
2145 stop: Arc<AtomicBool>,
2146 wake_addr: String,
2147 requests: Arc<Mutex<Vec<MirrorRequest>>>,
2148 handle: Option<std::thread::JoinHandle<()>>,
2149 }
2150
2151 impl MirrorFixtureServer {
2152 fn requests(&self) -> Vec<MirrorRequest> {
2153 self.requests.lock().unwrap().clone()
2154 }
2155 }
2156
2157 impl Drop for MirrorFixtureServer {
2158 fn drop(&mut self) {
2159 self.stop.store(true, Ordering::SeqCst);
2160 if let Ok(stream) = TcpStream::connect(&self.wake_addr) {
2161 let _ = stream.shutdown(Shutdown::Both);
2162 }
2163 if let Some(handle) = self.handle.take() {
2164 let _ = handle.join();
2165 }
2166 }
2167 }
2168
2169 fn start_mirror_fixture_server(routes: Vec<(String, MirrorRoute)>) -> MirrorFixtureServer {
2170 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test mirror server");
2171 listener
2172 .set_nonblocking(true)
2173 .expect("set test mirror server nonblocking");
2174 let addr = listener.local_addr().expect("read server address");
2175 let wake_addr = addr.to_string();
2176 let base_url = format!("http://{wake_addr}");
2177 let stop = Arc::new(AtomicBool::new(false));
2178 let stop_flag = Arc::clone(&stop);
2179 let requests = Arc::new(Mutex::new(Vec::new()));
2180 let request_log = Arc::clone(&requests);
2181 let route_map: BTreeMap<String, MirrorRoute> = routes.into_iter().collect();
2182 let handle = thread::spawn(move || {
2183 while !stop_flag.load(Ordering::SeqCst) {
2184 match listener.accept() {
2185 Ok((stream, _)) => {
2186 handle_mirror_request(stream, &route_map, &request_log);
2187 }
2188 Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {
2189 thread::sleep(Duration::from_millis(10));
2190 }
2191 Err(_) => break,
2192 }
2193 }
2194 });
2195 MirrorFixtureServer {
2196 base_url,
2197 stop,
2198 wake_addr,
2199 requests,
2200 handle: Some(handle),
2201 }
2202 }
2203
2204 fn handle_mirror_request(
2205 mut stream: TcpStream,
2206 routes: &BTreeMap<String, MirrorRoute>,
2207 request_log: &Arc<Mutex<Vec<MirrorRequest>>>,
2208 ) {
2209 let mut buffer = [0_u8; 8192];
2210 let read = match stream.read(&mut buffer) {
2211 Ok(read) => read,
2212 Err(_) => return,
2213 };
2214 let request = String::from_utf8_lossy(&buffer[..read]);
2215 let mut lines = request.lines();
2216 let target = lines
2217 .next()
2218 .and_then(|line| line.split_whitespace().nth(1))
2219 .unwrap_or("/");
2220 let path = target
2221 .split_once('?')
2222 .map(|(path, _)| path)
2223 .unwrap_or(target)
2224 .split_once('#')
2225 .map(|(path, _)| path)
2226 .unwrap_or(target)
2227 .to_string();
2228 let range_start = lines.find_map(parse_range_start_header);
2229 request_log.lock().unwrap().push(MirrorRequest {
2230 path: path.clone(),
2231 range_start,
2232 });
2233
2234 let Some(route) = routes.get(&path) else {
2235 let response = concat!(
2236 "HTTP/1.1 404 Not Found\r\n",
2237 "Content-Length: 9\r\n",
2238 "Content-Type: text/plain\r\n",
2239 "Connection: close\r\n\r\n",
2240 "not found"
2241 );
2242 let _ = stream.write_all(response.as_bytes());
2243 let _ = stream.flush();
2244 return;
2245 };
2246
2247 let start = range_start.unwrap_or(0) as usize;
2248 let mut status = "200 OK";
2249 let mut content_range = None;
2250 let body = if start >= route.body.len() {
2251 status = "416 Range Not Satisfiable";
2252 &[][..]
2253 } else if start > 0 {
2254 status = "206 Partial Content";
2255 content_range = Some(format!(
2256 "bytes {start}-{}/{}",
2257 route.body.len().saturating_sub(1),
2258 route.body.len()
2259 ));
2260 &route.body[start..]
2261 } else {
2262 route.body.as_slice()
2263 };
2264
2265 let mut response = format!(
2266 "HTTP/1.1 {status}\r\nContent-Length: {}\r\nContent-Type: {}\r\nConnection: close\r\n",
2267 body.len(),
2268 route.content_type
2269 );
2270 if let Some(content_range) = content_range {
2271 response.push_str(&format!("Content-Range: {content_range}\r\n"));
2272 }
2273 response.push_str("\r\n");
2274 let _ = stream.write_all(response.as_bytes());
2275 for chunk in body.chunks(route.chunk_size.max(1)) {
2276 if stream.write_all(chunk).is_err() {
2277 return;
2278 }
2279 let _ = stream.flush();
2280 if !route.chunk_delay.is_zero() {
2281 thread::sleep(route.chunk_delay);
2282 }
2283 }
2284 }
2285
2286 fn parse_range_start_header(line: &str) -> Option<u64> {
2287 let (name, value) = line.split_once(':')?;
2288 if !name.eq_ignore_ascii_case("range") {
2289 return None;
2290 }
2291 let value = value.trim();
2292 let value = value.strip_prefix("bytes=")?;
2293 let (start, _) = value.split_once('-')?;
2294 start.parse().ok()
2295 }
2296
2297 fn build_test_manifest(repo: &str, revision: &str, files: &[(&str, &[u8])]) -> ModelManifest {
2298 ModelManifest {
2299 id: "mirror-test-model".into(),
2300 repo: repo.into(),
2301 revision: revision.into(),
2302 files: files
2303 .iter()
2304 .map(|(name, body)| ModelFile {
2305 name: (*name).into(),
2306 sha256: hex::encode(Sha256::digest(body)),
2307 size: body.len() as u64,
2308 })
2309 .collect(),
2310 license: "Apache-2.0".into(),
2311 }
2312 }
2313
2314 fn mirror_route_path(prefix: &str, manifest: &ModelManifest, file: &ModelFile) -> String {
2315 format!(
2316 "{}/{}/resolve/{}/{}",
2317 prefix.trim_end_matches('/'),
2318 manifest.repo.trim_start_matches('/'),
2319 manifest.revision,
2320 file.name.trim_start_matches('/')
2321 )
2322 }
2323
2324 #[test]
2325 fn test_model_state_summary() {
2326 assert_eq!(ModelState::NotInstalled.summary(), "not installed");
2327 assert_eq!(ModelState::NeedsConsent.summary(), "needs consent");
2328 assert_eq!(ModelState::Ready.summary(), "ready");
2329 assert_eq!(
2330 ModelState::Downloading {
2331 progress_pct: 50,
2332 bytes_downloaded: 1000,
2333 total_bytes: 2000
2334 }
2335 .summary(),
2336 "downloading (50%)"
2337 );
2338 }
2339
2340 #[test]
2341 fn test_model_state_is_ready() {
2342 assert!(ModelState::Ready.is_ready());
2343 assert!(!ModelState::NotInstalled.is_ready());
2344 assert!(!ModelState::NeedsConsent.is_ready());
2345 assert!(
2346 !ModelState::Downloading {
2347 progress_pct: 0,
2348 bytes_downloaded: 0,
2349 total_bytes: 0
2350 }
2351 .is_ready()
2352 );
2353 }
2354
2355 #[test]
2356 fn test_model_manifest_total_size() {
2357 let manifest = ModelManifest::minilm_v2();
2358 assert!(manifest.total_size() > 20_000_000); }
2360
2361 #[test]
2362 fn test_model_manifest_download_url() {
2363 let manifest = ModelManifest::minilm_v2();
2364 let url = manifest.download_url(&manifest.files[0]);
2365 assert!(url.contains("huggingface.co"));
2366 assert!(url.contains("sentence-transformers/all-MiniLM-L6-v2"));
2367 assert!(url.contains("model.onnx"));
2368 }
2369
2370 #[test]
2371 fn test_model_manifest_download_url_with_mirror_base() {
2372 let manifest = ModelManifest::minilm_v2();
2373 let url = manifest
2374 .download_url_with_base(&manifest.files[0], Some("https://mirror.example/cache/"));
2375 assert_eq!(
2376 url,
2377 format!(
2378 "https://mirror.example/cache/{}/resolve/{}/{}",
2379 manifest.repo, manifest.revision, manifest.files[0].name
2380 )
2381 );
2382 }
2383
2384 #[test]
2385 fn air_gap_bash_script_uses_explicit_output_filenames() {
2386 let manifest = ModelManifest::minilm_v2();
2392 let script = manifest.air_gap_bash_script(None);
2393 assert!(script.contains("set -euo pipefail"));
2394 assert!(script.contains("DIR=\"${DIR:-./all-minilm-l6-v2_files}\""));
2395 for file in &manifest.files {
2396 let local = file.local_name();
2397 assert!(
2398 script.contains(&format!("-o \"$DIR/{local}\"")),
2399 "bash script must write {local} via explicit -o, got:\n{script}"
2400 );
2401 }
2402 assert!(
2403 script.contains("cass models install all-minilm-l6-v2 --from-file \"$DIR\" -y"),
2404 "bash script must invoke install with --from-file"
2405 );
2406 }
2407
2408 #[test]
2409 fn air_gap_bash_script_quotes_urls_with_single_quotes() {
2410 let manifest = ModelManifest::minilm_v2();
2412 let script = manifest.air_gap_bash_script(None);
2413 let sample_url = manifest.download_url(&manifest.files[0]);
2414 assert!(script.contains(&format!("'{sample_url}'")));
2415 }
2416
2417 #[test]
2418 fn air_gap_powershell_script_forces_tls12_and_basic_parsing() {
2419 let manifest = ModelManifest::minilm_v2();
2420 let script = manifest.air_gap_powershell_script(None);
2421 assert!(
2422 script.contains("SecurityProtocolType]::Tls12"),
2423 "PowerShell script must opt into TLS 1.2 for Windows PowerShell 5.1 compat"
2424 );
2425 assert!(
2426 script.contains("Invoke-WebRequest -UseBasicParsing"),
2427 "PowerShell script must use -UseBasicParsing for PS 5.1 compat"
2428 );
2429 for file in &manifest.files {
2430 let local = file.local_name();
2431 assert!(
2432 script.contains(&format!("(Join-Path $dir '{local}')")),
2433 "PowerShell script must materialize output path for {local}, got:\n{script}"
2434 );
2435 }
2436 assert!(
2437 script.contains("cass models install all-minilm-l6-v2 --from-file $dir -y"),
2438 "PowerShell script must invoke install with --from-file"
2439 );
2440 }
2441
2442 #[test]
2443 fn air_gap_scripts_honor_mirror_base_url() {
2444 let manifest = ModelManifest::minilm_v2();
2445 let mirror = Some("https://mirror.example/cache");
2446 let bash = manifest.air_gap_bash_script(mirror);
2447 let ps = manifest.air_gap_powershell_script(mirror);
2448 assert!(bash.contains("https://mirror.example/cache"));
2449 assert!(!bash.contains("huggingface.co"));
2450 assert!(ps.contains("https://mirror.example/cache"));
2451 assert!(!ps.contains("huggingface.co"));
2452 }
2453
2454 #[test]
2455 fn test_normalize_mirror_base_url_trims_trailing_slash() {
2456 let normalized = normalize_mirror_base_url("https://mirror.example/cache/").unwrap();
2457 assert_eq!(normalized, "https://mirror.example/cache");
2458 }
2459
2460 #[test]
2461 fn test_normalize_mirror_base_url_rejects_invalid_values() {
2462 let cases = [
2463 ("mirror.example", "invalid mirror URL"),
2464 ("file:///tmp/mirror", "unsupported URL scheme"),
2465 (
2466 "https://mirror.example/cache?trace=abc",
2467 "must not include query or fragment",
2468 ),
2469 ];
2470
2471 for (input, expected_fragment) in cases {
2472 let err = normalize_mirror_base_url(input).unwrap_err();
2473 let message = err.to_string();
2474 assert!(
2475 message.contains(expected_fragment),
2476 "expected error for {input:?} to contain {expected_fragment:?}, got {message:?}"
2477 );
2478 }
2479 }
2480
2481 #[test]
2482 fn test_invalid_mirror_url_helper_shape() {
2483 let err = invalid_mirror_url("ftp://mirror.example/model.onnx", "unsupported scheme");
2484
2485 assert!(matches!(
2486 &err,
2487 DownloadError::InvalidMirrorUrl { url, reason }
2488 if url == "ftp://mirror.example/model.onnx" && reason == "unsupported scheme"
2489 ));
2490 assert_eq!(
2491 err.to_string(),
2492 "invalid mirror URL 'ftp://mirror.example/model.onnx': unsupported scheme"
2493 );
2494 assert!(!err.is_retryable());
2495 }
2496
2497 #[test]
2498 fn test_check_model_installed_missing() {
2499 let tmp = tempfile::tempdir().unwrap();
2500 let model_dir = tmp.path().join("nonexistent");
2501 assert_eq!(
2502 check_model_installed(&model_dir, &ModelManifest::minilm_v2()),
2503 ModelState::NotInstalled
2504 );
2505 }
2506
2507 #[test]
2508 fn test_check_model_installed_no_marker() {
2509 let tmp = tempfile::tempdir().unwrap();
2510 let model_dir = tmp.path().join("model");
2511 let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/models");
2513 fs::create_dir_all(&model_dir).unwrap();
2514 fs::copy(fixture_dir.join("model.onnx"), model_dir.join("model.onnx")).unwrap();
2515 assert_eq!(
2516 check_model_installed(&model_dir, &ModelManifest::minilm_v2()),
2517 ModelState::NotInstalled
2518 );
2519 }
2520
2521 #[test]
2522 fn test_check_model_installed_ready() {
2523 let tmp = tempfile::tempdir().unwrap();
2524 let model_dir = tmp.path().join("model");
2525 copy_model_fixtures(&model_dir).unwrap();
2527 fs::write(model_dir.join(".verified"), "revision=test\n").unwrap();
2528 assert_eq!(
2529 check_model_installed(&model_dir, &ModelManifest::minilm_v2()),
2530 ModelState::Ready
2531 );
2532 }
2533
2534 #[test]
2535 fn classify_cache_policy_disabled_takes_precedence_over_missing() {
2536 let tmp = tempfile::tempdir().unwrap();
2537 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2538 let policy = ModelAcquisitionPolicy {
2539 downloads_enabled: false,
2540 offline: true,
2541 max_model_bytes: Some(1),
2542 ..ModelAcquisitionPolicy::default()
2543 };
2544
2545 let report = classify_model_cache(tmp.path(), &manifest, &policy);
2546 assert_eq!(report.state_code(), "disabled_by_policy");
2547 assert!(matches!(
2548 report.state,
2549 ModelCacheState::DisabledByPolicy { .. }
2550 ));
2551 }
2552
2553 #[test]
2554 fn classify_cache_detects_resume_stage_before_missing() {
2555 let tmp = tempfile::tempdir().unwrap();
2556 let model_dir = tmp.path().join("model");
2557 let staging_dir = tmp.path().join("model.downloading");
2558 fs::create_dir_all(&staging_dir).unwrap();
2559 fs::write(staging_dir.join("model.onnx"), b"partial").unwrap();
2560 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2561
2562 let report =
2563 classify_model_cache(&model_dir, &manifest, &ModelAcquisitionPolicy::default());
2564 assert_eq!(report.state_code(), "acquiring");
2565 assert!(matches!(
2566 report.state,
2567 ModelCacheState::Acquiring {
2568 bytes_present: 7,
2569 total_bytes: 5,
2570 ..
2571 }
2572 ));
2573 }
2574
2575 #[test]
2576 fn classify_cache_distinguishes_offline_and_budget_blocks() {
2577 let tmp = tempfile::tempdir().unwrap();
2578 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2579
2580 let offline = ModelAcquisitionPolicy {
2581 offline: true,
2582 ..ModelAcquisitionPolicy::default()
2583 };
2584 let report = classify_model_cache(tmp.path(), &manifest, &offline);
2585 assert_eq!(report.state_code(), "offline_blocked");
2586
2587 let budget = ModelAcquisitionPolicy {
2588 max_model_bytes: Some(1),
2589 ..ModelAcquisitionPolicy::default()
2590 };
2591 let report = classify_model_cache(tmp.path(), &manifest, &budget);
2592 assert_eq!(report.state_code(), "budget_blocked");
2593 }
2594
2595 #[test]
2596 fn classify_cache_accepts_preseeded_local_manifest_files() {
2597 let tmp = tempfile::tempdir().unwrap();
2598 let model_dir = tmp.path().join("model");
2599 fs::create_dir_all(model_dir.join("onnx")).unwrap();
2600 fs::write(model_dir.join("onnx/model.onnx"), b"model").unwrap();
2601 fs::write(model_dir.join("tokenizer.json"), b"tok").unwrap();
2602 let manifest = build_test_manifest(
2603 "repo/model",
2604 "rev1",
2605 &[("onnx/model.onnx", b"model"), ("tokenizer.json", b"tok")],
2606 );
2607
2608 let report =
2609 classify_model_cache(&model_dir, &manifest, &ModelAcquisitionPolicy::default());
2610 assert_eq!(report.state_code(), "preseeded_local");
2611 assert!(report.is_usable());
2612 }
2613
2614 #[test]
2615 fn classify_cache_detects_checksum_mismatch() {
2616 let tmp = tempfile::tempdir().unwrap();
2617 let model_dir = tmp.path().join("model");
2618 fs::create_dir_all(&model_dir).unwrap();
2619 fs::write(model_dir.join("model.onnx"), b"wrong").unwrap();
2620 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2621
2622 let report =
2623 classify_model_cache(&model_dir, &manifest, &ModelAcquisitionPolicy::default());
2624 assert_eq!(report.state_code(), "checksum_mismatch");
2625 assert!(matches!(
2626 report.state,
2627 ModelCacheState::ChecksumMismatch { .. }
2628 ));
2629 }
2630
2631 #[test]
2632 fn classify_cache_metadata_trusts_verified_marker_without_hashing_payload() {
2633 let tmp = tempfile::tempdir().unwrap();
2634 let model_dir = tmp.path().join("model");
2635 fs::create_dir_all(&model_dir).unwrap();
2636 fs::write(model_dir.join("model.onnx"), b"m0del").unwrap();
2637 fs::write(
2638 model_dir.join(".verified"),
2639 "revision=rev1\nsource=registry\n",
2640 )
2641 .unwrap();
2642 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2643
2644 let metadata_report = classify_model_cache_metadata(
2645 &model_dir,
2646 &manifest,
2647 &ModelAcquisitionPolicy::default(),
2648 );
2649 assert_eq!(metadata_report.state_code(), "acquired");
2650 assert!(metadata_report.is_usable());
2651
2652 let full_report =
2653 classify_model_cache(&model_dir, &manifest, &ModelAcquisitionPolicy::default());
2654 assert_eq!(full_report.state_code(), "checksum_mismatch");
2655 }
2656
2657 #[test]
2658 fn classify_cache_detects_incompatible_revision() {
2659 let tmp = tempfile::tempdir().unwrap();
2660 let model_dir = tmp.path().join("model");
2661 fs::create_dir_all(&model_dir).unwrap();
2662 fs::write(model_dir.join("model.onnx"), b"model").unwrap();
2663 fs::write(model_dir.join(".verified"), "revision=old\n").unwrap();
2664 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2665
2666 let report =
2667 classify_model_cache(&model_dir, &manifest, &ModelAcquisitionPolicy::default());
2668 assert_eq!(report.state_code(), "incompatible_version");
2669 assert!(matches!(
2670 report.state,
2671 ModelCacheState::IncompatibleVersion {
2672 current_revision,
2673 expected_revision
2674 } if current_revision == "old" && expected_revision == "rev1"
2675 ));
2676 }
2677
2678 #[test]
2679 fn classify_cache_reports_mirror_sourced_marker() {
2680 let tmp = tempfile::tempdir().unwrap();
2681 let model_dir = tmp.path().join("model");
2682 fs::create_dir_all(&model_dir).unwrap();
2683 fs::write(model_dir.join("model.onnx"), b"model").unwrap();
2684 fs::write(
2685 model_dir.join(".verified"),
2686 "revision=rev1\nsource=mirror:https://mirror.example/cache\n",
2687 )
2688 .unwrap();
2689 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2690
2691 let report =
2692 classify_model_cache(&model_dir, &manifest, &ModelAcquisitionPolicy::default());
2693 assert_eq!(report.state_code(), "mirror_sourced");
2694 assert!(matches!(
2695 report.state,
2696 ModelCacheState::MirrorSourced {
2697 mirror_base_url,
2698 ..
2699 } if mirror_base_url == "https://mirror.example/cache"
2700 ));
2701 }
2702
2703 #[test]
2704 fn classify_cache_reports_quarantine_marker() {
2705 let tmp = tempfile::tempdir().unwrap();
2706 let model_dir = tmp.path().join("model");
2707 fs::create_dir_all(&model_dir).unwrap();
2708 fs::write(model_dir.join(".quarantined"), "bad checksum\n").unwrap();
2709 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2710
2711 let report =
2712 classify_model_cache(&model_dir, &manifest, &ModelAcquisitionPolicy::default());
2713 assert_eq!(report.state_code(), "quarantined_corrupt");
2714 assert!(matches!(
2715 report.state,
2716 ModelCacheState::QuarantinedCorrupt { reason, .. } if reason == "bad checksum"
2717 ));
2718 }
2719
2720 #[test]
2721 fn test_compute_sha256() {
2722 let tmp = tempfile::tempdir().unwrap();
2723 let file_path = tmp.path().join("test.txt");
2724 fs::write(&file_path, b"hello world").unwrap();
2725 let hash = compute_sha256(&file_path).unwrap();
2726 assert_eq!(
2728 hash,
2729 "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"
2730 );
2731 }
2732
2733 #[test]
2734 fn test_check_version_mismatch_none() {
2735 let tmp = tempfile::tempdir().unwrap();
2736 let model_dir = tmp.path().join("model");
2737 fs::create_dir_all(&model_dir).unwrap();
2738 let manifest = ModelManifest::minilm_v2();
2740 fs::write(
2741 model_dir.join(".verified"),
2742 format!("revision={}\n", manifest.revision),
2743 )
2744 .unwrap();
2745
2746 let result = check_version_mismatch(&model_dir, &manifest);
2747 assert!(result.is_none());
2748 }
2749
2750 #[test]
2751 fn test_model_file_local_name() {
2752 let file = ModelFile {
2754 name: "onnx/model.onnx".into(),
2755 sha256: "abc123".into(),
2756 size: 1000,
2757 };
2758 assert_eq!(file.local_name(), "model.onnx");
2759
2760 let file2 = ModelFile {
2762 name: "tokenizer.json".into(),
2763 sha256: "def456".into(),
2764 size: 500,
2765 };
2766 assert_eq!(file2.local_name(), "tokenizer.json");
2767
2768 let file3 = ModelFile {
2770 name: "path/to/deep/model.bin".into(),
2771 sha256: "ghi789".into(),
2772 size: 2000,
2773 };
2774 assert_eq!(file3.local_name(), "model.bin");
2775 }
2776
2777 #[test]
2778 fn test_check_version_mismatch_found() {
2779 let tmp = tempfile::tempdir().unwrap();
2780 let model_dir = tmp.path().join("model");
2781 fs::create_dir_all(&model_dir).unwrap();
2782 fs::write(model_dir.join(".verified"), "revision=old_version\n").unwrap();
2783
2784 let manifest = ModelManifest::minilm_v2();
2785 let result = check_version_mismatch(&model_dir, &manifest);
2786 assert!(matches!(result, Some(ModelState::UpdateAvailable { .. })));
2787 }
2788
2789 #[test]
2790 fn test_atomic_install_preserves_preexisting_legacy_backup_dir() {
2791 let tmp = tempfile::tempdir().unwrap();
2792 let target_dir = tmp.path().join("model");
2793 copy_model_fixtures(&target_dir).unwrap();
2794 fs::write(target_dir.join(".verified"), "revision=old\n").unwrap();
2795
2796 let legacy_backup_dir = tmp.path().join("model.bak");
2797 fs::create_dir_all(&legacy_backup_dir).unwrap();
2798 fs::write(legacy_backup_dir.join("sentinel.txt"), "keep me").unwrap();
2799
2800 let downloader = ModelDownloader::new(target_dir.clone());
2801 copy_model_fixtures(&downloader.temp_dir).unwrap();
2802 fs::write(downloader.temp_dir.join(".verified"), "revision=new\n").unwrap();
2803
2804 downloader.atomic_install().unwrap();
2805
2806 assert_eq!(
2807 fs::read_to_string(legacy_backup_dir.join("sentinel.txt")).unwrap(),
2808 "keep me"
2809 );
2810 assert_eq!(
2811 fs::read_to_string(target_dir.join(".verified")).unwrap(),
2812 "revision=new\n"
2813 );
2814 }
2815
2816 #[test]
2817 fn test_atomic_install_rejects_file_target() {
2818 let tmp = tempfile::tempdir().unwrap();
2819 let target_dir = tmp.path().join("model");
2820 fs::write(&target_dir, "not a directory").unwrap();
2821
2822 let downloader = ModelDownloader::new(target_dir.clone());
2823 copy_model_fixtures(&downloader.temp_dir).unwrap();
2824
2825 let err = downloader.atomic_install().unwrap_err();
2826
2827 assert!(
2828 err.to_string().contains("not a directory"),
2829 "unexpected error: {err}"
2830 );
2831 assert!(downloader.temp_dir.exists());
2832 assert_eq!(fs::read_to_string(&target_dir).unwrap(), "not a directory");
2833 }
2834
2835 #[test]
2836 #[cfg(unix)]
2837 fn test_atomic_install_rejects_dangling_symlink_target() {
2838 use std::os::unix::fs::symlink;
2839
2840 let tmp = tempfile::tempdir().unwrap();
2841 let target_dir = tmp.path().join("model");
2842 let missing_target = tmp.path().join("missing-model");
2843 symlink(&missing_target, &target_dir).unwrap();
2844
2845 let downloader = ModelDownloader::new(target_dir.clone());
2846 copy_model_fixtures(&downloader.temp_dir).unwrap();
2847
2848 let err = downloader.atomic_install().unwrap_err();
2849
2850 assert!(
2851 err.to_string().contains("through symlink"),
2852 "unexpected error: {err}"
2853 );
2854 assert!(downloader.temp_dir.exists());
2855 assert!(
2856 fs::symlink_metadata(&target_dir)
2857 .unwrap()
2858 .file_type()
2859 .is_symlink()
2860 );
2861 assert!(!missing_target.exists());
2862 }
2863
2864 #[test]
2865 fn test_write_verified_marker_overwrites_existing_marker() {
2866 let tmp = tempfile::tempdir().unwrap();
2867 let target_dir = tmp.path().join("model");
2868 fs::create_dir_all(&target_dir).unwrap();
2869 fs::write(target_dir.join(".verified"), "revision=old\n").unwrap();
2870
2871 let downloader = ModelDownloader::new(target_dir.clone());
2872 let manifest = ModelManifest::minilm_v2();
2873 downloader.write_verified_marker(&manifest, None).unwrap();
2874
2875 let marker = fs::read_to_string(target_dir.join(".verified")).unwrap();
2876 assert!(marker.contains(&format!("revision={}", manifest.revision)));
2877 assert!(marker.contains("verified_at="));
2878 assert!(marker.contains("source=registry"));
2879 }
2880
2881 #[test]
2882 fn test_download_error_display() {
2883 let display_cases = [
2884 (
2885 DownloadError::NetworkError("connection refused".into()),
2886 "network error: connection refused",
2887 ),
2888 (
2889 DownloadError::VerificationFailed {
2890 file: "test.onnx".into(),
2891 expected: "abc".into(),
2892 actual: "def".into(),
2893 },
2894 "verification failed for test.onnx: expected abc, got def",
2895 ),
2896 (DownloadError::Cancelled, "download cancelled"),
2897 (DownloadError::Timeout, "download timed out"),
2898 (
2899 DownloadError::HttpError {
2900 status: 503,
2901 message: "service unavailable".into(),
2902 },
2903 "HTTP error 503: service unavailable",
2904 ),
2905 (
2906 DownloadError::ManifestNotVerified {
2907 model_id: "test-model".into(),
2908 unverified_files: vec!["model.onnx".into(), "config.json".into()],
2909 revision_unpinned: true,
2910 },
2911 "model 'test-model' is not production-ready: 2 file(s) have placeholder checksums and revision is not pinned",
2912 ),
2913 (
2914 DownloadError::ManifestNotVerified {
2915 model_id: "test-model".into(),
2916 unverified_files: vec!["model.onnx".into()],
2917 revision_unpinned: false,
2918 },
2919 "model 'test-model' is not production-ready: 1 file(s) have placeholder checksums",
2920 ),
2921 (
2922 DownloadError::InvalidMirrorUrl {
2923 url: "ftp://mirror.example/model.onnx".into(),
2924 reason: "unsupported scheme".into(),
2925 },
2926 "invalid mirror URL 'ftp://mirror.example/model.onnx': unsupported scheme",
2927 ),
2928 ];
2929
2930 for (err, expected) in display_cases {
2931 assert_eq!(err.to_string(), expected);
2932 }
2933
2934 let err: DownloadError = std::io::Error::other("disk full").into();
2935
2936 assert_eq!(err.to_string(), "I/O error: disk full");
2937 let source = err.source().expect("I/O errors expose their source");
2938 assert_eq!(source.to_string(), "disk full");
2939
2940 assert!(
2941 DownloadError::NetworkError("connection refused".into())
2942 .source()
2943 .is_none(),
2944 "non-source variants must not gain an error source"
2945 );
2946 }
2947
2948 #[test]
2949 fn test_manifest_production_ready_minilm() {
2950 let manifest = ModelManifest::minilm_v2();
2952 assert!(manifest.has_verified_checksums());
2953 assert!(manifest.has_pinned_revision());
2954 assert!(manifest.is_production_ready());
2955 }
2956
2957 #[test]
2958 fn test_all_bakeoff_candidates_production_ready() {
2959 let candidates = ModelManifest::bakeoff_candidates();
2961
2962 assert_eq!(candidates.len(), 3, "Expected 3 bake-off candidates");
2964
2965 for manifest in &candidates {
2967 assert!(
2968 manifest.is_production_ready(),
2969 "Model {} should be production-ready",
2970 manifest.id
2971 );
2972 assert!(
2973 manifest.has_verified_checksums(),
2974 "Model {} should have verified checksums",
2975 manifest.id
2976 );
2977 assert!(
2978 manifest.has_pinned_revision(),
2979 "Model {} should have pinned revision",
2980 manifest.id
2981 );
2982 }
2983
2984 assert!(
2986 candidates
2987 .iter()
2988 .any(|m| m.id == "snowflake-arctic-embed-s"),
2989 "Snowflake should be in candidates"
2990 );
2991 assert!(
2992 candidates.iter().any(|m| m.id == "nomic-embed-text-v1.5"),
2993 "Nomic should be in candidates"
2994 );
2995 assert!(
2996 candidates
2997 .iter()
2998 .any(|m| m.id == "jina-reranker-v1-turbo-en"),
2999 "Jina Turbo should be in candidates"
3000 );
3001 }
3002
3003 #[test]
3004 fn test_downloader_cancellation() {
3005 let tmp = tempfile::tempdir().unwrap();
3006 let downloader = ModelDownloader::new(tmp.path().join("model"));
3007
3008 assert!(!downloader.is_cancelled());
3009 downloader.cancel();
3010 assert!(downloader.is_cancelled());
3011 }
3012
3013 #[test]
3014 fn test_prepare_temp_dir_prunes_stale_entries() {
3015 let tmp = tempfile::tempdir().unwrap();
3016 let downloader = ModelDownloader::new(tmp.path().join("model"));
3017 fs::create_dir_all(&downloader.temp_dir).unwrap();
3018 fs::write(downloader.temp_dir.join("model.onnx"), b"partial").unwrap();
3019 fs::write(downloader.temp_dir.join("stale.bin"), b"stale").unwrap();
3020 fs::create_dir_all(downloader.temp_dir.join("nested")).unwrap();
3021 fs::write(
3022 downloader.temp_dir.join("nested").join("should-remove.txt"),
3023 b"stale",
3024 )
3025 .unwrap();
3026
3027 downloader
3028 .prepare_temp_dir(&ModelManifest::minilm_v2())
3029 .unwrap();
3030
3031 assert!(downloader.temp_dir.join("model.onnx").exists());
3032 assert!(!downloader.temp_dir.join("stale.bin").exists());
3033 assert!(!downloader.temp_dir.join("nested").exists());
3034 }
3035
3036 #[test]
3037 #[cfg(unix)]
3038 fn test_prepare_temp_dir_removes_symlink_entries() {
3039 use std::os::unix::fs::symlink;
3040
3041 let tmp = tempfile::tempdir().unwrap();
3042 let downloader = ModelDownloader::new(tmp.path().join("model"));
3043 fs::create_dir_all(&downloader.temp_dir).unwrap();
3044 let outside = tmp.path().join("outside.bin");
3045 fs::write(&outside, b"outside").unwrap();
3046 symlink(&outside, downloader.temp_dir.join("model.onnx")).unwrap();
3047
3048 downloader
3049 .prepare_temp_dir(&ModelManifest::minilm_v2())
3050 .unwrap();
3051
3052 let metadata = fs::symlink_metadata(downloader.temp_dir.join("model.onnx"));
3053 assert!(metadata.is_err(), "symlink should be removed before resume");
3054 assert!(
3055 outside.exists(),
3056 "cleanup must not touch the symlink target"
3057 );
3058 }
3059
3060 #[test]
3061 #[cfg(unix)]
3062 fn test_prepare_temp_dir_rejects_symlinked_temp_dir_without_pruning_target() {
3063 use std::os::unix::fs::symlink;
3064
3065 let tmp = tempfile::tempdir().unwrap();
3066 let downloader = ModelDownloader::new(tmp.path().join("model"));
3067 let outside = tmp.path().join("outside-download-cache");
3068 fs::create_dir_all(&outside).unwrap();
3069 fs::write(outside.join("stale.bin"), b"must remain").unwrap();
3070 symlink(&outside, &downloader.temp_dir).unwrap();
3071
3072 let err = downloader
3073 .prepare_temp_dir(&ModelManifest::minilm_v2())
3074 .expect_err("symlinked temp dir must be rejected before pruning");
3075
3076 assert!(
3077 err.to_string().contains("temp dir through symlink"),
3078 "unexpected symlink-temp-dir error: {err}"
3079 );
3080 assert_eq!(fs::read(outside.join("stale.bin")).unwrap(), b"must remain");
3081 assert!(
3082 fs::symlink_metadata(&downloader.temp_dir)
3083 .unwrap()
3084 .file_type()
3085 .is_symlink()
3086 );
3087 }
3088
3089 #[test]
3090 #[cfg(unix)]
3091 fn test_cleanup_temp_skips_symlinked_temp_dir() {
3092 use std::os::unix::fs::symlink;
3093
3094 let tmp = tempfile::tempdir().unwrap();
3095 let downloader = ModelDownloader::new(tmp.path().join("model"));
3096 let outside = tmp.path().join("outside-download-cache");
3097 fs::create_dir_all(&outside).unwrap();
3098 fs::write(outside.join("sentinel.bin"), b"must remain").unwrap();
3099 symlink(&outside, &downloader.temp_dir).unwrap();
3100
3101 downloader.cleanup_temp();
3102
3103 assert_eq!(
3104 fs::read(outside.join("sentinel.bin")).unwrap(),
3105 b"must remain"
3106 );
3107 assert!(
3108 fs::symlink_metadata(&downloader.temp_dir)
3109 .unwrap()
3110 .file_type()
3111 .is_symlink()
3112 );
3113 }
3114
3115 #[test]
3116 fn test_retryable_error_classification() {
3117 let cases = [
3118 (DownloadError::NetworkError("boom".into()), true),
3119 (DownloadError::Timeout, true),
3120 (
3121 DownloadError::HttpError {
3122 status: 503,
3123 message: "unavailable".into(),
3124 },
3125 true,
3126 ),
3127 (
3128 DownloadError::HttpError {
3129 status: 404,
3130 message: "missing".into(),
3131 },
3132 false,
3133 ),
3134 (DownloadError::Cancelled, false),
3135 (
3136 DownloadError::VerificationFailed {
3137 file: "model.onnx".into(),
3138 expected: "a".into(),
3139 actual: "b".into(),
3140 },
3141 false,
3142 ),
3143 ];
3144
3145 for (err, expected) in cases {
3146 assert_eq!(
3147 err.is_retryable(),
3148 expected,
3149 "retryability mismatch for {err}"
3150 );
3151 }
3152 }
3153
3154 #[test]
3155 fn test_cleanup_temp_for_error_preserves_partial_downloads_on_cancelled() {
3156 let tmp = tempfile::tempdir().unwrap();
3157 let downloader = ModelDownloader::new(tmp.path().join("model"));
3158 fs::create_dir_all(&downloader.temp_dir).unwrap();
3159 let partial = downloader.temp_dir.join("model.onnx");
3160 fs::write(&partial, b"partial").unwrap();
3161
3162 downloader.cleanup_temp_for_error(&DownloadError::Cancelled);
3163
3164 assert!(
3165 partial.exists(),
3166 "cancelled downloads should keep partial files for a resumable retry"
3167 );
3168 }
3169
3170 #[test]
3171 fn test_fail_if_cancelled_preserves_partial_downloads() {
3172 let tmp = tempfile::tempdir().unwrap();
3173 let downloader = ModelDownloader::new(tmp.path().join("model"));
3174 fs::create_dir_all(&downloader.temp_dir).unwrap();
3175 let partial = downloader.temp_dir.join("model.onnx");
3176 fs::write(&partial, b"partial").unwrap();
3177 downloader.cancel();
3178
3179 let result = downloader.fail_if_cancelled();
3180
3181 assert!(matches!(result, Err(DownloadError::Cancelled)));
3182 assert!(
3183 partial.exists(),
3184 "early cancellation checks should not discard resumable partial files"
3185 );
3186 }
3187
3188 #[test]
3189 fn test_cleanup_temp_for_error_discards_temp_after_verification_failure() {
3190 let tmp = tempfile::tempdir().unwrap();
3191 let downloader = ModelDownloader::new(tmp.path().join("model"));
3192 fs::create_dir_all(&downloader.temp_dir).unwrap();
3193 let partial = downloader.temp_dir.join("model.onnx");
3194 fs::write(&partial, b"partial").unwrap();
3195
3196 downloader.cleanup_temp_for_error(&DownloadError::VerificationFailed {
3197 file: "model.onnx".into(),
3198 expected: "good".into(),
3199 actual: "bad".into(),
3200 });
3201
3202 assert!(
3203 !downloader.temp_dir.exists(),
3204 "verification failures should discard the temp directory to avoid reusing corrupt data"
3205 );
3206 }
3207
3208 #[test]
3209 fn test_download_with_mirror_installs_verified_model_from_http_mirror() {
3210 let files = [
3211 ("onnx/model.onnx", b"mirror-model".as_slice()),
3212 ("tokenizer.json", br#"{"tokenizer":"ok"}"#.as_slice()),
3213 ];
3214 let manifest = build_test_manifest("mirror/test-model", "rev123", &files);
3215 let route_prefix = "/cache";
3216 let routes: Vec<(String, MirrorRoute)> = manifest
3217 .files
3218 .iter()
3219 .zip(files.iter())
3220 .map(|(file, (_, body))| {
3221 (
3222 mirror_route_path(route_prefix, &manifest, file),
3223 MirrorRoute {
3224 body: body.to_vec(),
3225 content_type: "application/octet-stream",
3226 chunk_size: 64,
3227 chunk_delay: Duration::ZERO,
3228 },
3229 )
3230 })
3231 .collect();
3232 let server = start_mirror_fixture_server(routes);
3233 let tmp = tempfile::tempdir().unwrap();
3234 let downloader = ModelDownloader::new(tmp.path().join("model"));
3235 let mirror_base = format!("{}/cache/", server.base_url);
3236
3237 downloader
3238 .download_with_mirror(&manifest, Some(&mirror_base), None)
3239 .unwrap();
3240
3241 for (name, body) in files {
3242 let installed = downloader.target_dir.join(
3243 Path::new(name)
3244 .file_name()
3245 .unwrap()
3246 .to_string_lossy()
3247 .as_ref(),
3248 );
3249 assert_eq!(
3250 fs::read(installed).unwrap(),
3251 body,
3252 "mirror install should persist the downloaded payload"
3253 );
3254 }
3255 let marker = fs::read_to_string(downloader.target_dir.join(".verified")).unwrap();
3256 assert!(
3257 marker.contains("revision=rev123"),
3258 "verified marker should preserve manifest identity after mirror install"
3259 );
3260 assert!(
3261 marker.contains("source=mirror:"),
3262 "verified marker should record mirror source"
3263 );
3264
3265 let requests = server.requests();
3266 assert_eq!(
3267 requests.len(),
3268 manifest.files.len(),
3269 "expected one request per manifest file"
3270 );
3271 assert!(
3272 requests
3273 .iter()
3274 .all(|request| request.path.starts_with("/cache/")),
3275 "mirror requests should stay under the configured mirror prefix: {requests:?}"
3276 );
3277 }
3278
3279 #[test]
3280 fn test_download_with_mirror_reports_missing_artifact_from_http_mirror() {
3281 let file_body = b"mirror-model".as_slice();
3282 let manifest = build_test_manifest(
3283 "mirror/test-model",
3284 "rev404",
3285 &[("onnx/model.onnx", file_body)],
3286 );
3287 let server = start_mirror_fixture_server(Vec::new());
3288 let tmp = tempfile::tempdir().unwrap();
3289 let downloader = ModelDownloader::new(tmp.path().join("model"));
3290 let mirror_base = format!("{}/cache", server.base_url);
3291
3292 let err = downloader
3293 .download_with_mirror(&manifest, Some(&mirror_base), None)
3294 .unwrap_err();
3295
3296 assert!(
3297 matches!(err, DownloadError::HttpError { status: 404, .. }),
3298 "missing mirror artifacts should surface as HTTP 404, got: {err}"
3299 );
3300 let requests = server.requests();
3301 assert_eq!(requests.len(), 1);
3302 assert!(
3303 requests[0].path.contains("/resolve/"),
3304 "mirror request should target the resolved artifact path: {requests:?}"
3305 );
3306 }
3307
3308 #[test]
3309 fn test_download_with_mirror_discards_corrupt_payload_from_http_mirror() {
3310 let manifest = build_test_manifest(
3311 "mirror/test-model",
3312 "revbad",
3313 &[("onnx/model.onnx", b"expected-bytes".as_slice())],
3314 );
3315 let route_prefix = "/cache";
3316 let server = start_mirror_fixture_server(vec![(
3317 mirror_route_path(route_prefix, &manifest, &manifest.files[0]),
3318 MirrorRoute {
3319 body: b"corrupt-bytes".to_vec(),
3320 content_type: "application/octet-stream",
3321 chunk_size: 64,
3322 chunk_delay: Duration::ZERO,
3323 },
3324 )]);
3325 let tmp = tempfile::tempdir().unwrap();
3326 let downloader = ModelDownloader::new(tmp.path().join("model"));
3327 let mirror_base = format!("{server_base}/cache", server_base = server.base_url);
3328
3329 let err = downloader
3330 .download_with_mirror(&manifest, Some(&mirror_base), None)
3331 .unwrap_err();
3332
3333 assert!(
3334 matches!(err, DownloadError::VerificationFailed { .. }),
3335 "corrupt mirror payloads must fail checksum verification, got: {err}"
3336 );
3337 assert!(
3338 !downloader.temp_dir.exists(),
3339 "verification failures should discard the temp directory so corrupt payloads are not reused"
3340 );
3341 assert!(
3342 !downloader.target_dir.exists(),
3343 "corrupt mirror payloads must not be promoted into the installed model directory"
3344 );
3345 }
3346
3347 #[test]
3348 fn test_download_with_mirror_resumes_after_cancelled_partial_download() {
3349 let large_payload = vec![b'x'; 128 * 1024];
3350 let manifest = build_test_manifest(
3351 "mirror/test-model",
3352 "revresume",
3353 &[("onnx/model.onnx", &large_payload)],
3354 );
3355 let route_prefix = "/cache";
3356 let server = start_mirror_fixture_server(vec![(
3357 mirror_route_path(route_prefix, &manifest, &manifest.files[0]),
3358 MirrorRoute {
3359 body: large_payload.clone(),
3360 content_type: "application/octet-stream",
3361 chunk_size: 1024,
3362 chunk_delay: Duration::from_millis(2),
3363 },
3364 )]);
3365 let tmp = tempfile::tempdir().unwrap();
3366 let downloader = Arc::new(ModelDownloader::new(tmp.path().join("model")));
3367 let mirror_base = format!("{server_base}/cache", server_base = server.base_url);
3368 let cancel_once = Arc::new(AtomicBool::new(false));
3369 let canceller = Arc::clone(&downloader);
3370 let cancel_flag = Arc::clone(&cancel_once);
3371
3372 let cancelled = downloader.download_with_mirror(
3373 &manifest,
3374 Some(&mirror_base),
3375 Some(Arc::new(move |progress| {
3376 if progress.total_bytes >= 16 * 1024
3377 && cancel_flag
3378 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
3379 .is_ok()
3380 {
3381 canceller.cancel();
3382 }
3383 })),
3384 );
3385
3386 assert!(
3387 matches!(cancelled, Err(DownloadError::Cancelled)),
3388 "first mirror attempt should stop with a cancellation so we can verify resumable recovery"
3389 );
3390 let partial_path = downloader.temp_dir.join("model.onnx");
3391 let partial_size = fs::metadata(&partial_path).unwrap().len();
3392 assert!(
3393 partial_size > 0 && partial_size < large_payload.len() as u64,
3394 "cancelled run should preserve a partial download for resume; got {partial_size} bytes"
3395 );
3396
3397 downloader
3398 .download_with_mirror(&manifest, Some(&mirror_base), None)
3399 .unwrap();
3400
3401 assert_eq!(
3402 fs::read(downloader.target_dir.join("model.onnx")).unwrap(),
3403 large_payload,
3404 "rerun after cancellation should finish the mirrored download and install the exact payload"
3405 );
3406 let requests = server.requests();
3407 assert!(
3408 requests
3409 .iter()
3410 .any(|request| request.range_start == Some(partial_size)),
3411 "rerun should resume from the preserved partial via Range requests; saw requests: {requests:?}"
3412 );
3413 }
3414}