1use std::path::Path;
8
9use anyhow::bail;
10use yoagent::provider::model::ModelConfig;
11
12pub mod anthropic;
13pub mod compat;
14pub mod generate_models;
15pub mod models;
16pub mod oauth;
17pub mod openai_compat;
18
19#[derive(Debug, Clone)]
21pub struct ResolvedModel {
22 pub model_config: ModelConfig,
24 pub api_key: String,
26}
27
28pub struct ProviderRegistry {
30 entries: Vec<models::ProviderEntry>,
31 auth_storage: crate::auth::AuthStorage,
33}
34
35impl ProviderRegistry {
36 pub fn load(agent_dir: &Path) -> anyhow::Result<Self> {
38 crate::provider::oauth::register_builtins();
40
41 let builtin_json = include_str!("models.json");
42 let builtin = models::load_builtin(builtin_json)?;
43
44 let user_path = agent_dir.join("models.json");
45 let user = models::load_user(&user_path)?;
46
47 let entries = models::merge(builtin, user);
48 let auth_storage = crate::auth::AuthStorage::load()?;
49
50 Ok(Self {
51 entries,
52 auth_storage,
53 })
54 }
55
56 pub fn reload(&mut self, agent_dir: &Path) -> anyhow::Result<()> {
58 let fresh = Self::load(agent_dir)?;
59 self.entries = fresh.entries;
60 self.auth_storage = fresh.auth_storage;
61 Ok(())
62 }
63
64 pub fn resolve(
70 &self,
71 model_id: &str,
72 preferred_provider: Option<&str>,
73 ) -> anyhow::Result<ResolvedModel> {
74 if let Some(preferred) = preferred_provider
76 && let Some(result) = self.resolve_from_provider(model_id, preferred)
77 {
78 return Ok(result);
79 }
80
81 for entry in &self.entries {
82 if let Some(model_config) = entry.models.iter().find(|m| m.id == model_id) {
83 let api_key = self
84 .auth_storage
85 .api_key(&entry.id)
86 .or_else(|| {
87 self.auth_storage.oauth_token(&entry.id)
89 })
90 .or_else(|| {
91 let env_var = entry.env_var_name();
93 std::env::var(env_var).ok()
94 })
95 .unwrap_or_default();
96
97 let mut model_config = model_config.clone();
98
99 if entry.id == "github-copilot" {
102 let enterprise_domain =
103 self.auth_storage
104 .oauth_credential(&entry.id)
105 .and_then(|c| match c {
106 crate::auth::AuthCredential::Oauth { enterprise_url, .. } => {
107 enterprise_url
108 }
109 _ => None,
110 });
111 let derived = crate::provider::oauth::github_copilot::get_copilot_base_url(
112 Some(&api_key),
113 enterprise_domain.as_deref(),
114 );
115 model_config.base_url = derived;
116 }
117
118 return Ok(ResolvedModel {
119 model_config,
120 api_key,
121 });
122 }
123 }
124
125 bail!(
126 "Unknown model '{}'. Available models: {}",
127 model_id,
128 self.list_models().join(", ")
129 );
130 }
131
132 fn resolve_from_provider(&self, model_id: &str, provider_id: &str) -> Option<ResolvedModel> {
135 let entry = self.entries.iter().find(|e| e.id == provider_id)?;
136 let mut model_config = entry.models.iter().find(|m| m.id == model_id)?.clone();
137 let api_key = self
138 .auth_storage
139 .api_key(provider_id)
140 .or_else(|| {
141 self.auth_storage.oauth_token(provider_id)
143 })
144 .or_else(|| {
145 let env_var = entry.env_var_name();
146 std::env::var(env_var).ok()
147 })
148 .unwrap_or_default();
149
150 if provider_id == "github-copilot" {
153 let enterprise_domain = self
154 .auth_storage
155 .oauth_credential(provider_id)
156 .and_then(|c| match c {
157 crate::auth::AuthCredential::Oauth { enterprise_url, .. } => enterprise_url,
158 _ => None,
159 });
160 let derived = crate::provider::oauth::github_copilot::get_copilot_base_url(
161 Some(&api_key),
162 enterprise_domain.as_deref(),
163 );
164 model_config.base_url = derived;
165 }
166
167 Some(ResolvedModel {
168 model_config,
169 api_key,
170 })
171 }
172
173 pub fn list_models(&self) -> Vec<String> {
177 let mut model_set = std::collections::BTreeSet::new();
178 for entry in &self.entries {
179 for m in &entry.models {
180 model_set.insert(m.id.clone());
181 }
182 }
183 model_set.into_iter().collect()
184 }
185
186 pub fn list_authenticated_model_ids(&self) -> Vec<String> {
189 let mut model_set = std::collections::BTreeSet::new();
190 for entry in &self.entries {
191 if self.provider_has_auth(&entry.id) {
192 for m in &entry.models {
193 model_set.insert(m.id.clone());
194 }
195 }
196 }
197 model_set.into_iter().collect()
198 }
199
200 pub fn list_model_provider_tuples(&self) -> Vec<(String, String, String)> {
204 let mut result = Vec::new();
205 for entry in &self.entries {
206 for m in &entry.models {
207 result.push((entry.id.clone(), m.id.clone(), m.name.clone()));
208 }
209 }
210 result
211 }
212
213 pub fn provider_for_model(
218 &self,
219 model_id: &str,
220 preferred_provider: Option<&str>,
221 ) -> Option<String> {
222 if let Some(preferred) = preferred_provider
224 && self
225 .entries
226 .iter()
227 .any(|e| e.id == preferred && e.models.iter().any(|m| m.id == model_id))
228 {
229 return Some(preferred.to_string());
230 }
231
232 for entry in &self.entries {
233 if entry.models.iter().any(|m| m.id == model_id) {
234 return Some(entry.id.clone());
235 }
236 }
237 None
238 }
239
240 pub fn api_key_for_provider(&self, provider_id: &str) -> Option<String> {
242 self.auth_storage.api_key(provider_id)
243 }
244
245 pub fn count_providers(&self) -> usize {
247 self.entries.len()
248 }
249
250 pub fn list_providers(&self) -> Vec<(String, String)> {
252 self.entries
253 .iter()
254 .map(|e| (e.id.clone(), e.name.clone()))
255 .collect()
256 }
257
258 pub fn configured_providers(&self) -> Vec<String> {
260 self.entries
261 .iter()
262 .filter_map(|e| {
263 if self.auth_storage.api_key(&e.id).is_some() {
264 Some(e.id.clone())
265 } else {
266 None
267 }
268 })
269 .collect()
270 }
271
272 pub fn provider_has_auth(&self, provider_id: &str) -> bool {
274 if self.auth_storage.api_key(provider_id).is_some()
275 || self.auth_storage.oauth_token(provider_id).is_some()
276 {
277 return true;
278 }
279 if crate::provider::oauth::is_built_in(provider_id) {
281 return self.auth_storage.oauth_token(provider_id).is_some();
282 }
283 self.entries
285 .iter()
286 .find(|e| e.id == provider_id)
287 .and_then(|e| {
288 let env_name = e.env_var_name();
289 if std::env::var(env_name).is_ok() {
290 Some(())
291 } else {
292 None
293 }
294 })
295 .is_some()
296 }
297
298 pub fn auth_status_for_provider(
300 &self,
301 provider_id: &str,
302 ) -> crate::agent::ui::components::oauth_selector::ProviderAuthStatus {
303 let has_stored = self.auth_storage.api_key(provider_id).is_some()
304 || self.auth_storage.oauth_token(provider_id).is_some();
305
306 let env_var = self
308 .entries
309 .iter()
310 .find(|e| e.id == provider_id)
311 .and_then(|e| {
312 let env_name = e.env_var_name();
313 if std::env::var(env_name).is_ok() {
314 Some(env_name.to_string())
315 } else {
316 None
317 }
318 });
319
320 let configured = has_stored || env_var.is_some();
321 let (source, label) = if has_stored {
322 (Some("stored".to_string()), None)
323 } else if let Some(env) = env_var {
324 (Some("environment".to_string()), Some(env))
325 } else {
326 (None, None)
327 };
328
329 crate::agent::ui::components::oauth_selector::ProviderAuthStatus {
330 configured,
331 source,
332 label,
333 }
334 }
335}
336
337pub fn get_agent_dir() -> std::path::PathBuf {
339 directories::BaseDirs::new()
340 .map(|d| d.home_dir().join(".rab").join("agent"))
341 .unwrap_or_else(|| std::path::PathBuf::from("/tmp/.rab/agent"))
342}