Skip to main content

agentctl/provider/
registry.rs

1use agent_provider_claude::ClaudeProviderAdapter;
2use agent_provider_codex::CodexProviderAdapter;
3use agent_provider_gemini::GeminiProviderAdapter;
4use agent_runtime_core::provider::ProviderAdapterV1;
5use std::collections::BTreeMap;
6use std::fmt;
7
8pub const DEFAULT_PROVIDER_ID: &str = "codex";
9pub const PROVIDER_OVERRIDE_ENV: &str = "AGENTCTL_PROVIDER";
10
11pub struct ProviderRegistry {
12    providers: BTreeMap<String, Box<dyn ProviderAdapterV1>>,
13    default_provider_id: String,
14}
15
16impl ProviderRegistry {
17    pub fn with_builtins() -> Self {
18        let mut registry = Self::new(DEFAULT_PROVIDER_ID);
19        registry.register(CodexProviderAdapter::new());
20        registry.register(ClaudeProviderAdapter::new());
21        registry.register(GeminiProviderAdapter::new());
22        registry
23    }
24
25    pub fn new(default_provider_id: impl Into<String>) -> Self {
26        Self {
27            providers: BTreeMap::new(),
28            default_provider_id: default_provider_id.into(),
29        }
30    }
31
32    pub fn register<T>(&mut self, adapter: T)
33    where
34        T: ProviderAdapterV1 + 'static,
35    {
36        let provider_id = adapter.metadata().id;
37        self.providers.insert(provider_id, Box::new(adapter));
38    }
39
40    pub fn iter(&self) -> impl Iterator<Item = (&str, &dyn ProviderAdapterV1)> + '_ {
41        self.providers
42            .iter()
43            .map(|(provider_id, adapter)| (provider_id.as_str(), adapter.as_ref()))
44    }
45
46    pub fn get(&self, provider_id: &str) -> Option<&dyn ProviderAdapterV1> {
47        self.providers.get(provider_id).map(Box::as_ref)
48    }
49
50    pub fn default_provider_id(&self) -> Option<&str> {
51        if self.providers.is_empty() {
52            return None;
53        }
54
55        if self
56            .providers
57            .contains_key(self.default_provider_id.as_str())
58        {
59            return Some(self.default_provider_id.as_str());
60        }
61
62        self.providers.keys().next().map(String::as_str)
63    }
64
65    pub fn resolve_selection(
66        &self,
67        cli_override: Option<&str>,
68    ) -> Result<ProviderSelection, ResolveProviderError> {
69        let env_override = std::env::var(PROVIDER_OVERRIDE_ENV).ok();
70        self.resolve_selection_with_env(cli_override, env_override.as_deref())
71    }
72
73    pub fn resolve_selection_with_env(
74        &self,
75        cli_override: Option<&str>,
76        env_override: Option<&str>,
77    ) -> Result<ProviderSelection, ResolveProviderError> {
78        if let Some(provider_id) = normalize_provider_id(cli_override) {
79            return self.resolve_override(provider_id, ProviderSelectionSource::CliArgument);
80        }
81
82        if let Some(provider_id) = normalize_provider_id(env_override) {
83            return self.resolve_override(provider_id, ProviderSelectionSource::Environment);
84        }
85
86        let provider_id = self
87            .default_provider_id()
88            .ok_or(ResolveProviderError::NoProvidersRegistered)?;
89        Ok(ProviderSelection {
90            provider_id: provider_id.to_string(),
91            source: ProviderSelectionSource::Default,
92        })
93    }
94
95    fn resolve_override(
96        &self,
97        provider_id: &str,
98        source: ProviderSelectionSource,
99    ) -> Result<ProviderSelection, ResolveProviderError> {
100        if !self.providers.contains_key(provider_id) {
101            return Err(ResolveProviderError::UnknownProvider {
102                provider_id: provider_id.to_string(),
103                source,
104            });
105        }
106
107        Ok(ProviderSelection {
108            provider_id: provider_id.to_string(),
109            source,
110        })
111    }
112}
113
114impl Default for ProviderRegistry {
115    fn default() -> Self {
116        Self::with_builtins()
117    }
118}
119
120#[derive(Debug, Clone, Copy, PartialEq, Eq)]
121pub enum ProviderSelectionSource {
122    CliArgument,
123    Environment,
124    Default,
125}
126
127impl ProviderSelectionSource {
128    pub const fn as_str(self) -> &'static str {
129        match self {
130            Self::CliArgument => "cli-argument",
131            Self::Environment => "environment",
132            Self::Default => "default",
133        }
134    }
135}
136
137#[derive(Debug, Clone, PartialEq, Eq)]
138pub struct ProviderSelection {
139    pub provider_id: String,
140    pub source: ProviderSelectionSource,
141}
142
143#[derive(Debug, Clone, PartialEq, Eq)]
144pub enum ResolveProviderError {
145    NoProvidersRegistered,
146    UnknownProvider {
147        provider_id: String,
148        source: ProviderSelectionSource,
149    },
150}
151
152impl fmt::Display for ResolveProviderError {
153    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
154        match self {
155            Self::NoProvidersRegistered => f.write_str("no providers are registered"),
156            Self::UnknownProvider {
157                provider_id,
158                source,
159            } => write!(
160                f,
161                "unknown provider '{}' from {} override",
162                provider_id,
163                source.as_str()
164            ),
165        }
166    }
167}
168
169impl std::error::Error for ResolveProviderError {}
170
171fn normalize_provider_id(raw: Option<&str>) -> Option<&str> {
172    raw.map(str::trim).filter(|value| !value.is_empty())
173}