1use std::collections::HashSet;
20use std::fs::{self, File};
21use std::future::{Future, poll_fn};
22use std::io::{BufReader, Read, Write};
23use std::path::{Path, PathBuf};
24use std::pin::Pin;
25use std::sync::Arc;
26use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
27use std::sync::mpsc::TryRecvError;
28use std::time::{Duration, Instant};
29
30use asupersync::bytes::Buf;
31use asupersync::http::Body;
32use serde::{Deserialize, Serialize};
33use sha2::{Digest, Sha256};
34use thiserror::Error;
35use url::Url;
36
37use crate::search::policy::{ModelDownloadPolicy, SemanticPolicy};
38
39#[derive(Debug, Clone, PartialEq)]
54pub enum ModelState {
55 NotInstalled,
57 NeedsConsent,
59 Downloading {
61 progress_pct: u8,
63 bytes_downloaded: u64,
65 total_bytes: u64,
67 },
68 Verifying,
70 Ready,
72 Disabled { reason: String },
74 VerificationFailed { reason: String },
76 UpdateAvailable {
78 current_revision: String,
80 latest_revision: String,
82 },
83 Cancelled,
85}
86
87impl ModelState {
88 pub fn is_ready(&self) -> bool {
90 matches!(self, ModelState::Ready)
91 }
92
93 pub fn is_downloading(&self) -> bool {
95 matches!(self, ModelState::Downloading { .. })
96 }
97
98 pub fn needs_consent(&self) -> bool {
100 matches!(self, ModelState::NeedsConsent)
101 }
102
103 pub fn summary(&self) -> String {
105 match self {
106 ModelState::NotInstalled => "not installed".into(),
107 ModelState::NeedsConsent => "needs consent".into(),
108 ModelState::Downloading { progress_pct, .. } => {
109 format!("downloading ({progress_pct}%)")
110 }
111 ModelState::Verifying => "verifying".into(),
112 ModelState::Ready => "ready".into(),
113 ModelState::Disabled { reason } => format!("disabled: {reason}"),
114 ModelState::VerificationFailed { reason } => format!("verification failed: {reason}"),
115 ModelState::UpdateAvailable {
116 current_revision,
117 latest_revision,
118 } => {
119 format!("update available: {current_revision} -> {latest_revision}")
120 }
121 ModelState::Cancelled => "cancelled".into(),
122 }
123 }
124}
125
126#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
132pub struct ModelAcquisitionPolicy {
133 pub downloads_enabled: bool,
135 pub requires_consent: bool,
137 pub offline: bool,
139 pub metered: bool,
141 pub allow_metered: bool,
143 pub max_model_bytes: Option<u64>,
145 pub mirror_base_url: Option<String>,
147 pub config_source: String,
149}
150
151impl Default for ModelAcquisitionPolicy {
152 fn default() -> Self {
153 Self {
154 downloads_enabled: true,
155 requires_consent: true,
156 offline: false,
157 metered: false,
158 allow_metered: false,
159 max_model_bytes: None,
160 mirror_base_url: None,
161 config_source: "compiled_default".to_string(),
162 }
163 }
164}
165
166impl ModelAcquisitionPolicy {
167 pub fn from_semantic_policy(policy: &SemanticPolicy) -> Self {
169 const MIB: u64 = 1_048_576;
170
171 Self {
172 downloads_enabled: policy.mode.should_build_semantic(),
173 requires_consent: matches!(policy.download_policy, ModelDownloadPolicy::OptIn),
174 max_model_bytes: Some(policy.max_model_size_mb.saturating_mul(MIB)),
175 config_source: "semantic_policy".to_string(),
176 ..Self::default()
177 }
178 }
179}
180
181#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
183#[serde(rename_all = "snake_case", tag = "state")]
184pub enum ModelCacheState {
185 NotAcquired {
187 missing_files: Vec<String>,
189 needs_consent: bool,
191 },
192 Acquiring {
194 staging_dir: PathBuf,
195 bytes_present: u64,
196 total_bytes: u64,
197 },
198 Acquired { model_dir: PathBuf },
200 ChecksumMismatch {
202 file: String,
203 expected: String,
204 actual: String,
205 },
206 IncompatibleVersion {
208 current_revision: String,
209 expected_revision: String,
210 },
211 DisabledByPolicy { reason: String },
213 BudgetBlocked { required_bytes: u64, max_bytes: u64 },
215 QuarantinedCorrupt {
217 marker_path: PathBuf,
218 reason: String,
219 },
220 PreseededLocal { model_dir: PathBuf },
222 MirrorSourced {
224 model_dir: PathBuf,
225 mirror_base_url: String,
226 },
227 OfflineBlocked { missing_files: Vec<String> },
229}
230
231impl ModelCacheState {
232 pub fn code(&self) -> &'static str {
234 match self {
235 Self::NotAcquired { .. } => "not_acquired",
236 Self::Acquiring { .. } => "acquiring",
237 Self::Acquired { .. } => "acquired",
238 Self::ChecksumMismatch { .. } => "checksum_mismatch",
239 Self::IncompatibleVersion { .. } => "incompatible_version",
240 Self::DisabledByPolicy { .. } => "disabled_by_policy",
241 Self::BudgetBlocked { .. } => "budget_blocked",
242 Self::QuarantinedCorrupt { .. } => "quarantined_corrupt",
243 Self::PreseededLocal { .. } => "preseeded_local",
244 Self::MirrorSourced { .. } => "mirror_sourced",
245 Self::OfflineBlocked { .. } => "offline_blocked",
246 }
247 }
248
249 pub fn summary(&self) -> String {
251 match self {
252 Self::NotAcquired {
253 missing_files,
254 needs_consent,
255 } => {
256 let action = if *needs_consent {
257 "user consent required"
258 } else {
259 "ready to acquire"
260 };
261 format!(
262 "model not acquired ({action}); missing {}",
263 missing_files.join(", ")
264 )
265 }
266 Self::Acquiring {
267 bytes_present,
268 total_bytes,
269 staging_dir,
270 } => format!(
271 "model acquisition in progress at {} ({bytes_present}/{total_bytes} bytes)",
272 staging_dir.display()
273 ),
274 Self::Acquired { .. } => "model cache acquired and verified".to_string(),
275 Self::ChecksumMismatch {
276 file,
277 expected,
278 actual,
279 } => format!("checksum mismatch for {file}: expected {expected}, got {actual}"),
280 Self::IncompatibleVersion {
281 current_revision,
282 expected_revision,
283 } => format!("model revision mismatch: {current_revision} != {expected_revision}"),
284 Self::DisabledByPolicy { reason } => format!("model acquisition disabled: {reason}"),
285 Self::BudgetBlocked {
286 required_bytes,
287 max_bytes,
288 } => {
289 format!("model requires {required_bytes} bytes but policy allows {max_bytes} bytes")
290 }
291 Self::QuarantinedCorrupt { reason, .. } => {
292 format!("model cache quarantined: {reason}")
293 }
294 Self::PreseededLocal { .. } => "preseeded local model cache verified".to_string(),
295 Self::MirrorSourced {
296 mirror_base_url, ..
297 } => {
298 format!("model cache verified from mirror {mirror_base_url}")
299 }
300 Self::OfflineBlocked { missing_files } => {
301 format!(
302 "offline and model is not acquired; missing {}",
303 missing_files.join(", ")
304 )
305 }
306 }
307 }
308
309 pub fn next_step(&self) -> Option<&'static str> {
311 match self {
312 Self::NotAcquired { .. } => {
313 Some("Run `cass models install`, or keep using lexical search.")
314 }
315 Self::Acquiring { .. } => {
316 Some("Wait for model acquisition to finish, or keep using lexical search.")
317 }
318 Self::Acquired { .. } | Self::PreseededLocal { .. } | Self::MirrorSourced { .. } => {
319 None
320 }
321 Self::ChecksumMismatch { .. } | Self::QuarantinedCorrupt { .. } => Some(
322 "Run `cass models verify --repair`, or reinstall with `cass models install -y`.",
323 ),
324 Self::IncompatibleVersion { .. } => {
325 Some("Run `cass models install -y` to refresh the model cache.")
326 }
327 Self::DisabledByPolicy { .. } => {
328 Some("Use lexical search or re-enable semantic model acquisition in policy.")
329 }
330 Self::BudgetBlocked { .. } => {
331 Some("Increase the semantic model budget or keep using lexical search.")
332 }
333 Self::OfflineBlocked { .. } => Some(
334 "Reconnect or install from local files with `cass models install --from-file`.",
335 ),
336 }
337 }
338
339 pub fn is_usable(&self) -> bool {
341 matches!(
342 self,
343 Self::Acquired { .. } | Self::PreseededLocal { .. } | Self::MirrorSourced { .. }
344 )
345 }
346}
347
348#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
350pub struct ModelCacheReport {
351 pub model_id: String,
352 pub model_dir: PathBuf,
353 pub state: ModelCacheState,
354 pub required_size_bytes: u64,
355 pub installed_size_bytes: u64,
356 pub policy_source: String,
357}
358
359impl ModelCacheReport {
360 pub fn state_code(&self) -> &'static str {
362 self.state.code()
363 }
364
365 pub fn is_usable(&self) -> bool {
367 self.state.is_usable()
368 }
369}
370
371#[derive(Debug, Clone)]
375pub struct ModelFile {
376 pub name: String,
378 pub sha256: String,
380 pub size: u64,
382}
383
384impl ModelFile {
385 pub fn local_name(&self) -> &str {
390 self.name.rsplit('/').next().unwrap_or(&self.name)
391 }
392}
393
394#[derive(Debug, Clone)]
400pub struct ModelManifest {
401 pub id: String,
403 pub repo: String,
405 pub revision: String,
407 pub files: Vec<ModelFile>,
409 pub license: String,
411}
412
413pub const PLACEHOLDER_CHECKSUM: &str = "PLACEHOLDER_VERIFY_AFTER_DOWNLOAD";
418
419pub fn normalize_mirror_base_url(base_url: &str) -> Result<String, DownloadError> {
424 let trimmed = base_url.trim();
425 if trimmed.is_empty() {
426 return Err(invalid_mirror_url(base_url, "mirror URL cannot be empty"));
427 }
428
429 let parsed = Url::parse(trimmed).map_err(|err| invalid_mirror_url(trimmed, err.to_string()))?;
430
431 match parsed.scheme() {
432 "http" | "https" => {}
433 scheme => {
434 return Err(invalid_mirror_url(
435 trimmed,
436 format!("unsupported URL scheme '{scheme}' (expected http or https)"),
437 ));
438 }
439 }
440
441 if parsed.host_str().is_none() {
442 return Err(invalid_mirror_url(
443 trimmed,
444 "mirror URL must include a host",
445 ));
446 }
447
448 if parsed.query().is_some() || parsed.fragment().is_some() {
449 return Err(invalid_mirror_url(
450 trimmed,
451 "mirror URL must not include query or fragment components",
452 ));
453 }
454
455 Ok(parsed.to_string().trim_end_matches('/').to_string())
456}
457
458fn invalid_mirror_url(url: impl Into<String>, reason: impl Into<String>) -> DownloadError {
459 DownloadError::InvalidMirrorUrl {
460 url: url.into(),
461 reason: reason.into(),
462 }
463}
464
465impl ModelManifest {
466 pub fn has_verified_checksums(&self) -> bool {
471 self.files.iter().all(|f| f.sha256 != PLACEHOLDER_CHECKSUM)
472 }
473
474 pub fn has_pinned_revision(&self) -> bool {
479 self.revision != "main"
480 }
481
482 pub fn is_production_ready(&self) -> bool {
488 self.has_verified_checksums() && self.has_pinned_revision()
489 }
490
491 pub fn minilm_v2() -> Self {
496 Self {
497 id: "all-minilm-l6-v2".into(),
498 repo: "sentence-transformers/all-MiniLM-L6-v2".into(),
499 revision: "c9745ed1d9f207416be6d2e6f8de32d1f16199bf".into(),
501 files: vec![
502 ModelFile {
503 name: "onnx/model.onnx".into(),
505 sha256: "6fd5d72fe4589f189f8ebc006442dbb529bb7ce38f8082112682524616046452"
506 .into(),
507 size: 90405214,
508 },
509 ModelFile {
510 name: "tokenizer.json".into(),
511 sha256: "be50c3628f2bf5bb5e3a7f17b1f74611b2561a3a27eeab05e5aa30f411572037"
512 .into(),
513 size: 466247,
514 },
515 ModelFile {
516 name: "config.json".into(),
517 sha256: "953f9c0d463486b10a6871cc2fd59f223b2c70184f49815e7efbcab5d8908b41"
518 .into(),
519 size: 612,
520 },
521 ModelFile {
525 name: "special_tokens_map.json".into(),
526 sha256: "303df45a03609e4ead04bc3dc1536d0ab19b5358db685b6f3da123d05ec200e3"
527 .into(),
528 size: 112,
529 },
530 ModelFile {
531 name: "tokenizer_config.json".into(),
532 sha256: "acb92769e8195aabd29b7b2137a9e6d6e25c476a4f15aa4355c233426c61576b"
533 .into(),
534 size: 350,
535 },
536 ],
537 license: "Apache-2.0".into(),
538 }
539 }
540
541 pub fn snowflake_arctic_s() -> Self {
555 Self {
556 id: "snowflake-arctic-embed-s".into(),
557 repo: "Snowflake/snowflake-arctic-embed-s".into(),
558 revision: "e596f507467533e48a2e17c007f0e1dacc837b33".into(),
559 files: vec![
560 ModelFile {
561 name: "onnx/model.onnx".into(),
562 sha256: "579c1f1778a0993eb0d2a1403340ffb491c769247fb46acc4f5cf8ac5b89c1e1"
563 .into(),
564 size: 133_093_492,
565 },
566 ModelFile {
567 name: "tokenizer.json".into(),
568 sha256: "91f1def9b9391fdabe028cd3f3fcc4efd34e5d1f08c3bf2de513ebb5911a1854"
569 .into(),
570 size: 711_649,
571 },
572 ModelFile {
573 name: "config.json".into(),
574 sha256: "4e519aa92ec40943356032afe458c8829d70c5766b109e4a57490b82f72dcfb7"
575 .into(),
576 size: 703,
577 },
578 ModelFile {
579 name: "special_tokens_map.json".into(),
580 sha256: "5d5b662e421ea9fac075174bb0688ee0d9431699900b90662acd44b2a350503a"
581 .into(),
582 size: 695,
583 },
584 ModelFile {
585 name: "tokenizer_config.json".into(),
586 sha256: "9ca59277519f6e3692c8685e26b94d4afca2d5438deff66483db495e48735810"
587 .into(),
588 size: 1_433,
589 },
590 ],
591 license: "Apache-2.0".into(),
592 }
593 }
594
595 pub fn nomic_embed() -> Self {
603 Self {
604 id: "nomic-embed-text-v1.5".into(),
605 repo: "nomic-ai/nomic-embed-text-v1.5".into(),
606 revision: "e5cf08aadaa33385f5990def41f7a23405aec398".into(),
607 files: vec![
608 ModelFile {
609 name: "onnx/model.onnx".into(),
610 sha256: "147d5aa88c2101237358e17796cf3a227cead1ec304ec34b465bb08e9d952965"
611 .into(),
612 size: 547_310_275,
613 },
614 ModelFile {
615 name: "tokenizer.json".into(),
616 sha256: "d241a60d5e8f04cc1b2b3e9ef7a4921b27bf526d9f6050ab90f9267a1f9e5c66"
617 .into(),
618 size: 711_396,
619 },
620 ModelFile {
621 name: "config.json".into(),
622 sha256: "0168e0883705b0bf8f2b381e10f45a9f3e1ef4b13869b43c160e4c8a70ddf442"
623 .into(),
624 size: 2_331,
625 },
626 ModelFile {
627 name: "special_tokens_map.json".into(),
628 sha256: "5d5b662e421ea9fac075174bb0688ee0d9431699900b90662acd44b2a350503a"
629 .into(),
630 size: 695,
631 },
632 ModelFile {
633 name: "tokenizer_config.json".into(),
634 sha256: "d7e0000bcc80134debd2222220427e6bf5fa20a669f40a0d0d1409cc18e0a9bc"
635 .into(),
636 size: 1_191,
637 },
638 ],
639 license: "Apache-2.0".into(),
640 }
641 }
642
643 pub fn msmarco_reranker() -> Self {
650 Self {
651 id: "ms-marco-MiniLM-L6-v2".into(),
652 repo: "cross-encoder/ms-marco-MiniLM-L6-v2".into(),
653 revision: "c5ee24cb16019beea0893ab7796b1df96625c6b8".into(),
654 files: vec![
655 ModelFile {
656 name: "onnx/model.onnx".into(),
657 sha256: "5d3e70fd0c9ff14b9b5169a51e957b7a9c74897afd0a35ce4bd318150c1d4d4a"
658 .into(),
659 size: 91_011_230,
660 },
661 ModelFile {
662 name: "tokenizer.json".into(),
663 sha256: "d241a60d5e8f04cc1b2b3e9ef7a4921b27bf526d9f6050ab90f9267a1f9e5c66"
664 .into(),
665 size: 711_396,
666 },
667 ModelFile {
668 name: "config.json".into(),
669 sha256: "380e02c93f431831be65d99a4e7e5f67c133985bf2e77d9d4eba46847190bacc"
670 .into(),
671 size: 794,
672 },
673 ModelFile {
674 name: "special_tokens_map.json".into(),
675 sha256: "3c3507f36dff57bce437223db3b3081d1e2b52ec3e56ee55438193ecb2c94dd6"
676 .into(),
677 size: 132,
678 },
679 ModelFile {
680 name: "tokenizer_config.json".into(),
681 sha256: "a5c2e5a7b1a29a0702cd28c08a399b5ecc110c263009d17f7e3b415f25905fd8"
682 .into(),
683 size: 1_330,
684 },
685 ],
686 license: "Apache-2.0".into(),
687 }
688 }
689
690 pub fn jina_reranker_turbo() -> Self {
697 Self {
698 id: "jina-reranker-v1-turbo-en".into(),
699 repo: "jinaai/jina-reranker-v1-turbo-en".into(),
700 revision: "b8c14f4e723d9e0aab4732a7b7b93741eeeb77c2".into(),
701 files: vec![
702 ModelFile {
703 name: "onnx/model.onnx".into(),
704 sha256: "c1296c66c119de645fa9cdee536d8637740efe85224cfa270281e50f213aa565"
705 .into(),
706 size: 151_296_975,
707 },
708 ModelFile {
709 name: "tokenizer.json".into(),
710 sha256: "0046da43cc8c424b317f56b092b0512aaaa65c4f925d2f16af9d9eeb4d0ef902"
711 .into(),
712 size: 2_030_772,
713 },
714 ModelFile {
715 name: "config.json".into(),
716 sha256: "e050ff6a15ae9295e84882fa0e98051bd8754856cd5201395ebf00ce9f2d609b"
717 .into(),
718 size: 1_206,
719 },
720 ModelFile {
721 name: "special_tokens_map.json".into(),
722 sha256: "06e405a36dfe4b9604f484f6a1e619af1a7f7d09e34a8555eb0b77b66318067f"
723 .into(),
724 size: 280,
725 },
726 ModelFile {
727 name: "tokenizer_config.json".into(),
728 sha256: "d291c6652d96d56ffdbcf1ea19d9bae5ed79003f7648c627e725a619227ce8fa"
729 .into(),
730 size: 1_215,
731 },
732 ],
733 license: "Apache-2.0".into(),
734 }
735 }
736
737 pub fn for_embedder(name: &str) -> Option<Self> {
741 match name {
742 "minilm" => Some(Self::minilm_v2()),
743 "snowflake-arctic-s" => Some(Self::snowflake_arctic_s()),
744 "nomic-embed" => Some(Self::nomic_embed()),
745 _ => None,
746 }
747 }
748
749 pub fn for_reranker(name: &str) -> Option<Self> {
751 match name {
752 "ms-marco" => Some(Self::msmarco_reranker()),
753 "jina-reranker-turbo" => Some(Self::jina_reranker_turbo()),
754 _ => None,
755 }
756 }
757
758 pub fn bakeoff_embedder_candidates() -> Vec<Self> {
762 vec![Self::snowflake_arctic_s(), Self::nomic_embed()]
763 }
764
765 pub fn bakeoff_reranker_candidates() -> Vec<Self> {
769 vec![Self::jina_reranker_turbo()]
770 }
771
772 pub fn bakeoff_candidates() -> Vec<Self> {
776 let mut candidates = Self::bakeoff_embedder_candidates();
777 candidates.extend(Self::bakeoff_reranker_candidates());
778 candidates
779 }
780
781 pub fn total_size(&self) -> u64 {
783 self.files.iter().map(|f| f.size).sum()
784 }
785
786 pub fn download_url_with_base(&self, file: &ModelFile, base_url: Option<&str>) -> String {
788 let root = base_url.unwrap_or("https://huggingface.co");
789 format!(
790 "{}/{}/resolve/{}/{}",
791 root.trim_end_matches('/'),
792 self.repo.trim_start_matches('/'),
793 self.revision,
794 file.name.trim_start_matches('/')
795 )
796 }
797
798 pub fn download_url(&self, file: &ModelFile) -> String {
800 self.download_url_with_base(file, None)
801 }
802
803 pub fn air_gap_bash_script(&self, base_url: Option<&str>) -> String {
810 fn quote_url(url: &str) -> String {
816 debug_assert!(
817 !url.contains('\''),
818 "model download URL unexpectedly contains a single quote: {url}"
819 );
820 format!("'{url}'")
821 }
822
823 let mut out = String::new();
824 out.push_str("# Air-gap model install (bash / Git Bash / MSYS2)\n");
825 out.push_str(
826 "# Run these commands, then re-run `cass models install --from-file \"$DIR\"`.\n",
827 );
828 out.push_str("set -euo pipefail\n");
829 out.push_str(&format!("DIR=\"${{DIR:-./{}_files}}\"\n", self.id));
830 out.push_str("mkdir -p \"$DIR\"\n");
831 for file in &self.files {
832 let url = self.download_url_with_base(file, base_url);
837 out.push_str(&format!(
838 "curl -fL --retry 3 {} -o \"$DIR/{}\" # {} bytes\n",
839 quote_url(&url),
840 file.local_name(),
841 file.size,
842 ));
843 }
844 out.push_str(&format!(
845 "cass models install {} --from-file \"$DIR\" -y\n",
846 self.id
847 ));
848 out
849 }
850
851 pub fn air_gap_powershell_script(&self, base_url: Option<&str>) -> String {
854 fn quote_url_ps(url: &str) -> String {
856 debug_assert!(
857 !url.contains('\''),
858 "model download URL unexpectedly contains a single quote: {url}"
859 );
860 format!("'{url}'")
861 }
862
863 let mut out = String::new();
864 out.push_str("# Air-gap model install (PowerShell 5.1+ and 7+)\n");
865 out.push_str("$ErrorActionPreference = 'Stop'\n");
866 out.push_str(
869 "[System.Net.ServicePointManager]::SecurityProtocol = \
870 [System.Net.ServicePointManager]::SecurityProtocol -bor \
871 [System.Net.SecurityProtocolType]::Tls12\n",
872 );
873 out.push_str(&format!("$dir = \"{}_files\"\n", self.id));
874 out.push_str("New-Item -ItemType Directory -Force -Path $dir | Out-Null\n");
875 for file in &self.files {
876 let url = self.download_url_with_base(file, base_url);
877 out.push_str(&format!(
880 "Invoke-WebRequest -UseBasicParsing -Uri {} -OutFile (Join-Path $dir '{}') # {} bytes\n",
881 quote_url_ps(&url),
882 file.local_name(),
883 file.size,
884 ));
885 }
886 out.push_str(&format!(
887 "cass models install {} --from-file $dir -y\n",
888 self.id
889 ));
890 out
891 }
892}
893
894pub type ProgressCallback = Arc<dyn Fn(DownloadProgress) + Send + Sync>;
896
897#[derive(Debug, Clone)]
899pub struct DownloadProgress {
900 pub current_file: String,
902 pub file_index: usize,
904 pub total_files: usize,
906 pub file_bytes: u64,
908 pub file_total: u64,
910 pub total_bytes: u64,
912 pub grand_total: u64,
914 pub progress_pct: u8,
916}
917
918#[derive(Debug, Error)]
920pub enum DownloadError {
921 #[error("network error: {0}")]
923 NetworkError(String),
924 #[error("I/O error: {0}")]
926 IoError(#[from] std::io::Error),
927 #[error("verification failed for {file}: expected {expected}, got {actual}")]
929 VerificationFailed {
930 file: String,
931 expected: String,
932 actual: String,
933 },
934 #[error("download cancelled")]
936 Cancelled,
937 #[error("download timed out")]
939 Timeout,
940 #[error("HTTP error {status}: {message}")]
942 HttpError { status: u16, message: String },
943 #[error(
950 "model '{model_id}' is not production-ready: {} file(s) have placeholder checksums{}",
951 unverified_files.len(),
952 if *revision_unpinned {
953 " and revision is not pinned"
954 } else {
955 ""
956 }
957 )]
958 ManifestNotVerified {
959 model_id: String,
960 unverified_files: Vec<String>,
961 revision_unpinned: bool,
962 },
963 #[error("invalid mirror URL '{url}': {reason}")]
965 InvalidMirrorUrl { url: String, reason: String },
966}
967
968impl DownloadError {
969 fn is_retryable(&self) -> bool {
970 match self {
971 DownloadError::NetworkError(_) | DownloadError::IoError(_) | DownloadError::Timeout => {
972 true
973 }
974 DownloadError::HttpError { status, .. } => {
975 *status == 408 || *status == 429 || (500..=599).contains(status)
976 }
977 DownloadError::VerificationFailed { .. }
978 | DownloadError::Cancelled
979 | DownloadError::ManifestNotVerified { .. }
980 | DownloadError::InvalidMirrorUrl { .. } => false,
981 }
982 }
983
984 fn should_discard_temp(&self) -> bool {
985 matches!(self, DownloadError::VerificationFailed { .. })
986 }
987}
988
989fn run_download_with_cx<T, F, Fut>(f: F) -> Result<T, DownloadError>
990where
991 T: Send + 'static,
992 F: FnOnce(asupersync::Cx) -> Fut + Send + 'static,
993 Fut: Future<Output = Result<T, DownloadError>> + Send + 'static,
994{
995 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
996 .build()
997 .map_err(|e| {
998 DownloadError::NetworkError(format!("failed to build download runtime: {e}"))
999 })?;
1000
1001 runtime.block_on(async move {
1002 let handle = asupersync::runtime::Runtime::current_handle().ok_or_else(|| {
1003 DownloadError::NetworkError("download runtime handle unavailable".into())
1004 })?;
1005 let (tx, rx) = std::sync::mpsc::channel();
1006 handle
1007 .try_spawn_with_cx(move |cx| async move {
1008 let _ = tx.send(f(cx).await);
1009 })
1010 .map_err(|e| {
1011 DownloadError::NetworkError(format!("failed to spawn download task: {e}"))
1012 })?;
1013
1014 loop {
1015 match rx.try_recv() {
1016 Ok(result) => return result,
1017 Err(TryRecvError::Empty) => asupersync::runtime::yield_now().await,
1018 Err(TryRecvError::Disconnected) => {
1019 return Err(DownloadError::NetworkError(
1020 "download task exited before returning a result".into(),
1021 ));
1022 }
1023 }
1024 }
1025 })
1026}
1027
1028pub struct ModelDownloader {
1030 target_dir: PathBuf,
1032 temp_dir: PathBuf,
1034 cancelled: Arc<AtomicBool>,
1036 connect_timeout: Duration,
1038 file_timeout: Duration,
1040 max_retries: u32,
1042}
1043
1044impl ModelDownloader {
1045 pub fn new(target_dir: PathBuf) -> Self {
1047 let temp_dir = if let Some(parent) = target_dir.parent() {
1050 let dir_name = target_dir
1051 .file_name()
1052 .and_then(|n| n.to_str())
1053 .unwrap_or("model");
1054 parent.join(format!("{}.downloading", dir_name))
1055 } else {
1056 target_dir.with_extension("downloading")
1058 };
1059 Self {
1060 target_dir,
1061 temp_dir,
1062 cancelled: Arc::new(AtomicBool::new(false)),
1063 connect_timeout: Duration::from_secs(30),
1064 file_timeout: Duration::from_secs(300), max_retries: 3,
1066 }
1067 }
1068
1069 pub fn cancellation_handle(&self) -> Arc<AtomicBool> {
1071 Arc::clone(&self.cancelled)
1072 }
1073
1074 pub fn cancel(&self) {
1076 self.cancelled.store(true, Ordering::SeqCst);
1077 }
1078
1079 pub fn is_cancelled(&self) -> bool {
1081 self.cancelled.load(Ordering::SeqCst)
1082 }
1083
1084 pub fn download(
1101 &self,
1102 manifest: &ModelManifest,
1103 on_progress: Option<ProgressCallback>,
1104 ) -> Result<(), DownloadError> {
1105 self.download_with_mirror(manifest, None, on_progress)
1106 }
1107
1108 pub fn download_with_mirror(
1110 &self,
1111 manifest: &ModelManifest,
1112 mirror_base_url: Option<&str>,
1113 on_progress: Option<ProgressCallback>,
1114 ) -> Result<(), DownloadError> {
1115 if !manifest.is_production_ready() {
1118 let unverified_files: Vec<String> = manifest
1119 .files
1120 .iter()
1121 .filter(|f| f.sha256 == PLACEHOLDER_CHECKSUM)
1122 .map(|f| f.name.clone())
1123 .collect();
1124 return Err(DownloadError::ManifestNotVerified {
1125 model_id: manifest.id.clone(),
1126 unverified_files,
1127 revision_unpinned: !manifest.has_pinned_revision(),
1128 });
1129 }
1130
1131 self.cancelled.store(false, Ordering::SeqCst);
1133
1134 self.prepare_temp_dir(manifest)?;
1137
1138 let grand_total = manifest.total_size();
1139 let total_files = manifest.files.len();
1140 let bytes_downloaded = Arc::new(AtomicU64::new(0));
1141
1142 for (idx, file) in manifest.files.iter().enumerate() {
1143 self.fail_if_cancelled()?;
1144
1145 let file_path = self.temp_dir.join(file.local_name());
1147 let url = manifest.download_url_with_base(file, mirror_base_url);
1148
1149 let bytes_before_file = bytes_downloaded.load(Ordering::SeqCst);
1151
1152 let mut last_error = None;
1154 for attempt in 0..self.max_retries {
1155 self.fail_if_cancelled()?;
1156
1157 if attempt > 0 {
1159 bytes_downloaded.store(bytes_before_file, Ordering::SeqCst);
1160 }
1161
1162 if attempt > 0 {
1164 let delay = Duration::from_secs(5 * (1 << (attempt - 1)));
1165 std::thread::sleep(delay);
1166 }
1167
1168 match self.download_file(
1169 &url,
1170 &file_path,
1171 file.size,
1172 idx,
1173 total_files,
1174 &bytes_downloaded,
1175 grand_total,
1176 on_progress.as_ref(),
1177 ) {
1178 Ok(()) => {
1179 last_error = None;
1180 break;
1181 }
1182 Err(DownloadError::Cancelled) => {
1183 return Err(DownloadError::Cancelled);
1184 }
1185 Err(e) => {
1186 if !e.is_retryable() {
1187 self.cleanup_temp_for_error(&e);
1188 return Err(e);
1189 }
1190 last_error = Some(e);
1191 }
1192 }
1193 }
1194
1195 if let Some(err) = last_error {
1196 self.cleanup_temp_for_error(&err);
1197 return Err(err);
1198 }
1199
1200 self.fail_if_cancelled()?;
1202
1203 let actual_hash = compute_sha256(&file_path)?;
1204 if actual_hash != file.sha256 {
1205 let err = DownloadError::VerificationFailed {
1206 file: file.name.clone(),
1207 expected: file.sha256.clone(),
1208 actual: actual_hash,
1209 };
1210 self.cleanup_temp_for_error(&err);
1211 return Err(err);
1212 }
1213 }
1214
1215 self.atomic_install()?;
1217
1218 self.write_verified_marker(manifest, mirror_base_url)?;
1220
1221 Ok(())
1222 }
1223
1224 fn prepare_temp_dir(&self, manifest: &ModelManifest) -> Result<(), DownloadError> {
1225 ensure_model_download_temp_dir(&self.temp_dir)?;
1226
1227 let expected_files: HashSet<String> = manifest
1228 .files
1229 .iter()
1230 .map(|file| file.local_name().to_string())
1231 .collect();
1232
1233 for entry in fs::read_dir(&self.temp_dir)? {
1234 let entry = entry?;
1235 let entry_type = entry.file_type()?;
1236 let entry_name = entry.file_name();
1237 let keep_entry = entry_type.is_file()
1238 && entry_name
1239 .to_str()
1240 .is_some_and(|name| expected_files.contains(name));
1241
1242 if keep_entry {
1243 continue;
1244 }
1245
1246 let entry_path = entry.path();
1247 if entry_type.is_dir() {
1248 fs::remove_dir_all(entry_path)?;
1249 } else {
1250 fs::remove_file(entry_path)?;
1251 }
1252 }
1253
1254 Ok(())
1255 }
1256
1257 #[allow(clippy::too_many_arguments)]
1259 fn download_file(
1260 &self,
1261 url: &str,
1262 path: &Path,
1263 expected_size: u64,
1264 file_idx: usize,
1265 total_files: usize,
1266 bytes_downloaded: &Arc<AtomicU64>,
1267 grand_total: u64,
1268 on_progress: Option<&ProgressCallback>,
1269 ) -> Result<(), DownloadError> {
1270 let mut existing_size = if path.exists() {
1272 fs::metadata(path).map(|m| m.len()).unwrap_or(0)
1273 } else {
1274 0
1275 };
1276
1277 if existing_size > expected_size {
1279 let _ = fs::remove_file(path);
1280 existing_size = 0;
1281 }
1282
1283 if existing_size == expected_size {
1285 bytes_downloaded.fetch_add(expected_size, Ordering::SeqCst);
1286 return Ok(());
1287 }
1288
1289 let url = url.to_string();
1290 let path = path.to_path_buf();
1291 let bytes_downloaded = Arc::clone(bytes_downloaded);
1292 let cancelled = Arc::clone(&self.cancelled);
1293 let progress_callback = on_progress.cloned();
1294 let connect_timeout = self.connect_timeout;
1295 let file_timeout = self.file_timeout;
1296
1297 run_download_with_cx(move |cx| async move {
1298 const MODEL_MAX_BODY_SIZE: usize = 500 * 1024 * 1024;
1302
1303 let client = asupersync::http::h1::HttpClient::builder()
1304 .user_agent(concat!(
1305 "cass/",
1306 env!("CARGO_PKG_VERSION"),
1307 " (model-download)"
1308 ))
1309 .max_body_size(MODEL_MAX_BODY_SIZE)
1310 .build();
1311 let mut headers = vec![("Accept".to_string(), "application/octet-stream".to_string())];
1312
1313 if existing_size > 0 {
1314 headers.push(("Range".to_string(), format!("bytes={existing_size}-")));
1315 bytes_downloaded.fetch_add(existing_size, Ordering::SeqCst);
1316 }
1317
1318 let mut response = asupersync::time::timeout(
1319 cx.now(),
1320 connect_timeout,
1321 client.request_streaming(
1322 &cx,
1323 asupersync::http::h1::Method::Get,
1324 &url,
1325 headers,
1326 Vec::new(),
1327 ),
1328 )
1329 .await
1330 .map_err(|_| DownloadError::Timeout)?
1331 .map_err(|e| DownloadError::NetworkError(e.to_string()))?;
1332
1333 let status = response.head.status;
1334 if status >= 400 {
1335 return Err(DownloadError::HttpError {
1336 status,
1337 message: if response.head.reason.is_empty() {
1338 status.to_string()
1339 } else {
1340 format!("{} {}", status, response.head.reason)
1341 },
1342 });
1343 }
1344
1345 let actually_resuming = existing_size > 0 && status == 206;
1347 if existing_size > 0 && status == 200 {
1348 bytes_downloaded.fetch_sub(existing_size, Ordering::SeqCst);
1349 existing_size = 0;
1350 }
1351
1352 let mut file = fs::OpenOptions::new()
1353 .create(true)
1354 .append(actually_resuming)
1355 .write(true)
1356 .truncate(!actually_resuming)
1357 .open(&path)?;
1358
1359 let file_name = path
1360 .file_name()
1361 .and_then(|n| n.to_str())
1362 .unwrap_or("unknown")
1363 .to_string();
1364 let start = Instant::now();
1365 let mut file_bytes = if actually_resuming { existing_size } else { 0 };
1366
1367 loop {
1368 if cancelled.load(Ordering::SeqCst) {
1369 return Err(DownloadError::Cancelled);
1370 }
1371
1372 let remaining = file_timeout.saturating_sub(start.elapsed());
1373 if remaining.is_zero() {
1374 return Err(DownloadError::Timeout);
1375 }
1376
1377 let frame = asupersync::time::timeout(
1378 cx.now(),
1379 remaining,
1380 poll_fn(|task_cx| Pin::new(&mut response.body).poll_frame(task_cx)),
1381 )
1382 .await
1383 .map_err(|_| DownloadError::Timeout)?;
1384
1385 let Some(frame) = frame else {
1386 break;
1387 };
1388
1389 match frame.map_err(|e| DownloadError::NetworkError(e.to_string()))? {
1390 asupersync::http::body::Frame::Data(mut buf) => {
1391 while buf.has_remaining() {
1392 let chunk = buf.chunk();
1393 if chunk.is_empty() {
1394 break;
1395 }
1396 file.write_all(chunk)?;
1397 let chunk_len = chunk.len();
1398 buf.advance(chunk_len);
1399 file_bytes = file_bytes.saturating_add(chunk_len as u64);
1400 bytes_downloaded.fetch_add(chunk_len as u64, Ordering::SeqCst);
1401
1402 if let Some(callback) = progress_callback.as_ref() {
1403 let total_downloaded = bytes_downloaded.load(Ordering::SeqCst);
1404 let progress_pct = if grand_total > 0 {
1405 ((total_downloaded as f64 / grand_total as f64) * 100.0)
1406 .min(100.0) as u8
1407 } else {
1408 0
1409 };
1410
1411 callback(DownloadProgress {
1412 current_file: file_name.clone(),
1413 file_index: file_idx + 1,
1414 total_files,
1415 file_bytes,
1416 file_total: expected_size,
1417 total_bytes: total_downloaded,
1418 grand_total,
1419 progress_pct,
1420 });
1421 }
1422 }
1423 }
1424 asupersync::http::body::Frame::Trailers(_) => {}
1425 }
1426 }
1427
1428 file.sync_all()?;
1429 Ok(())
1430 })
1431 }
1432
1433 fn atomic_install(&self) -> Result<(), DownloadError> {
1440 let backup_dir = unique_model_backup_dir(&self.target_dir);
1441 sync_tree(&self.temp_dir)?;
1442
1443 let had_existing = if ensure_replaceable_model_dir(&self.target_dir)? {
1445 fs::rename(&self.target_dir, &backup_dir)?;
1446 true
1447 } else {
1448 false
1449 };
1450
1451 match fs::rename(&self.temp_dir, &self.target_dir) {
1453 Ok(()) => {
1454 sync_parent_directory(&self.target_dir)?;
1455 if had_existing {
1457 let _ = fs::remove_dir_all(&backup_dir);
1458 sync_parent_directory(&self.target_dir)?;
1459 }
1460 }
1461 Err(e) => {
1462 if had_existing && backup_dir.exists() {
1464 match fs::rename(&backup_dir, &self.target_dir) {
1465 Ok(()) => {
1466 sync_parent_directory(&self.target_dir)?;
1467 return Err(std::io::Error::other(format!(
1468 "failed installing {} from {}: {e}; restored original model",
1469 self.target_dir.display(),
1470 self.temp_dir.display()
1471 ))
1472 .into());
1473 }
1474 Err(restore_err) => {
1475 return Err(std::io::Error::other(format!(
1476 "failed installing {} from {}: {e}; restore error: {restore_err}; temp model retained at {}",
1477 self.target_dir.display(),
1478 self.temp_dir.display(),
1479 self.temp_dir.display()
1480 ))
1481 .into());
1482 }
1483 }
1484 }
1485 return Err(e.into());
1486 }
1487 }
1488
1489 Ok(())
1490 }
1491
1492 fn write_verified_marker(
1494 &self,
1495 manifest: &ModelManifest,
1496 mirror_base_url: Option<&str>,
1497 ) -> Result<(), DownloadError> {
1498 let marker_path = self.target_dir.join(".verified");
1499 let source = mirror_base_url
1500 .map(|url| format!("mirror:{url}"))
1501 .unwrap_or_else(|| "registry".to_string());
1502 let content = format!(
1503 "revision={}\nverified_at={}\nsource={}\n",
1504 manifest.revision,
1505 chrono::Utc::now().to_rfc3339(),
1506 source
1507 );
1508 let temp_path = unique_model_sidecar_path(&marker_path, "tmp", ".verified");
1509 let mut file = File::create(&temp_path)?;
1510 file.write_all(content.as_bytes())?;
1511 file.sync_all()?;
1512 replace_file_from_temp(&temp_path, &marker_path)?;
1513 sync_parent_directory(&marker_path)?;
1514 Ok(())
1515 }
1516
1517 fn cleanup_temp(&self) {
1519 if model_dir_is_real_directory(&self.temp_dir).unwrap_or(false) {
1520 let _ = fs::remove_dir_all(&self.temp_dir);
1521 }
1522 }
1523
1524 fn cleanup_temp_for_error(&self, err: &DownloadError) {
1525 if err.should_discard_temp() {
1526 self.cleanup_temp();
1527 }
1528 }
1529
1530 fn fail_if_cancelled(&self) -> Result<(), DownloadError> {
1531 if self.is_cancelled() {
1532 Err(DownloadError::Cancelled)
1533 } else {
1534 Ok(())
1535 }
1536 }
1537}
1538
1539pub fn compute_sha256(path: &Path) -> Result<String, DownloadError> {
1541 let file = File::open(path)?;
1542 let mut reader = BufReader::new(file);
1543 let mut hasher = Sha256::new();
1544
1545 let mut buffer = [0u8; 8192];
1546 loop {
1547 let n = reader.read(&mut buffer)?;
1548 if n == 0 {
1549 break;
1550 }
1551 hasher.update(&buffer[..n]);
1552 }
1553
1554 let hash = hasher.finalize();
1555 Ok(hex::encode(hash))
1556}
1557
1558pub fn classify_model_cache(
1564 model_dir: &Path,
1565 manifest: &ModelManifest,
1566 policy: &ModelAcquisitionPolicy,
1567) -> ModelCacheReport {
1568 classify_model_cache_with_integrity(model_dir, manifest, policy, ModelCacheIntegrity::Full)
1569}
1570
1571pub(crate) fn classify_model_cache_metadata(
1578 model_dir: &Path,
1579 manifest: &ModelManifest,
1580 policy: &ModelAcquisitionPolicy,
1581) -> ModelCacheReport {
1582 classify_model_cache_with_integrity(model_dir, manifest, policy, ModelCacheIntegrity::Metadata)
1583}
1584
1585#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1586enum ModelCacheIntegrity {
1587 Full,
1588 Metadata,
1589}
1590
1591fn classify_model_cache_with_integrity(
1592 model_dir: &Path,
1593 manifest: &ModelManifest,
1594 policy: &ModelAcquisitionPolicy,
1595 integrity: ModelCacheIntegrity,
1596) -> ModelCacheReport {
1597 let required_size_bytes = manifest.total_size();
1598 let installed_size_bytes = installed_manifest_size(model_dir, manifest);
1599 let missing_files = missing_manifest_files(model_dir, manifest);
1600 let state = classify_model_cache_state(model_dir, manifest, policy, &missing_files, integrity);
1601
1602 ModelCacheReport {
1603 model_id: manifest.id.clone(),
1604 model_dir: model_dir.to_path_buf(),
1605 state,
1606 required_size_bytes,
1607 installed_size_bytes,
1608 policy_source: policy.config_source.clone(),
1609 }
1610}
1611
1612fn classify_model_cache_state(
1613 model_dir: &Path,
1614 manifest: &ModelManifest,
1615 policy: &ModelAcquisitionPolicy,
1616 missing_files: &[String],
1617 integrity: ModelCacheIntegrity,
1618) -> ModelCacheState {
1619 if !policy.downloads_enabled {
1620 return ModelCacheState::DisabledByPolicy {
1621 reason: "semantic model downloads disabled by policy".to_string(),
1622 };
1623 }
1624
1625 let quarantine_marker = model_dir.join(".quarantined");
1626 if quarantine_marker.is_file() {
1627 let reason = fs::read_to_string(&quarantine_marker)
1628 .ok()
1629 .map(|s| s.trim().to_string())
1630 .filter(|s| !s.is_empty())
1631 .unwrap_or_else(|| "model cache quarantined after integrity failure".to_string());
1632 return ModelCacheState::QuarantinedCorrupt {
1633 marker_path: quarantine_marker,
1634 reason,
1635 };
1636 }
1637
1638 let staging_dir = model_download_temp_dir(model_dir);
1639 if staging_dir.is_dir() {
1640 return ModelCacheState::Acquiring {
1641 bytes_present: directory_size_bytes(&staging_dir),
1642 staging_dir,
1643 total_bytes: manifest.total_size(),
1644 };
1645 }
1646
1647 if !missing_files.is_empty() {
1648 if policy.offline {
1649 return ModelCacheState::OfflineBlocked {
1650 missing_files: missing_files.to_vec(),
1651 };
1652 }
1653
1654 if policy.metered && !policy.allow_metered {
1655 return ModelCacheState::DisabledByPolicy {
1656 reason: "metered network disallows model acquisition".to_string(),
1657 };
1658 }
1659
1660 if let Some(max_bytes) = policy.max_model_bytes
1661 && manifest.total_size() > max_bytes
1662 {
1663 return ModelCacheState::BudgetBlocked {
1664 required_bytes: manifest.total_size(),
1665 max_bytes,
1666 };
1667 }
1668
1669 return ModelCacheState::NotAcquired {
1670 missing_files: missing_files.to_vec(),
1671 needs_consent: policy.requires_consent,
1672 };
1673 }
1674
1675 if integrity == ModelCacheIntegrity::Full {
1676 for file in &manifest.files {
1677 let Some(path) = model_file_path(model_dir, file) else {
1678 continue;
1679 };
1680 match compute_sha256(&path) {
1681 Ok(actual) if actual == file.sha256 => {}
1682 Ok(actual) => {
1683 return ModelCacheState::ChecksumMismatch {
1684 file: file.local_name().to_string(),
1685 expected: file.sha256.clone(),
1686 actual,
1687 };
1688 }
1689 Err(err) => {
1690 return ModelCacheState::QuarantinedCorrupt {
1691 marker_path: path,
1692 reason: format!("unable to hash model file {}: {err}", file.local_name()),
1693 };
1694 }
1695 }
1696 }
1697 }
1698
1699 let verified_marker = model_dir.join(".verified");
1700 if !verified_marker.is_file() {
1701 return ModelCacheState::PreseededLocal {
1702 model_dir: model_dir.to_path_buf(),
1703 };
1704 }
1705
1706 let marker = match fs::read_to_string(&verified_marker) {
1707 Ok(marker) => marker,
1708 Err(err) => {
1709 return ModelCacheState::QuarantinedCorrupt {
1710 marker_path: verified_marker,
1711 reason: format!("unable to read verified marker: {err}"),
1712 };
1713 }
1714 };
1715
1716 let current_revision =
1717 marker_field(&marker, "revision").unwrap_or_else(|| "<unknown>".to_string());
1718 if current_revision != manifest.revision {
1719 return ModelCacheState::IncompatibleVersion {
1720 current_revision,
1721 expected_revision: manifest.revision.clone(),
1722 };
1723 }
1724
1725 match marker_field(&marker, "source") {
1726 Some(source) if source == "preseeded_local" => ModelCacheState::PreseededLocal {
1727 model_dir: model_dir.to_path_buf(),
1728 },
1729 Some(source) if source.starts_with("mirror:") => ModelCacheState::MirrorSourced {
1730 model_dir: model_dir.to_path_buf(),
1731 mirror_base_url: source.trim_start_matches("mirror:").to_string(),
1732 },
1733 _ => ModelCacheState::Acquired {
1734 model_dir: model_dir.to_path_buf(),
1735 },
1736 }
1737}
1738
1739pub fn check_model_installed(model_dir: &Path, manifest: &ModelManifest) -> ModelState {
1749 if !model_dir.is_dir() {
1750 return ModelState::NotInstalled;
1751 }
1752
1753 let verified_marker = model_dir.join(".verified");
1754 if !verified_marker.is_file() {
1755 return ModelState::NotInstalled;
1756 }
1757
1758 for file in &manifest.files {
1762 if model_file_path(model_dir, file).is_none() {
1763 return ModelState::NotInstalled;
1764 }
1765 }
1766
1767 ModelState::Ready
1768}
1769
1770pub fn check_version_mismatch(model_dir: &Path, manifest: &ModelManifest) -> Option<ModelState> {
1772 let verified_marker = model_dir.join(".verified");
1773 if !verified_marker.is_file() {
1774 return None;
1775 }
1776
1777 let content = fs::read_to_string(&verified_marker).ok()?;
1779 let installed_revision = content
1780 .lines()
1781 .find(|l| l.starts_with("revision="))
1782 .map(|l| l.trim_start_matches("revision=").to_string())?;
1783
1784 if installed_revision != manifest.revision {
1785 Some(ModelState::UpdateAvailable {
1786 current_revision: installed_revision,
1787 latest_revision: manifest.revision.clone(),
1788 })
1789 } else {
1790 None
1791 }
1792}
1793
1794fn ensure_replaceable_model_dir(path: &Path) -> Result<bool, DownloadError> {
1795 match fs::symlink_metadata(path) {
1796 Ok(metadata) => {
1797 ensure_real_model_directory_metadata(
1798 path,
1799 &metadata,
1800 "refusing to install model through symlink",
1801 "refusing to replace model target because it is not a directory",
1802 )?;
1803 Ok(true)
1804 }
1805 Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(false),
1806 Err(err) => Err(std::io::Error::new(
1807 err.kind(),
1808 format!(
1809 "failed inspecting model target before install {}: {err}",
1810 path.display()
1811 ),
1812 )
1813 .into()),
1814 }
1815}
1816
1817fn ensure_model_download_temp_dir(path: &Path) -> Result<(), DownloadError> {
1818 match fs::symlink_metadata(path) {
1819 Ok(metadata) => {
1820 ensure_real_model_directory_metadata(
1821 path,
1822 &metadata,
1823 "refusing to prepare model download temp dir through symlink",
1824 "refusing to prepare model download temp dir because it is not a directory",
1825 )?;
1826 }
1827 Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
1828 fs::create_dir_all(path)?;
1829 let metadata = fs::symlink_metadata(path).map_err(|err| {
1830 std::io::Error::new(
1831 err.kind(),
1832 format!(
1833 "failed inspecting model download temp dir after create {}: {err}",
1834 path.display()
1835 ),
1836 )
1837 })?;
1838 ensure_real_model_directory_metadata(
1839 path,
1840 &metadata,
1841 "refusing to prepare model download temp dir through symlink",
1842 "refusing to prepare model download temp dir because it is not a directory",
1843 )?;
1844 }
1845 Err(err) => {
1846 return Err(std::io::Error::new(
1847 err.kind(),
1848 format!(
1849 "failed inspecting model download temp dir before prepare {}: {err}",
1850 path.display()
1851 ),
1852 )
1853 .into());
1854 }
1855 }
1856 Ok(())
1857}
1858
1859fn model_dir_is_real_directory(path: &Path) -> Result<bool, DownloadError> {
1860 match fs::symlink_metadata(path) {
1861 Ok(metadata) => {
1862 let file_type = metadata.file_type();
1863 Ok(file_type.is_dir() && !file_type.is_symlink())
1864 }
1865 Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(false),
1866 Err(err) => Err(err.into()),
1867 }
1868}
1869
1870fn ensure_real_model_directory_metadata(
1871 path: &Path,
1872 metadata: &fs::Metadata,
1873 symlink_message: &str,
1874 non_dir_message: &str,
1875) -> Result<(), DownloadError> {
1876 let file_type = metadata.file_type();
1877 if file_type.is_symlink() {
1878 return Err(std::io::Error::other(format!("{symlink_message}: {}", path.display())).into());
1879 }
1880 if !file_type.is_dir() {
1881 return Err(std::io::Error::other(format!("{non_dir_message}: {}", path.display())).into());
1882 }
1883 Ok(())
1884}
1885
1886fn model_download_temp_dir(target_dir: &Path) -> PathBuf {
1887 if let Some(parent) = target_dir.parent() {
1888 let dir_name = target_dir
1889 .file_name()
1890 .and_then(|name| name.to_str())
1891 .unwrap_or("model");
1892 parent.join(format!("{dir_name}.downloading"))
1893 } else {
1894 target_dir.with_extension("downloading")
1895 }
1896}
1897
1898pub fn model_file_path(model_dir: &Path, file: &ModelFile) -> Option<PathBuf> {
1903 let canonical = model_dir.join(&file.name);
1904 if canonical.is_file() {
1905 return Some(canonical);
1906 }
1907
1908 let local = model_dir.join(file.local_name());
1909 if local.is_file() {
1910 return Some(local);
1911 }
1912
1913 None
1914}
1915
1916fn missing_manifest_files(model_dir: &Path, manifest: &ModelManifest) -> Vec<String> {
1917 manifest
1918 .files
1919 .iter()
1920 .filter(|file| model_file_path(model_dir, file).is_none())
1921 .map(|file| file.local_name().to_string())
1922 .collect()
1923}
1924
1925fn installed_manifest_size(model_dir: &Path, manifest: &ModelManifest) -> u64 {
1926 manifest
1927 .files
1928 .iter()
1929 .filter_map(|file| model_file_path(model_dir, file))
1930 .filter_map(|path| path.metadata().ok())
1931 .map(|metadata| metadata.len())
1932 .sum()
1933}
1934
1935fn directory_size_bytes(path: &Path) -> u64 {
1936 let Ok(entries) = fs::read_dir(path) else {
1937 return 0;
1938 };
1939
1940 entries
1941 .filter_map(Result::ok)
1942 .map(|entry| {
1943 let path = entry.path();
1944 match entry.file_type() {
1945 Ok(file_type) if file_type.is_file() => {
1946 entry.metadata().map(|metadata| metadata.len()).unwrap_or(0)
1947 }
1948 Ok(file_type) if file_type.is_dir() => directory_size_bytes(&path),
1949 _ => 0,
1950 }
1951 })
1952 .sum()
1953}
1954
1955fn marker_field(content: &str, field: &str) -> Option<String> {
1956 let prefix = format!("{field}=");
1957 content
1958 .lines()
1959 .find_map(|line| line.strip_prefix(&prefix))
1960 .map(|value| value.trim().to_string())
1961 .filter(|value| !value.is_empty())
1962}
1963
1964fn unique_model_backup_dir(path: &Path) -> PathBuf {
1965 unique_model_sidecar_path(path, "bak", "model")
1966}
1967
1968fn unique_model_sidecar_path(path: &Path, suffix: &str, fallback_name: &str) -> PathBuf {
1969 static NEXT_NONCE: AtomicU64 = AtomicU64::new(0);
1970
1971 let timestamp = std::time::SystemTime::now()
1972 .duration_since(std::time::UNIX_EPOCH)
1973 .unwrap_or_default()
1974 .as_nanos();
1975 let nonce = NEXT_NONCE.fetch_add(1, Ordering::Relaxed);
1976 let file_name = path
1977 .file_name()
1978 .and_then(|name| name.to_str())
1979 .unwrap_or(fallback_name);
1980
1981 path.with_file_name(format!(
1982 ".{file_name}.{suffix}.{}.{}.{}",
1983 std::process::id(),
1984 timestamp,
1985 nonce
1986 ))
1987}
1988
1989fn replace_file_from_temp(temp_path: &Path, final_path: &Path) -> Result<(), DownloadError> {
1990 #[cfg(windows)]
1991 {
1992 match fs::rename(temp_path, final_path) {
1993 Ok(()) => sync_parent_directory(final_path),
1994 Err(first_err)
1995 if final_path.exists()
1996 && matches!(
1997 first_err.kind(),
1998 std::io::ErrorKind::AlreadyExists | std::io::ErrorKind::PermissionDenied
1999 ) =>
2000 {
2001 let backup_path = unique_model_backup_dir(final_path);
2002 fs::rename(final_path, &backup_path).map_err(|backup_err| {
2003 let _ = fs::remove_file(temp_path);
2004 DownloadError::IoError(std::io::Error::other(format!(
2005 "failed preparing backup {} before replacing {}: first error: {first_err}; backup error: {backup_err}",
2006 backup_path.display(),
2007 final_path.display()
2008 )))
2009 })?;
2010 match fs::rename(temp_path, final_path) {
2011 Ok(()) => {
2012 let _ = fs::remove_file(&backup_path);
2013 sync_parent_directory(final_path)
2014 }
2015 Err(second_err) => match fs::rename(&backup_path, final_path) {
2016 Ok(()) => {
2017 let _ = fs::remove_file(temp_path);
2018 sync_parent_directory(final_path)?;
2019 Err(std::io::Error::other(format!(
2020 "failed replacing {} with {}: first error: {first_err}; second error: {second_err}; restored original file",
2021 final_path.display(),
2022 temp_path.display()
2023 ))
2024 .into())
2025 }
2026 Err(restore_err) => Err(std::io::Error::other(format!(
2027 "failed replacing {} with {}: first error: {first_err}; second error: {second_err}; restore error: {restore_err}; temp file retained at {}",
2028 final_path.display(),
2029 temp_path.display(),
2030 temp_path.display()
2031 ))
2032 .into()),
2033 },
2034 }
2035 }
2036 Err(rename_err) => Err(rename_err.into()),
2037 }
2038 }
2039
2040 #[cfg(not(windows))]
2041 {
2042 fs::rename(temp_path, final_path)?;
2043 sync_parent_directory(final_path)
2044 }
2045}
2046
2047#[cfg(not(windows))]
2048fn sync_tree(path: &Path) -> Result<(), DownloadError> {
2049 sync_tree_inner(path)?;
2050 sync_parent_directory(path)
2051}
2052
2053#[cfg(not(windows))]
2054fn sync_tree_inner(path: &Path) -> Result<(), DownloadError> {
2055 let metadata = fs::metadata(path)?;
2056 if metadata.is_dir() {
2057 for entry in fs::read_dir(path)? {
2058 let entry = entry?;
2059 sync_tree_inner(&entry.path())?;
2060 }
2061 File::open(path)?.sync_all()?;
2062 } else if metadata.is_file() {
2063 File::open(path)?.sync_all()?;
2064 }
2065 Ok(())
2066}
2067
2068#[cfg(windows)]
2069fn sync_tree(_path: &Path) -> Result<(), DownloadError> {
2070 Ok(())
2071}
2072
2073#[cfg(not(windows))]
2074fn sync_parent_directory(path: &Path) -> Result<(), DownloadError> {
2075 let Some(parent) = path.parent() else {
2076 return Ok(());
2077 };
2078 File::open(parent)?.sync_all()?;
2079 Ok(())
2080}
2081
2082#[cfg(windows)]
2083fn sync_parent_directory(_path: &Path) -> Result<(), DownloadError> {
2084 Ok(())
2085}
2086
2087#[cfg(test)]
2088mod tests {
2089 use super::*;
2090 use std::collections::BTreeMap;
2091 use std::error::Error as _;
2092 use std::io::{Read, Write};
2093 use std::net::{Shutdown, TcpListener, TcpStream};
2094 use std::sync::atomic::{AtomicBool, Ordering};
2095 use std::sync::{Arc, Mutex};
2096 use std::thread;
2097 use std::time::Duration;
2098
2099 fn copy_model_fixtures(target_dir: &Path) -> std::io::Result<()> {
2102 let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/models");
2103 fs::create_dir_all(target_dir)?;
2104
2105 fs::copy(
2107 fixture_dir.join("model.onnx"),
2108 target_dir.join("model.onnx"),
2109 )?;
2110
2111 for file in &[
2113 "tokenizer.json",
2114 "config.json",
2115 "special_tokens_map.json",
2116 "tokenizer_config.json",
2117 ] {
2118 fs::copy(fixture_dir.join(file), target_dir.join(file))?;
2119 }
2120
2121 Ok(())
2122 }
2123
2124 #[derive(Clone, Debug)]
2125 struct MirrorRequest {
2126 path: String,
2127 range_start: Option<u64>,
2128 }
2129
2130 #[derive(Clone)]
2131 struct MirrorRoute {
2132 body: Vec<u8>,
2133 content_type: &'static str,
2134 chunk_size: usize,
2135 chunk_delay: Duration,
2136 }
2137
2138 struct MirrorFixtureServer {
2139 base_url: String,
2140 stop: Arc<AtomicBool>,
2141 wake_addr: String,
2142 requests: Arc<Mutex<Vec<MirrorRequest>>>,
2143 handle: Option<std::thread::JoinHandle<()>>,
2144 }
2145
2146 impl MirrorFixtureServer {
2147 fn requests(&self) -> Vec<MirrorRequest> {
2148 self.requests.lock().unwrap().clone()
2149 }
2150 }
2151
2152 impl Drop for MirrorFixtureServer {
2153 fn drop(&mut self) {
2154 self.stop.store(true, Ordering::SeqCst);
2155 if let Ok(stream) = TcpStream::connect(&self.wake_addr) {
2156 let _ = stream.shutdown(Shutdown::Both);
2157 }
2158 if let Some(handle) = self.handle.take() {
2159 let _ = handle.join();
2160 }
2161 }
2162 }
2163
2164 fn start_mirror_fixture_server(routes: Vec<(String, MirrorRoute)>) -> MirrorFixtureServer {
2165 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test mirror server");
2166 listener
2167 .set_nonblocking(true)
2168 .expect("set test mirror server nonblocking");
2169 let addr = listener.local_addr().expect("read server address");
2170 let wake_addr = addr.to_string();
2171 let base_url = format!("http://{wake_addr}");
2172 let stop = Arc::new(AtomicBool::new(false));
2173 let stop_flag = Arc::clone(&stop);
2174 let requests = Arc::new(Mutex::new(Vec::new()));
2175 let request_log = Arc::clone(&requests);
2176 let route_map: BTreeMap<String, MirrorRoute> = routes.into_iter().collect();
2177 let handle = thread::spawn(move || {
2178 while !stop_flag.load(Ordering::SeqCst) {
2179 match listener.accept() {
2180 Ok((stream, _)) => {
2181 handle_mirror_request(stream, &route_map, &request_log);
2182 }
2183 Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {
2184 thread::sleep(Duration::from_millis(10));
2185 }
2186 Err(_) => break,
2187 }
2188 }
2189 });
2190 MirrorFixtureServer {
2191 base_url,
2192 stop,
2193 wake_addr,
2194 requests,
2195 handle: Some(handle),
2196 }
2197 }
2198
2199 fn handle_mirror_request(
2200 mut stream: TcpStream,
2201 routes: &BTreeMap<String, MirrorRoute>,
2202 request_log: &Arc<Mutex<Vec<MirrorRequest>>>,
2203 ) {
2204 let mut buffer = [0_u8; 8192];
2205 let read = match stream.read(&mut buffer) {
2206 Ok(read) => read,
2207 Err(_) => return,
2208 };
2209 let request = String::from_utf8_lossy(&buffer[..read]);
2210 let mut lines = request.lines();
2211 let target = lines
2212 .next()
2213 .and_then(|line| line.split_whitespace().nth(1))
2214 .unwrap_or("/");
2215 let path = target
2216 .split_once('?')
2217 .map(|(path, _)| path)
2218 .unwrap_or(target)
2219 .split_once('#')
2220 .map(|(path, _)| path)
2221 .unwrap_or(target)
2222 .to_string();
2223 let range_start = lines.find_map(parse_range_start_header);
2224 request_log.lock().unwrap().push(MirrorRequest {
2225 path: path.clone(),
2226 range_start,
2227 });
2228
2229 let Some(route) = routes.get(&path) else {
2230 let response = concat!(
2231 "HTTP/1.1 404 Not Found\r\n",
2232 "Content-Length: 9\r\n",
2233 "Content-Type: text/plain\r\n",
2234 "Connection: close\r\n\r\n",
2235 "not found"
2236 );
2237 let _ = stream.write_all(response.as_bytes());
2238 let _ = stream.flush();
2239 return;
2240 };
2241
2242 let start = range_start.unwrap_or(0) as usize;
2243 let mut status = "200 OK";
2244 let mut content_range = None;
2245 let body = if start >= route.body.len() {
2246 status = "416 Range Not Satisfiable";
2247 &[][..]
2248 } else if start > 0 {
2249 status = "206 Partial Content";
2250 content_range = Some(format!(
2251 "bytes {start}-{}/{}",
2252 route.body.len().saturating_sub(1),
2253 route.body.len()
2254 ));
2255 &route.body[start..]
2256 } else {
2257 route.body.as_slice()
2258 };
2259
2260 let mut response = format!(
2261 "HTTP/1.1 {status}\r\nContent-Length: {}\r\nContent-Type: {}\r\nConnection: close\r\n",
2262 body.len(),
2263 route.content_type
2264 );
2265 if let Some(content_range) = content_range {
2266 response.push_str(&format!("Content-Range: {content_range}\r\n"));
2267 }
2268 response.push_str("\r\n");
2269 let _ = stream.write_all(response.as_bytes());
2270 for chunk in body.chunks(route.chunk_size.max(1)) {
2271 if stream.write_all(chunk).is_err() {
2272 return;
2273 }
2274 let _ = stream.flush();
2275 if !route.chunk_delay.is_zero() {
2276 thread::sleep(route.chunk_delay);
2277 }
2278 }
2279 }
2280
2281 fn parse_range_start_header(line: &str) -> Option<u64> {
2282 let (name, value) = line.split_once(':')?;
2283 if !name.eq_ignore_ascii_case("range") {
2284 return None;
2285 }
2286 let value = value.trim();
2287 let value = value.strip_prefix("bytes=")?;
2288 let (start, _) = value.split_once('-')?;
2289 start.parse().ok()
2290 }
2291
2292 fn build_test_manifest(repo: &str, revision: &str, files: &[(&str, &[u8])]) -> ModelManifest {
2293 ModelManifest {
2294 id: "mirror-test-model".into(),
2295 repo: repo.into(),
2296 revision: revision.into(),
2297 files: files
2298 .iter()
2299 .map(|(name, body)| ModelFile {
2300 name: (*name).into(),
2301 sha256: hex::encode(Sha256::digest(body)),
2302 size: body.len() as u64,
2303 })
2304 .collect(),
2305 license: "Apache-2.0".into(),
2306 }
2307 }
2308
2309 fn mirror_route_path(prefix: &str, manifest: &ModelManifest, file: &ModelFile) -> String {
2310 format!(
2311 "{}/{}/resolve/{}/{}",
2312 prefix.trim_end_matches('/'),
2313 manifest.repo.trim_start_matches('/'),
2314 manifest.revision,
2315 file.name.trim_start_matches('/')
2316 )
2317 }
2318
2319 #[test]
2320 fn test_model_state_summary() {
2321 assert_eq!(ModelState::NotInstalled.summary(), "not installed");
2322 assert_eq!(ModelState::NeedsConsent.summary(), "needs consent");
2323 assert_eq!(ModelState::Ready.summary(), "ready");
2324 assert_eq!(
2325 ModelState::Downloading {
2326 progress_pct: 50,
2327 bytes_downloaded: 1000,
2328 total_bytes: 2000
2329 }
2330 .summary(),
2331 "downloading (50%)"
2332 );
2333 }
2334
2335 #[test]
2336 fn test_model_state_is_ready() {
2337 assert!(ModelState::Ready.is_ready());
2338 assert!(!ModelState::NotInstalled.is_ready());
2339 assert!(!ModelState::NeedsConsent.is_ready());
2340 assert!(
2341 !ModelState::Downloading {
2342 progress_pct: 0,
2343 bytes_downloaded: 0,
2344 total_bytes: 0
2345 }
2346 .is_ready()
2347 );
2348 }
2349
2350 #[test]
2351 fn test_model_manifest_total_size() {
2352 let manifest = ModelManifest::minilm_v2();
2353 assert!(manifest.total_size() > 20_000_000); }
2355
2356 #[test]
2357 fn test_model_manifest_download_url() {
2358 let manifest = ModelManifest::minilm_v2();
2359 let url = manifest.download_url(&manifest.files[0]);
2360 assert!(url.contains("huggingface.co"));
2361 assert!(url.contains("sentence-transformers/all-MiniLM-L6-v2"));
2362 assert!(url.contains("model.onnx"));
2363 }
2364
2365 #[test]
2366 fn test_model_manifest_download_url_with_mirror_base() {
2367 let manifest = ModelManifest::minilm_v2();
2368 let url = manifest
2369 .download_url_with_base(&manifest.files[0], Some("https://mirror.example/cache/"));
2370 assert_eq!(
2371 url,
2372 format!(
2373 "https://mirror.example/cache/{}/resolve/{}/{}",
2374 manifest.repo, manifest.revision, manifest.files[0].name
2375 )
2376 );
2377 }
2378
2379 #[test]
2380 fn air_gap_bash_script_uses_explicit_output_filenames() {
2381 let manifest = ModelManifest::minilm_v2();
2387 let script = manifest.air_gap_bash_script(None);
2388 assert!(script.contains("set -euo pipefail"));
2389 assert!(script.contains("DIR=\"${DIR:-./all-minilm-l6-v2_files}\""));
2390 for file in &manifest.files {
2391 let local = file.local_name();
2392 assert!(
2393 script.contains(&format!("-o \"$DIR/{local}\"")),
2394 "bash script must write {local} via explicit -o, got:\n{script}"
2395 );
2396 }
2397 assert!(
2398 script.contains("cass models install all-minilm-l6-v2 --from-file \"$DIR\" -y"),
2399 "bash script must invoke install with --from-file"
2400 );
2401 }
2402
2403 #[test]
2404 fn air_gap_bash_script_quotes_urls_with_single_quotes() {
2405 let manifest = ModelManifest::minilm_v2();
2407 let script = manifest.air_gap_bash_script(None);
2408 let sample_url = manifest.download_url(&manifest.files[0]);
2409 assert!(script.contains(&format!("'{sample_url}'")));
2410 }
2411
2412 #[test]
2413 fn air_gap_powershell_script_forces_tls12_and_basic_parsing() {
2414 let manifest = ModelManifest::minilm_v2();
2415 let script = manifest.air_gap_powershell_script(None);
2416 assert!(
2417 script.contains("SecurityProtocolType]::Tls12"),
2418 "PowerShell script must opt into TLS 1.2 for Windows PowerShell 5.1 compat"
2419 );
2420 assert!(
2421 script.contains("Invoke-WebRequest -UseBasicParsing"),
2422 "PowerShell script must use -UseBasicParsing for PS 5.1 compat"
2423 );
2424 for file in &manifest.files {
2425 let local = file.local_name();
2426 assert!(
2427 script.contains(&format!("(Join-Path $dir '{local}')")),
2428 "PowerShell script must materialize output path for {local}, got:\n{script}"
2429 );
2430 }
2431 assert!(
2432 script.contains("cass models install all-minilm-l6-v2 --from-file $dir -y"),
2433 "PowerShell script must invoke install with --from-file"
2434 );
2435 }
2436
2437 #[test]
2438 fn air_gap_scripts_honor_mirror_base_url() {
2439 let manifest = ModelManifest::minilm_v2();
2440 let mirror = Some("https://mirror.example/cache");
2441 let bash = manifest.air_gap_bash_script(mirror);
2442 let ps = manifest.air_gap_powershell_script(mirror);
2443 assert!(bash.contains("https://mirror.example/cache"));
2444 assert!(!bash.contains("huggingface.co"));
2445 assert!(ps.contains("https://mirror.example/cache"));
2446 assert!(!ps.contains("huggingface.co"));
2447 }
2448
2449 #[test]
2450 fn test_normalize_mirror_base_url_trims_trailing_slash() {
2451 let normalized = normalize_mirror_base_url("https://mirror.example/cache/").unwrap();
2452 assert_eq!(normalized, "https://mirror.example/cache");
2453 }
2454
2455 #[test]
2456 fn test_normalize_mirror_base_url_rejects_invalid_values() {
2457 let cases = [
2458 ("mirror.example", "invalid mirror URL"),
2459 ("file:///tmp/mirror", "unsupported URL scheme"),
2460 (
2461 "https://mirror.example/cache?trace=abc",
2462 "must not include query or fragment",
2463 ),
2464 ];
2465
2466 for (input, expected_fragment) in cases {
2467 let err = normalize_mirror_base_url(input).unwrap_err();
2468 let message = err.to_string();
2469 assert!(
2470 message.contains(expected_fragment),
2471 "expected error for {input:?} to contain {expected_fragment:?}, got {message:?}"
2472 );
2473 }
2474 }
2475
2476 #[test]
2477 fn test_invalid_mirror_url_helper_shape() {
2478 let err = invalid_mirror_url("ftp://mirror.example/model.onnx", "unsupported scheme");
2479
2480 assert!(matches!(
2481 &err,
2482 DownloadError::InvalidMirrorUrl { url, reason }
2483 if url == "ftp://mirror.example/model.onnx" && reason == "unsupported scheme"
2484 ));
2485 assert_eq!(
2486 err.to_string(),
2487 "invalid mirror URL 'ftp://mirror.example/model.onnx': unsupported scheme"
2488 );
2489 assert!(!err.is_retryable());
2490 }
2491
2492 #[test]
2493 fn test_check_model_installed_missing() {
2494 let tmp = tempfile::tempdir().unwrap();
2495 let model_dir = tmp.path().join("nonexistent");
2496 assert_eq!(
2497 check_model_installed(&model_dir, &ModelManifest::minilm_v2()),
2498 ModelState::NotInstalled
2499 );
2500 }
2501
2502 #[test]
2503 fn test_check_model_installed_no_marker() {
2504 let tmp = tempfile::tempdir().unwrap();
2505 let model_dir = tmp.path().join("model");
2506 let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/models");
2508 fs::create_dir_all(&model_dir).unwrap();
2509 fs::copy(fixture_dir.join("model.onnx"), model_dir.join("model.onnx")).unwrap();
2510 assert_eq!(
2511 check_model_installed(&model_dir, &ModelManifest::minilm_v2()),
2512 ModelState::NotInstalled
2513 );
2514 }
2515
2516 #[test]
2517 fn test_check_model_installed_ready() {
2518 let tmp = tempfile::tempdir().unwrap();
2519 let model_dir = tmp.path().join("model");
2520 copy_model_fixtures(&model_dir).unwrap();
2522 fs::write(model_dir.join(".verified"), "revision=test\n").unwrap();
2523 assert_eq!(
2524 check_model_installed(&model_dir, &ModelManifest::minilm_v2()),
2525 ModelState::Ready
2526 );
2527 }
2528
2529 #[test]
2530 fn classify_cache_policy_disabled_takes_precedence_over_missing() {
2531 let tmp = tempfile::tempdir().unwrap();
2532 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2533 let policy = ModelAcquisitionPolicy {
2534 downloads_enabled: false,
2535 offline: true,
2536 max_model_bytes: Some(1),
2537 ..ModelAcquisitionPolicy::default()
2538 };
2539
2540 let report = classify_model_cache(tmp.path(), &manifest, &policy);
2541 assert_eq!(report.state_code(), "disabled_by_policy");
2542 assert!(matches!(
2543 report.state,
2544 ModelCacheState::DisabledByPolicy { .. }
2545 ));
2546 }
2547
2548 #[test]
2549 fn classify_cache_detects_resume_stage_before_missing() {
2550 let tmp = tempfile::tempdir().unwrap();
2551 let model_dir = tmp.path().join("model");
2552 let staging_dir = tmp.path().join("model.downloading");
2553 fs::create_dir_all(&staging_dir).unwrap();
2554 fs::write(staging_dir.join("model.onnx"), b"partial").unwrap();
2555 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2556
2557 let report =
2558 classify_model_cache(&model_dir, &manifest, &ModelAcquisitionPolicy::default());
2559 assert_eq!(report.state_code(), "acquiring");
2560 assert!(matches!(
2561 report.state,
2562 ModelCacheState::Acquiring {
2563 bytes_present: 7,
2564 total_bytes: 5,
2565 ..
2566 }
2567 ));
2568 }
2569
2570 #[test]
2571 fn classify_cache_distinguishes_offline_and_budget_blocks() {
2572 let tmp = tempfile::tempdir().unwrap();
2573 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2574
2575 let offline = ModelAcquisitionPolicy {
2576 offline: true,
2577 ..ModelAcquisitionPolicy::default()
2578 };
2579 let report = classify_model_cache(tmp.path(), &manifest, &offline);
2580 assert_eq!(report.state_code(), "offline_blocked");
2581
2582 let budget = ModelAcquisitionPolicy {
2583 max_model_bytes: Some(1),
2584 ..ModelAcquisitionPolicy::default()
2585 };
2586 let report = classify_model_cache(tmp.path(), &manifest, &budget);
2587 assert_eq!(report.state_code(), "budget_blocked");
2588 }
2589
2590 #[test]
2591 fn classify_cache_accepts_preseeded_local_manifest_files() {
2592 let tmp = tempfile::tempdir().unwrap();
2593 let model_dir = tmp.path().join("model");
2594 fs::create_dir_all(model_dir.join("onnx")).unwrap();
2595 fs::write(model_dir.join("onnx/model.onnx"), b"model").unwrap();
2596 fs::write(model_dir.join("tokenizer.json"), b"tok").unwrap();
2597 let manifest = build_test_manifest(
2598 "repo/model",
2599 "rev1",
2600 &[("onnx/model.onnx", b"model"), ("tokenizer.json", b"tok")],
2601 );
2602
2603 let report =
2604 classify_model_cache(&model_dir, &manifest, &ModelAcquisitionPolicy::default());
2605 assert_eq!(report.state_code(), "preseeded_local");
2606 assert!(report.is_usable());
2607 }
2608
2609 #[test]
2610 fn classify_cache_detects_checksum_mismatch() {
2611 let tmp = tempfile::tempdir().unwrap();
2612 let model_dir = tmp.path().join("model");
2613 fs::create_dir_all(&model_dir).unwrap();
2614 fs::write(model_dir.join("model.onnx"), b"wrong").unwrap();
2615 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2616
2617 let report =
2618 classify_model_cache(&model_dir, &manifest, &ModelAcquisitionPolicy::default());
2619 assert_eq!(report.state_code(), "checksum_mismatch");
2620 assert!(matches!(
2621 report.state,
2622 ModelCacheState::ChecksumMismatch { .. }
2623 ));
2624 }
2625
2626 #[test]
2627 fn classify_cache_metadata_trusts_verified_marker_without_hashing_payload() {
2628 let tmp = tempfile::tempdir().unwrap();
2629 let model_dir = tmp.path().join("model");
2630 fs::create_dir_all(&model_dir).unwrap();
2631 fs::write(model_dir.join("model.onnx"), b"m0del").unwrap();
2632 fs::write(
2633 model_dir.join(".verified"),
2634 "revision=rev1\nsource=registry\n",
2635 )
2636 .unwrap();
2637 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2638
2639 let metadata_report = classify_model_cache_metadata(
2640 &model_dir,
2641 &manifest,
2642 &ModelAcquisitionPolicy::default(),
2643 );
2644 assert_eq!(metadata_report.state_code(), "acquired");
2645 assert!(metadata_report.is_usable());
2646
2647 let full_report =
2648 classify_model_cache(&model_dir, &manifest, &ModelAcquisitionPolicy::default());
2649 assert_eq!(full_report.state_code(), "checksum_mismatch");
2650 }
2651
2652 #[test]
2653 fn classify_cache_detects_incompatible_revision() {
2654 let tmp = tempfile::tempdir().unwrap();
2655 let model_dir = tmp.path().join("model");
2656 fs::create_dir_all(&model_dir).unwrap();
2657 fs::write(model_dir.join("model.onnx"), b"model").unwrap();
2658 fs::write(model_dir.join(".verified"), "revision=old\n").unwrap();
2659 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2660
2661 let report =
2662 classify_model_cache(&model_dir, &manifest, &ModelAcquisitionPolicy::default());
2663 assert_eq!(report.state_code(), "incompatible_version");
2664 assert!(matches!(
2665 report.state,
2666 ModelCacheState::IncompatibleVersion {
2667 current_revision,
2668 expected_revision
2669 } if current_revision == "old" && expected_revision == "rev1"
2670 ));
2671 }
2672
2673 #[test]
2674 fn classify_cache_reports_mirror_sourced_marker() {
2675 let tmp = tempfile::tempdir().unwrap();
2676 let model_dir = tmp.path().join("model");
2677 fs::create_dir_all(&model_dir).unwrap();
2678 fs::write(model_dir.join("model.onnx"), b"model").unwrap();
2679 fs::write(
2680 model_dir.join(".verified"),
2681 "revision=rev1\nsource=mirror:https://mirror.example/cache\n",
2682 )
2683 .unwrap();
2684 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2685
2686 let report =
2687 classify_model_cache(&model_dir, &manifest, &ModelAcquisitionPolicy::default());
2688 assert_eq!(report.state_code(), "mirror_sourced");
2689 assert!(matches!(
2690 report.state,
2691 ModelCacheState::MirrorSourced {
2692 mirror_base_url,
2693 ..
2694 } if mirror_base_url == "https://mirror.example/cache"
2695 ));
2696 }
2697
2698 #[test]
2699 fn classify_cache_reports_quarantine_marker() {
2700 let tmp = tempfile::tempdir().unwrap();
2701 let model_dir = tmp.path().join("model");
2702 fs::create_dir_all(&model_dir).unwrap();
2703 fs::write(model_dir.join(".quarantined"), "bad checksum\n").unwrap();
2704 let manifest = build_test_manifest("repo/model", "rev1", &[("model.onnx", b"model")]);
2705
2706 let report =
2707 classify_model_cache(&model_dir, &manifest, &ModelAcquisitionPolicy::default());
2708 assert_eq!(report.state_code(), "quarantined_corrupt");
2709 assert!(matches!(
2710 report.state,
2711 ModelCacheState::QuarantinedCorrupt { reason, .. } if reason == "bad checksum"
2712 ));
2713 }
2714
2715 #[test]
2716 fn test_compute_sha256() {
2717 let tmp = tempfile::tempdir().unwrap();
2718 let file_path = tmp.path().join("test.txt");
2719 fs::write(&file_path, b"hello world").unwrap();
2720 let hash = compute_sha256(&file_path).unwrap();
2721 assert_eq!(
2723 hash,
2724 "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"
2725 );
2726 }
2727
2728 #[test]
2729 fn test_check_version_mismatch_none() {
2730 let tmp = tempfile::tempdir().unwrap();
2731 let model_dir = tmp.path().join("model");
2732 fs::create_dir_all(&model_dir).unwrap();
2733 let manifest = ModelManifest::minilm_v2();
2735 fs::write(
2736 model_dir.join(".verified"),
2737 format!("revision={}\n", manifest.revision),
2738 )
2739 .unwrap();
2740
2741 let result = check_version_mismatch(&model_dir, &manifest);
2742 assert!(result.is_none());
2743 }
2744
2745 #[test]
2746 fn test_model_file_local_name() {
2747 let file = ModelFile {
2749 name: "onnx/model.onnx".into(),
2750 sha256: "abc123".into(),
2751 size: 1000,
2752 };
2753 assert_eq!(file.local_name(), "model.onnx");
2754
2755 let file2 = ModelFile {
2757 name: "tokenizer.json".into(),
2758 sha256: "def456".into(),
2759 size: 500,
2760 };
2761 assert_eq!(file2.local_name(), "tokenizer.json");
2762
2763 let file3 = ModelFile {
2765 name: "path/to/deep/model.bin".into(),
2766 sha256: "ghi789".into(),
2767 size: 2000,
2768 };
2769 assert_eq!(file3.local_name(), "model.bin");
2770 }
2771
2772 #[test]
2773 fn test_check_version_mismatch_found() {
2774 let tmp = tempfile::tempdir().unwrap();
2775 let model_dir = tmp.path().join("model");
2776 fs::create_dir_all(&model_dir).unwrap();
2777 fs::write(model_dir.join(".verified"), "revision=old_version\n").unwrap();
2778
2779 let manifest = ModelManifest::minilm_v2();
2780 let result = check_version_mismatch(&model_dir, &manifest);
2781 assert!(matches!(result, Some(ModelState::UpdateAvailable { .. })));
2782 }
2783
2784 #[test]
2785 fn test_atomic_install_preserves_preexisting_legacy_backup_dir() {
2786 let tmp = tempfile::tempdir().unwrap();
2787 let target_dir = tmp.path().join("model");
2788 copy_model_fixtures(&target_dir).unwrap();
2789 fs::write(target_dir.join(".verified"), "revision=old\n").unwrap();
2790
2791 let legacy_backup_dir = tmp.path().join("model.bak");
2792 fs::create_dir_all(&legacy_backup_dir).unwrap();
2793 fs::write(legacy_backup_dir.join("sentinel.txt"), "keep me").unwrap();
2794
2795 let downloader = ModelDownloader::new(target_dir.clone());
2796 copy_model_fixtures(&downloader.temp_dir).unwrap();
2797 fs::write(downloader.temp_dir.join(".verified"), "revision=new\n").unwrap();
2798
2799 downloader.atomic_install().unwrap();
2800
2801 assert_eq!(
2802 fs::read_to_string(legacy_backup_dir.join("sentinel.txt")).unwrap(),
2803 "keep me"
2804 );
2805 assert_eq!(
2806 fs::read_to_string(target_dir.join(".verified")).unwrap(),
2807 "revision=new\n"
2808 );
2809 }
2810
2811 #[test]
2812 fn test_atomic_install_rejects_file_target() {
2813 let tmp = tempfile::tempdir().unwrap();
2814 let target_dir = tmp.path().join("model");
2815 fs::write(&target_dir, "not a directory").unwrap();
2816
2817 let downloader = ModelDownloader::new(target_dir.clone());
2818 copy_model_fixtures(&downloader.temp_dir).unwrap();
2819
2820 let err = downloader.atomic_install().unwrap_err();
2821
2822 assert!(
2823 err.to_string().contains("not a directory"),
2824 "unexpected error: {err}"
2825 );
2826 assert!(downloader.temp_dir.exists());
2827 assert_eq!(fs::read_to_string(&target_dir).unwrap(), "not a directory");
2828 }
2829
2830 #[test]
2831 #[cfg(unix)]
2832 fn test_atomic_install_rejects_dangling_symlink_target() {
2833 use std::os::unix::fs::symlink;
2834
2835 let tmp = tempfile::tempdir().unwrap();
2836 let target_dir = tmp.path().join("model");
2837 let missing_target = tmp.path().join("missing-model");
2838 symlink(&missing_target, &target_dir).unwrap();
2839
2840 let downloader = ModelDownloader::new(target_dir.clone());
2841 copy_model_fixtures(&downloader.temp_dir).unwrap();
2842
2843 let err = downloader.atomic_install().unwrap_err();
2844
2845 assert!(
2846 err.to_string().contains("through symlink"),
2847 "unexpected error: {err}"
2848 );
2849 assert!(downloader.temp_dir.exists());
2850 assert!(
2851 fs::symlink_metadata(&target_dir)
2852 .unwrap()
2853 .file_type()
2854 .is_symlink()
2855 );
2856 assert!(!missing_target.exists());
2857 }
2858
2859 #[test]
2860 fn test_write_verified_marker_overwrites_existing_marker() {
2861 let tmp = tempfile::tempdir().unwrap();
2862 let target_dir = tmp.path().join("model");
2863 fs::create_dir_all(&target_dir).unwrap();
2864 fs::write(target_dir.join(".verified"), "revision=old\n").unwrap();
2865
2866 let downloader = ModelDownloader::new(target_dir.clone());
2867 let manifest = ModelManifest::minilm_v2();
2868 downloader.write_verified_marker(&manifest, None).unwrap();
2869
2870 let marker = fs::read_to_string(target_dir.join(".verified")).unwrap();
2871 assert!(marker.contains(&format!("revision={}", manifest.revision)));
2872 assert!(marker.contains("verified_at="));
2873 assert!(marker.contains("source=registry"));
2874 }
2875
2876 #[test]
2877 fn test_download_error_display() {
2878 let display_cases = [
2879 (
2880 DownloadError::NetworkError("connection refused".into()),
2881 "network error: connection refused",
2882 ),
2883 (
2884 DownloadError::VerificationFailed {
2885 file: "test.onnx".into(),
2886 expected: "abc".into(),
2887 actual: "def".into(),
2888 },
2889 "verification failed for test.onnx: expected abc, got def",
2890 ),
2891 (DownloadError::Cancelled, "download cancelled"),
2892 (DownloadError::Timeout, "download timed out"),
2893 (
2894 DownloadError::HttpError {
2895 status: 503,
2896 message: "service unavailable".into(),
2897 },
2898 "HTTP error 503: service unavailable",
2899 ),
2900 (
2901 DownloadError::ManifestNotVerified {
2902 model_id: "test-model".into(),
2903 unverified_files: vec!["model.onnx".into(), "config.json".into()],
2904 revision_unpinned: true,
2905 },
2906 "model 'test-model' is not production-ready: 2 file(s) have placeholder checksums and revision is not pinned",
2907 ),
2908 (
2909 DownloadError::ManifestNotVerified {
2910 model_id: "test-model".into(),
2911 unverified_files: vec!["model.onnx".into()],
2912 revision_unpinned: false,
2913 },
2914 "model 'test-model' is not production-ready: 1 file(s) have placeholder checksums",
2915 ),
2916 (
2917 DownloadError::InvalidMirrorUrl {
2918 url: "ftp://mirror.example/model.onnx".into(),
2919 reason: "unsupported scheme".into(),
2920 },
2921 "invalid mirror URL 'ftp://mirror.example/model.onnx': unsupported scheme",
2922 ),
2923 ];
2924
2925 for (err, expected) in display_cases {
2926 assert_eq!(err.to_string(), expected);
2927 }
2928
2929 let err: DownloadError = std::io::Error::other("disk full").into();
2930
2931 assert_eq!(err.to_string(), "I/O error: disk full");
2932 let source = err.source().expect("I/O errors expose their source");
2933 assert_eq!(source.to_string(), "disk full");
2934
2935 assert!(
2936 DownloadError::NetworkError("connection refused".into())
2937 .source()
2938 .is_none(),
2939 "non-source variants must not gain an error source"
2940 );
2941 }
2942
2943 #[test]
2944 fn test_manifest_production_ready_minilm() {
2945 let manifest = ModelManifest::minilm_v2();
2947 assert!(manifest.has_verified_checksums());
2948 assert!(manifest.has_pinned_revision());
2949 assert!(manifest.is_production_ready());
2950 }
2951
2952 #[test]
2953 fn test_all_bakeoff_candidates_production_ready() {
2954 let candidates = ModelManifest::bakeoff_candidates();
2956
2957 assert_eq!(candidates.len(), 3, "Expected 3 bake-off candidates");
2959
2960 for manifest in &candidates {
2962 assert!(
2963 manifest.is_production_ready(),
2964 "Model {} should be production-ready",
2965 manifest.id
2966 );
2967 assert!(
2968 manifest.has_verified_checksums(),
2969 "Model {} should have verified checksums",
2970 manifest.id
2971 );
2972 assert!(
2973 manifest.has_pinned_revision(),
2974 "Model {} should have pinned revision",
2975 manifest.id
2976 );
2977 }
2978
2979 assert!(
2981 candidates
2982 .iter()
2983 .any(|m| m.id == "snowflake-arctic-embed-s"),
2984 "Snowflake should be in candidates"
2985 );
2986 assert!(
2987 candidates.iter().any(|m| m.id == "nomic-embed-text-v1.5"),
2988 "Nomic should be in candidates"
2989 );
2990 assert!(
2991 candidates
2992 .iter()
2993 .any(|m| m.id == "jina-reranker-v1-turbo-en"),
2994 "Jina Turbo should be in candidates"
2995 );
2996 }
2997
2998 #[test]
2999 fn test_downloader_cancellation() {
3000 let tmp = tempfile::tempdir().unwrap();
3001 let downloader = ModelDownloader::new(tmp.path().join("model"));
3002
3003 assert!(!downloader.is_cancelled());
3004 downloader.cancel();
3005 assert!(downloader.is_cancelled());
3006 }
3007
3008 #[test]
3009 fn test_prepare_temp_dir_prunes_stale_entries() {
3010 let tmp = tempfile::tempdir().unwrap();
3011 let downloader = ModelDownloader::new(tmp.path().join("model"));
3012 fs::create_dir_all(&downloader.temp_dir).unwrap();
3013 fs::write(downloader.temp_dir.join("model.onnx"), b"partial").unwrap();
3014 fs::write(downloader.temp_dir.join("stale.bin"), b"stale").unwrap();
3015 fs::create_dir_all(downloader.temp_dir.join("nested")).unwrap();
3016 fs::write(
3017 downloader.temp_dir.join("nested").join("should-remove.txt"),
3018 b"stale",
3019 )
3020 .unwrap();
3021
3022 downloader
3023 .prepare_temp_dir(&ModelManifest::minilm_v2())
3024 .unwrap();
3025
3026 assert!(downloader.temp_dir.join("model.onnx").exists());
3027 assert!(!downloader.temp_dir.join("stale.bin").exists());
3028 assert!(!downloader.temp_dir.join("nested").exists());
3029 }
3030
3031 #[test]
3032 #[cfg(unix)]
3033 fn test_prepare_temp_dir_removes_symlink_entries() {
3034 use std::os::unix::fs::symlink;
3035
3036 let tmp = tempfile::tempdir().unwrap();
3037 let downloader = ModelDownloader::new(tmp.path().join("model"));
3038 fs::create_dir_all(&downloader.temp_dir).unwrap();
3039 let outside = tmp.path().join("outside.bin");
3040 fs::write(&outside, b"outside").unwrap();
3041 symlink(&outside, downloader.temp_dir.join("model.onnx")).unwrap();
3042
3043 downloader
3044 .prepare_temp_dir(&ModelManifest::minilm_v2())
3045 .unwrap();
3046
3047 let metadata = fs::symlink_metadata(downloader.temp_dir.join("model.onnx"));
3048 assert!(metadata.is_err(), "symlink should be removed before resume");
3049 assert!(
3050 outside.exists(),
3051 "cleanup must not touch the symlink target"
3052 );
3053 }
3054
3055 #[test]
3056 #[cfg(unix)]
3057 fn test_prepare_temp_dir_rejects_symlinked_temp_dir_without_pruning_target() {
3058 use std::os::unix::fs::symlink;
3059
3060 let tmp = tempfile::tempdir().unwrap();
3061 let downloader = ModelDownloader::new(tmp.path().join("model"));
3062 let outside = tmp.path().join("outside-download-cache");
3063 fs::create_dir_all(&outside).unwrap();
3064 fs::write(outside.join("stale.bin"), b"must remain").unwrap();
3065 symlink(&outside, &downloader.temp_dir).unwrap();
3066
3067 let err = downloader
3068 .prepare_temp_dir(&ModelManifest::minilm_v2())
3069 .expect_err("symlinked temp dir must be rejected before pruning");
3070
3071 assert!(
3072 err.to_string().contains("temp dir through symlink"),
3073 "unexpected symlink-temp-dir error: {err}"
3074 );
3075 assert_eq!(fs::read(outside.join("stale.bin")).unwrap(), b"must remain");
3076 assert!(
3077 fs::symlink_metadata(&downloader.temp_dir)
3078 .unwrap()
3079 .file_type()
3080 .is_symlink()
3081 );
3082 }
3083
3084 #[test]
3085 #[cfg(unix)]
3086 fn test_cleanup_temp_skips_symlinked_temp_dir() {
3087 use std::os::unix::fs::symlink;
3088
3089 let tmp = tempfile::tempdir().unwrap();
3090 let downloader = ModelDownloader::new(tmp.path().join("model"));
3091 let outside = tmp.path().join("outside-download-cache");
3092 fs::create_dir_all(&outside).unwrap();
3093 fs::write(outside.join("sentinel.bin"), b"must remain").unwrap();
3094 symlink(&outside, &downloader.temp_dir).unwrap();
3095
3096 downloader.cleanup_temp();
3097
3098 assert_eq!(
3099 fs::read(outside.join("sentinel.bin")).unwrap(),
3100 b"must remain"
3101 );
3102 assert!(
3103 fs::symlink_metadata(&downloader.temp_dir)
3104 .unwrap()
3105 .file_type()
3106 .is_symlink()
3107 );
3108 }
3109
3110 #[test]
3111 fn test_retryable_error_classification() {
3112 let cases = [
3113 (DownloadError::NetworkError("boom".into()), true),
3114 (DownloadError::Timeout, true),
3115 (
3116 DownloadError::HttpError {
3117 status: 503,
3118 message: "unavailable".into(),
3119 },
3120 true,
3121 ),
3122 (
3123 DownloadError::HttpError {
3124 status: 404,
3125 message: "missing".into(),
3126 },
3127 false,
3128 ),
3129 (DownloadError::Cancelled, false),
3130 (
3131 DownloadError::VerificationFailed {
3132 file: "model.onnx".into(),
3133 expected: "a".into(),
3134 actual: "b".into(),
3135 },
3136 false,
3137 ),
3138 ];
3139
3140 for (err, expected) in cases {
3141 assert_eq!(
3142 err.is_retryable(),
3143 expected,
3144 "retryability mismatch for {err}"
3145 );
3146 }
3147 }
3148
3149 #[test]
3150 fn test_cleanup_temp_for_error_preserves_partial_downloads_on_cancelled() {
3151 let tmp = tempfile::tempdir().unwrap();
3152 let downloader = ModelDownloader::new(tmp.path().join("model"));
3153 fs::create_dir_all(&downloader.temp_dir).unwrap();
3154 let partial = downloader.temp_dir.join("model.onnx");
3155 fs::write(&partial, b"partial").unwrap();
3156
3157 downloader.cleanup_temp_for_error(&DownloadError::Cancelled);
3158
3159 assert!(
3160 partial.exists(),
3161 "cancelled downloads should keep partial files for a resumable retry"
3162 );
3163 }
3164
3165 #[test]
3166 fn test_fail_if_cancelled_preserves_partial_downloads() {
3167 let tmp = tempfile::tempdir().unwrap();
3168 let downloader = ModelDownloader::new(tmp.path().join("model"));
3169 fs::create_dir_all(&downloader.temp_dir).unwrap();
3170 let partial = downloader.temp_dir.join("model.onnx");
3171 fs::write(&partial, b"partial").unwrap();
3172 downloader.cancel();
3173
3174 let result = downloader.fail_if_cancelled();
3175
3176 assert!(matches!(result, Err(DownloadError::Cancelled)));
3177 assert!(
3178 partial.exists(),
3179 "early cancellation checks should not discard resumable partial files"
3180 );
3181 }
3182
3183 #[test]
3184 fn test_cleanup_temp_for_error_discards_temp_after_verification_failure() {
3185 let tmp = tempfile::tempdir().unwrap();
3186 let downloader = ModelDownloader::new(tmp.path().join("model"));
3187 fs::create_dir_all(&downloader.temp_dir).unwrap();
3188 let partial = downloader.temp_dir.join("model.onnx");
3189 fs::write(&partial, b"partial").unwrap();
3190
3191 downloader.cleanup_temp_for_error(&DownloadError::VerificationFailed {
3192 file: "model.onnx".into(),
3193 expected: "good".into(),
3194 actual: "bad".into(),
3195 });
3196
3197 assert!(
3198 !downloader.temp_dir.exists(),
3199 "verification failures should discard the temp directory to avoid reusing corrupt data"
3200 );
3201 }
3202
3203 #[test]
3204 fn test_download_with_mirror_installs_verified_model_from_http_mirror() {
3205 let files = [
3206 ("onnx/model.onnx", b"mirror-model".as_slice()),
3207 ("tokenizer.json", br#"{"tokenizer":"ok"}"#.as_slice()),
3208 ];
3209 let manifest = build_test_manifest("mirror/test-model", "rev123", &files);
3210 let route_prefix = "/cache";
3211 let routes: Vec<(String, MirrorRoute)> = manifest
3212 .files
3213 .iter()
3214 .zip(files.iter())
3215 .map(|(file, (_, body))| {
3216 (
3217 mirror_route_path(route_prefix, &manifest, file),
3218 MirrorRoute {
3219 body: body.to_vec(),
3220 content_type: "application/octet-stream",
3221 chunk_size: 64,
3222 chunk_delay: Duration::ZERO,
3223 },
3224 )
3225 })
3226 .collect();
3227 let server = start_mirror_fixture_server(routes);
3228 let tmp = tempfile::tempdir().unwrap();
3229 let downloader = ModelDownloader::new(tmp.path().join("model"));
3230 let mirror_base = format!("{}/cache/", server.base_url);
3231
3232 downloader
3233 .download_with_mirror(&manifest, Some(&mirror_base), None)
3234 .unwrap();
3235
3236 for (name, body) in files {
3237 let installed = downloader.target_dir.join(
3238 Path::new(name)
3239 .file_name()
3240 .unwrap()
3241 .to_string_lossy()
3242 .as_ref(),
3243 );
3244 assert_eq!(
3245 fs::read(installed).unwrap(),
3246 body,
3247 "mirror install should persist the downloaded payload"
3248 );
3249 }
3250 let marker = fs::read_to_string(downloader.target_dir.join(".verified")).unwrap();
3251 assert!(
3252 marker.contains("revision=rev123"),
3253 "verified marker should preserve manifest identity after mirror install"
3254 );
3255 assert!(
3256 marker.contains("source=mirror:"),
3257 "verified marker should record mirror source"
3258 );
3259
3260 let requests = server.requests();
3261 assert_eq!(
3262 requests.len(),
3263 manifest.files.len(),
3264 "expected one request per manifest file"
3265 );
3266 assert!(
3267 requests
3268 .iter()
3269 .all(|request| request.path.starts_with("/cache/")),
3270 "mirror requests should stay under the configured mirror prefix: {requests:?}"
3271 );
3272 }
3273
3274 #[test]
3275 fn test_download_with_mirror_reports_missing_artifact_from_http_mirror() {
3276 let file_body = b"mirror-model".as_slice();
3277 let manifest = build_test_manifest(
3278 "mirror/test-model",
3279 "rev404",
3280 &[("onnx/model.onnx", file_body)],
3281 );
3282 let server = start_mirror_fixture_server(Vec::new());
3283 let tmp = tempfile::tempdir().unwrap();
3284 let downloader = ModelDownloader::new(tmp.path().join("model"));
3285 let mirror_base = format!("{}/cache", server.base_url);
3286
3287 let err = downloader
3288 .download_with_mirror(&manifest, Some(&mirror_base), None)
3289 .unwrap_err();
3290
3291 assert!(
3292 matches!(err, DownloadError::HttpError { status: 404, .. }),
3293 "missing mirror artifacts should surface as HTTP 404, got: {err}"
3294 );
3295 let requests = server.requests();
3296 assert_eq!(requests.len(), 1);
3297 assert!(
3298 requests[0].path.contains("/resolve/"),
3299 "mirror request should target the resolved artifact path: {requests:?}"
3300 );
3301 }
3302
3303 #[test]
3304 fn test_download_with_mirror_discards_corrupt_payload_from_http_mirror() {
3305 let manifest = build_test_manifest(
3306 "mirror/test-model",
3307 "revbad",
3308 &[("onnx/model.onnx", b"expected-bytes".as_slice())],
3309 );
3310 let route_prefix = "/cache";
3311 let server = start_mirror_fixture_server(vec![(
3312 mirror_route_path(route_prefix, &manifest, &manifest.files[0]),
3313 MirrorRoute {
3314 body: b"corrupt-bytes".to_vec(),
3315 content_type: "application/octet-stream",
3316 chunk_size: 64,
3317 chunk_delay: Duration::ZERO,
3318 },
3319 )]);
3320 let tmp = tempfile::tempdir().unwrap();
3321 let downloader = ModelDownloader::new(tmp.path().join("model"));
3322 let mirror_base = format!("{server_base}/cache", server_base = server.base_url);
3323
3324 let err = downloader
3325 .download_with_mirror(&manifest, Some(&mirror_base), None)
3326 .unwrap_err();
3327
3328 assert!(
3329 matches!(err, DownloadError::VerificationFailed { .. }),
3330 "corrupt mirror payloads must fail checksum verification, got: {err}"
3331 );
3332 assert!(
3333 !downloader.temp_dir.exists(),
3334 "verification failures should discard the temp directory so corrupt payloads are not reused"
3335 );
3336 assert!(
3337 !downloader.target_dir.exists(),
3338 "corrupt mirror payloads must not be promoted into the installed model directory"
3339 );
3340 }
3341
3342 #[test]
3343 fn test_download_with_mirror_resumes_after_cancelled_partial_download() {
3344 let large_payload = vec![b'x'; 128 * 1024];
3345 let manifest = build_test_manifest(
3346 "mirror/test-model",
3347 "revresume",
3348 &[("onnx/model.onnx", &large_payload)],
3349 );
3350 let route_prefix = "/cache";
3351 let server = start_mirror_fixture_server(vec![(
3352 mirror_route_path(route_prefix, &manifest, &manifest.files[0]),
3353 MirrorRoute {
3354 body: large_payload.clone(),
3355 content_type: "application/octet-stream",
3356 chunk_size: 1024,
3357 chunk_delay: Duration::from_millis(2),
3358 },
3359 )]);
3360 let tmp = tempfile::tempdir().unwrap();
3361 let downloader = Arc::new(ModelDownloader::new(tmp.path().join("model")));
3362 let mirror_base = format!("{server_base}/cache", server_base = server.base_url);
3363 let cancel_once = Arc::new(AtomicBool::new(false));
3364 let canceller = Arc::clone(&downloader);
3365 let cancel_flag = Arc::clone(&cancel_once);
3366
3367 let cancelled = downloader.download_with_mirror(
3368 &manifest,
3369 Some(&mirror_base),
3370 Some(Arc::new(move |progress| {
3371 if progress.total_bytes >= 16 * 1024
3372 && cancel_flag
3373 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
3374 .is_ok()
3375 {
3376 canceller.cancel();
3377 }
3378 })),
3379 );
3380
3381 assert!(
3382 matches!(cancelled, Err(DownloadError::Cancelled)),
3383 "first mirror attempt should stop with a cancellation so we can verify resumable recovery"
3384 );
3385 let partial_path = downloader.temp_dir.join("model.onnx");
3386 let partial_size = fs::metadata(&partial_path).unwrap().len();
3387 assert!(
3388 partial_size > 0 && partial_size < large_payload.len() as u64,
3389 "cancelled run should preserve a partial download for resume; got {partial_size} bytes"
3390 );
3391
3392 downloader
3393 .download_with_mirror(&manifest, Some(&mirror_base), None)
3394 .unwrap();
3395
3396 assert_eq!(
3397 fs::read(downloader.target_dir.join("model.onnx")).unwrap(),
3398 large_payload,
3399 "rerun after cancellation should finish the mirrored download and install the exact payload"
3400 );
3401 let requests = server.requests();
3402 assert!(
3403 requests
3404 .iter()
3405 .any(|request| request.range_start == Some(partial_size)),
3406 "rerun should resume from the preserved partial via Range requests; saw requests: {requests:?}"
3407 );
3408 }
3409}