Skip to main content

cuenv_core/tools/
provider.rs

1//! Tool provider trait for extensible tool fetching.
2//!
3//! This module defines the `ToolProvider` trait that allows different sources
4//! (GitHub releases, Nix packages, OCI images) to be registered and used
5//! uniformly for fetching development tools.
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::path::PathBuf;
10
11use crate::Result;
12
13/// Platform identifier combining OS and architecture.
14#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
15pub struct Platform {
16    pub os: Os,
17    pub arch: Arch,
18}
19
20impl Platform {
21    /// Create a new platform.
22    #[must_use]
23    pub fn new(os: Os, arch: Arch) -> Self {
24        Self { os, arch }
25    }
26
27    /// Get the current platform.
28    #[must_use]
29    pub fn current() -> Self {
30        Self {
31            os: Os::current(),
32            arch: Arch::current(),
33        }
34    }
35
36    /// Parse from string like "darwin-arm64".
37    pub fn parse(s: &str) -> Option<Self> {
38        let parts: Vec<&str> = s.split('-').collect();
39        if parts.len() != 2 {
40            return None;
41        }
42        Some(Self {
43            os: Os::parse(parts[0])?,
44            arch: Arch::parse(parts[1])?,
45        })
46    }
47}
48
49impl std::fmt::Display for Platform {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        write!(f, "{}-{}", self.os, self.arch)
52    }
53}
54
55/// Operating system.
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
57#[serde(rename_all = "lowercase")]
58pub enum Os {
59    Darwin,
60    Linux,
61}
62
63impl Os {
64    /// Get the current OS.
65    #[must_use]
66    pub fn current() -> Self {
67        #[cfg(target_os = "macos")]
68        return Self::Darwin;
69        #[cfg(target_os = "linux")]
70        return Self::Linux;
71        #[cfg(not(any(target_os = "macos", target_os = "linux")))]
72        compile_error!("Unsupported OS");
73    }
74
75    /// Parse from string.
76    #[must_use]
77    pub fn parse(s: &str) -> Option<Self> {
78        match s.to_lowercase().as_str() {
79            "darwin" | "macos" => Some(Self::Darwin),
80            "linux" => Some(Self::Linux),
81            _ => None,
82        }
83    }
84}
85
86impl std::fmt::Display for Os {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        match self {
89            Self::Darwin => write!(f, "darwin"),
90            Self::Linux => write!(f, "linux"),
91        }
92    }
93}
94
95/// CPU architecture.
96#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
97#[serde(rename_all = "lowercase")]
98pub enum Arch {
99    Arm64,
100    X86_64,
101}
102
103impl Arch {
104    /// Get the current architecture.
105    #[must_use]
106    pub fn current() -> Self {
107        #[cfg(target_arch = "aarch64")]
108        return Self::Arm64;
109        #[cfg(target_arch = "x86_64")]
110        return Self::X86_64;
111        #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
112        compile_error!("Unsupported architecture");
113    }
114
115    /// Parse from string.
116    #[must_use]
117    pub fn parse(s: &str) -> Option<Self> {
118        match s.to_lowercase().as_str() {
119            "arm64" | "aarch64" => Some(Self::Arm64),
120            "x86_64" | "amd64" | "x64" => Some(Self::X86_64),
121            _ => None,
122        }
123    }
124}
125
126impl std::fmt::Display for Arch {
127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128        match self {
129            Self::Arm64 => write!(f, "arm64"),
130            Self::X86_64 => write!(f, "x86_64"),
131        }
132    }
133}
134
135/// Source-specific resolution data.
136///
137/// This enum contains the provider-specific information needed to fetch a tool.
138#[derive(Debug, Clone, Serialize, Deserialize)]
139#[serde(tag = "type", rename_all = "lowercase")]
140pub enum ToolSource {
141    /// Binary extracted from an OCI container image.
142    Oci { image: String, path: String },
143    /// Asset from a GitHub release.
144    GitHub {
145        repo: String,
146        tag: String,
147        asset: String,
148        #[serde(skip_serializing_if = "Option::is_none")]
149        path: Option<String>,
150    },
151    /// Package from a Nix flake.
152    Nix {
153        flake: String,
154        package: String,
155        #[serde(skip_serializing_if = "Option::is_none")]
156        output: Option<String>,
157    },
158    /// Rust toolchain managed by rustup.
159    Rustup {
160        /// Toolchain identifier (e.g., "stable", "1.83.0", "nightly-2024-01-01").
161        toolchain: String,
162        /// Installation profile: minimal, default, complete.
163        #[serde(skip_serializing_if = "Option::is_none")]
164        profile: Option<String>,
165        /// Additional components to install (e.g., "clippy", "rustfmt", "rust-src").
166        #[serde(skip_serializing_if = "Vec::is_empty", default)]
167        components: Vec<String>,
168        /// Additional targets to install (e.g., "x86_64-unknown-linux-gnu").
169        #[serde(skip_serializing_if = "Vec::is_empty", default)]
170        targets: Vec<String>,
171    },
172}
173
174impl ToolSource {
175    /// Get the provider type name.
176    #[must_use]
177    pub fn provider_type(&self) -> &'static str {
178        match self {
179            Self::Oci { .. } => "oci",
180            Self::GitHub { .. } => "github",
181            Self::Nix { .. } => "nix",
182            Self::Rustup { .. } => "rustup",
183        }
184    }
185}
186
187/// A resolved tool ready to be fetched.
188///
189/// This represents a fully resolved tool specification with all information
190/// needed to download and cache the binary.
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct ResolvedTool {
193    /// Tool name (e.g., "jq", "bun").
194    pub name: String,
195    /// Version string.
196    pub version: String,
197    /// Target platform.
198    pub platform: Platform,
199    /// Source-specific data.
200    pub source: ToolSource,
201}
202
203/// Result of fetching a tool.
204#[derive(Debug)]
205pub struct FetchedTool {
206    /// Tool name.
207    pub name: String,
208    /// Path to the cached binary.
209    pub binary_path: PathBuf,
210    /// SHA256 hash of the binary.
211    pub sha256: String,
212}
213
214/// Options for tool operations.
215#[derive(Debug, Clone, Default)]
216pub struct ToolOptions {
217    /// Custom cache directory.
218    pub cache_dir: Option<PathBuf>,
219    /// Force re-fetch even if cached.
220    pub force_refetch: bool,
221}
222
223impl ToolOptions {
224    /// Create new options with default cache directory.
225    #[must_use]
226    pub fn new() -> Self {
227        Self::default()
228    }
229
230    /// Set the cache directory.
231    #[must_use]
232    pub fn with_cache_dir(mut self, path: PathBuf) -> Self {
233        self.cache_dir = Some(path);
234        self
235    }
236
237    /// Set force refetch.
238    #[must_use]
239    pub fn with_force_refetch(mut self, force: bool) -> Self {
240        self.force_refetch = force;
241        self
242    }
243
244    /// Get the cache directory, defaulting to ~/.cache/cuenv/tools.
245    #[must_use]
246    pub fn cache_dir(&self) -> PathBuf {
247        self.cache_dir.clone().unwrap_or_else(default_cache_dir)
248    }
249}
250
251/// Get the default cache directory for tools.
252#[must_use]
253pub fn default_cache_dir() -> PathBuf {
254    dirs::cache_dir()
255        .unwrap_or_else(|| PathBuf::from(".cache"))
256        .join("cuenv")
257        .join("tools")
258}
259
260/// Request parameters for tool resolution.
261pub struct ToolResolveRequest<'a> {
262    /// Name of the tool (e.g., "jq").
263    pub tool_name: &'a str,
264    /// Version string from the manifest.
265    pub version: &'a str,
266    /// Target platform.
267    pub platform: &'a Platform,
268    /// Provider-specific configuration from CUE.
269    pub config: &'a serde_json::Value,
270    /// Optional authentication token (e.g., GitHub token for rate limiting).
271    pub token: Option<&'a str>,
272}
273
274/// Trait for tool providers (GitHub, OCI, Nix).
275///
276/// Each provider implements this trait to handle resolution and fetching
277/// of tools from a specific source type. Providers are registered with
278/// the `ToolRegistry` and selected based on the source configuration.
279///
280/// # Example
281///
282/// ```ignore
283/// pub struct GitHubToolProvider { /* ... */ }
284///
285/// #[async_trait]
286/// impl ToolProvider for GitHubToolProvider {
287///     fn name(&self) -> &'static str { "github" }
288///     fn description(&self) -> &'static str { "Fetch tools from GitHub releases" }
289///     // ...
290/// }
291/// ```
292#[async_trait]
293pub trait ToolProvider: Send + Sync {
294    /// Provider name (e.g., "github", "nix", "oci").
295    ///
296    /// This should match the `type` field in the CUE schema.
297    fn name(&self) -> &'static str;
298
299    /// Human-readable description for help text.
300    fn description(&self) -> &'static str;
301
302    /// Check if this provider can handle the given source type.
303    fn can_handle(&self, source: &ToolSource) -> bool;
304
305    /// Resolve a tool specification to a fetchable artifact.
306    ///
307    /// This performs version resolution, platform matching, and returns
308    /// the concrete artifact reference (image digest, release URL, etc.)
309    ///
310    /// # Arguments
311    ///
312    /// * `request` - Resolution parameters including tool name, version, platform,
313    ///   provider-specific config, and optional authentication token
314    ///
315    /// # Errors
316    ///
317    /// Returns an error if resolution fails (version not found, etc.)
318    async fn resolve(&self, request: &ToolResolveRequest<'_>) -> Result<ResolvedTool>;
319
320    /// Fetch and cache a resolved tool.
321    ///
322    /// Downloads the artifact, extracts binaries, and returns the local path.
323    /// If the tool is already cached and `force_refetch` is false, returns
324    /// the cached path without re-downloading.
325    ///
326    /// # Arguments
327    ///
328    /// * `resolved` - A previously resolved tool
329    /// * `options` - Fetch options (cache dir, force refetch)
330    ///
331    /// # Errors
332    ///
333    /// Returns an error if fetching or extraction fails.
334    async fn fetch(&self, resolved: &ResolvedTool, options: &ToolOptions) -> Result<FetchedTool>;
335
336    /// Check if a tool is already cached.
337    ///
338    /// Returns true if the tool binary exists in the cache directory.
339    fn is_cached(&self, resolved: &ResolvedTool, options: &ToolOptions) -> bool;
340
341    /// Check if provider prerequisites are available.
342    ///
343    /// Called early during runtime activation to fail fast if required
344    /// dependencies are missing (e.g., Nix CLI not installed).
345    ///
346    /// # Default Implementation
347    ///
348    /// Returns `Ok(())` - most providers only need HTTP access.
349    ///
350    /// # Errors
351    ///
352    /// Returns an error with a helpful message if prerequisites are not met.
353    async fn check_prerequisites(&self) -> Result<()> {
354        Ok(())
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361
362    #[test]
363    fn test_platform_parse() {
364        let p = Platform::parse("darwin-arm64").unwrap();
365        assert_eq!(p.os, Os::Darwin);
366        assert_eq!(p.arch, Arch::Arm64);
367
368        let p = Platform::parse("linux-x86_64").unwrap();
369        assert_eq!(p.os, Os::Linux);
370        assert_eq!(p.arch, Arch::X86_64);
371
372        assert!(Platform::parse("invalid").is_none());
373    }
374
375    #[test]
376    fn test_platform_parse_edge_cases() {
377        // Too few parts
378        assert!(Platform::parse("darwin").is_none());
379        // Too many parts
380        assert!(Platform::parse("darwin-arm64-extra").is_none());
381        // Empty string
382        assert!(Platform::parse("").is_none());
383        // Invalid OS
384        assert!(Platform::parse("windows-arm64").is_none());
385        // Invalid arch
386        assert!(Platform::parse("darwin-mips").is_none());
387    }
388
389    #[test]
390    fn test_platform_display() {
391        let p = Platform::new(Os::Darwin, Arch::Arm64);
392        assert_eq!(p.to_string(), "darwin-arm64");
393    }
394
395    #[test]
396    fn test_platform_display_all_combinations() {
397        assert_eq!(
398            Platform::new(Os::Darwin, Arch::Arm64).to_string(),
399            "darwin-arm64"
400        );
401        assert_eq!(
402            Platform::new(Os::Darwin, Arch::X86_64).to_string(),
403            "darwin-x86_64"
404        );
405        assert_eq!(
406            Platform::new(Os::Linux, Arch::Arm64).to_string(),
407            "linux-arm64"
408        );
409        assert_eq!(
410            Platform::new(Os::Linux, Arch::X86_64).to_string(),
411            "linux-x86_64"
412        );
413    }
414
415    #[test]
416    fn test_platform_current() {
417        let p = Platform::current();
418        // Should return a valid platform for the current system
419        assert!(matches!(p.os, Os::Darwin | Os::Linux));
420        assert!(matches!(p.arch, Arch::Arm64 | Arch::X86_64));
421    }
422
423    #[test]
424    fn test_os_parse() {
425        assert_eq!(Os::parse("darwin"), Some(Os::Darwin));
426        assert_eq!(Os::parse("macos"), Some(Os::Darwin));
427        assert_eq!(Os::parse("linux"), Some(Os::Linux));
428        assert_eq!(Os::parse("windows"), None);
429    }
430
431    #[test]
432    fn test_os_parse_case_insensitive() {
433        assert_eq!(Os::parse("DARWIN"), Some(Os::Darwin));
434        assert_eq!(Os::parse("Darwin"), Some(Os::Darwin));
435        assert_eq!(Os::parse("LINUX"), Some(Os::Linux));
436        assert_eq!(Os::parse("Linux"), Some(Os::Linux));
437        assert_eq!(Os::parse("MACOS"), Some(Os::Darwin));
438        assert_eq!(Os::parse("MacOS"), Some(Os::Darwin));
439    }
440
441    #[test]
442    fn test_os_display() {
443        assert_eq!(Os::Darwin.to_string(), "darwin");
444        assert_eq!(Os::Linux.to_string(), "linux");
445    }
446
447    #[test]
448    fn test_os_current() {
449        let os = Os::current();
450        // Should return a valid OS for the current system
451        assert!(matches!(os, Os::Darwin | Os::Linux));
452    }
453
454    #[test]
455    fn test_arch_parse() {
456        assert_eq!(Arch::parse("arm64"), Some(Arch::Arm64));
457        assert_eq!(Arch::parse("aarch64"), Some(Arch::Arm64));
458        assert_eq!(Arch::parse("x86_64"), Some(Arch::X86_64));
459        assert_eq!(Arch::parse("amd64"), Some(Arch::X86_64));
460    }
461
462    #[test]
463    fn test_arch_parse_case_insensitive() {
464        assert_eq!(Arch::parse("ARM64"), Some(Arch::Arm64));
465        assert_eq!(Arch::parse("Arm64"), Some(Arch::Arm64));
466        assert_eq!(Arch::parse("AARCH64"), Some(Arch::Arm64));
467        assert_eq!(Arch::parse("X86_64"), Some(Arch::X86_64));
468        assert_eq!(Arch::parse("AMD64"), Some(Arch::X86_64));
469    }
470
471    #[test]
472    fn test_arch_parse_x64_alias() {
473        assert_eq!(Arch::parse("x64"), Some(Arch::X86_64));
474        assert_eq!(Arch::parse("X64"), Some(Arch::X86_64));
475    }
476
477    #[test]
478    fn test_arch_parse_invalid() {
479        assert!(Arch::parse("mips").is_none());
480        assert!(Arch::parse("riscv").is_none());
481        assert!(Arch::parse("").is_none());
482    }
483
484    #[test]
485    fn test_arch_display() {
486        assert_eq!(Arch::Arm64.to_string(), "arm64");
487        assert_eq!(Arch::X86_64.to_string(), "x86_64");
488    }
489
490    #[test]
491    fn test_arch_current() {
492        let arch = Arch::current();
493        // Should return a valid arch for the current system
494        assert!(matches!(arch, Arch::Arm64 | Arch::X86_64));
495    }
496
497    #[test]
498    fn test_tool_source_provider_type() {
499        let s = ToolSource::GitHub {
500            repo: "jqlang/jq".into(),
501            tag: "jq-1.7.1".into(),
502            asset: "jq-macos-arm64".into(),
503            path: None,
504        };
505        assert_eq!(s.provider_type(), "github");
506
507        let s = ToolSource::Nix {
508            flake: "nixpkgs".into(),
509            package: "jq".into(),
510            output: None,
511        };
512        assert_eq!(s.provider_type(), "nix");
513
514        let s = ToolSource::Rustup {
515            toolchain: "1.83.0".into(),
516            profile: Some("default".into()),
517            components: vec!["clippy".into(), "rustfmt".into()],
518            targets: vec!["x86_64-unknown-linux-gnu".into()],
519        };
520        assert_eq!(s.provider_type(), "rustup");
521    }
522
523    #[test]
524    fn test_tool_source_oci_provider_type() {
525        let s = ToolSource::Oci {
526            image: "docker.io/library/alpine:latest".into(),
527            path: "/usr/bin/jq".into(),
528        };
529        assert_eq!(s.provider_type(), "oci");
530    }
531
532    #[test]
533    fn test_tool_source_serialization() {
534        let source = ToolSource::GitHub {
535            repo: "jqlang/jq".into(),
536            tag: "jq-1.7.1".into(),
537            asset: "jq-macos-arm64".into(),
538            path: Some("jq-macos-arm64/jq".into()),
539        };
540        let json = serde_json::to_string(&source).unwrap();
541        assert!(json.contains("\"type\":\"github\""));
542        assert!(json.contains("\"repo\":\"jqlang/jq\""));
543        assert!(json.contains("\"path\":\"jq-macos-arm64/jq\""));
544    }
545
546    #[test]
547    fn test_tool_source_deserialization() {
548        let json =
549            r#"{"type":"github","repo":"jqlang/jq","tag":"jq-1.7.1","asset":"jq-macos-arm64"}"#;
550        let source: ToolSource = serde_json::from_str(json).unwrap();
551        match source {
552            ToolSource::GitHub {
553                repo, tag, asset, ..
554            } => {
555                assert_eq!(repo, "jqlang/jq");
556                assert_eq!(tag, "jq-1.7.1");
557                assert_eq!(asset, "jq-macos-arm64");
558            }
559            _ => panic!("Expected GitHub source"),
560        }
561    }
562
563    #[test]
564    fn test_tool_source_nix_serialization() {
565        let source = ToolSource::Nix {
566            flake: "nixpkgs".into(),
567            package: "jq".into(),
568            output: Some("bin".into()),
569        };
570        let json = serde_json::to_string(&source).unwrap();
571        assert!(json.contains("\"type\":\"nix\""));
572        assert!(json.contains("\"output\":\"bin\""));
573    }
574
575    #[test]
576    fn test_tool_source_rustup_serialization() {
577        let source = ToolSource::Rustup {
578            toolchain: "stable".into(),
579            profile: None,
580            components: vec![],
581            targets: vec![],
582        };
583        let json = serde_json::to_string(&source).unwrap();
584        assert!(json.contains("\"type\":\"rustup\""));
585        // Empty vecs should not be serialized
586        assert!(!json.contains("components"));
587        assert!(!json.contains("targets"));
588    }
589
590    #[test]
591    fn test_resolved_tool_serialization() {
592        let tool = ResolvedTool {
593            name: "jq".into(),
594            version: "1.7.1".into(),
595            platform: Platform::new(Os::Darwin, Arch::Arm64),
596            source: ToolSource::GitHub {
597                repo: "jqlang/jq".into(),
598                tag: "jq-1.7.1".into(),
599                asset: "jq-macos-arm64".into(),
600                path: None,
601            },
602        };
603        let json = serde_json::to_string(&tool).unwrap();
604        assert!(json.contains("\"name\":\"jq\""));
605        assert!(json.contains("\"version\":\"1.7.1\""));
606    }
607
608    #[test]
609    fn test_tool_options_default() {
610        let opts = ToolOptions::default();
611        assert!(opts.cache_dir.is_none());
612        assert!(!opts.force_refetch);
613    }
614
615    #[test]
616    fn test_tool_options_new() {
617        let opts = ToolOptions::new();
618        assert!(opts.cache_dir.is_none());
619        assert!(!opts.force_refetch);
620    }
621
622    #[test]
623    fn test_tool_options_builder() {
624        let opts = ToolOptions::new()
625            .with_cache_dir(PathBuf::from("/custom/cache"))
626            .with_force_refetch(true);
627
628        assert_eq!(opts.cache_dir, Some(PathBuf::from("/custom/cache")));
629        assert!(opts.force_refetch);
630    }
631
632    #[test]
633    fn test_tool_options_cache_dir_default() {
634        let opts = ToolOptions::new();
635        let cache_dir = opts.cache_dir();
636        // Should end with cuenv/tools
637        assert!(cache_dir.ends_with("cuenv/tools"));
638    }
639
640    #[test]
641    fn test_tool_options_cache_dir_custom() {
642        let opts = ToolOptions::new().with_cache_dir(PathBuf::from("/my/cache"));
643        assert_eq!(opts.cache_dir(), PathBuf::from("/my/cache"));
644    }
645
646    #[test]
647    fn test_default_cache_dir() {
648        let cache_dir = default_cache_dir();
649        // Should end with cuenv/tools
650        assert!(cache_dir.ends_with("cuenv/tools"));
651    }
652
653    #[test]
654    fn test_platform_equality() {
655        let p1 = Platform::new(Os::Darwin, Arch::Arm64);
656        let p2 = Platform::new(Os::Darwin, Arch::Arm64);
657        let p3 = Platform::new(Os::Linux, Arch::Arm64);
658
659        assert_eq!(p1, p2);
660        assert_ne!(p1, p3);
661    }
662
663    #[test]
664    fn test_platform_hash() {
665        use std::collections::HashSet;
666
667        let mut set = HashSet::new();
668        set.insert(Platform::new(Os::Darwin, Arch::Arm64));
669        set.insert(Platform::new(Os::Darwin, Arch::Arm64)); // Duplicate
670
671        assert_eq!(set.len(), 1);
672
673        set.insert(Platform::new(Os::Linux, Arch::Arm64));
674        assert_eq!(set.len(), 2);
675    }
676
677    #[test]
678    fn test_os_equality() {
679        assert_eq!(Os::Darwin, Os::Darwin);
680        assert_eq!(Os::Linux, Os::Linux);
681        assert_ne!(Os::Darwin, Os::Linux);
682    }
683
684    #[test]
685    fn test_arch_equality() {
686        assert_eq!(Arch::Arm64, Arch::Arm64);
687        assert_eq!(Arch::X86_64, Arch::X86_64);
688        assert_ne!(Arch::Arm64, Arch::X86_64);
689    }
690}