agentctl/provider/
registry.rs1use agent_provider_claude::ClaudeProviderAdapter;
2use agent_provider_codex::CodexProviderAdapter;
3use agent_provider_gemini::GeminiProviderAdapter;
4use agent_runtime_core::provider::ProviderAdapterV1;
5use std::collections::BTreeMap;
6use std::fmt;
7
8pub const DEFAULT_PROVIDER_ID: &str = "codex";
9pub const PROVIDER_OVERRIDE_ENV: &str = "AGENTCTL_PROVIDER";
10
11pub struct ProviderRegistry {
12 providers: BTreeMap<String, Box<dyn ProviderAdapterV1>>,
13 default_provider_id: String,
14}
15
16impl ProviderRegistry {
17 pub fn with_builtins() -> Self {
18 let mut registry = Self::new(DEFAULT_PROVIDER_ID);
19 registry.register(CodexProviderAdapter::new());
20 registry.register(ClaudeProviderAdapter::new());
21 registry.register(GeminiProviderAdapter::new());
22 registry
23 }
24
25 pub fn new(default_provider_id: impl Into<String>) -> Self {
26 Self {
27 providers: BTreeMap::new(),
28 default_provider_id: default_provider_id.into(),
29 }
30 }
31
32 pub fn register<T>(&mut self, adapter: T)
33 where
34 T: ProviderAdapterV1 + 'static,
35 {
36 let provider_id = adapter.metadata().id;
37 self.providers.insert(provider_id, Box::new(adapter));
38 }
39
40 pub fn iter(&self) -> impl Iterator<Item = (&str, &dyn ProviderAdapterV1)> + '_ {
41 self.providers
42 .iter()
43 .map(|(provider_id, adapter)| (provider_id.as_str(), adapter.as_ref()))
44 }
45
46 pub fn get(&self, provider_id: &str) -> Option<&dyn ProviderAdapterV1> {
47 self.providers.get(provider_id).map(Box::as_ref)
48 }
49
50 pub fn default_provider_id(&self) -> Option<&str> {
51 if self.providers.is_empty() {
52 return None;
53 }
54
55 if self
56 .providers
57 .contains_key(self.default_provider_id.as_str())
58 {
59 return Some(self.default_provider_id.as_str());
60 }
61
62 self.providers.keys().next().map(String::as_str)
63 }
64
65 pub fn resolve_selection(
66 &self,
67 cli_override: Option<&str>,
68 ) -> Result<ProviderSelection, ResolveProviderError> {
69 let env_override = std::env::var(PROVIDER_OVERRIDE_ENV).ok();
70 self.resolve_selection_with_env(cli_override, env_override.as_deref())
71 }
72
73 pub fn resolve_selection_with_env(
74 &self,
75 cli_override: Option<&str>,
76 env_override: Option<&str>,
77 ) -> Result<ProviderSelection, ResolveProviderError> {
78 if let Some(provider_id) = normalize_provider_id(cli_override) {
79 return self.resolve_override(provider_id, ProviderSelectionSource::CliArgument);
80 }
81
82 if let Some(provider_id) = normalize_provider_id(env_override) {
83 return self.resolve_override(provider_id, ProviderSelectionSource::Environment);
84 }
85
86 let provider_id = self
87 .default_provider_id()
88 .ok_or(ResolveProviderError::NoProvidersRegistered)?;
89 Ok(ProviderSelection {
90 provider_id: provider_id.to_string(),
91 source: ProviderSelectionSource::Default,
92 })
93 }
94
95 fn resolve_override(
96 &self,
97 provider_id: &str,
98 source: ProviderSelectionSource,
99 ) -> Result<ProviderSelection, ResolveProviderError> {
100 if !self.providers.contains_key(provider_id) {
101 return Err(ResolveProviderError::UnknownProvider {
102 provider_id: provider_id.to_string(),
103 source,
104 });
105 }
106
107 Ok(ProviderSelection {
108 provider_id: provider_id.to_string(),
109 source,
110 })
111 }
112}
113
114impl Default for ProviderRegistry {
115 fn default() -> Self {
116 Self::with_builtins()
117 }
118}
119
120#[derive(Debug, Clone, Copy, PartialEq, Eq)]
121pub enum ProviderSelectionSource {
122 CliArgument,
123 Environment,
124 Default,
125}
126
127impl ProviderSelectionSource {
128 pub const fn as_str(self) -> &'static str {
129 match self {
130 Self::CliArgument => "cli-argument",
131 Self::Environment => "environment",
132 Self::Default => "default",
133 }
134 }
135}
136
137#[derive(Debug, Clone, PartialEq, Eq)]
138pub struct ProviderSelection {
139 pub provider_id: String,
140 pub source: ProviderSelectionSource,
141}
142
143#[derive(Debug, Clone, PartialEq, Eq)]
144pub enum ResolveProviderError {
145 NoProvidersRegistered,
146 UnknownProvider {
147 provider_id: String,
148 source: ProviderSelectionSource,
149 },
150}
151
152impl fmt::Display for ResolveProviderError {
153 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
154 match self {
155 Self::NoProvidersRegistered => f.write_str("no providers are registered"),
156 Self::UnknownProvider {
157 provider_id,
158 source,
159 } => write!(
160 f,
161 "unknown provider '{}' from {} override",
162 provider_id,
163 source.as_str()
164 ),
165 }
166 }
167}
168
169impl std::error::Error for ResolveProviderError {}
170
171fn normalize_provider_id(raw: Option<&str>) -> Option<&str> {
172 raw.map(str::trim).filter(|value| !value.is_empty())
173}