1use crate::aliases::{AliasEntry, AliasRegistry, ResolvedAlias};
27use crate::cache::{
28 format_bytes, CacheConfig, CacheEntry, CacheManager, CacheStats, DownloadProgress,
29 EvictionPolicy,
30};
31use crate::error::{PachaError, Result};
32use crate::format::{detect_format, ModelFormat, QuantType};
33use crate::resolver::ModelResolver;
34use crate::uri::ModelUri;
35use serde::{Deserialize, Serialize};
36use std::path::{Path, PathBuf};
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct FetchConfig {
45 pub cache: CacheConfig,
47 pub default_quant: Option<QuantType>,
49 pub auto_pull: bool,
51 pub max_concurrent: usize,
53 pub verify_integrity: bool,
55 pub eviction_policy: EvictionPolicy,
57}
58
59impl Default for FetchConfig {
60 fn default() -> Self {
61 Self {
62 cache: CacheConfig::default(),
63 default_quant: Some(QuantType::Q4_K_M),
64 auto_pull: true,
65 max_concurrent: 2,
66 verify_integrity: true,
67 eviction_policy: EvictionPolicy::LRU,
68 }
69 }
70}
71
72impl FetchConfig {
73 #[must_use]
75 pub fn new() -> Self {
76 Self::default()
77 }
78
79 #[must_use]
81 pub fn with_cache(mut self, cache: CacheConfig) -> Self {
82 self.cache = cache;
83 self
84 }
85
86 #[must_use]
88 pub fn with_default_quant(mut self, quant: QuantType) -> Self {
89 self.default_quant = Some(quant);
90 self
91 }
92
93 #[must_use]
95 pub fn with_auto_pull(mut self, enabled: bool) -> Self {
96 contract_pre_configuration!();
97 self.auto_pull = enabled;
98 self
99 }
100
101 #[must_use]
103 pub fn with_eviction_policy(mut self, policy: EvictionPolicy) -> Self {
104 self.eviction_policy = policy;
105 self
106 }
107}
108
109#[derive(Debug)]
115pub struct FetchResult {
116 pub path: PathBuf,
118 pub format: ModelFormat,
120 pub size_bytes: u64,
122 pub cache_hit: bool,
124 pub reference: String,
126 pub resolved_uri: String,
128 pub hash: String,
130}
131
132impl FetchResult {
133 #[must_use]
135 pub fn size_human(&self) -> String {
136 format_bytes(self.size_bytes)
137 }
138
139 #[must_use]
141 pub fn is_quantized(&self) -> bool {
142 match &self.format {
143 ModelFormat::Gguf(info) => info.quantization.is_some(),
144 _ => false,
145 }
146 }
147
148 #[must_use]
150 pub fn quant_type(&self) -> Option<QuantType> {
151 match &self.format {
152 ModelFormat::Gguf(info) => {
153 info.quantization.as_ref().and_then(|q| QuantType::from_str(q))
154 }
155 _ => None,
156 }
157 }
158}
159
160pub struct ModelFetcher {
166 config: FetchConfig,
168 aliases: AliasRegistry,
170 cache: CacheManager,
172 resolver: Option<ModelResolver>,
174 cache_dir: PathBuf,
176}
177
178impl ModelFetcher {
179 pub fn new() -> Result<Self> {
181 Self::with_config(FetchConfig::default())
182 }
183
184 pub fn with_config(config: FetchConfig) -> Result<Self> {
186 let cache_dir = get_default_cache_dir();
187
188 std::fs::create_dir_all(&cache_dir).map_err(|e| {
190 PachaError::Io(std::io::Error::new(
191 e.kind(),
192 format!("Failed to create cache dir: {}", cache_dir.display()),
193 ))
194 })?;
195
196 let mut cache = CacheManager::new(config.cache.clone()).with_policy(config.eviction_policy);
198 Self::load_manifest(&cache_dir, &mut cache);
199
200 let resolver = ModelResolver::new_default().ok();
201
202 Ok(Self { config, aliases: AliasRegistry::with_defaults(), cache, resolver, cache_dir })
203 }
204
205 fn load_manifest(cache_dir: &Path, cache: &mut CacheManager) {
207 let manifest_path = cache_dir.join("manifest.json");
208 if let Ok(data) = std::fs::read_to_string(&manifest_path) {
209 if let Ok(entries) = serde_json::from_str::<Vec<CacheEntry>>(&data) {
210 for entry in entries {
211 if entry.path.exists() {
213 cache.add(entry);
214 }
215 }
216 }
217 }
218 }
219
220 fn save_manifest(&self) {
225 let manifest_path = self.cache_dir.join("manifest.json");
226 let tmp_path = self.cache_dir.join("manifest.json.tmp");
227 let entries: Vec<&CacheEntry> = self.cache.list();
228 if let Ok(data) = serde_json::to_string_pretty(&entries) {
229 if std::fs::write(&tmp_path, &data).is_ok() {
230 let _ = std::fs::rename(&tmp_path, &manifest_path);
231 }
232 }
233 }
234
235 pub fn with_cache_dir(cache_dir: PathBuf, config: FetchConfig) -> Result<Self> {
237 std::fs::create_dir_all(&cache_dir).map_err(|e| {
238 PachaError::Io(std::io::Error::new(
239 e.kind(),
240 format!("Failed to create cache dir: {}", cache_dir.display()),
241 ))
242 })?;
243
244 let cache = CacheManager::new(config.cache.clone()).with_policy(config.eviction_policy);
245
246 let resolver = ModelResolver::new_default().ok();
247
248 Ok(Self { config, aliases: AliasRegistry::with_defaults(), cache, resolver, cache_dir })
249 }
250
251 #[must_use]
253 pub fn config(&self) -> &FetchConfig {
254 &self.config
255 }
256
257 #[must_use]
259 pub fn aliases(&self) -> &AliasRegistry {
260 &self.aliases
261 }
262
263 pub fn add_alias(&mut self, alias: &str, uri: &str) -> Result<()> {
265 self.aliases.add(AliasEntry::new(alias, uri));
266 Ok(())
267 }
268
269 pub fn resolve_ref(&self, model_ref: &str) -> Result<ResolvedAlias> {
271 contract_pre_configuration!(model_ref.as_bytes());
272 let resolved = self.aliases.resolve(model_ref);
273 if resolved.is_alias || model_ref.contains("://") {
275 Ok(resolved)
276 } else {
277 Err(PachaError::NotFound {
279 kind: "alias".to_string(),
280 name: model_ref.to_string(),
281 version: "N/A".to_string(),
282 })
283 }
284 }
285
286 pub fn pull<F>(&mut self, model_ref: &str, progress_fn: F) -> Result<FetchResult>
288 where
289 F: Fn(&DownloadProgress),
290 {
291 contract_pre_configuration!(model_ref);
292 let resolved = self.aliases.resolve(model_ref);
294
295 let uri_str = resolved.uri;
297
298 let cache_key = Self::cache_key(&uri_str);
300 if let Some(entry) = self.cache.get(&cache_key, "1.0") {
301 let format = format_from_path(&entry.path);
302 return Ok(FetchResult {
303 path: entry.path.clone(),
304 format,
305 size_bytes: entry.size_bytes,
306 cache_hit: true,
307 reference: model_ref.to_string(),
308 resolved_uri: uri_str,
309 hash: entry.hash.clone(),
310 });
311 }
312
313 let uri = ModelUri::parse(&uri_str)?;
315 let resolver = self
316 .resolver
317 .as_ref()
318 .ok_or_else(|| PachaError::NotInitialized(PathBuf::from("~/.pacha")))?;
319
320 let mut progress = DownloadProgress::new(0); progress_fn(&progress);
323
324 let resolved_model = resolver.resolve(&uri)?;
326
327 progress = DownloadProgress::new(resolved_model.data.len() as u64);
329 progress.update(resolved_model.data.len() as u64);
330 progress_fn(&progress);
331
332 let format = detect_format(&resolved_model.data);
334
335 let hash = blake3::hash(&resolved_model.data).to_hex().to_string();
337
338 let extension = match &format {
340 ModelFormat::Gguf(_) => "gguf",
341 ModelFormat::SafeTensors(_) => "safetensors",
342 ModelFormat::Apr(_) => "apr",
343 ModelFormat::Onnx(_) => "onnx",
344 ModelFormat::PyTorch => "pt",
345 ModelFormat::Unknown => "bin",
346 };
347
348 let filename = format!("{}.{}", &hash[..16], extension);
350 let cache_path = self.cache_dir.join(&filename);
351
352 std::fs::write(&cache_path, &resolved_model.data).map_err(|e| {
353 PachaError::Io(std::io::Error::new(
354 e.kind(),
355 format!("Failed to write to cache: {}", cache_path.display()),
356 ))
357 })?;
358
359 let entry = CacheEntry::new(
361 &cache_key,
362 "1.0",
363 resolved_model.data.len() as u64,
364 &hash,
365 cache_path.clone(),
366 );
367 self.cache.add(entry);
368
369 self.save_manifest();
371
372 Ok(FetchResult {
373 path: cache_path,
374 format,
375 size_bytes: resolved_model.data.len() as u64,
376 cache_hit: false,
377 reference: model_ref.to_string(),
378 resolved_uri: uri_str,
379 hash,
380 })
381 }
382
383 pub fn pull_quiet(&mut self, model_ref: &str) -> Result<FetchResult> {
385 contract_pre_configuration!(model_ref);
386 self.pull(model_ref, |_| {})
387 }
388
389 #[must_use]
391 pub fn is_cached(&self, model_ref: &str) -> bool {
392 let resolved = self.aliases.resolve(model_ref);
393 let key = Self::cache_key(&resolved.uri);
394 self.cache.contains(&key, "1.0")
395 }
396
397 pub fn remove(&mut self, model_ref: &str) -> Result<bool> {
399 let resolved = self.aliases.resolve(model_ref);
400 let uri = resolved.uri;
401
402 let key = Self::cache_key(&uri);
403 if let Some(entry) = self.cache.remove(&key, "1.0") {
404 if entry.path.exists() {
406 std::fs::remove_file(&entry.path).ok();
407 }
408 Ok(true)
409 } else {
410 Ok(false)
411 }
412 }
413
414 #[must_use]
416 pub fn list(&self) -> Vec<CachedModel> {
417 self.cache
418 .list()
419 .iter()
420 .map(|e| {
421 let format = format_from_path(&e.path);
422 CachedModel {
423 name: e.name.clone(),
424 version: e.version.clone(),
425 size_bytes: e.size_bytes,
426 format,
427 path: e.path.clone(),
428 last_accessed: e.last_accessed,
429 access_count: e.access_count,
430 pinned: e.pinned,
431 }
432 })
433 .collect()
434 }
435
436 #[must_use]
438 pub fn stats(&self) -> CacheStats {
439 self.cache.stats()
440 }
441
442 pub fn cleanup(&mut self) -> u64 {
444 self.cache.cleanup_to_target()
445 }
446
447 pub fn cleanup_old(&mut self) -> u64 {
449 self.cache.cleanup_old_entries()
450 }
451
452 pub fn clear(&mut self) -> u64 {
454 for entry in self.cache.list() {
456 if entry.path.exists() {
457 std::fs::remove_file(&entry.path).ok();
458 }
459 }
460 self.cache.clear()
461 }
462
463 pub fn pin(&mut self, model_ref: &str) -> bool {
465 let key = Self::cache_key(model_ref);
466 self.cache.pin(&key, "1.0")
467 }
468
469 pub fn unpin(&mut self, model_ref: &str) -> bool {
471 let key = Self::cache_key(model_ref);
472 self.cache.unpin(&key, "1.0")
473 }
474
475 #[must_use]
477 pub fn cache_dir(&self) -> &PathBuf {
478 &self.cache_dir
479 }
480
481 fn cache_key(uri: &str) -> String {
483 uri.replace("://", "_").replace('/', "_").replace(':', "_")
485 }
486}
487
488#[derive(Debug, Clone)]
494pub struct CachedModel {
495 pub name: String,
497 pub version: String,
499 pub size_bytes: u64,
501 pub format: ModelFormat,
503 pub path: PathBuf,
505 pub last_accessed: std::time::SystemTime,
507 pub access_count: u64,
509 pub pinned: bool,
511}
512
513impl CachedModel {
514 #[must_use]
516 pub fn size_human(&self) -> String {
517 format_bytes(self.size_bytes)
518 }
519
520 #[must_use]
522 pub fn quant_type(&self) -> Option<QuantType> {
523 match &self.format {
524 ModelFormat::Gguf(info) => {
525 info.quantization.as_ref().and_then(|q| QuantType::from_str(q))
526 }
527 _ => None,
528 }
529 }
530}
531
532fn format_from_path(path: &Path) -> ModelFormat {
538 let ext = path.extension().and_then(|e| e.to_str()).map(|e| e.to_lowercase());
539
540 match ext.as_deref() {
541 Some("gguf") => ModelFormat::Gguf(Default::default()),
542 Some("safetensors") => ModelFormat::SafeTensors(Default::default()),
543 Some("apr") => ModelFormat::Apr(Default::default()),
544 Some("onnx") => ModelFormat::Onnx(Default::default()),
545 Some("pt") | Some("pth") => ModelFormat::PyTorch,
546 _ => ModelFormat::Unknown,
547 }
548}
549
550fn get_default_cache_dir() -> PathBuf {
552 if let Ok(cache_home) = std::env::var("XDG_CACHE_HOME") {
554 return PathBuf::from(cache_home).join("pacha").join("models");
555 }
556
557 if let Ok(home) = std::env::var("HOME") {
559 return PathBuf::from(home).join(".cache").join("pacha").join("models");
560 }
561
562 if let Ok(local_app_data) = std::env::var("LOCALAPPDATA") {
564 return PathBuf::from(local_app_data).join("pacha").join("cache").join("models");
565 }
566
567 PathBuf::from(".cache").join("pacha").join("models")
569}
570
571#[cfg(test)]
576mod tests {
577 use super::*;
578 use tempfile::TempDir;
579
580 #[test]
585 fn test_fetch_config_default() {
586 let config = FetchConfig::default();
587 assert!(config.auto_pull);
588 assert_eq!(config.max_concurrent, 2);
589 assert!(config.verify_integrity);
590 }
591
592 #[test]
593 fn test_fetch_config_builder() {
594 let config = FetchConfig::new()
595 .with_default_quant(QuantType::Q8_0)
596 .with_auto_pull(false)
597 .with_eviction_policy(EvictionPolicy::LFU);
598
599 assert_eq!(config.default_quant, Some(QuantType::Q8_0));
600 assert!(!config.auto_pull);
601 assert_eq!(config.eviction_policy, EvictionPolicy::LFU);
602 }
603
604 #[test]
605 fn test_fetch_config_with_cache() {
606 let cache_config = CacheConfig::new().with_max_size_gb(100.0);
607 let config = FetchConfig::new().with_cache(cache_config.clone());
608
609 assert_eq!(config.cache.max_size_bytes, cache_config.max_size_bytes);
610 }
611
612 #[test]
617 fn test_fetcher_with_cache_dir() {
618 let dir = TempDir::new().unwrap();
619 let result = ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default());
620 assert!(result.is_ok());
621 }
622
623 #[test]
624 fn test_fetcher_cache_dir_created() {
625 let dir = TempDir::new().unwrap();
626 let cache_dir = dir.path().join("models");
627
628 let _ = ModelFetcher::with_cache_dir(cache_dir.clone(), FetchConfig::default()).unwrap();
629
630 assert!(cache_dir.exists());
631 }
632
633 #[test]
634 fn test_fetcher_config_access() {
635 let dir = TempDir::new().unwrap();
636 let config = FetchConfig::new().with_auto_pull(false);
637 let fetcher = ModelFetcher::with_cache_dir(dir.path().to_path_buf(), config).unwrap();
638
639 assert!(!fetcher.config().auto_pull);
640 }
641
642 #[test]
647 fn test_fetcher_has_default_aliases() {
648 let dir = TempDir::new().unwrap();
649 let fetcher =
650 ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
651
652 let aliases = fetcher.aliases();
653 assert!(aliases.get("llama3").is_some());
654 assert!(aliases.get("mistral").is_some());
655 }
656
657 #[test]
658 fn test_fetcher_add_alias() {
659 let dir = TempDir::new().unwrap();
660 let mut fetcher =
661 ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
662
663 fetcher.add_alias("mymodel", "hf://my-org/my-model").unwrap();
664
665 assert!(fetcher.aliases().get("mymodel").is_some());
666 }
667
668 #[test]
669 fn test_fetcher_resolve_ref() {
670 let dir = TempDir::new().unwrap();
671 let fetcher =
672 ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
673
674 let resolved = fetcher.resolve_ref("llama3");
675 assert!(resolved.is_ok());
676 let uri = resolved.unwrap().uri;
677 assert!(uri.starts_with("hf://"), "Expected hf:// URI, got: {}", uri);
679 }
680
681 #[test]
682 fn test_fetcher_resolve_ref_not_found() {
683 let dir = TempDir::new().unwrap();
684 let fetcher =
685 ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
686
687 let resolved = fetcher.resolve_ref("nonexistent-model-xyz");
688 assert!(resolved.is_err());
689 }
690
691 #[test]
696 fn test_fetcher_is_cached_empty() {
697 let dir = TempDir::new().unwrap();
698 let fetcher =
699 ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
700
701 assert!(!fetcher.is_cached("llama3"));
702 }
703
704 #[test]
705 fn test_fetcher_stats_empty() {
706 let dir = TempDir::new().unwrap();
707 let fetcher =
708 ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
709
710 let stats = fetcher.stats();
711 assert_eq!(stats.model_count, 0);
712 assert_eq!(stats.total_size_bytes, 0);
713 }
714
715 #[test]
716 fn test_fetcher_list_empty() {
717 let dir = TempDir::new().unwrap();
718 let fetcher =
719 ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
720
721 assert!(fetcher.list().is_empty());
722 }
723
724 #[test]
725 fn test_fetcher_clear() {
726 let dir = TempDir::new().unwrap();
727 let mut fetcher =
728 ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
729
730 let freed = fetcher.clear();
731 assert_eq!(freed, 0); }
733
734 #[test]
735 fn test_fetcher_cleanup() {
736 let dir = TempDir::new().unwrap();
737 let mut fetcher =
738 ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
739
740 let freed = fetcher.cleanup();
741 assert_eq!(freed, 0);
742 }
743
744 #[test]
749 fn test_cache_key_generation() {
750 let key1 = ModelFetcher::cache_key("hf://meta-llama/Llama-3-8B");
751 let key2 = ModelFetcher::cache_key("pacha://model:1.0.0");
752
753 assert!(!key1.contains("://"));
754 assert!(!key2.contains("://"));
755 }
756
757 #[test]
758 fn test_cache_key_unique() {
759 let key1 = ModelFetcher::cache_key("hf://org/model1");
760 let key2 = ModelFetcher::cache_key("hf://org/model2");
761
762 assert_ne!(key1, key2);
763 }
764
765 #[test]
770 fn test_fetch_result_size_human() {
771 let result = FetchResult {
772 path: PathBuf::from("/cache/model.gguf"),
773 format: ModelFormat::Unknown,
774 size_bytes: 4 * 1024 * 1024 * 1024, cache_hit: true,
776 reference: "llama3".to_string(),
777 resolved_uri: "hf://meta-llama/Llama-3-8B".to_string(),
778 hash: "abc123".to_string(),
779 };
780
781 assert!(result.size_human().contains("GB"));
782 }
783
784 #[test]
785 fn test_fetch_result_not_quantized() {
786 let result = FetchResult {
787 path: PathBuf::from("/cache/model.safetensors"),
788 format: ModelFormat::SafeTensors(Default::default()),
789 size_bytes: 1000,
790 cache_hit: false,
791 reference: "test".to_string(),
792 resolved_uri: "test".to_string(),
793 hash: "hash".to_string(),
794 };
795
796 assert!(!result.is_quantized());
797 assert!(result.quant_type().is_none());
798 }
799
800 #[test]
801 fn test_fetch_result_quantized_gguf() {
802 use crate::format::GgufInfo;
803
804 let result = FetchResult {
805 path: PathBuf::from("/cache/model.gguf"),
806 format: ModelFormat::Gguf(GgufInfo {
807 version: 3,
808 tensor_count: 100,
809 metadata_count: 10,
810 quantization: Some("Q4_K_M".to_string()),
811 ..Default::default()
812 }),
813 size_bytes: 4_000_000_000,
814 cache_hit: true,
815 reference: "llama3:8b-q4_k_m".to_string(),
816 resolved_uri: "hf://...".to_string(),
817 hash: "hash".to_string(),
818 };
819
820 assert!(result.is_quantized());
821 assert_eq!(result.quant_type(), Some(QuantType::Q4_K_M));
822 }
823
824 #[test]
829 fn test_cached_model_size_human() {
830 let model = CachedModel {
831 name: "llama3".to_string(),
832 version: "8b".to_string(),
833 size_bytes: 4 * 1024 * 1024 * 1024,
834 format: ModelFormat::Unknown,
835 path: PathBuf::from("/cache"),
836 last_accessed: std::time::SystemTime::now(),
837 access_count: 5,
838 pinned: false,
839 };
840
841 assert!(model.size_human().contains("GB"));
842 }
843
844 #[test]
845 fn test_cached_model_quant_type() {
846 use crate::format::GgufInfo;
847
848 let model = CachedModel {
849 name: "llama3".to_string(),
850 version: "8b".to_string(),
851 size_bytes: 4_000_000_000,
852 format: ModelFormat::Gguf(GgufInfo {
853 version: 3,
854 tensor_count: 100,
855 metadata_count: 10,
856 quantization: Some("Q8_0".to_string()),
857 ..Default::default()
858 }),
859 path: PathBuf::from("/cache/model.gguf"),
860 last_accessed: std::time::SystemTime::now(),
861 access_count: 1,
862 pinned: true,
863 };
864
865 assert_eq!(model.quant_type(), Some(QuantType::Q8_0));
866 }
867
868 #[test]
873 fn test_fetcher_pin_nonexistent() {
874 let dir = TempDir::new().unwrap();
875 let mut fetcher =
876 ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
877
878 assert!(!fetcher.pin("nonexistent"));
879 }
880
881 #[test]
882 fn test_fetcher_unpin_nonexistent() {
883 let dir = TempDir::new().unwrap();
884 let mut fetcher =
885 ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
886
887 assert!(!fetcher.unpin("nonexistent"));
888 }
889
890 #[test]
895 fn test_fetcher_remove_nonexistent() {
896 let dir = TempDir::new().unwrap();
897 let mut fetcher =
898 ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
899
900 let result = fetcher.remove("nonexistent");
901 assert!(result.is_ok());
902 assert!(!result.unwrap());
903 }
904
905 #[test]
910 fn test_fetch_config_serialization() {
911 let config = FetchConfig::new().with_default_quant(QuantType::Q4_K_M).with_auto_pull(false);
912
913 let json = serde_json::to_string(&config).unwrap();
914 let parsed: FetchConfig = serde_json::from_str(&json).unwrap();
915
916 assert_eq!(parsed.default_quant, config.default_quant);
917 assert_eq!(parsed.auto_pull, config.auto_pull);
918 }
919}