1use std::path::Path;
8
9use anyhow::bail;
10use yoagent::provider::model::{CostConfig, ModelConfig};
11use yoagent::types::Usage;
12
13pub mod anthropic;
14pub mod compat;
15pub mod generate_models;
16pub mod models;
17pub mod oauth;
18pub mod openai_compat;
19
20#[derive(Debug, Clone)]
22pub struct ResolvedModel {
23 pub model_config: ModelConfig,
25 pub api_key: String,
27}
28
29pub struct ProviderRegistry {
31 entries: Vec<models::ProviderEntry>,
32 auth_storage: crate::auth::AuthStorage,
34}
35
36impl ProviderRegistry {
37 pub fn load(agent_dir: &Path) -> anyhow::Result<Self> {
39 crate::provider::oauth::register_builtins();
41
42 let builtin_json = include_str!("models.json");
43 let builtin = models::load_builtin(builtin_json)?;
44
45 let user_path = agent_dir.join("models.json");
46 let user = models::load_user(&user_path)?;
47
48 let entries = models::merge(builtin, user);
49 let auth_storage = crate::auth::AuthStorage::load()?;
50
51 Ok(Self {
52 entries,
53 auth_storage,
54 })
55 }
56
57 pub fn reload(&mut self, agent_dir: &Path) -> anyhow::Result<()> {
59 let fresh = Self::load(agent_dir)?;
60 self.entries = fresh.entries;
61 self.auth_storage = fresh.auth_storage;
62 Ok(())
63 }
64
65 pub fn resolve(
71 &self,
72 model_id: &str,
73 preferred_provider: Option<&str>,
74 ) -> anyhow::Result<ResolvedModel> {
75 if let Some(preferred) = preferred_provider
77 && let Some(result) = self.resolve_from_provider(model_id, preferred)
78 {
79 return Ok(result);
80 }
81
82 for entry in &self.entries {
83 if let Some(model_config) = entry.models.iter().find(|m| m.id == model_id) {
84 let api_key = self
85 .auth_storage
86 .api_key(&entry.id)
87 .or_else(|| {
88 self.auth_storage.oauth_token(&entry.id)
90 })
91 .or_else(|| {
92 let env_var = entry.env_var_name();
94 std::env::var(env_var).ok()
95 })
96 .unwrap_or_default();
97
98 let mut model_config = model_config.clone();
99
100 if entry.id == "github-copilot" {
103 let enterprise_domain =
104 self.auth_storage
105 .oauth_credential(&entry.id)
106 .and_then(|c| match c {
107 crate::auth::AuthCredential::Oauth { enterprise_url, .. } => {
108 enterprise_url
109 }
110 _ => None,
111 });
112 let derived = crate::provider::oauth::github_copilot::get_copilot_base_url(
113 Some(&api_key),
114 enterprise_domain.as_deref(),
115 );
116 model_config.base_url = derived;
117 }
118
119 return Ok(ResolvedModel {
120 model_config,
121 api_key,
122 });
123 }
124 }
125
126 bail!(
127 "Unknown model '{}'. Available models: {}",
128 model_id,
129 self.list_models().join(", ")
130 );
131 }
132
133 fn resolve_from_provider(&self, model_id: &str, provider_id: &str) -> Option<ResolvedModel> {
136 let entry = self.entries.iter().find(|e| e.id == provider_id)?;
137 let mut model_config = entry.models.iter().find(|m| m.id == model_id)?.clone();
138 let api_key = self
139 .auth_storage
140 .api_key(provider_id)
141 .or_else(|| {
142 self.auth_storage.oauth_token(provider_id)
144 })
145 .or_else(|| {
146 let env_var = entry.env_var_name();
147 std::env::var(env_var).ok()
148 })
149 .unwrap_or_default();
150
151 if provider_id == "github-copilot" {
154 let enterprise_domain = self
155 .auth_storage
156 .oauth_credential(provider_id)
157 .and_then(|c| match c {
158 crate::auth::AuthCredential::Oauth { enterprise_url, .. } => enterprise_url,
159 _ => None,
160 });
161 let derived = crate::provider::oauth::github_copilot::get_copilot_base_url(
162 Some(&api_key),
163 enterprise_domain.as_deref(),
164 );
165 model_config.base_url = derived;
166 }
167
168 Some(ResolvedModel {
169 model_config,
170 api_key,
171 })
172 }
173
174 pub fn list_models(&self) -> Vec<String> {
178 let mut model_set = std::collections::BTreeSet::new();
179 for entry in &self.entries {
180 for m in &entry.models {
181 model_set.insert(m.id.clone());
182 }
183 }
184 model_set.into_iter().collect()
185 }
186
187 pub fn list_authenticated_model_ids(&self) -> Vec<String> {
190 let mut model_set = std::collections::BTreeSet::new();
191 for entry in &self.entries {
192 if self.provider_has_auth(&entry.id) {
193 for m in &entry.models {
194 model_set.insert(m.id.clone());
195 }
196 }
197 }
198 model_set.into_iter().collect()
199 }
200
201 pub fn list_model_provider_tuples(&self) -> Vec<(String, String, String)> {
205 let mut result = Vec::new();
206 for entry in &self.entries {
207 for m in &entry.models {
208 result.push((entry.id.clone(), m.id.clone(), m.name.clone()));
209 }
210 }
211 result
212 }
213
214 pub fn provider_for_model(
219 &self,
220 model_id: &str,
221 preferred_provider: Option<&str>,
222 ) -> Option<String> {
223 if let Some(preferred) = preferred_provider
225 && self
226 .entries
227 .iter()
228 .any(|e| e.id == preferred && e.models.iter().any(|m| m.id == model_id))
229 {
230 return Some(preferred.to_string());
231 }
232
233 for entry in &self.entries {
234 if entry.models.iter().any(|m| m.id == model_id) {
235 return Some(entry.id.clone());
236 }
237 }
238 None
239 }
240
241 pub fn api_key_for_provider(&self, provider_id: &str) -> Option<String> {
243 self.auth_storage.api_key(provider_id)
244 }
245
246 pub fn count_providers(&self) -> usize {
248 self.entries.len()
249 }
250
251 pub fn list_providers(&self) -> Vec<(String, String)> {
253 self.entries
254 .iter()
255 .map(|e| (e.id.clone(), e.name.clone()))
256 .collect()
257 }
258
259 pub fn configured_providers(&self) -> Vec<String> {
261 self.entries
262 .iter()
263 .filter_map(|e| {
264 if self.auth_storage.api_key(&e.id).is_some() {
265 Some(e.id.clone())
266 } else {
267 None
268 }
269 })
270 .collect()
271 }
272
273 pub fn provider_has_auth(&self, provider_id: &str) -> bool {
275 if self.auth_storage.api_key(provider_id).is_some()
276 || self.auth_storage.oauth_token(provider_id).is_some()
277 {
278 return true;
279 }
280 if crate::provider::oauth::is_built_in(provider_id) {
282 return self.auth_storage.oauth_token(provider_id).is_some();
283 }
284 self.entries
286 .iter()
287 .find(|e| e.id == provider_id)
288 .and_then(|e| {
289 let env_name = e.env_var_name();
290 if std::env::var(env_name).is_ok() {
291 Some(())
292 } else {
293 None
294 }
295 })
296 .is_some()
297 }
298
299 pub fn auth_status_for_provider(
301 &self,
302 provider_id: &str,
303 ) -> crate::agent::ui::components::oauth_selector::ProviderAuthStatus {
304 let has_stored = self.auth_storage.api_key(provider_id).is_some()
305 || self.auth_storage.oauth_token(provider_id).is_some();
306
307 let env_var = self
309 .entries
310 .iter()
311 .find(|e| e.id == provider_id)
312 .and_then(|e| {
313 let env_name = e.env_var_name();
314 if std::env::var(env_name).is_ok() {
315 Some(env_name.to_string())
316 } else {
317 None
318 }
319 });
320
321 let configured = has_stored || env_var.is_some();
322 let (source, label) = if has_stored {
323 (Some("stored".to_string()), None)
324 } else if let Some(env) = env_var {
325 (Some("environment".to_string()), Some(env))
326 } else {
327 (None, None)
328 };
329
330 crate::agent::ui::components::oauth_selector::ProviderAuthStatus {
331 configured,
332 source,
333 label,
334 }
335 }
336}
337
338pub fn calculate_cost(cost_config: &CostConfig, usage: &Usage) -> (f64, f64, f64, f64, f64) {
346 let input_cost = (cost_config.input_per_million / 1_000_000.0) * usage.input as f64;
347 let output_cost = (cost_config.output_per_million / 1_000_000.0) * usage.output as f64;
348 let cache_read_cost =
349 (cost_config.cache_read_per_million / 1_000_000.0) * usage.cache_read as f64;
350 let cache_write_cost =
351 (cost_config.cache_write_per_million / 1_000_000.0) * usage.cache_write as f64;
352 let total = input_cost + output_cost + cache_read_cost + cache_write_cost;
353 (
354 input_cost,
355 output_cost,
356 cache_read_cost,
357 cache_write_cost,
358 total,
359 )
360}
361
362pub fn get_agent_dir() -> std::path::PathBuf {
364 directories::BaseDirs::new()
365 .map(|d| d.home_dir().join(".rab").join("agent"))
366 .unwrap_or_else(|| std::path::PathBuf::from("/tmp/.rab/agent"))
367}