1use serde::{Deserialize, Serialize};
10use sha2::{Digest, Sha256};
11use std::fmt;
12use std::path::{Path, PathBuf};
13
14#[derive(Debug, thiserror::Error)]
20pub enum SourceError {
21 #[error("invalid source string: {input}")]
23 InvalidSource { input: String },
24
25 #[error("failed to fetch from {origin}: {reason}")]
27 FetchFailed { origin: String, reason: String },
28
29 #[error("cache error: {reason}")]
31 CacheError { reason: String },
32
33 #[error("version mismatch for {name}: expected {expected}, got {actual}")]
35 VersionMismatch {
36 name: String,
37 expected: String,
38 actual: String,
39 },
40
41 #[error(transparent)]
43 Io(#[from] std::io::Error),
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
52#[serde(tag = "kind", rename_all = "snake_case")]
53pub enum ExternalSource {
54 Registry { org: String, name: String },
56
57 GitHub {
59 org: String,
60 repo: String,
61 path: Option<String>,
62 #[serde(rename = "ref")]
63 ref_: Option<String>,
64 },
65
66 Url { url: String },
68}
69
70impl ExternalSource {
71 pub fn parse(source: &str) -> Result<Self, SourceError> {
78 let source = source.trim();
79
80 if let Some(rest) = source.strip_prefix("registry:") {
81 let parts: Vec<&str> = rest.splitn(2, '/').collect();
82 if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() {
83 return Err(SourceError::InvalidSource {
84 input: source.to_string(),
85 });
86 }
87 return Ok(ExternalSource::Registry {
88 org: parts[0].to_string(),
89 name: parts[1].to_string(),
90 });
91 }
92
93 if let Some(rest) = source.strip_prefix("gh:") {
94 let (path_part, ref_) = if let Some(idx) = rest.find('@') {
96 (&rest[..idx], Some(rest[idx + 1..].to_string()))
97 } else {
98 (rest, None)
99 };
100
101 let segments: Vec<&str> = path_part.splitn(3, '/').collect();
102 if segments.len() < 2 || segments[0].is_empty() || segments[1].is_empty() {
103 return Err(SourceError::InvalidSource {
104 input: source.to_string(),
105 });
106 }
107 let path = if segments.len() == 3 && !segments[2].is_empty() {
108 Some(segments[2].to_string())
109 } else {
110 None
111 };
112 return Ok(ExternalSource::GitHub {
113 org: segments[0].to_string(),
114 repo: segments[1].to_string(),
115 path,
116 ref_,
117 });
118 }
119
120 if source.starts_with("https://") || source.starts_with("http://") {
121 return Ok(ExternalSource::Url {
122 url: source.to_string(),
123 });
124 }
125
126 Err(SourceError::InvalidSource {
127 input: source.to_string(),
128 })
129 }
130
131 pub fn fetch_url(&self) -> String {
133 match self {
134 ExternalSource::Registry { org, name } => {
135 format!(
136 "https://registry.trustedautonomy.dev/v1/{}/{}.yaml",
137 org, name
138 )
139 }
140 ExternalSource::GitHub {
141 org,
142 repo,
143 path,
144 ref_,
145 } => {
146 let branch = ref_.as_deref().unwrap_or("main");
147 let file_path = path.as_deref().unwrap_or("workflow-package.yaml");
148 format!(
149 "https://raw.githubusercontent.com/{}/{}/{}/{}",
150 org, repo, branch, file_path
151 )
152 }
153 ExternalSource::Url { url } => url.clone(),
154 }
155 }
156}
157
158impl fmt::Display for ExternalSource {
159 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
160 match self {
161 ExternalSource::Registry { org, name } => write!(f, "registry:{}/{}", org, name),
162 ExternalSource::GitHub {
163 org,
164 repo,
165 path,
166 ref_,
167 } => {
168 write!(f, "gh:{}/{}", org, repo)?;
169 if let Some(p) = path {
170 write!(f, "/{}", p)?;
171 }
172 if let Some(r) = ref_ {
173 write!(f, "@{}", r)?;
174 }
175 Ok(())
176 }
177 ExternalSource::Url { url } => write!(f, "{}", url),
178 }
179 }
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct PackageManifest {
191 pub name: String,
193 pub version: String,
195 pub author: Option<String>,
197 pub description: Option<String>,
199 pub ta_version: Option<String>,
201 pub files: Vec<String>,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct CachedItem {
212 pub name: String,
214 pub version: String,
216 pub source: String,
218 pub cached_at: String,
220 pub file_path: PathBuf,
222}
223
224pub struct SourceCache {
233 cache_dir: PathBuf,
234}
235
236impl SourceCache {
237 pub fn new(kind: &str) -> Self {
241 let home = std::env::var("HOME")
242 .or_else(|_| std::env::var("USERPROFILE"))
243 .unwrap_or_else(|_| "/tmp".to_string());
244 Self {
245 cache_dir: PathBuf::from(home).join(".ta").join("cache").join(kind),
246 }
247 }
248
249 pub fn with_dir(dir: PathBuf) -> Self {
251 Self { cache_dir: dir }
252 }
253
254 pub fn cache_dir(&self) -> &Path {
256 &self.cache_dir
257 }
258
259 fn yaml_path(&self, name: &str) -> PathBuf {
260 self.cache_dir.join(format!("{}.yaml", name))
261 }
262
263 fn meta_path(&self, name: &str) -> PathBuf {
264 self.cache_dir.join(format!("{}.meta.json", name))
265 }
266
267 pub fn get(&self, name: &str) -> Option<CachedItem> {
269 let meta_path = self.meta_path(name);
270 let data = std::fs::read_to_string(&meta_path).ok()?;
271 serde_json::from_str(&data).ok()
272 }
273
274 pub fn store(
276 &self,
277 name: &str,
278 content: &str,
279 source: &ExternalSource,
280 version: &str,
281 ) -> Result<CachedItem, SourceError> {
282 std::fs::create_dir_all(&self.cache_dir).map_err(|e| SourceError::CacheError {
283 reason: format!(
284 "failed to create cache directory {}: {}",
285 self.cache_dir.display(),
286 e
287 ),
288 })?;
289
290 let yaml_path = self.yaml_path(name);
291 std::fs::write(&yaml_path, content).map_err(|e| SourceError::CacheError {
292 reason: format!("failed to write {}: {}", yaml_path.display(), e),
293 })?;
294
295 let item = CachedItem {
296 name: name.to_string(),
297 version: version.to_string(),
298 source: source.to_string(),
299 cached_at: chrono::Utc::now().to_rfc3339(),
300 file_path: yaml_path,
301 };
302
303 let meta_path = self.meta_path(name);
304 let meta_json =
305 serde_json::to_string_pretty(&item).map_err(|e| SourceError::CacheError {
306 reason: format!("failed to serialize metadata: {}", e),
307 })?;
308 std::fs::write(&meta_path, meta_json).map_err(|e| SourceError::CacheError {
309 reason: format!("failed to write {}: {}", meta_path.display(), e),
310 })?;
311
312 Ok(item)
313 }
314
315 pub fn remove(&self, name: &str) -> Result<bool, SourceError> {
317 let yaml_path = self.yaml_path(name);
318 let meta_path = self.meta_path(name);
319
320 let existed = yaml_path.exists() || meta_path.exists();
321 if yaml_path.exists() {
322 std::fs::remove_file(&yaml_path)?;
323 }
324 if meta_path.exists() {
325 std::fs::remove_file(&meta_path)?;
326 }
327 Ok(existed)
328 }
329
330 pub fn list(&self) -> Vec<CachedItem> {
332 let mut items = Vec::new();
333 let entries = match std::fs::read_dir(&self.cache_dir) {
334 Ok(entries) => entries,
335 Err(_) => return items,
336 };
337 for entry in entries.flatten() {
338 let path = entry.path();
339 if path.extension().is_some_and(|ext| ext == "json")
340 && path
341 .file_name()
342 .and_then(|n| n.to_str())
343 .is_some_and(|n| n.ends_with(".meta.json"))
344 {
345 if let Ok(data) = std::fs::read_to_string(&path) {
346 if let Ok(item) = serde_json::from_str::<CachedItem>(&data) {
347 items.push(item);
348 }
349 }
350 }
351 }
352 items.sort_by(|a, b| a.name.cmp(&b.name));
353 items
354 }
355}
356
357#[derive(Debug, Clone, Serialize, Deserialize)]
363pub struct LockEntry {
364 pub name: String,
366 pub version: String,
368 pub source: String,
370 pub checksum: String,
372}
373
374#[derive(Debug, Clone, Serialize, Deserialize)]
379pub struct Lockfile {
380 pub entries: Vec<LockEntry>,
382}
383
384impl Lockfile {
385 pub fn new() -> Self {
387 Self {
388 entries: Vec::new(),
389 }
390 }
391
392 pub fn load(path: &Path) -> Result<Self, SourceError> {
395 if !path.exists() {
396 return Ok(Self::new());
397 }
398 let data = std::fs::read_to_string(path)?;
399 serde_yaml::from_str(&data).map_err(|e| SourceError::CacheError {
400 reason: format!("failed to parse lockfile {}: {}", path.display(), e),
401 })
402 }
403
404 pub fn save(&self, path: &Path) -> Result<(), SourceError> {
406 if let Some(parent) = path.parent() {
407 std::fs::create_dir_all(parent)?;
408 }
409 let yaml = serde_yaml::to_string(self).map_err(|e| SourceError::CacheError {
410 reason: format!("failed to serialize lockfile: {}", e),
411 })?;
412 std::fs::write(path, yaml)?;
413 Ok(())
414 }
415
416 pub fn add(&mut self, entry: LockEntry) {
419 self.remove(&entry.name);
420 self.entries.push(entry);
421 }
422
423 pub fn remove(&mut self, name: &str) -> bool {
425 let before = self.entries.len();
426 self.entries.retain(|e| e.name != name);
427 self.entries.len() < before
428 }
429
430 pub fn get(&self, name: &str) -> Option<&LockEntry> {
432 self.entries.iter().find(|e| e.name == name)
433 }
434
435 pub fn entries(&self) -> &[LockEntry] {
437 &self.entries
438 }
439}
440
441impl Default for Lockfile {
442 fn default() -> Self {
443 Self::new()
444 }
445}
446
447pub fn sha256_hex(content: &str) -> String {
453 let mut hasher = Sha256::new();
454 hasher.update(content.as_bytes());
455 format!("{:x}", hasher.finalize())
456}
457
458pub fn verify_checksum(content: &str, checksum: &str) -> bool {
460 sha256_hex(content) == checksum
461}
462
463#[cfg(test)]
468mod tests {
469 use super::*;
470 use tempfile::tempdir;
471
472 #[test]
475 fn parse_registry_source() {
476 let src = ExternalSource::parse("registry:trustedautonomy/workflows").unwrap();
477 assert_eq!(
478 src,
479 ExternalSource::Registry {
480 org: "trustedautonomy".into(),
481 name: "workflows".into(),
482 }
483 );
484 }
485
486 #[test]
487 fn parse_github_simple() {
488 let src = ExternalSource::parse("gh:myorg/ta-workflows").unwrap();
489 assert_eq!(
490 src,
491 ExternalSource::GitHub {
492 org: "myorg".into(),
493 repo: "ta-workflows".into(),
494 path: None,
495 ref_: None,
496 }
497 );
498 }
499
500 #[test]
501 fn parse_github_with_path() {
502 let src = ExternalSource::parse("gh:myorg/repo/path/to/file.yaml").unwrap();
503 assert_eq!(
504 src,
505 ExternalSource::GitHub {
506 org: "myorg".into(),
507 repo: "repo".into(),
508 path: Some("path/to/file.yaml".into()),
509 ref_: None,
510 }
511 );
512 }
513
514 #[test]
515 fn parse_github_with_ref() {
516 let src = ExternalSource::parse("gh:myorg/repo@v1.2.3").unwrap();
517 assert_eq!(
518 src,
519 ExternalSource::GitHub {
520 org: "myorg".into(),
521 repo: "repo".into(),
522 path: None,
523 ref_: Some("v1.2.3".into()),
524 }
525 );
526 }
527
528 #[test]
529 fn parse_github_with_path_and_ref() {
530 let src = ExternalSource::parse("gh:myorg/repo/workflows/ci.yaml@main").unwrap();
531 assert_eq!(
532 src,
533 ExternalSource::GitHub {
534 org: "myorg".into(),
535 repo: "repo".into(),
536 path: Some("workflows/ci.yaml".into()),
537 ref_: Some("main".into()),
538 }
539 );
540 }
541
542 #[test]
543 fn parse_url_https() {
544 let src = ExternalSource::parse("https://example.com/workflow.yaml").unwrap();
545 assert_eq!(
546 src,
547 ExternalSource::Url {
548 url: "https://example.com/workflow.yaml".into(),
549 }
550 );
551 }
552
553 #[test]
554 fn parse_url_http() {
555 let src = ExternalSource::parse("http://localhost:8080/w.yaml").unwrap();
556 assert_eq!(
557 src,
558 ExternalSource::Url {
559 url: "http://localhost:8080/w.yaml".into(),
560 }
561 );
562 }
563
564 #[test]
565 fn parse_invalid_returns_error() {
566 assert!(ExternalSource::parse("ftp://bad").is_err());
567 assert!(ExternalSource::parse("registry:").is_err());
568 assert!(ExternalSource::parse("registry:org").is_err());
569 assert!(ExternalSource::parse("gh:").is_err());
570 assert!(ExternalSource::parse("gh:org").is_err());
571 assert!(ExternalSource::parse("").is_err());
572 }
573
574 #[test]
577 fn display_round_trip() {
578 let cases = vec![
579 "registry:trustedautonomy/workflows",
580 "gh:myorg/repo",
581 "gh:myorg/repo/path/to/file.yaml",
582 "https://example.com/workflow.yaml",
583 ];
584 for input in cases {
585 let src = ExternalSource::parse(input).unwrap();
586 assert_eq!(src.to_string(), input, "round-trip failed for {}", input);
587 }
588 }
589
590 #[test]
593 fn fetch_url_registry() {
594 let src = ExternalSource::Registry {
595 org: "ta".into(),
596 name: "ci".into(),
597 };
598 assert_eq!(
599 src.fetch_url(),
600 "https://registry.trustedautonomy.dev/v1/ta/ci.yaml"
601 );
602 }
603
604 #[test]
605 fn fetch_url_github_defaults() {
606 let src = ExternalSource::GitHub {
607 org: "org".into(),
608 repo: "repo".into(),
609 path: None,
610 ref_: None,
611 };
612 assert_eq!(
613 src.fetch_url(),
614 "https://raw.githubusercontent.com/org/repo/main/workflow-package.yaml"
615 );
616 }
617
618 #[test]
619 fn fetch_url_github_custom() {
620 let src = ExternalSource::GitHub {
621 org: "org".into(),
622 repo: "repo".into(),
623 path: Some("defs/w.yaml".into()),
624 ref_: Some("v2".into()),
625 };
626 assert_eq!(
627 src.fetch_url(),
628 "https://raw.githubusercontent.com/org/repo/v2/defs/w.yaml"
629 );
630 }
631
632 #[test]
635 fn cache_store_get_list_remove() {
636 let dir = tempdir().unwrap();
637 let cache = SourceCache::with_dir(dir.path().to_path_buf());
638
639 let source = ExternalSource::Registry {
640 org: "ta".into(),
641 name: "ci".into(),
642 };
643 let content = "name: ci\nversion: '1.0'\n";
644
645 let item = cache.store("ci", content, &source, "1.0").unwrap();
647 assert_eq!(item.name, "ci");
648 assert_eq!(item.version, "1.0");
649 assert!(item.file_path.exists());
650
651 let fetched = cache.get("ci").unwrap();
653 assert_eq!(fetched.name, "ci");
654 assert_eq!(fetched.version, "1.0");
655
656 let items = cache.list();
658 assert_eq!(items.len(), 1);
659
660 let removed = cache.remove("ci").unwrap();
662 assert!(removed);
663 assert!(cache.get("ci").is_none());
664 assert!(cache.list().is_empty());
665 }
666
667 #[test]
668 fn cache_get_missing_returns_none() {
669 let dir = tempdir().unwrap();
670 let cache = SourceCache::with_dir(dir.path().to_path_buf());
671 assert!(cache.get("nonexistent").is_none());
672 }
673
674 #[test]
675 fn cache_remove_missing_returns_false() {
676 let dir = tempdir().unwrap();
677 let cache = SourceCache::with_dir(dir.path().to_path_buf());
678 assert!(!cache.remove("nonexistent").unwrap());
679 }
680
681 #[test]
684 fn lockfile_add_get_remove() {
685 let mut lock = Lockfile::new();
686 assert!(lock.get("ci").is_none());
687
688 lock.add(LockEntry {
689 name: "ci".into(),
690 version: "1.0".into(),
691 source: "registry:ta/ci".into(),
692 checksum: "abc123".into(),
693 });
694 assert_eq!(lock.get("ci").unwrap().version, "1.0");
695
696 lock.add(LockEntry {
698 name: "ci".into(),
699 version: "2.0".into(),
700 source: "registry:ta/ci".into(),
701 checksum: "def456".into(),
702 });
703 assert_eq!(lock.get("ci").unwrap().version, "2.0");
704 assert_eq!(lock.entries.len(), 1);
705
706 assert!(lock.remove("ci"));
707 assert!(!lock.remove("ci"));
708 }
709
710 #[test]
711 fn lockfile_save_load_round_trip() {
712 let dir = tempdir().unwrap();
713 let path = dir.path().join("test.lock");
714
715 let mut lock = Lockfile::new();
716 lock.add(LockEntry {
717 name: "review".into(),
718 version: "0.3.0".into(),
719 source: "gh:ta/review".into(),
720 checksum: "aabbcc".into(),
721 });
722 lock.save(&path).unwrap();
723
724 let loaded = Lockfile::load(&path).unwrap();
725 assert_eq!(loaded.entries.len(), 1);
726 assert_eq!(loaded.get("review").unwrap().version, "0.3.0");
727 }
728
729 #[test]
730 fn lockfile_load_missing_returns_empty() {
731 let dir = tempdir().unwrap();
732 let path = dir.path().join("does-not-exist.lock");
733 let lock = Lockfile::load(&path).unwrap();
734 assert!(lock.entries.is_empty());
735 }
736
737 #[test]
740 fn sha256_and_verify() {
741 let content = "hello, world";
742 let hash = sha256_hex(content);
743 assert!(verify_checksum(content, &hash));
744 assert!(!verify_checksum("different", &hash));
745 }
746
747 #[test]
750 fn package_manifest_yaml_round_trip() {
751 let yaml = r#"
752name: ci-review
753version: "1.2.0"
754author: trustedautonomy
755description: CI review workflow
756ta_version: ">=0.9.8"
757files:
758 - workflow.yaml
759 - agents/reviewer.yaml
760"#;
761 let manifest: PackageManifest = serde_yaml::from_str(yaml).unwrap();
762 assert_eq!(manifest.name, "ci-review");
763 assert_eq!(manifest.version, "1.2.0");
764 assert_eq!(manifest.author.as_deref(), Some("trustedautonomy"));
765 assert_eq!(manifest.ta_version.as_deref(), Some(">=0.9.8"));
766 assert_eq!(manifest.files.len(), 2);
767
768 let reserialized = serde_yaml::to_string(&manifest).unwrap();
770 let re: PackageManifest = serde_yaml::from_str(&reserialized).unwrap();
771 assert_eq!(re.name, manifest.name);
772 }
773
774 #[test]
777 fn external_source_json_serde() {
778 let src = ExternalSource::GitHub {
779 org: "org".into(),
780 repo: "repo".into(),
781 path: Some("w.yaml".into()),
782 ref_: Some("v1".into()),
783 };
784 let json = serde_json::to_string(&src).unwrap();
785 let de: ExternalSource = serde_json::from_str(&json).unwrap();
786 assert_eq!(de, src);
787 }
788}