1use crate::error::{PachaError, Result};
25use crate::model::{Model, ModelVersion};
26use serde::{Deserialize, Serialize};
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ListModelsResponse {
35 pub models: Vec<String>,
37 pub total: usize,
39 pub next_cursor: Option<String>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ListVersionsResponse {
46 pub model: String,
48 pub versions: Vec<VersionInfo>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct VersionInfo {
55 pub version: String,
57 pub hash: String,
59 pub size: u64,
61 pub created_at: String,
63 pub stage: String,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct ModelMetadataResponse {
70 pub name: String,
72 pub version: String,
74 pub hash: String,
76 pub size: u64,
78 pub card: Option<serde_json::Value>,
80 pub lineage: Option<LineageInfo>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct LineageInfo {
87 pub parent: Option<String>,
89 pub dataset: Option<String>,
91 pub recipe: Option<String>,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct PushRequest {
98 pub name: String,
100 pub version: String,
102 pub hash: String,
104 pub card: Option<serde_json::Value>,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct PushResponse {
111 pub upload_url: String,
113 pub upload_id: String,
115}
116
117#[derive(Debug, Clone)]
119pub enum RegistryAuth {
120 None,
122 Token(String),
124 Basic {
126 username: String,
128 password: String,
130 },
131 ApiKey {
133 header: String,
135 key: String,
137 },
138}
139
140impl Default for RegistryAuth {
141 fn default() -> Self {
142 Self::None
143 }
144}
145
146#[derive(Debug)]
152pub struct RemoteRegistry {
153 base_url: String,
155 auth: RegistryAuth,
157 #[cfg(feature = "remote")]
159 client: reqwest::Client,
160}
161
162impl RemoteRegistry {
163 #[must_use]
165 pub fn new(base_url: impl Into<String>) -> Self {
166 let base_url = base_url.into().trim_end_matches('/').to_string();
167
168 Self {
169 base_url,
170 auth: RegistryAuth::None,
171 #[cfg(feature = "remote")]
172 client: reqwest::Client::builder()
173 .user_agent(concat!("pacha/", env!("CARGO_PKG_VERSION")))
174 .connect_timeout(std::time::Duration::from_secs(30))
175 .timeout(std::time::Duration::from_secs(300))
176 .build()
177 .expect("Failed to create HTTP client"),
178 }
179 }
180
181 #[must_use]
183 pub fn with_auth(mut self, auth: RegistryAuth) -> Self {
184 self.auth = auth;
185 self
186 }
187
188 #[must_use]
190 pub fn base_url(&self) -> &str {
191 &self.base_url
192 }
193
194 #[must_use]
196 pub fn has_auth(&self) -> bool {
197 !matches!(self.auth, RegistryAuth::None)
198 }
199
200 #[cfg(feature = "remote")]
206 pub async fn list_models(&self) -> Result<ListModelsResponse> {
207 let url = format!("{}/api/v1/models", self.base_url);
208 let response = self
209 .build_request(reqwest::Method::GET, &url)
210 .send()
211 .await
212 .map_err(|e| PachaError::Io(std::io::Error::other(e.to_string())))?;
213
214 self.handle_response(response).await
215 }
216
217 #[cfg(not(feature = "remote"))]
219 pub async fn list_models(&self) -> Result<ListModelsResponse> {
220 Err(PachaError::UnsupportedOperation {
221 operation: "list_models".to_string(),
222 reason: "Remote feature not enabled. Rebuild with --features remote".to_string(),
223 })
224 }
225
226 #[cfg(feature = "remote")]
228 pub async fn list_versions(&self, model: &str) -> Result<ListVersionsResponse> {
229 let url = format!("{}/api/v1/models/{}/versions", self.base_url, model);
230 let response = self
231 .build_request(reqwest::Method::GET, &url)
232 .send()
233 .await
234 .map_err(|e| PachaError::Io(std::io::Error::other(e.to_string())))?;
235
236 self.handle_response(response).await
237 }
238
239 #[cfg(not(feature = "remote"))]
241 pub async fn list_versions(&self, _model: &str) -> Result<ListVersionsResponse> {
242 Err(PachaError::UnsupportedOperation {
243 operation: "list_versions".to_string(),
244 reason: "Remote feature not enabled. Rebuild with --features remote".to_string(),
245 })
246 }
247
248 #[cfg(feature = "remote")]
250 pub async fn get_metadata(&self, model: &str, version: &str) -> Result<ModelMetadataResponse> {
251 let url = format!("{}/api/v1/models/{}/versions/{}", self.base_url, model, version);
252 let response = self
253 .build_request(reqwest::Method::GET, &url)
254 .send()
255 .await
256 .map_err(|e| PachaError::Io(std::io::Error::other(e.to_string())))?;
257
258 self.handle_response(response).await
259 }
260
261 #[cfg(not(feature = "remote"))]
263 pub async fn get_metadata(
264 &self,
265 _model: &str,
266 _version: &str,
267 ) -> Result<ModelMetadataResponse> {
268 Err(PachaError::UnsupportedOperation {
269 operation: "get_metadata".to_string(),
270 reason: "Remote feature not enabled. Rebuild with --features remote".to_string(),
271 })
272 }
273
274 #[cfg(feature = "remote")]
276 pub async fn pull_model(&self, model: &str, version: &str) -> Result<Vec<u8>> {
277 let url =
278 format!("{}/api/v1/models/{}/versions/{}/artifact", self.base_url, model, version);
279 let response = self
280 .build_request(reqwest::Method::GET, &url)
281 .send()
282 .await
283 .map_err(|e| PachaError::Io(std::io::Error::other(e.to_string())))?;
284
285 if !response.status().is_success() {
286 return Err(self.handle_error_response(response).await);
287 }
288
289 response
290 .bytes()
291 .await
292 .map(|b| b.to_vec())
293 .map_err(|e| PachaError::Io(std::io::Error::other(e.to_string())))
294 }
295
296 #[cfg(not(feature = "remote"))]
298 pub async fn pull_model(&self, _model: &str, _version: &str) -> Result<Vec<u8>> {
299 Err(PachaError::UnsupportedOperation {
300 operation: "pull_model".to_string(),
301 reason: "Remote feature not enabled. Rebuild with --features remote".to_string(),
302 })
303 }
304
305 #[cfg(feature = "remote")]
311 pub async fn init_push(&self, request: &PushRequest) -> Result<PushResponse> {
312 let url = format!("{}/api/v1/models/{}/versions", self.base_url, request.name);
313 let response = self
314 .build_request(reqwest::Method::POST, &url)
315 .json(request)
316 .send()
317 .await
318 .map_err(|e| PachaError::Io(std::io::Error::other(e.to_string())))?;
319
320 self.handle_response(response).await
321 }
322
323 #[cfg(not(feature = "remote"))]
325 pub async fn init_push(&self, _request: &PushRequest) -> Result<PushResponse> {
326 Err(PachaError::UnsupportedOperation {
327 operation: "init_push".to_string(),
328 reason: "Remote feature not enabled. Rebuild with --features remote".to_string(),
329 })
330 }
331
332 #[cfg(feature = "remote")]
334 pub async fn upload_artifact(&self, upload_url: &str, data: Vec<u8>) -> Result<()> {
335 let response = self
336 .build_request(reqwest::Method::PUT, upload_url)
337 .body(data)
338 .send()
339 .await
340 .map_err(|e| PachaError::Io(std::io::Error::other(e.to_string())))?;
341
342 if !response.status().is_success() {
343 return Err(self.handle_error_response(response).await);
344 }
345
346 Ok(())
347 }
348
349 #[cfg(not(feature = "remote"))]
351 pub async fn upload_artifact(&self, _upload_url: &str, _data: Vec<u8>) -> Result<()> {
352 Err(PachaError::UnsupportedOperation {
353 operation: "upload_artifact".to_string(),
354 reason: "Remote feature not enabled. Rebuild with --features remote".to_string(),
355 })
356 }
357
358 #[cfg(feature = "remote")]
360 pub async fn push_model(
361 &self,
362 name: &str,
363 version: &ModelVersion,
364 data: &[u8],
365 card: Option<serde_json::Value>,
366 ) -> Result<()> {
367 let hash = blake3::hash(data).to_hex().to_string();
368
369 let request =
370 PushRequest { name: name.to_string(), version: version.to_string(), hash, card };
371
372 let response = self.init_push(&request).await?;
373 self.upload_artifact(&response.upload_url, data.to_vec()).await
374 }
375
376 #[cfg(not(feature = "remote"))]
378 pub async fn push_model(
379 &self,
380 _name: &str,
381 _version: &ModelVersion,
382 _data: &[u8],
383 _card: Option<serde_json::Value>,
384 ) -> Result<()> {
385 Err(PachaError::UnsupportedOperation {
386 operation: "push_model".to_string(),
387 reason: "Remote feature not enabled. Rebuild with --features remote".to_string(),
388 })
389 }
390
391 #[cfg(feature = "remote")]
396 fn build_request(&self, method: reqwest::Method, url: &str) -> reqwest::RequestBuilder {
397 let mut request = self.client.request(method, url);
398
399 match &self.auth {
400 RegistryAuth::None => {}
401 RegistryAuth::Token(token) => {
402 request = request.bearer_auth(token);
403 }
404 RegistryAuth::Basic { username, password } => {
405 request = request.basic_auth(username, Some(password));
406 }
407 RegistryAuth::ApiKey { header, key } => {
408 request = request.header(header.as_str(), key.as_str());
409 }
410 }
411
412 request
413 }
414
415 #[cfg(feature = "remote")]
416 async fn handle_response<T: serde::de::DeserializeOwned>(
417 &self,
418 response: reqwest::Response,
419 ) -> Result<T> {
420 if !response.status().is_success() {
421 return Err(self.handle_error_response(response).await);
422 }
423
424 response.json().await.map_err(|e| {
425 PachaError::Json(serde_json::Error::io(std::io::Error::other(e.to_string())))
426 })
427 }
428
429 #[cfg(feature = "remote")]
430 async fn handle_error_response(&self, response: reqwest::Response) -> PachaError {
431 let status = response.status();
432 let body = response.text().await.unwrap_or_default();
433
434 if status == reqwest::StatusCode::NOT_FOUND {
435 PachaError::NotFound {
436 kind: "remote".to_string(),
437 name: body,
438 version: "unknown".to_string(),
439 }
440 } else if status == reqwest::StatusCode::UNAUTHORIZED
441 || status == reqwest::StatusCode::FORBIDDEN
442 {
443 PachaError::Validation(format!("Authentication failed: {body}"))
444 } else {
445 PachaError::Io(std::io::Error::other(format!("HTTP {}: {}", status, body)))
446 }
447 }
448}
449
450pub async fn pull_to_local(
456 remote: &RemoteRegistry,
457 local: &crate::registry::Registry,
458 model: &str,
459 version: &str,
460) -> Result<Model> {
461 let metadata = remote.get_metadata(model, version).await?;
463
464 let model_version = parse_version(&metadata.version)?;
466
467 if let Ok(local_model) = local.get_model(model, &model_version) {
469 let local_artifact = local.get_model_artifact(model, &model_version)?;
471 let local_hash = blake3::hash(&local_artifact).to_hex().to_string();
472
473 if local_hash == metadata.hash {
474 return Ok(local_model);
475 }
476 }
478
479 let data = remote.pull_model(model, version).await?;
481
482 let hash = blake3::hash(&data).to_hex().to_string();
484 if hash != metadata.hash {
485 return Err(PachaError::HashMismatch { expected: metadata.hash, actual: hash });
486 }
487
488 let card = metadata
490 .card
491 .and_then(|v| serde_json::from_value(v).ok())
492 .unwrap_or_else(|| crate::model::ModelCard::new("Pulled from remote registry"));
493
494 local.register_model(model, &model_version, &data, card)?;
495
496 local.get_model(model, &model_version)
497}
498
499pub async fn push_to_remote(
501 local: &crate::registry::Registry,
502 remote: &RemoteRegistry,
503 model: &str,
504 version: &ModelVersion,
505) -> Result<()> {
506 let local_model = local.get_model(model, version)?;
508 let data = local.get_model_artifact(model, version)?;
509
510 let card = serde_json::to_value(&local_model.card).ok();
512
513 remote.push_model(model, version, &data, card).await
515}
516
517fn parse_version(s: &str) -> Result<ModelVersion> {
519 let parts: Vec<&str> = s.split('.').collect();
520 if parts.len() == 3 {
521 let major: u32 = parts[0].parse().map_err(|_| PachaError::InvalidVersion(s.to_string()))?;
522 let minor: u32 = parts[1].parse().map_err(|_| PachaError::InvalidVersion(s.to_string()))?;
523 let patch: u32 = parts[2].parse().map_err(|_| PachaError::InvalidVersion(s.to_string()))?;
524 return Ok(ModelVersion::new(major, minor, patch));
525 }
526 Err(PachaError::InvalidVersion(s.to_string()))
527}
528
529#[cfg(test)]
534mod tests {
535 use super::*;
536
537 #[test]
542 fn test_list_models_response_serialize() {
543 let response = ListModelsResponse {
544 models: vec!["llama3".to_string(), "mistral".to_string()],
545 total: 2,
546 next_cursor: None,
547 };
548
549 let json = serde_json::to_string(&response).unwrap();
550 assert!(json.contains("llama3"));
551 assert!(json.contains("mistral"));
552 }
553
554 #[test]
555 fn test_list_models_response_deserialize() {
556 let json = r#"{"models":["llama3"],"total":1,"next_cursor":null}"#;
557 let response: ListModelsResponse = serde_json::from_str(json).unwrap();
558
559 assert_eq!(response.models.len(), 1);
560 assert_eq!(response.models[0], "llama3");
561 assert_eq!(response.total, 1);
562 assert!(response.next_cursor.is_none());
563 }
564
565 #[test]
566 fn test_version_info_serialize() {
567 let info = VersionInfo {
568 version: "1.0.0".to_string(),
569 hash: "abc123".to_string(),
570 size: 1024,
571 created_at: "2024-01-01T00:00:00Z".to_string(),
572 stage: "production".to_string(),
573 };
574
575 let json = serde_json::to_string(&info).unwrap();
576 assert!(json.contains("1.0.0"));
577 assert!(json.contains("abc123"));
578 }
579
580 #[test]
581 fn test_version_info_deserialize() {
582 let json = r#"{"version":"2.0.0","hash":"def456","size":2048,"created_at":"2024-06-01T00:00:00Z","stage":"staging"}"#;
583 let info: VersionInfo = serde_json::from_str(json).unwrap();
584
585 assert_eq!(info.version, "2.0.0");
586 assert_eq!(info.hash, "def456");
587 assert_eq!(info.size, 2048);
588 assert_eq!(info.stage, "staging");
589 }
590
591 #[test]
592 fn test_model_metadata_response() {
593 let response = ModelMetadataResponse {
594 name: "test-model".to_string(),
595 version: "1.2.3".to_string(),
596 hash: "hash123".to_string(),
597 size: 4096,
598 card: Some(serde_json::json!({"description": "Test model"})),
599 lineage: None,
600 };
601
602 let json = serde_json::to_string(&response).unwrap();
603 let parsed: ModelMetadataResponse = serde_json::from_str(&json).unwrap();
604
605 assert_eq!(parsed.name, "test-model");
606 assert_eq!(parsed.version, "1.2.3");
607 }
608
609 #[test]
610 fn test_lineage_info() {
611 let lineage = LineageInfo {
612 parent: Some("base-model:1.0.0".to_string()),
613 dataset: Some("training-data:1.0.0".to_string()),
614 recipe: Some("fine-tune-recipe:1.0.0".to_string()),
615 };
616
617 let json = serde_json::to_string(&lineage).unwrap();
618 assert!(json.contains("base-model"));
619 assert!(json.contains("training-data"));
620 }
621
622 #[test]
623 fn test_push_request() {
624 let request = PushRequest {
625 name: "new-model".to_string(),
626 version: "0.1.0".to_string(),
627 hash: "newhash".to_string(),
628 card: None,
629 };
630
631 let json = serde_json::to_string(&request).unwrap();
632 assert!(json.contains("new-model"));
633 assert!(json.contains("0.1.0"));
634 }
635
636 #[test]
637 fn test_push_response() {
638 let json =
639 r#"{"upload_url":"https://storage.example.com/upload/123","upload_id":"upload-123"}"#;
640 let response: PushResponse = serde_json::from_str(json).unwrap();
641
642 assert!(response.upload_url.contains("storage.example.com"));
643 assert_eq!(response.upload_id, "upload-123");
644 }
645
646 #[test]
651 fn test_registry_auth_default() {
652 let auth = RegistryAuth::default();
653 assert!(matches!(auth, RegistryAuth::None));
654 }
655
656 #[test]
657 fn test_registry_auth_token() {
658 let auth = RegistryAuth::Token("my-token".to_string());
659 assert!(matches!(auth, RegistryAuth::Token(_)));
660 }
661
662 #[test]
663 fn test_registry_auth_basic() {
664 let auth =
665 RegistryAuth::Basic { username: "user".to_string(), password: "pass".to_string() };
666 assert!(matches!(auth, RegistryAuth::Basic { .. }));
667 }
668
669 #[test]
670 fn test_registry_auth_api_key() {
671 let auth =
672 RegistryAuth::ApiKey { header: "X-Api-Key".to_string(), key: "secret-key".to_string() };
673 assert!(matches!(auth, RegistryAuth::ApiKey { .. }));
674 }
675
676 #[test]
681 fn test_remote_registry_new() {
682 let registry = RemoteRegistry::new("https://registry.example.com");
683 assert_eq!(registry.base_url(), "https://registry.example.com");
684 assert!(!registry.has_auth());
685 }
686
687 #[test]
688 fn test_remote_registry_trailing_slash() {
689 let registry = RemoteRegistry::new("https://registry.example.com/");
690 assert_eq!(registry.base_url(), "https://registry.example.com");
691 }
692
693 #[test]
694 fn test_remote_registry_with_auth() {
695 let registry = RemoteRegistry::new("https://registry.example.com")
696 .with_auth(RegistryAuth::Token("token".to_string()));
697 assert!(registry.has_auth());
698 }
699
700 #[test]
701 fn test_remote_registry_no_auth() {
702 let registry =
703 RemoteRegistry::new("https://registry.example.com").with_auth(RegistryAuth::None);
704 assert!(!registry.has_auth());
705 }
706
707 #[test]
712 fn test_parse_version_valid() {
713 let v = parse_version("1.2.3").unwrap();
714 assert_eq!(v, ModelVersion::new(1, 2, 3));
715 }
716
717 #[test]
718 fn test_parse_version_zeros() {
719 let v = parse_version("0.0.0").unwrap();
720 assert_eq!(v, ModelVersion::new(0, 0, 0));
721 }
722
723 #[test]
724 fn test_parse_version_large() {
725 let v = parse_version("100.200.300").unwrap();
726 assert_eq!(v, ModelVersion::new(100, 200, 300));
727 }
728
729 #[test]
730 fn test_parse_version_invalid_format() {
731 assert!(parse_version("1.2").is_err());
732 assert!(parse_version("1").is_err());
733 assert!(parse_version("1.2.3.4").is_err());
734 }
735
736 #[test]
737 fn test_parse_version_non_numeric() {
738 assert!(parse_version("a.b.c").is_err());
739 assert!(parse_version("1.x.0").is_err());
740 }
741
742 #[test]
747 fn test_list_versions_response_roundtrip() {
748 let response = ListVersionsResponse {
749 model: "test".to_string(),
750 versions: vec![
751 VersionInfo {
752 version: "1.0.0".to_string(),
753 hash: "hash1".to_string(),
754 size: 100,
755 created_at: "2024-01-01T00:00:00Z".to_string(),
756 stage: "production".to_string(),
757 },
758 VersionInfo {
759 version: "2.0.0".to_string(),
760 hash: "hash2".to_string(),
761 size: 200,
762 created_at: "2024-06-01T00:00:00Z".to_string(),
763 stage: "staging".to_string(),
764 },
765 ],
766 };
767
768 let json = serde_json::to_string(&response).unwrap();
769 let parsed: ListVersionsResponse = serde_json::from_str(&json).unwrap();
770
771 assert_eq!(parsed.model, "test");
772 assert_eq!(parsed.versions.len(), 2);
773 }
774
775 #[test]
776 fn test_metadata_with_lineage_roundtrip() {
777 let response = ModelMetadataResponse {
778 name: "derived-model".to_string(),
779 version: "1.0.0".to_string(),
780 hash: "hash".to_string(),
781 size: 1000,
782 card: Some(serde_json::json!({"description": "A derived model"})),
783 lineage: Some(LineageInfo {
784 parent: Some("base:1.0.0".to_string()),
785 dataset: Some("data:1.0.0".to_string()),
786 recipe: None,
787 }),
788 };
789
790 let json = serde_json::to_string(&response).unwrap();
791 let parsed: ModelMetadataResponse = serde_json::from_str(&json).unwrap();
792
793 assert!(parsed.lineage.is_some());
794 let lineage = parsed.lineage.unwrap();
795 assert_eq!(lineage.parent.unwrap(), "base:1.0.0");
796 }
797
798 #[test]
803 fn test_empty_models_list() {
804 let response = ListModelsResponse { models: vec![], total: 0, next_cursor: None };
805
806 let json = serde_json::to_string(&response).unwrap();
807 let parsed: ListModelsResponse = serde_json::from_str(&json).unwrap();
808
809 assert!(parsed.models.is_empty());
810 assert_eq!(parsed.total, 0);
811 }
812
813 #[test]
814 fn test_pagination_cursor() {
815 let response = ListModelsResponse {
816 models: vec!["model1".to_string()],
817 total: 100,
818 next_cursor: Some("cursor-abc".to_string()),
819 };
820
821 let json = serde_json::to_string(&response).unwrap();
822 let parsed: ListModelsResponse = serde_json::from_str(&json).unwrap();
823
824 assert_eq!(parsed.next_cursor.unwrap(), "cursor-abc");
825 }
826
827 #[test]
828 fn test_push_request_with_card() {
829 let request = PushRequest {
830 name: "model".to_string(),
831 version: "1.0.0".to_string(),
832 hash: "hash".to_string(),
833 card: Some(serde_json::json!({
834 "description": "Test model",
835 "metrics": {"accuracy": 0.95}
836 })),
837 };
838
839 let json = serde_json::to_string(&request).unwrap();
840 assert!(json.contains("accuracy"));
841 assert!(json.contains("0.95"));
842 }
843}