Skip to main content

rab/provider/
mod.rs

1//! Provider and model system.
2//!
3//! Loads a built-in model catalog from `models.json` in this directory,
4//! overlays user overrides from `~/.rab/agent/models.json`,
5//! and provides the right `StreamProvider` for each model's API protocol.
6
7use std::path::Path;
8
9use anyhow::bail;
10use yoagent::provider::model::{CostConfig, ModelConfig};
11use yoagent::types::Usage;
12
13pub mod anthropic;
14pub mod compat;
15pub mod generate_models;
16pub mod models;
17pub mod oauth;
18pub mod openai_compat;
19
20/// A resolved model ready for use by the agent.
21#[derive(Debug, Clone)]
22pub struct ResolvedModel {
23    /// The yoagent ModelConfig with correct base URL, compat, pricing, etc.
24    pub model_config: ModelConfig,
25    /// The API key for this provider (from auth.json or env var).
26    pub api_key: String,
27}
28
29/// The provider registry — holds all known providers and their models.
30pub struct ProviderRegistry {
31    entries: Vec<models::ProviderEntry>,
32    /// Auth storage for API key lookups.
33    auth_storage: crate::auth::AuthStorage,
34}
35
36impl ProviderRegistry {
37    /// Load the provider registry from built-in + user models.json.
38    pub fn load(agent_dir: &Path) -> anyhow::Result<Self> {
39        // Register built-in OAuth providers once
40        crate::provider::oauth::register_builtins();
41
42        let builtin_json = include_str!("models.json");
43        let builtin = models::load_builtin(builtin_json)?;
44
45        let user_path = agent_dir.join("models.json");
46        let user = models::load_user(&user_path)?;
47
48        let entries = models::merge(builtin, user);
49        let auth_storage = crate::auth::AuthStorage::load()?;
50
51        Ok(Self {
52            entries,
53            auth_storage,
54        })
55    }
56
57    /// Reload from disk (for /reload support).
58    pub fn reload(&mut self, agent_dir: &Path) -> anyhow::Result<()> {
59        let fresh = Self::load(agent_dir)?;
60        self.entries = fresh.entries;
61        self.auth_storage = fresh.auth_storage;
62        Ok(())
63    }
64
65    /// Resolve a model ID (e.g. "deepseek-v4-flash") to a `ResolvedModel`.
66    ///
67    /// Scans all providers for a matching model ID. If `preferred_provider` is
68    /// set, that provider is checked first; otherwise returns the first match.
69    /// Also resolves the API key for that provider.
70    pub fn resolve(
71        &self,
72        model_id: &str,
73        preferred_provider: Option<&str>,
74    ) -> anyhow::Result<ResolvedModel> {
75        // Try preferred provider first when specified.
76        if let Some(preferred) = preferred_provider
77            && let Some(result) = self.resolve_from_provider(model_id, preferred)
78        {
79            return Ok(result);
80        }
81
82        for entry in &self.entries {
83            if let Some(model_config) = entry.models.iter().find(|m| m.id == model_id) {
84                let api_key = self
85                    .auth_storage
86                    .api_key(&entry.id)
87                    .or_else(|| {
88                        // Check for valid OAuth access token
89                        self.auth_storage.oauth_token(&entry.id)
90                    })
91                    .or_else(|| {
92                        // Fallback: check environment variable
93                        let env_var = entry.env_var_name();
94                        std::env::var(env_var).ok()
95                    })
96                    .unwrap_or_default();
97
98                let mut model_config = model_config.clone();
99
100                // For GitHub Copilot, derive the API base URL from the OAuth
101                // token's proxy-ep field (pi-compatible dynamic endpoint).
102                if entry.id == "github-copilot" {
103                    let enterprise_domain =
104                        self.auth_storage
105                            .oauth_credential(&entry.id)
106                            .and_then(|c| match c {
107                                crate::auth::AuthCredential::Oauth { enterprise_url, .. } => {
108                                    enterprise_url
109                                }
110                                _ => None,
111                            });
112                    let derived = crate::provider::oauth::github_copilot::get_copilot_base_url(
113                        Some(&api_key),
114                        enterprise_domain.as_deref(),
115                    );
116                    model_config.base_url = derived;
117                }
118
119                return Ok(ResolvedModel {
120                    model_config,
121                    api_key,
122                });
123            }
124        }
125
126        bail!(
127            "Unknown model '{}'. Available models: {}",
128            model_id,
129            self.list_models().join(", ")
130        );
131    }
132
133    /// Resolve from a specific provider. Returns `None` if the provider doesn't
134    /// exist or doesn't have the given model.
135    fn resolve_from_provider(&self, model_id: &str, provider_id: &str) -> Option<ResolvedModel> {
136        let entry = self.entries.iter().find(|e| e.id == provider_id)?;
137        let mut model_config = entry.models.iter().find(|m| m.id == model_id)?.clone();
138        let api_key = self
139            .auth_storage
140            .api_key(provider_id)
141            .or_else(|| {
142                // Check for valid OAuth access token
143                self.auth_storage.oauth_token(provider_id)
144            })
145            .or_else(|| {
146                let env_var = entry.env_var_name();
147                std::env::var(env_var).ok()
148            })
149            .unwrap_or_default();
150
151        // For GitHub Copilot, derive the API base URL from the OAuth
152        // token's proxy-ep field (pi-compatible dynamic endpoint).
153        if provider_id == "github-copilot" {
154            let enterprise_domain = self
155                .auth_storage
156                .oauth_credential(provider_id)
157                .and_then(|c| match c {
158                    crate::auth::AuthCredential::Oauth { enterprise_url, .. } => enterprise_url,
159                    _ => None,
160                });
161            let derived = crate::provider::oauth::github_copilot::get_copilot_base_url(
162                Some(&api_key),
163                enterprise_domain.as_deref(),
164            );
165            model_config.base_url = derived;
166        }
167
168        Some(ResolvedModel {
169            model_config,
170            api_key,
171        })
172    }
173
174    /// List all available model IDs (for UI selector and /model command).
175    /// Deduplicated: each model ID appears only once even if registered
176    /// under multiple providers.
177    pub fn list_models(&self) -> Vec<String> {
178        let mut model_set = std::collections::BTreeSet::new();
179        for entry in &self.entries {
180            for m in &entry.models {
181                model_set.insert(m.id.clone());
182            }
183        }
184        model_set.into_iter().collect()
185    }
186
187    /// List model IDs from providers that have valid authentication.
188    /// Used by the model cycle and selector to hide unconfigured providers.
189    pub fn list_authenticated_model_ids(&self) -> Vec<String> {
190        let mut model_set = std::collections::BTreeSet::new();
191        for entry in &self.entries {
192            if self.provider_has_auth(&entry.id) {
193                for m in &entry.models {
194                    model_set.insert(m.id.clone());
195                }
196            }
197        }
198        model_set.into_iter().collect()
199    }
200
201    /// List all (provider, model_id, model_name) tuples, one per provider entry.
202    /// Unlike `list_models()`, the same model ID can appear under multiple
203    /// providers. Used by the model selector to show provider-prefixed entries.
204    pub fn list_model_provider_tuples(&self) -> Vec<(String, String, String)> {
205        let mut result = Vec::new();
206        for entry in &self.entries {
207            for m in &entry.models {
208                result.push((entry.id.clone(), m.id.clone(), m.name.clone()));
209            }
210        }
211        result
212    }
213
214    /// Get the provider name for a model ID.
215    ///
216    /// When `preferred_provider` is set and that provider has the model,
217    /// returns the preferred provider. Otherwise returns the first match.
218    pub fn provider_for_model(
219        &self,
220        model_id: &str,
221        preferred_provider: Option<&str>,
222    ) -> Option<String> {
223        // Try preferred provider first.
224        if let Some(preferred) = preferred_provider
225            && self
226                .entries
227                .iter()
228                .any(|e| e.id == preferred && e.models.iter().any(|m| m.id == model_id))
229        {
230            return Some(preferred.to_string());
231        }
232
233        for entry in &self.entries {
234            if entry.models.iter().any(|m| m.id == model_id) {
235                return Some(entry.id.clone());
236            }
237        }
238        None
239    }
240
241    /// Get the API key for a provider.
242    pub fn api_key_for_provider(&self, provider_id: &str) -> Option<String> {
243        self.auth_storage.api_key(provider_id)
244    }
245
246    /// Count the number of distinct providers in the registry.
247    pub fn count_providers(&self) -> usize {
248        self.entries.len()
249    }
250
251    /// List all provider (id, name) tuples.
252    pub fn list_providers(&self) -> Vec<(String, String)> {
253        self.entries
254            .iter()
255            .map(|e| (e.id.clone(), e.name.clone()))
256            .collect()
257    }
258
259    /// Get the list of provider IDs that have stored credentials.
260    pub fn configured_providers(&self) -> Vec<String> {
261        self.entries
262            .iter()
263            .filter_map(|e| {
264                if self.auth_storage.api_key(&e.id).is_some() {
265                    Some(e.id.clone())
266                } else {
267                    None
268                }
269            })
270            .collect()
271    }
272
273    /// Check whether a provider has valid authentication (stored credential or env var).
274    pub fn provider_has_auth(&self, provider_id: &str) -> bool {
275        if self.auth_storage.api_key(provider_id).is_some()
276            || self.auth_storage.oauth_token(provider_id).is_some()
277        {
278            return true;
279        }
280        // Check if this is an OAuth provider that could be logged in
281        if crate::provider::oauth::is_built_in(provider_id) {
282            return self.auth_storage.oauth_token(provider_id).is_some();
283        }
284        // Check env var
285        self.entries
286            .iter()
287            .find(|e| e.id == provider_id)
288            .and_then(|e| {
289                let env_name = e.env_var_name();
290                if std::env::var(env_name).is_ok() {
291                    Some(())
292                } else {
293                    None
294                }
295            })
296            .is_some()
297    }
298
299    /// Get auth status for a provider (for UI display).
300    pub fn auth_status_for_provider(
301        &self,
302        provider_id: &str,
303    ) -> crate::agent::ui::components::oauth_selector::ProviderAuthStatus {
304        let has_stored = self.auth_storage.api_key(provider_id).is_some()
305            || self.auth_storage.oauth_token(provider_id).is_some();
306
307        // Check env var
308        let env_var = self
309            .entries
310            .iter()
311            .find(|e| e.id == provider_id)
312            .and_then(|e| {
313                let env_name = e.env_var_name();
314                if std::env::var(env_name).is_ok() {
315                    Some(env_name.to_string())
316                } else {
317                    None
318                }
319            });
320
321        let configured = has_stored || env_var.is_some();
322        let (source, label) = if has_stored {
323            (Some("stored".to_string()), None)
324        } else if let Some(env) = env_var {
325            (Some("environment".to_string()), Some(env))
326        } else {
327            (None, None)
328        };
329
330        crate::agent::ui::components::oauth_selector::ProviderAuthStatus {
331            configured,
332            source,
333            label,
334        }
335    }
336}
337
338/// Calculate the USD cost components of a usage record given a model's cost config.
339///
340/// Matches pi's `calculateCost()` in `packages/ai/src/models.ts`.
341/// Returns `(input, output, cache_read, cache_write, total)`.
342///
343/// Note: pi also handles Anthropic's 1h cache write pricing (cacheWrite1h),
344/// but yoagent's Usage struct doesn't expose that field, so it's omitted here.
345pub fn calculate_cost(cost_config: &CostConfig, usage: &Usage) -> (f64, f64, f64, f64, f64) {
346    let input_cost = (cost_config.input_per_million / 1_000_000.0) * usage.input as f64;
347    let output_cost = (cost_config.output_per_million / 1_000_000.0) * usage.output as f64;
348    let cache_read_cost =
349        (cost_config.cache_read_per_million / 1_000_000.0) * usage.cache_read as f64;
350    let cache_write_cost =
351        (cost_config.cache_write_per_million / 1_000_000.0) * usage.cache_write as f64;
352    let total = input_cost + output_cost + cache_read_cost + cache_write_cost;
353    (
354        input_cost,
355        output_cost,
356        cache_read_cost,
357        cache_write_cost,
358        total,
359    )
360}
361
362/// Get the agent config directory (~/.rab/agent).
363pub fn get_agent_dir() -> std::path::PathBuf {
364    directories::BaseDirs::new()
365        .map(|d| d.home_dir().join(".rab").join("agent"))
366        .unwrap_or_else(|| std::path::PathBuf::from("/tmp/.rab/agent"))
367}