1use crate::error::{PachaError, Result};
26use crate::model::{Model, ModelVersion};
27use crate::registry::{Registry, RegistryConfig};
28use crate::remote::RegistryAuth;
29#[cfg(feature = "remote")]
30use crate::remote::RemoteRegistry;
31use crate::uri::{ModelUri, UriScheme};
32use std::fs;
33use std::path::Path;
34
35#[derive(Debug)]
37pub struct ResolvedModel {
38 pub data: Vec<u8>,
40 pub source: ModelSource,
42 pub model: Option<Model>,
44}
45
46#[derive(Debug, Clone, PartialEq, Eq)]
48pub enum ModelSource {
49 LocalFile(String),
51 PachaLocal {
53 name: String,
55 version: String,
57 },
58 PachaRemote {
60 host: String,
62 name: String,
64 version: String,
66 },
67 HuggingFace {
69 repo_id: String,
71 revision: Option<String>,
73 },
74}
75
76pub struct ModelResolver {
78 registry: Option<Registry>,
80 remote_auth: Option<RegistryAuth>,
82}
83
84impl ModelResolver {
85 pub fn new_default() -> Result<Self> {
87 let registry = Registry::open_default().ok();
88 Ok(Self { registry, remote_auth: None })
89 }
90
91 pub fn new(registry_path: impl AsRef<Path>) -> Result<Self> {
93 let config = RegistryConfig::new(registry_path);
94 let registry = Registry::open(config).ok();
95 Ok(Self { registry, remote_auth: None })
96 }
97
98 #[must_use]
100 pub fn file_only() -> Self {
101 Self { registry: None, remote_auth: None }
102 }
103
104 #[must_use]
106 pub fn with_remote_auth(mut self, auth: RegistryAuth) -> Self {
107 self.remote_auth = Some(auth);
108 self
109 }
110
111 #[must_use]
113 pub fn has_registry(&self) -> bool {
114 self.registry.is_some()
115 }
116
117 #[must_use]
119 pub fn has_remote_auth(&self) -> bool {
120 self.remote_auth.is_some()
121 }
122
123 pub fn resolve(&self, uri: &ModelUri) -> Result<ResolvedModel> {
125 match uri.scheme {
126 UriScheme::File => self.resolve_file(uri),
127 UriScheme::Pacha => self.resolve_pacha(uri),
128 UriScheme::HuggingFace => self.resolve_huggingface(uri),
129 }
130 }
131
132 pub fn resolve_str(&self, uri: &str) -> Result<ResolvedModel> {
134 let parsed = ModelUri::parse(uri)?;
135 self.resolve(&parsed)
136 }
137
138 pub fn exists(&self, uri: &ModelUri) -> bool {
140 match uri.scheme {
141 UriScheme::File => uri.as_path().map_or(false, |p| p.exists()),
142 UriScheme::Pacha => {
143 if uri.is_remote() {
144 false
146 } else if let Some(ref registry) = self.registry {
147 let version = uri.version.as_deref().unwrap_or("latest");
148 if let Ok(version) = parse_version(version) {
149 registry.get_model(&uri.name, &version).is_ok()
150 } else {
151 registry.list_model_versions(&uri.name).map_or(false, |v| !v.is_empty())
153 }
154 } else {
155 false
156 }
157 }
158 UriScheme::HuggingFace => {
159 false
161 }
162 }
163 }
164
165 fn resolve_file(&self, uri: &ModelUri) -> Result<ResolvedModel> {
166 let path = uri
167 .as_path()
168 .ok_or_else(|| PachaError::InvalidUri("File URI has no path".to_string()))?;
169
170 if !path.exists() {
171 return Err(PachaError::NotFound {
172 kind: "file".to_string(),
173 name: path.display().to_string(),
174 version: "N/A".to_string(),
175 });
176 }
177
178 let data = fs::read(&path).map_err(|e| {
179 PachaError::Io(std::io::Error::new(
180 e.kind(),
181 format!("Failed to read {}: {}", path.display(), e),
182 ))
183 })?;
184
185 Ok(ResolvedModel {
186 data,
187 source: ModelSource::LocalFile(path.display().to_string()),
188 model: None,
189 })
190 }
191
192 fn resolve_pacha(&self, uri: &ModelUri) -> Result<ResolvedModel> {
193 if uri.is_remote() {
195 return self.resolve_pacha_remote(uri);
196 }
197
198 let registry = self
200 .registry
201 .as_ref()
202 .ok_or_else(|| PachaError::NotInitialized(std::path::PathBuf::from("~/.pacha")))?;
203
204 let version_str = uri.version.as_deref().unwrap_or("latest");
206
207 let version = if version_str == "latest" {
209 let versions = registry.list_model_versions(&uri.name)?;
210 if versions.is_empty() {
211 return Err(PachaError::NotFound {
212 kind: "model".to_string(),
213 name: uri.name.clone(),
214 version: "latest".to_string(),
215 });
216 }
217 versions.into_iter().max().ok_or_else(|| PachaError::NotFound {
219 kind: "model".to_string(),
220 name: uri.name.clone(),
221 version: "latest".to_string(),
222 })?
223 } else {
224 parse_version(version_str)?
225 };
226
227 let model = registry.get_model(&uri.name, &version)?;
229 let data = registry.get_model_artifact(&uri.name, &version)?;
230
231 Ok(ResolvedModel {
232 data,
233 source: ModelSource::PachaLocal {
234 name: uri.name.clone(),
235 version: version.to_string(),
236 },
237 model: Some(model),
238 })
239 }
240
241 #[cfg(feature = "remote")]
242 fn resolve_pacha_remote(&self, uri: &ModelUri) -> Result<ResolvedModel> {
243 let host = uri
244 .host
245 .as_ref()
246 .ok_or_else(|| PachaError::InvalidUri("Remote URI missing host".to_string()))?;
247
248 let version = uri.version.as_deref().unwrap_or("latest");
249 let base_url = format!("https://{host}");
251
252 let mut remote = RemoteRegistry::new(&base_url);
254 if let Some(ref auth) = self.remote_auth {
255 remote = remote.with_auth(auth.clone());
256 }
257
258 let rt = tokio::runtime::Builder::new_current_thread()
260 .enable_all()
261 .build()
262 .map_err(|e| PachaError::Io(std::io::Error::other(e.to_string())))?;
263
264 let data = rt.block_on(remote.pull_model(&uri.name, version))?;
266
267 if let Some(ref registry) = self.registry {
269 let model_version =
270 parse_version(version).unwrap_or_else(|_| ModelVersion::new(0, 0, 0));
271 let _ = registry.register_model(
272 &uri.name,
273 &model_version,
274 &data,
275 crate::model::ModelCard::new(&format!("Pulled from {host}")),
276 );
277 }
278
279 Ok(ResolvedModel {
280 data,
281 source: ModelSource::PachaRemote {
282 host: host.clone(),
283 name: uri.name.clone(),
284 version: version.to_string(),
285 },
286 model: None,
287 })
288 }
289
290 #[cfg(not(feature = "remote"))]
291 fn resolve_pacha_remote(&self, uri: &ModelUri) -> Result<ResolvedModel> {
292 let host = uri
293 .host
294 .as_ref()
295 .ok_or_else(|| PachaError::InvalidUri("Remote URI missing host".to_string()))?;
296
297 Err(PachaError::UnsupportedOperation {
298 operation: "remote_registry".to_string(),
299 reason: format!(
300 "Remote feature not enabled. Rebuild with --features remote. Host: {}",
301 host
302 ),
303 })
304 }
305
306 #[cfg(feature = "remote")]
307 fn resolve_huggingface(&self, uri: &ModelUri) -> Result<ResolvedModel> {
308 let (repo_id, revision) = if uri.name.contains('@') {
310 let parts: Vec<&str> = uri.name.splitn(2, '@').collect();
311 (parts[0].to_string(), parts.get(1).map(|s| s.to_string()))
312 } else {
313 (uri.name.clone(), uri.version.clone())
314 };
315
316 let revision = revision.as_deref().unwrap_or("main");
317
318 let filename = uri.path.as_deref().unwrap_or("model.safetensors");
320
321 let url = format!("https://huggingface.co/{}/resolve/{}/{}", repo_id, revision, filename);
323
324 let rt = tokio::runtime::Builder::new_current_thread()
326 .enable_all()
327 .build()
328 .map_err(|e| PachaError::Io(std::io::Error::other(e.to_string())))?;
329
330 let client = reqwest::Client::builder()
331 .user_agent(concat!("pacha/", env!("CARGO_PKG_VERSION")))
332 .connect_timeout(std::time::Duration::from_secs(30))
333 .timeout(std::time::Duration::from_secs(300))
334 .build()
335 .map_err(|e| PachaError::Io(std::io::Error::other(e.to_string())))?;
336
337 let data = rt.block_on(async {
338 let response = client
339 .get(&url)
340 .send()
341 .await
342 .map_err(|e| PachaError::Io(std::io::Error::other(e.to_string())))?;
343
344 if !response.status().is_success() {
345 return Err(PachaError::NotFound {
346 kind: "huggingface".to_string(),
347 name: repo_id.clone(),
348 version: revision.to_string(),
349 });
350 }
351
352 response
353 .bytes()
354 .await
355 .map(|b| b.to_vec())
356 .map_err(|e| PachaError::Io(std::io::Error::other(e.to_string())))
357 })?;
358
359 if let Some(ref registry) = self.registry {
361 let model_name = repo_id.replace('/', "-");
363 let model_version =
364 parse_version(revision).unwrap_or_else(|_| ModelVersion::new(0, 0, 0));
365 let _ = registry.register_model(
366 &model_name,
367 &model_version,
368 &data,
369 crate::model::ModelCard::new(&format!("Downloaded from HuggingFace: {repo_id}")),
370 );
371 }
372
373 Ok(ResolvedModel {
374 data,
375 source: ModelSource::HuggingFace { repo_id, revision: Some(revision.to_string()) },
376 model: None,
377 })
378 }
379
380 #[cfg(not(feature = "remote"))]
381 fn resolve_huggingface(&self, uri: &ModelUri) -> Result<ResolvedModel> {
382 Err(PachaError::UnsupportedOperation {
383 operation: "huggingface".to_string(),
384 reason: format!("HuggingFace Hub requires --features remote. Model: {}", uri.name),
385 })
386 }
387
388 pub fn list_models(&self) -> Result<Vec<String>> {
390 let registry = self
391 .registry
392 .as_ref()
393 .ok_or_else(|| PachaError::NotInitialized(std::path::PathBuf::from("~/.pacha")))?;
394 registry.list_models()
395 }
396
397 pub fn list_versions(&self, model_name: &str) -> Result<Vec<ModelVersion>> {
399 let registry = self
400 .registry
401 .as_ref()
402 .ok_or_else(|| PachaError::NotInitialized(std::path::PathBuf::from("~/.pacha")))?;
403 registry.list_model_versions(model_name)
404 }
405}
406
407fn parse_version(s: &str) -> Result<ModelVersion> {
409 let parts: Vec<&str> = s.split('.').collect();
411 if parts.len() == 3 {
412 let major: u32 = parts[0]
413 .parse()
414 .map_err(|_| PachaError::InvalidUri(format!("Invalid version: {s}")))?;
415 let minor: u32 = parts[1]
416 .parse()
417 .map_err(|_| PachaError::InvalidUri(format!("Invalid version: {s}")))?;
418 let patch: u32 = parts[2]
419 .parse()
420 .map_err(|_| PachaError::InvalidUri(format!("Invalid version: {s}")))?;
421 return Ok(ModelVersion::new(major, minor, patch));
422 }
423
424 if let Ok(major) = s.parse::<u32>() {
426 return Ok(ModelVersion::new(major, 0, 0));
427 }
428
429 Err(PachaError::InvalidUri(format!("Cannot parse version: {s}. Expected format: x.y.z")))
430}
431
432#[cfg(test)]
437mod tests {
438 use super::*;
439 use crate::model::ModelCard;
440 use std::io::Write;
441 use tempfile::TempDir;
442
443 fn setup_registry() -> (TempDir, ModelResolver) {
448 let dir = TempDir::new().unwrap();
449 let config = RegistryConfig::new(dir.path());
450 let registry = Registry::open(config).unwrap();
451
452 registry
454 .register_model(
455 "test-model",
456 &ModelVersion::new(1, 0, 0),
457 b"model data v1.0.0",
458 ModelCard::new("Test model v1"),
459 )
460 .unwrap();
461
462 registry
463 .register_model(
464 "test-model",
465 &ModelVersion::new(1, 1, 0),
466 b"model data v1.1.0",
467 ModelCard::new("Test model v1.1"),
468 )
469 .unwrap();
470
471 let resolver = ModelResolver::new(dir.path()).unwrap();
472 (dir, resolver)
473 }
474
475 fn create_temp_file(content: &[u8]) -> (TempDir, std::path::PathBuf) {
476 let dir = TempDir::new().unwrap();
477 let path = dir.path().join("model.gguf");
478 let mut file = std::fs::File::create(&path).unwrap();
479 file.write_all(content).unwrap();
480 (dir, path)
481 }
482
483 #[test]
488 fn test_resolve_file() {
489 let (_dir, path) = create_temp_file(b"GGUF model data");
490 let resolver = ModelResolver::file_only();
491
492 let uri = ModelUri::parse(&format!("file://{}", path.display())).unwrap();
493 let resolved = resolver.resolve(&uri).unwrap();
494
495 assert_eq!(resolved.data, b"GGUF model data");
496 assert!(matches!(resolved.source, ModelSource::LocalFile(_)));
497 assert!(resolved.model.is_none());
498 }
499
500 #[test]
501 fn test_resolve_bare_path() {
502 let (_dir, path) = create_temp_file(b"model content");
503 let resolver = ModelResolver::file_only();
504
505 let uri = ModelUri::parse(path.to_str().unwrap()).unwrap();
506 let resolved = resolver.resolve(&uri).unwrap();
507
508 assert_eq!(resolved.data, b"model content");
509 }
510
511 #[test]
512 fn test_resolve_nonexistent_file() {
513 let resolver = ModelResolver::file_only();
514 let uri = ModelUri::parse("file:///nonexistent/model.gguf").unwrap();
515
516 let result = resolver.resolve(&uri);
517 assert!(matches!(result, Err(PachaError::NotFound { .. })));
518 }
519
520 #[test]
521 fn test_exists_file() {
522 let (_dir, path) = create_temp_file(b"data");
523 let resolver = ModelResolver::file_only();
524
525 let uri = ModelUri::parse(path.to_str().unwrap()).unwrap();
526 assert!(resolver.exists(&uri));
527
528 let uri = ModelUri::parse("file:///nonexistent.gguf").unwrap();
529 assert!(!resolver.exists(&uri));
530 }
531
532 #[test]
537 fn test_resolve_pacha_with_version() {
538 let (_dir, resolver) = setup_registry();
539
540 let uri = ModelUri::parse("pacha://test-model:1.0.0").unwrap();
541 let resolved = resolver.resolve(&uri).unwrap();
542
543 assert_eq!(resolved.data, b"model data v1.0.0");
544 assert!(matches!(
545 resolved.source,
546 ModelSource::PachaLocal { ref name, ref version }
547 if name == "test-model" && version == "1.0.0"
548 ));
549 assert!(resolved.model.is_some());
550 }
551
552 #[test]
553 fn test_resolve_pacha_latest() {
554 let (_dir, resolver) = setup_registry();
555
556 let uri = ModelUri::parse("pacha://test-model:latest").unwrap();
557 let resolved = resolver.resolve(&uri).unwrap();
558
559 assert_eq!(resolved.data, b"model data v1.1.0");
561 }
562
563 #[test]
564 fn test_resolve_pacha_no_version() {
565 let (_dir, resolver) = setup_registry();
566
567 let uri = ModelUri::parse("pacha://test-model").unwrap();
568 let resolved = resolver.resolve(&uri).unwrap();
569
570 assert_eq!(resolved.data, b"model data v1.1.0");
572 }
573
574 #[test]
575 fn test_resolve_pacha_not_found() {
576 let (_dir, resolver) = setup_registry();
577
578 let uri = ModelUri::parse("pacha://nonexistent:1.0.0").unwrap();
579 let result = resolver.resolve(&uri);
580
581 assert!(matches!(result, Err(PachaError::NotFound { .. })));
582 }
583
584 #[test]
585 fn test_resolve_pacha_no_registry() {
586 let resolver = ModelResolver::file_only();
587
588 let uri = ModelUri::parse("pacha://test-model:1.0.0").unwrap();
589 let result = resolver.resolve(&uri);
590
591 assert!(matches!(result, Err(PachaError::NotInitialized(_))));
592 }
593
594 #[test]
595 fn test_exists_pacha() {
596 let (_dir, resolver) = setup_registry();
597
598 let uri = ModelUri::parse("pacha://test-model:1.0.0").unwrap();
599 assert!(resolver.exists(&uri));
600
601 let uri = ModelUri::parse("pacha://nonexistent:1.0.0").unwrap();
602 assert!(!resolver.exists(&uri));
603 }
604
605 #[test]
610 #[cfg(not(feature = "remote"))]
611 fn test_resolve_pacha_remote_not_implemented() {
612 let (_dir, resolver) = setup_registry();
613
614 let uri = ModelUri::parse("pacha://registry.example.com/model:1.0.0").unwrap();
615 let result = resolver.resolve(&uri);
616
617 assert!(matches!(result, Err(PachaError::UnsupportedOperation { .. })));
618 }
619
620 #[test]
621 #[cfg(feature = "remote")]
622 fn test_resolve_pacha_remote_connection_error() {
623 let (_dir, resolver) = setup_registry();
624
625 let uri = ModelUri::parse("pacha://nonexistent.invalid/model:1.0.0").unwrap();
627 let result = resolver.resolve(&uri);
628
629 assert!(result.is_err());
631 }
632
633 #[test]
638 #[cfg(not(feature = "remote"))]
639 fn test_resolve_huggingface_not_implemented() {
640 let resolver = ModelResolver::file_only();
641
642 let uri = ModelUri::parse("hf://meta-llama/Llama-3-8B").unwrap();
643 let result = resolver.resolve(&uri);
644
645 assert!(matches!(result, Err(PachaError::UnsupportedOperation { .. })));
646 }
647
648 #[test]
649 #[cfg(feature = "remote")]
650 fn test_resolve_huggingface_nonexistent_repo() {
651 let resolver = ModelResolver::file_only();
652
653 let uri = ModelUri::parse("hf://nonexistent-user-12345/nonexistent-model-67890").unwrap();
655 let result = resolver.resolve(&uri);
656
657 assert!(result.is_err());
659 }
660
661 #[test]
662 fn test_huggingface_uri_parsing() {
663 let uri = ModelUri::parse("hf://meta-llama/Llama-3-8B").unwrap();
665 assert_eq!(uri.name, "meta-llama/Llama-3-8B");
666 assert_eq!(uri.scheme, UriScheme::HuggingFace);
667
668 let uri = ModelUri::parse("hf://meta-llama/Llama-3-8B:main").unwrap();
670 assert_eq!(uri.name, "meta-llama/Llama-3-8B");
671 assert_eq!(uri.version, Some("main".to_string()));
672 }
673
674 #[test]
675 fn test_huggingface_uri_with_path() {
676 let uri = ModelUri::parse("hf://meta-llama/Llama-3-8B/config.json").unwrap();
678 assert_eq!(uri.name, "meta-llama/Llama-3-8B");
679 assert_eq!(uri.path, Some("config.json".to_string()));
680 }
681
682 #[test]
683 fn test_model_source_huggingface_clone() {
684 let source = ModelSource::HuggingFace {
685 repo_id: "meta-llama/Llama-3-8B".to_string(),
686 revision: Some("main".to_string()),
687 };
688 let cloned = source.clone();
689 assert_eq!(source, cloned);
690 }
691
692 #[test]
693 fn test_model_source_huggingface_without_revision() {
694 let source =
695 ModelSource::HuggingFace { repo_id: "google/gemma-7b".to_string(), revision: None };
696 assert!(matches!(source, ModelSource::HuggingFace { revision: None, .. }));
697 }
698
699 #[test]
700 fn test_exists_huggingface() {
701 let resolver = ModelResolver::file_only();
702
703 let uri = ModelUri::parse("hf://meta-llama/Llama-3-8B").unwrap();
705 assert!(!resolver.exists(&uri));
706 }
707
708 #[test]
713 fn test_resolve_str() {
714 let (_dir, path) = create_temp_file(b"test data");
715 let resolver = ModelResolver::file_only();
716
717 let resolved = resolver.resolve_str(path.to_str().unwrap()).unwrap();
718 assert_eq!(resolved.data, b"test data");
719 }
720
721 #[test]
722 fn test_resolve_str_invalid() {
723 let resolver = ModelResolver::file_only();
724 let result = resolver.resolve_str("invalid://uri");
725 assert!(result.is_err());
726 }
727
728 #[test]
733 fn test_list_models() {
734 let (_dir, resolver) = setup_registry();
735
736 let models = resolver.list_models().unwrap();
737 assert!(models.contains(&"test-model".to_string()));
738 }
739
740 #[test]
741 fn test_list_versions() {
742 let (_dir, resolver) = setup_registry();
743
744 let versions = resolver.list_versions("test-model").unwrap();
745 assert_eq!(versions.len(), 2);
746 }
747
748 #[test]
749 fn test_list_models_no_registry() {
750 let resolver = ModelResolver::file_only();
751 let result = resolver.list_models();
752 assert!(matches!(result, Err(PachaError::NotInitialized(_))));
753 }
754
755 #[test]
760 fn test_parse_version_semver() {
761 let v = parse_version("1.2.3").unwrap();
762 assert_eq!(v, ModelVersion::new(1, 2, 3));
763 }
764
765 #[test]
766 fn test_parse_version_single() {
767 let v = parse_version("2").unwrap();
768 assert_eq!(v, ModelVersion::new(2, 0, 0));
769 }
770
771 #[test]
772 fn test_parse_version_invalid() {
773 assert!(parse_version("invalid").is_err());
774 assert!(parse_version("1.2").is_err());
775 assert!(parse_version("a.b.c").is_err());
776 }
777
778 #[test]
783 fn test_model_source_equality() {
784 let s1 = ModelSource::LocalFile("/path/to/model".to_string());
785 let s2 = ModelSource::LocalFile("/path/to/model".to_string());
786 let s3 = ModelSource::LocalFile("/other/path".to_string());
787
788 assert_eq!(s1, s2);
789 assert_ne!(s1, s3);
790 }
791
792 #[test]
793 fn test_model_source_pacha_local() {
794 let source =
795 ModelSource::PachaLocal { name: "llama3".to_string(), version: "8b".to_string() };
796 assert!(matches!(source, ModelSource::PachaLocal { .. }));
797 }
798
799 #[test]
804 fn test_has_registry() {
805 let (_dir, resolver) = setup_registry();
806 assert!(resolver.has_registry());
807
808 let resolver = ModelResolver::file_only();
809 assert!(!resolver.has_registry());
810 }
811
812 #[test]
817 fn test_with_remote_auth() {
818 let resolver = ModelResolver::file_only()
819 .with_remote_auth(RegistryAuth::Token("test-token".to_string()));
820
821 assert!(resolver.has_remote_auth());
822 }
823
824 #[test]
825 fn test_without_remote_auth() {
826 let resolver = ModelResolver::file_only();
827 assert!(!resolver.has_remote_auth());
828 }
829
830 #[test]
831 fn test_remote_auth_basic() {
832 let resolver = ModelResolver::file_only().with_remote_auth(RegistryAuth::Basic {
833 username: "user".to_string(),
834 password: "pass".to_string(),
835 });
836
837 assert!(resolver.has_remote_auth());
838 }
839
840 #[test]
841 fn test_remote_auth_api_key() {
842 let resolver = ModelResolver::file_only().with_remote_auth(RegistryAuth::ApiKey {
843 header: "X-Api-Key".to_string(),
844 key: "secret".to_string(),
845 });
846
847 assert!(resolver.has_remote_auth());
848 }
849
850 #[test]
855 fn test_model_source_pacha_remote() {
856 let source = ModelSource::PachaRemote {
857 host: "registry.example.com".to_string(),
858 name: "llama3".to_string(),
859 version: "1.0.0".to_string(),
860 };
861
862 assert!(matches!(source, ModelSource::PachaRemote { .. }));
863 }
864
865 #[test]
866 fn test_model_source_huggingface() {
867 let source = ModelSource::HuggingFace {
868 repo_id: "meta-llama/Llama-3-8B".to_string(),
869 revision: Some("main".to_string()),
870 };
871
872 assert!(matches!(source, ModelSource::HuggingFace { .. }));
873 }
874}