Skip to main content

ta_changeset/
registry_client.rs

1// registry_client.rs — Plugin registry client and platform detection.
2//
3// The registry is a static JSON index served over HTTP:
4//   https://registry.trustedautonomy.dev/v1/index.json
5//
6// This module handles:
7// - Fetching and caching the registry index
8// - Platform detection (os + arch → registry platform key)
9// - Resolving plugin download URLs from registry entries
10// - GitHub release URL construction
11
12use std::collections::HashMap;
13use std::path::PathBuf;
14
15use serde::{Deserialize, Serialize};
16
17/// Default registry URL.
18pub const DEFAULT_REGISTRY_URL: &str = "https://registry.trustedautonomy.dev/v1/index.json";
19
20/// Default cache TTL in seconds (1 hour).
21pub const DEFAULT_CACHE_TTL_SECS: u64 = 3600;
22
23/// Registry index — the top-level JSON structure.
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct RegistryIndex {
26    /// Schema version for forward compatibility.
27    pub schema_version: u32,
28
29    /// Map of plugin name → plugin entry.
30    pub plugins: HashMap<String, RegistryPluginEntry>,
31}
32
33/// A single plugin's registry entry.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct RegistryPluginEntry {
36    /// Plugin type (e.g., "channel").
37    #[serde(rename = "type")]
38    pub plugin_type: String,
39
40    /// Human-readable description.
41    #[serde(default)]
42    pub description: Option<String>,
43
44    /// Available versions with platform-specific downloads.
45    pub versions: HashMap<String, RegistryVersion>,
46}
47
48/// A specific version's release information.
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct RegistryVersion {
51    /// Plugin protocol version.
52    #[serde(default = "default_protocol_version")]
53    pub protocol_version: u32,
54
55    /// Minimum TA CLI version required.
56    #[serde(default)]
57    pub min_ta_version: Option<String>,
58
59    /// Platform-specific download information.
60    pub platforms: HashMap<String, PlatformDownload>,
61}
62
63fn default_protocol_version() -> u32 {
64    1
65}
66
67/// Download info for a specific platform.
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct PlatformDownload {
70    /// Download URL for the tarball.
71    pub url: String,
72
73    /// SHA-256 hash of the tarball for integrity verification.
74    pub sha256: String,
75}
76
77/// Detect the current platform and return the registry platform key.
78///
79/// Returns one of:
80/// - `aarch64-apple-darwin` (Apple Silicon macOS)
81/// - `x86_64-apple-darwin` (Intel macOS)
82/// - `x86_64-unknown-linux-musl` (Linux x86_64)
83/// - `aarch64-unknown-linux-musl` (Linux ARM64)
84/// - `x86_64-pc-windows-msvc` (Windows x86_64)
85pub fn detect_platform() -> String {
86    let arch = std::env::consts::ARCH;
87    let os = std::env::consts::OS;
88
89    match (os, arch) {
90        ("macos", "aarch64") => "aarch64-apple-darwin".to_string(),
91        ("macos", "x86_64") => "x86_64-apple-darwin".to_string(),
92        ("linux", "x86_64") => "x86_64-unknown-linux-musl".to_string(),
93        ("linux", "aarch64") => "aarch64-unknown-linux-musl".to_string(),
94        ("windows", "x86_64") => "x86_64-pc-windows-msvc".to_string(),
95        _ => format!("{}-unknown-{}", arch, os),
96    }
97}
98
99/// Registry client for fetching and caching the plugin index.
100pub struct RegistryClient {
101    /// Registry index URL.
102    registry_url: String,
103    /// Local cache directory (e.g., `~/.cache/ta/registry/`).
104    cache_dir: PathBuf,
105    /// Cache TTL in seconds.
106    cache_ttl_secs: u64,
107}
108
109/// Errors from registry operations.
110#[derive(Debug, thiserror::Error)]
111pub enum RegistryError {
112    #[error("failed to fetch registry index from {url}: {reason}")]
113    FetchFailed { url: String, reason: String },
114
115    #[error("failed to parse registry index: {0}")]
116    ParseFailed(String),
117
118    #[error("plugin '{name}' not found in registry")]
119    PluginNotFound { name: String },
120
121    #[error("plugin '{name}' version '{version}' not found in registry")]
122    VersionNotFound { name: String, version: String },
123
124    #[error("plugin '{name}' version '{version}' has no binary for platform '{platform}'")]
125    PlatformNotAvailable {
126        name: String,
127        version: String,
128        platform: String,
129    },
130
131    #[error("I/O error: {0}")]
132    Io(#[from] std::io::Error),
133}
134
135impl Default for RegistryClient {
136    fn default() -> Self {
137        Self::new()
138    }
139}
140
141impl RegistryClient {
142    /// Create a new registry client with the default registry URL.
143    pub fn new() -> Self {
144        Self {
145            registry_url: DEFAULT_REGISTRY_URL.to_string(),
146            cache_dir: default_cache_dir(),
147            cache_ttl_secs: DEFAULT_CACHE_TTL_SECS,
148        }
149    }
150
151    /// Create a registry client with a custom URL and cache directory.
152    pub fn with_config(registry_url: String, cache_dir: PathBuf, cache_ttl_secs: u64) -> Self {
153        Self {
154            registry_url,
155            cache_dir,
156            cache_ttl_secs,
157        }
158    }
159
160    /// Get the cache file path for the registry index.
161    fn cache_path(&self) -> PathBuf {
162        self.cache_dir.join("index.json")
163    }
164
165    /// Get the timestamp file path for cache TTL tracking.
166    fn cache_timestamp_path(&self) -> PathBuf {
167        self.cache_dir.join("index.timestamp")
168    }
169
170    /// Check if the cached index is still valid (within TTL).
171    fn is_cache_valid(&self) -> bool {
172        let ts_path = self.cache_timestamp_path();
173        if !ts_path.exists() || !self.cache_path().exists() {
174            return false;
175        }
176        match std::fs::metadata(&ts_path) {
177            Ok(meta) => {
178                if let Ok(modified) = meta.modified() {
179                    if let Ok(elapsed) = modified.elapsed() {
180                        return elapsed.as_secs() < self.cache_ttl_secs;
181                    }
182                }
183                false
184            }
185            Err(_) => false,
186        }
187    }
188
189    /// Load the registry index from cache.
190    fn load_cached(&self) -> Option<RegistryIndex> {
191        if !self.is_cache_valid() {
192            return None;
193        }
194        let content = std::fs::read_to_string(self.cache_path()).ok()?;
195        serde_json::from_str(&content).ok()
196    }
197
198    /// Save the registry index to cache.
199    fn save_cache(&self, index: &RegistryIndex) -> Result<(), RegistryError> {
200        std::fs::create_dir_all(&self.cache_dir)?;
201        let json = serde_json::to_string_pretty(index)
202            .map_err(|e| RegistryError::ParseFailed(e.to_string()))?;
203        std::fs::write(self.cache_path(), json)?;
204        std::fs::write(self.cache_timestamp_path(), "")?;
205        Ok(())
206    }
207
208    /// Fetch the registry index, using cache if available.
209    ///
210    /// This is a blocking HTTP call. Returns the cached index if within TTL,
211    /// otherwise fetches from the registry URL and updates the cache.
212    pub fn fetch_index(&self) -> Result<RegistryIndex, RegistryError> {
213        // Try cache first.
214        if let Some(cached) = self.load_cached() {
215            tracing::debug!(
216                url = %self.registry_url,
217                "Using cached registry index"
218            );
219            return Ok(cached);
220        }
221
222        // Fetch from network.
223        tracing::info!(
224            url = %self.registry_url,
225            "Fetching plugin registry index"
226        );
227
228        let client = reqwest::blocking::Client::builder()
229            .timeout(std::time::Duration::from_secs(30))
230            .build()
231            .map_err(|e| RegistryError::FetchFailed {
232                url: self.registry_url.clone(),
233                reason: e.to_string(),
234            })?;
235
236        let resp =
237            client
238                .get(&self.registry_url)
239                .send()
240                .map_err(|e| RegistryError::FetchFailed {
241                    url: self.registry_url.clone(),
242                    reason: e.to_string(),
243                })?;
244
245        if !resp.status().is_success() {
246            return Err(RegistryError::FetchFailed {
247                url: self.registry_url.clone(),
248                reason: format!("HTTP {}", resp.status()),
249            });
250        }
251
252        let body = resp.text().map_err(|e| RegistryError::FetchFailed {
253            url: self.registry_url.clone(),
254            reason: e.to_string(),
255        })?;
256
257        let index: RegistryIndex =
258            serde_json::from_str(&body).map_err(|e| RegistryError::ParseFailed(e.to_string()))?;
259
260        // Cache the result.
261        if let Err(e) = self.save_cache(&index) {
262            tracing::warn!(error = %e, "Failed to cache registry index");
263        }
264
265        Ok(index)
266    }
267
268    /// Load a registry index from a JSON string (for testing or offline use).
269    pub fn parse_index(json: &str) -> Result<RegistryIndex, RegistryError> {
270        serde_json::from_str(json).map_err(|e| RegistryError::ParseFailed(e.to_string()))
271    }
272
273    /// Look up a plugin in the registry and find the best matching version
274    /// for the given constraint and platform.
275    pub fn resolve(
276        &self,
277        index: &RegistryIndex,
278        plugin_name: &str,
279        version_constraint: &str,
280        platform: &str,
281    ) -> Result<ResolvedPlugin, RegistryError> {
282        let entry =
283            index
284                .plugins
285                .get(plugin_name)
286                .ok_or_else(|| RegistryError::PluginNotFound {
287                    name: plugin_name.to_string(),
288                })?;
289
290        // Find the latest version that satisfies the constraint.
291        let min_version =
292            super::project_manifest::parse_min_version(version_constraint).unwrap_or("0.0.0");
293
294        let mut best: Option<(&str, &RegistryVersion)> = None;
295        for (ver_str, ver_info) in &entry.versions {
296            if super::project_manifest::compare_versions(ver_str, min_version)
297                != std::cmp::Ordering::Less
298            {
299                match &best {
300                    Some((best_ver, _)) => {
301                        if super::project_manifest::compare_versions(ver_str, best_ver)
302                            == std::cmp::Ordering::Greater
303                        {
304                            best = Some((ver_str, ver_info));
305                        }
306                    }
307                    None => {
308                        best = Some((ver_str, ver_info));
309                    }
310                }
311            }
312        }
313
314        let (resolved_version, version_info) =
315            best.ok_or_else(|| RegistryError::VersionNotFound {
316                name: plugin_name.to_string(),
317                version: version_constraint.to_string(),
318            })?;
319
320        let download = version_info.platforms.get(platform).ok_or_else(|| {
321            RegistryError::PlatformNotAvailable {
322                name: plugin_name.to_string(),
323                version: resolved_version.to_string(),
324                platform: platform.to_string(),
325            }
326        })?;
327
328        Ok(ResolvedPlugin {
329            name: plugin_name.to_string(),
330            version: resolved_version.to_string(),
331            download_url: download.url.clone(),
332            sha256: download.sha256.clone(),
333            plugin_type: entry.plugin_type.clone(),
334        })
335    }
336
337    /// Construct a GitHub release download URL for a plugin.
338    ///
339    /// Format: `https://github.com/{owner}/{repo}/releases/download/v{version}/{name}-{version}-{platform}.tar.gz`
340    pub fn github_release_url(
341        repo: &str,
342        plugin_name: &str,
343        version: &str,
344        platform: &str,
345    ) -> String {
346        format!(
347            "https://github.com/{}/releases/download/v{}/{}-{}-{}.tar.gz",
348            repo, version, plugin_name, version, platform
349        )
350    }
351}
352
353/// A fully resolved plugin download target.
354#[derive(Debug, Clone)]
355pub struct ResolvedPlugin {
356    /// Plugin name.
357    pub name: String,
358    /// Resolved version string.
359    pub version: String,
360    /// Download URL.
361    pub download_url: String,
362    /// Expected SHA-256 hash.
363    pub sha256: String,
364    /// Plugin type (e.g., "channel").
365    pub plugin_type: String,
366}
367
368/// Get the default cache directory for registry data.
369fn default_cache_dir() -> PathBuf {
370    if let Ok(cache) = std::env::var("XDG_CACHE_HOME") {
371        return PathBuf::from(cache).join("ta").join("registry");
372    }
373    if let Ok(home) = std::env::var("HOME") {
374        return PathBuf::from(home)
375            .join(".cache")
376            .join("ta")
377            .join("registry");
378    }
379    PathBuf::from("/tmp/ta-registry-cache")
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    #[test]
387    fn detect_platform_returns_known_format() {
388        let platform = detect_platform();
389        // Should contain arch and os components.
390        assert!(
391            platform.contains('-'),
392            "platform key should contain a dash: {}",
393            platform
394        );
395        // Should match one of the known formats or a fallback.
396        let known = [
397            "aarch64-apple-darwin",
398            "x86_64-apple-darwin",
399            "x86_64-unknown-linux-musl",
400            "aarch64-unknown-linux-musl",
401            "x86_64-pc-windows-msvc",
402        ];
403        // On CI, we might get different platforms, so just verify format.
404        if !known.contains(&platform.as_str()) {
405            assert!(
406                platform.contains("unknown"),
407                "fallback should contain 'unknown': {}",
408                platform
409            );
410        }
411    }
412
413    #[test]
414    fn parse_registry_index() {
415        let json = r#"{
416            "schema_version": 1,
417            "plugins": {
418                "ta-channel-discord": {
419                    "type": "channel",
420                    "description": "Discord channel plugin",
421                    "versions": {
422                        "0.1.0": {
423                            "protocol_version": 1,
424                            "min_ta_version": "0.11.0",
425                            "platforms": {
426                                "aarch64-apple-darwin": {
427                                    "url": "https://example.com/discord-0.1.0-aarch64-apple-darwin.tar.gz",
428                                    "sha256": "abc123"
429                                },
430                                "x86_64-unknown-linux-musl": {
431                                    "url": "https://example.com/discord-0.1.0-linux.tar.gz",
432                                    "sha256": "def456"
433                                }
434                            }
435                        },
436                        "0.2.0": {
437                            "protocol_version": 1,
438                            "platforms": {
439                                "aarch64-apple-darwin": {
440                                    "url": "https://example.com/discord-0.2.0-aarch64-apple-darwin.tar.gz",
441                                    "sha256": "ghi789"
442                                }
443                            }
444                        }
445                    }
446                }
447            }
448        }"#;
449
450        let index = RegistryClient::parse_index(json).unwrap();
451        assert_eq!(index.schema_version, 1);
452        assert_eq!(index.plugins.len(), 1);
453        let discord = &index.plugins["ta-channel-discord"];
454        assert_eq!(discord.plugin_type, "channel");
455        assert_eq!(discord.versions.len(), 2);
456    }
457
458    #[test]
459    fn resolve_latest_version() {
460        let json = r#"{
461            "schema_version": 1,
462            "plugins": {
463                "test-plugin": {
464                    "type": "channel",
465                    "versions": {
466                        "0.1.0": {
467                            "platforms": {
468                                "aarch64-apple-darwin": {
469                                    "url": "https://example.com/v0.1.0.tar.gz",
470                                    "sha256": "aaa"
471                                }
472                            }
473                        },
474                        "0.2.0": {
475                            "platforms": {
476                                "aarch64-apple-darwin": {
477                                    "url": "https://example.com/v0.2.0.tar.gz",
478                                    "sha256": "bbb"
479                                }
480                            }
481                        },
482                        "0.3.0": {
483                            "platforms": {
484                                "aarch64-apple-darwin": {
485                                    "url": "https://example.com/v0.3.0.tar.gz",
486                                    "sha256": "ccc"
487                                }
488                            }
489                        }
490                    }
491                }
492            }
493        }"#;
494
495        let index = RegistryClient::parse_index(json).unwrap();
496        let client = RegistryClient::new();
497
498        // Should resolve to 0.3.0 (latest satisfying >=0.1.0).
499        let resolved = client
500            .resolve(&index, "test-plugin", ">=0.1.0", "aarch64-apple-darwin")
501            .unwrap();
502        assert_eq!(resolved.version, "0.3.0");
503        assert_eq!(resolved.sha256, "ccc");
504
505        // Should resolve to 0.3.0 (latest satisfying >=0.2.0).
506        let resolved = client
507            .resolve(&index, "test-plugin", ">=0.2.0", "aarch64-apple-darwin")
508            .unwrap();
509        assert_eq!(resolved.version, "0.3.0");
510
511        // Should resolve to 0.3.0 for exact match.
512        let resolved = client
513            .resolve(&index, "test-plugin", ">=0.3.0", "aarch64-apple-darwin")
514            .unwrap();
515        assert_eq!(resolved.version, "0.3.0");
516    }
517
518    #[test]
519    fn resolve_version_not_found() {
520        let json = r#"{
521            "schema_version": 1,
522            "plugins": {
523                "test-plugin": {
524                    "type": "channel",
525                    "versions": {
526                        "0.1.0": {
527                            "platforms": {
528                                "aarch64-apple-darwin": {
529                                    "url": "https://example.com/v0.1.0.tar.gz",
530                                    "sha256": "aaa"
531                                }
532                            }
533                        }
534                    }
535                }
536            }
537        }"#;
538
539        let index = RegistryClient::parse_index(json).unwrap();
540        let client = RegistryClient::new();
541
542        let err = client
543            .resolve(&index, "test-plugin", ">=1.0.0", "aarch64-apple-darwin")
544            .unwrap_err();
545        assert!(matches!(err, RegistryError::VersionNotFound { .. }));
546    }
547
548    #[test]
549    fn resolve_plugin_not_found() {
550        let json = r#"{"schema_version": 1, "plugins": {}}"#;
551        let index = RegistryClient::parse_index(json).unwrap();
552        let client = RegistryClient::new();
553
554        let err = client
555            .resolve(&index, "nonexistent", ">=0.1.0", "aarch64-apple-darwin")
556            .unwrap_err();
557        assert!(matches!(err, RegistryError::PluginNotFound { .. }));
558    }
559
560    #[test]
561    fn resolve_platform_not_available() {
562        let json = r#"{
563            "schema_version": 1,
564            "plugins": {
565                "test-plugin": {
566                    "type": "channel",
567                    "versions": {
568                        "0.1.0": {
569                            "platforms": {
570                                "x86_64-unknown-linux-musl": {
571                                    "url": "https://example.com/v0.1.0.tar.gz",
572                                    "sha256": "aaa"
573                                }
574                            }
575                        }
576                    }
577                }
578            }
579        }"#;
580
581        let index = RegistryClient::parse_index(json).unwrap();
582        let client = RegistryClient::new();
583
584        let err = client
585            .resolve(&index, "test-plugin", ">=0.1.0", "aarch64-apple-darwin")
586            .unwrap_err();
587        assert!(matches!(err, RegistryError::PlatformNotAvailable { .. }));
588    }
589
590    #[test]
591    fn github_release_url_format() {
592        let url = RegistryClient::github_release_url(
593            "Trusted-Autonomy/ta-channel-discord",
594            "ta-channel-discord",
595            "0.1.0",
596            "aarch64-apple-darwin",
597        );
598        assert_eq!(
599            url,
600            "https://github.com/Trusted-Autonomy/ta-channel-discord/releases/download/v0.1.0/ta-channel-discord-0.1.0-aarch64-apple-darwin.tar.gz"
601        );
602    }
603
604    #[test]
605    fn cache_validity() {
606        let dir = tempfile::tempdir().unwrap();
607        let client = RegistryClient::with_config(
608            "https://example.com/index.json".to_string(),
609            dir.path().to_path_buf(),
610            3600,
611        );
612
613        // No cache yet.
614        assert!(!client.is_cache_valid());
615
616        // Create cache files.
617        let index = RegistryIndex {
618            schema_version: 1,
619            plugins: HashMap::new(),
620        };
621        client.save_cache(&index).unwrap();
622
623        // Cache should be valid now.
624        assert!(client.is_cache_valid());
625
626        // Load from cache.
627        let cached = client.load_cached();
628        assert!(cached.is_some());
629        assert_eq!(cached.unwrap().schema_version, 1);
630    }
631
632    #[test]
633    fn cache_expired() {
634        let dir = tempfile::tempdir().unwrap();
635        let client = RegistryClient::with_config(
636            "https://example.com/index.json".to_string(),
637            dir.path().to_path_buf(),
638            0, // 0 second TTL = always expired.
639        );
640
641        let index = RegistryIndex {
642            schema_version: 1,
643            plugins: HashMap::new(),
644        };
645        client.save_cache(&index).unwrap();
646
647        // Should be expired immediately with TTL=0.
648        // Note: This may pass on fast machines since the timestamp check
649        // uses file modification time. Using TTL=0 ensures expiry.
650        assert!(!client.is_cache_valid());
651    }
652
653    #[test]
654    fn registry_error_display() {
655        let err = RegistryError::PluginNotFound {
656            name: "test".into(),
657        };
658        assert!(err.to_string().contains("test"));
659
660        let err = RegistryError::PlatformNotAvailable {
661            name: "test".into(),
662            version: "0.1.0".into(),
663            platform: "arm".into(),
664        };
665        assert!(err.to_string().contains("arm"));
666    }
667}