Skip to main content

rab/provider/
mod.rs

1//! Provider and model system.
2//!
3//! Loads a built-in model catalog from `models.json` in this directory,
4//! overlays user overrides from `~/.rab/agent/models.json`,
5//! and provides the right `StreamProvider` for each model's API protocol.
6
7use std::path::Path;
8
9use anyhow::bail;
10use yoagent::provider::model::ModelConfig;
11
12pub mod anthropic;
13pub mod compat;
14pub mod generate_models;
15pub mod models;
16pub mod oauth;
17pub mod openai_compat;
18
19/// A resolved model ready for use by the agent.
20#[derive(Debug, Clone)]
21pub struct ResolvedModel {
22    /// The yoagent ModelConfig with correct base URL, compat, pricing, etc.
23    pub model_config: ModelConfig,
24    /// The API key for this provider (from auth.json or env var).
25    pub api_key: String,
26}
27
28/// The provider registry — holds all known providers and their models.
29pub struct ProviderRegistry {
30    entries: Vec<models::ProviderEntry>,
31    /// Auth storage for API key lookups.
32    auth_storage: crate::auth::AuthStorage,
33}
34
35impl ProviderRegistry {
36    /// Load the provider registry from built-in + user models.json.
37    pub fn load(agent_dir: &Path) -> anyhow::Result<Self> {
38        // Register built-in OAuth providers once
39        crate::provider::oauth::register_builtins();
40
41        let builtin_json = include_str!("models.json");
42        let builtin = models::load_builtin(builtin_json)?;
43
44        let user_path = agent_dir.join("models.json");
45        let user = models::load_user(&user_path)?;
46
47        let entries = models::merge(builtin, user);
48        let auth_storage = crate::auth::AuthStorage::load()?;
49
50        Ok(Self {
51            entries,
52            auth_storage,
53        })
54    }
55
56    /// Reload from disk (for /reload support).
57    pub fn reload(&mut self, agent_dir: &Path) -> anyhow::Result<()> {
58        let fresh = Self::load(agent_dir)?;
59        self.entries = fresh.entries;
60        self.auth_storage = fresh.auth_storage;
61        Ok(())
62    }
63
64    /// Resolve a model ID (e.g. "deepseek-v4-flash") to a `ResolvedModel`.
65    ///
66    /// Scans all providers for a matching model ID. If `preferred_provider` is
67    /// set, that provider is checked first; otherwise returns the first match.
68    /// Also resolves the API key for that provider.
69    pub fn resolve(
70        &self,
71        model_id: &str,
72        preferred_provider: Option<&str>,
73    ) -> anyhow::Result<ResolvedModel> {
74        // Try preferred provider first when specified.
75        if let Some(preferred) = preferred_provider
76            && let Some(result) = self.resolve_from_provider(model_id, preferred)
77        {
78            return Ok(result);
79        }
80
81        for entry in &self.entries {
82            if let Some(model_config) = entry.models.iter().find(|m| m.id == model_id) {
83                let api_key = self
84                    .auth_storage
85                    .api_key(&entry.id)
86                    .or_else(|| {
87                        // Check for valid OAuth access token
88                        self.auth_storage.oauth_token(&entry.id)
89                    })
90                    .or_else(|| {
91                        // Fallback: check environment variable
92                        let env_var = entry.env_var_name();
93                        std::env::var(env_var).ok()
94                    })
95                    .unwrap_or_default();
96
97                let mut model_config = model_config.clone();
98
99                // For GitHub Copilot, derive the API base URL from the OAuth
100                // token's proxy-ep field (pi-compatible dynamic endpoint).
101                if entry.id == "github-copilot" {
102                    let enterprise_domain =
103                        self.auth_storage
104                            .oauth_credential(&entry.id)
105                            .and_then(|c| match c {
106                                crate::auth::AuthCredential::Oauth { enterprise_url, .. } => {
107                                    enterprise_url
108                                }
109                                _ => None,
110                            });
111                    let derived = crate::provider::oauth::github_copilot::get_copilot_base_url(
112                        Some(&api_key),
113                        enterprise_domain.as_deref(),
114                    );
115                    model_config.base_url = derived;
116                }
117
118                return Ok(ResolvedModel {
119                    model_config,
120                    api_key,
121                });
122            }
123        }
124
125        bail!(
126            "Unknown model '{}'. Available models: {}",
127            model_id,
128            self.list_models().join(", ")
129        );
130    }
131
132    /// Resolve from a specific provider. Returns `None` if the provider doesn't
133    /// exist or doesn't have the given model.
134    fn resolve_from_provider(&self, model_id: &str, provider_id: &str) -> Option<ResolvedModel> {
135        let entry = self.entries.iter().find(|e| e.id == provider_id)?;
136        let mut model_config = entry.models.iter().find(|m| m.id == model_id)?.clone();
137        let api_key = self
138            .auth_storage
139            .api_key(provider_id)
140            .or_else(|| {
141                // Check for valid OAuth access token
142                self.auth_storage.oauth_token(provider_id)
143            })
144            .or_else(|| {
145                let env_var = entry.env_var_name();
146                std::env::var(env_var).ok()
147            })
148            .unwrap_or_default();
149
150        // For GitHub Copilot, derive the API base URL from the OAuth
151        // token's proxy-ep field (pi-compatible dynamic endpoint).
152        if provider_id == "github-copilot" {
153            let enterprise_domain = self
154                .auth_storage
155                .oauth_credential(provider_id)
156                .and_then(|c| match c {
157                    crate::auth::AuthCredential::Oauth { enterprise_url, .. } => enterprise_url,
158                    _ => None,
159                });
160            let derived = crate::provider::oauth::github_copilot::get_copilot_base_url(
161                Some(&api_key),
162                enterprise_domain.as_deref(),
163            );
164            model_config.base_url = derived;
165        }
166
167        Some(ResolvedModel {
168            model_config,
169            api_key,
170        })
171    }
172
173    /// List all available model IDs (for UI selector and /model command).
174    /// Deduplicated: each model ID appears only once even if registered
175    /// under multiple providers.
176    pub fn list_models(&self) -> Vec<String> {
177        let mut model_set = std::collections::BTreeSet::new();
178        for entry in &self.entries {
179            for m in &entry.models {
180                model_set.insert(m.id.clone());
181            }
182        }
183        model_set.into_iter().collect()
184    }
185
186    /// List model IDs from providers that have valid authentication.
187    /// Used by the model cycle and selector to hide unconfigured providers.
188    pub fn list_authenticated_model_ids(&self) -> Vec<String> {
189        let mut model_set = std::collections::BTreeSet::new();
190        for entry in &self.entries {
191            if self.provider_has_auth(&entry.id) {
192                for m in &entry.models {
193                    model_set.insert(m.id.clone());
194                }
195            }
196        }
197        model_set.into_iter().collect()
198    }
199
200    /// List all (provider, model_id, model_name) tuples, one per provider entry.
201    /// Unlike `list_models()`, the same model ID can appear under multiple
202    /// providers. Used by the model selector to show provider-prefixed entries.
203    pub fn list_model_provider_tuples(&self) -> Vec<(String, String, String)> {
204        let mut result = Vec::new();
205        for entry in &self.entries {
206            for m in &entry.models {
207                result.push((entry.id.clone(), m.id.clone(), m.name.clone()));
208            }
209        }
210        result
211    }
212
213    /// Get the provider name for a model ID.
214    ///
215    /// When `preferred_provider` is set and that provider has the model,
216    /// returns the preferred provider. Otherwise returns the first match.
217    pub fn provider_for_model(
218        &self,
219        model_id: &str,
220        preferred_provider: Option<&str>,
221    ) -> Option<String> {
222        // Try preferred provider first.
223        if let Some(preferred) = preferred_provider
224            && self
225                .entries
226                .iter()
227                .any(|e| e.id == preferred && e.models.iter().any(|m| m.id == model_id))
228        {
229            return Some(preferred.to_string());
230        }
231
232        for entry in &self.entries {
233            if entry.models.iter().any(|m| m.id == model_id) {
234                return Some(entry.id.clone());
235            }
236        }
237        None
238    }
239
240    /// Get the API key for a provider.
241    pub fn api_key_for_provider(&self, provider_id: &str) -> Option<String> {
242        self.auth_storage.api_key(provider_id)
243    }
244
245    /// Count the number of distinct providers in the registry.
246    pub fn count_providers(&self) -> usize {
247        self.entries.len()
248    }
249
250    /// List all provider (id, name) tuples.
251    pub fn list_providers(&self) -> Vec<(String, String)> {
252        self.entries
253            .iter()
254            .map(|e| (e.id.clone(), e.name.clone()))
255            .collect()
256    }
257
258    /// Get the list of provider IDs that have stored credentials.
259    pub fn configured_providers(&self) -> Vec<String> {
260        self.entries
261            .iter()
262            .filter_map(|e| {
263                if self.auth_storage.api_key(&e.id).is_some() {
264                    Some(e.id.clone())
265                } else {
266                    None
267                }
268            })
269            .collect()
270    }
271
272    /// Check whether a provider has valid authentication (stored credential or env var).
273    pub fn provider_has_auth(&self, provider_id: &str) -> bool {
274        if self.auth_storage.api_key(provider_id).is_some()
275            || self.auth_storage.oauth_token(provider_id).is_some()
276        {
277            return true;
278        }
279        // Check if this is an OAuth provider that could be logged in
280        if crate::provider::oauth::is_built_in(provider_id) {
281            return self.auth_storage.oauth_token(provider_id).is_some();
282        }
283        // Check env var
284        self.entries
285            .iter()
286            .find(|e| e.id == provider_id)
287            .and_then(|e| {
288                let env_name = e.env_var_name();
289                if std::env::var(env_name).is_ok() {
290                    Some(())
291                } else {
292                    None
293                }
294            })
295            .is_some()
296    }
297
298    /// Get auth status for a provider (for UI display).
299    pub fn auth_status_for_provider(
300        &self,
301        provider_id: &str,
302    ) -> crate::agent::ui::components::oauth_selector::ProviderAuthStatus {
303        let has_stored = self.auth_storage.api_key(provider_id).is_some()
304            || self.auth_storage.oauth_token(provider_id).is_some();
305
306        // Check env var
307        let env_var = self
308            .entries
309            .iter()
310            .find(|e| e.id == provider_id)
311            .and_then(|e| {
312                let env_name = e.env_var_name();
313                if std::env::var(env_name).is_ok() {
314                    Some(env_name.to_string())
315                } else {
316                    None
317                }
318            });
319
320        let configured = has_stored || env_var.is_some();
321        let (source, label) = if has_stored {
322            (Some("stored".to_string()), None)
323        } else if let Some(env) = env_var {
324            (Some("environment".to_string()), Some(env))
325        } else {
326            (None, None)
327        };
328
329        crate::agent::ui::components::oauth_selector::ProviderAuthStatus {
330            configured,
331            source,
332            label,
333        }
334    }
335}
336
337/// Get the agent config directory (~/.rab/agent).
338pub fn get_agent_dir() -> std::path::PathBuf {
339    directories::BaseDirs::new()
340        .map(|d| d.home_dir().join(".rab").join("agent"))
341        .unwrap_or_else(|| std::path::PathBuf::from("/tmp/.rab/agent"))
342}