Skip to main content

rho_core/
models.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
4#[serde(rename_all = "lowercase")]
5pub enum ProviderType {
6    Anthropic,
7    #[serde(alias = "openai-compatible")]
8    OpenAi,
9    #[serde(alias = "xai-responses")]
10    XaiResponses,
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ModelConfig {
15    /// User-facing ID (e.g. "gpt-4o", "claude-sonnet")
16    pub id: String,
17    pub provider: ProviderType,
18    /// Wire model ID sent to the API
19    pub model_id: String,
20    /// Empty = use provider default
21    #[serde(default)]
22    pub base_url: String,
23    /// Env var name holding the API key (e.g. "OPENAI_API_KEY")
24    pub api_key_env: Option<String>,
25    #[serde(default = "default_context_window")]
26    pub context_window: usize,
27    #[serde(default = "default_max_tokens")]
28    pub max_tokens: usize,
29    /// Whether the model supports extended thinking
30    #[serde(default)]
31    pub thinking: bool,
32    /// Provider-managed tools to inject into the request (e.g. xAI's "web_search", "x_search").
33    /// These are not executed locally — they run on the provider's servers.
34    #[serde(default)]
35    pub server_tools: Option<Vec<String>>,
36}
37
38fn default_context_window() -> usize {
39    200_000
40}
41fn default_max_tokens() -> usize {
42    8_192
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct ModelsFile {
47    #[serde(default, rename = "model")]
48    pub models: Vec<ModelConfig>,
49}
50
51pub struct ModelRegistry {
52    models: Vec<ModelConfig>,
53}
54
55impl Default for ModelRegistry {
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61impl ModelRegistry {
62    /// Create registry with built-in Anthropic defaults.
63    pub fn new() -> Self {
64        Self {
65            models: built_in_models(),
66        }
67    }
68
69    /// Load from `~/.rho/models.toml`, merging with built-ins.
70    /// User configs override built-ins by id.
71    pub fn load() -> Self {
72        let mut registry = Self::new();
73
74        if let Some(home) = dirs::home_dir() {
75            let path = home.join(".rho").join("models.toml");
76            if path.is_file() {
77                match std::fs::read_to_string(&path) {
78                    Ok(content) => match toml::from_str::<ModelsFile>(&content) {
79                        Ok(file) => {
80                            for user_model in file.models {
81                                if let Some(existing) =
82                                    registry.models.iter_mut().find(|m| m.id == user_model.id)
83                                {
84                                    *existing = user_model;
85                                } else {
86                                    registry.models.push(user_model);
87                                }
88                            }
89                        }
90                        Err(e) => {
91                            tracing::warn!("Failed to parse ~/.rho/models.toml: {}", e);
92                        }
93                    },
94                    Err(e) => {
95                        tracing::warn!("Failed to read ~/.rho/models.toml: {}", e);
96                    }
97                }
98            }
99        }
100
101        load_zen_models(&mut registry.models);
102
103        registry
104    }
105
106    pub fn get(&self, id: &str) -> Option<&ModelConfig> {
107        self.models.iter().find(|m| m.id == id)
108    }
109
110    pub fn list(&self) -> &[ModelConfig] {
111        &self.models
112    }
113
114    /// Convert a `ModelConfig` to the runtime `Model` type.
115    pub fn to_model(config: &ModelConfig) -> crate::types::Model {
116        crate::types::Model {
117            id: config.model_id.clone(),
118            name: config.id.clone(),
119            provider: match config.provider {
120                ProviderType::Anthropic => "anthropic".into(),
121                ProviderType::OpenAi => "openai".into(),
122                ProviderType::XaiResponses => "xai-responses".into(),
123            },
124            base_url: config.base_url.clone(),
125            reasoning: config.thinking,
126            context_window: config.context_window,
127            max_tokens: config.max_tokens,
128        }
129    }
130
131    /// Resolve the API key for a model config.
132    ///
133    /// Resolution order:
134    /// 1. `api_key_env` env var (if set and non-empty)
135    /// 2. `anthropic-auth` keychain / OAuth (for Anthropic provider only)
136    /// 3. `"local"` (for localhost/127.0.0.1 base URLs — e.g. Ollama)
137    /// 4. Error
138    pub fn resolve_api_key(config: &ModelConfig) -> Result<String, String> {
139        // 1. Try designated env var
140        if let Some(ref env_var) = config.api_key_env {
141            if let Ok(val) = std::env::var(env_var) {
142                if !val.is_empty() {
143                    return Ok(val);
144                }
145            }
146        }
147
148        // 2. For Anthropic, try keychain / OAuth credentials
149        if config.provider == ProviderType::Anthropic {
150            if let Ok(token) = crate::auth::get_token() {
151                return Ok(token);
152            }
153        }
154
155        // 3. For localhost endpoints (Ollama, etc.), no auth needed
156        if config.base_url.contains("localhost") || config.base_url.contains("127.0.0.1") {
157            return Ok("local".into());
158        }
159
160        Err(format!(
161            "No API key found for model '{}'. Set the {} environment variable.",
162            config.id,
163            config
164                .api_key_env
165                .as_deref()
166                .unwrap_or("appropriate API key env var")
167        ))
168    }
169}
170
171/// Load Zen models into the registry (only if OPENCODE_ZEN_API_KEY is set).
172fn load_zen_models(models: &mut Vec<ModelConfig>) {
173    if std::env::var("OPENCODE_ZEN_API_KEY")
174        .ok()
175        .filter(|v| !v.is_empty())
176        .is_none()
177    {
178        return;
179    }
180
181    let zen_ids = crate::zen::fetch_zen_models();
182    for model_id in zen_ids {
183        let registry_id = format!("zen-{}", model_id);
184        // Skip if user already defined this ID
185        if models.iter().any(|m| m.id == registry_id) {
186            continue;
187        }
188
189        let (provider, base_url) = if model_id.contains("claude") {
190            (ProviderType::Anthropic, "https://opencode.ai/zen".to_string())
191        } else {
192            (ProviderType::OpenAi, "https://opencode.ai/zen/v1".to_string())
193        };
194
195        models.push(ModelConfig {
196            id: registry_id,
197            provider,
198            model_id: model_id.clone(),
199            base_url,
200            api_key_env: Some("OPENCODE_ZEN_API_KEY".into()),
201            context_window: 200_000,
202            max_tokens: 16_384,
203            thinking: model_id.contains("opus"),
204            server_tools: None,
205        });
206    }
207}
208
209fn built_in_models() -> Vec<ModelConfig> {
210    vec![
211        // Anthropic
212        ModelConfig {
213            id: "claude-sonnet".into(),
214            provider: ProviderType::Anthropic,
215            model_id: "claude-sonnet-4-5-20250929".into(),
216            base_url: String::new(),
217            api_key_env: Some("ANTHROPIC_API_KEY".into()),
218            context_window: 200_000,
219            max_tokens: 8_192,
220            thinking: false,
221            server_tools: None,
222        },
223        ModelConfig {
224            id: "claude-opus".into(),
225            provider: ProviderType::Anthropic,
226            model_id: "claude-opus-4-6".into(),
227            base_url: String::new(),
228            api_key_env: Some("ANTHROPIC_API_KEY".into()),
229            context_window: 200_000,
230            max_tokens: 8_192,
231            thinking: true,
232            server_tools: None,
233        },
234        ModelConfig {
235            id: "claude-haiku".into(),
236            provider: ProviderType::Anthropic,
237            model_id: "claude-haiku-4-5-20251001".into(),
238            base_url: String::new(),
239            api_key_env: Some("ANTHROPIC_API_KEY".into()),
240            context_window: 200_000,
241            max_tokens: 8_192,
242            thinking: false,
243            server_tools: None,
244        },
245        // xAI (Grok) — OpenAI-compatible endpoint
246        ModelConfig {
247            id: "grok-3".into(),
248            provider: ProviderType::OpenAi,
249            model_id: "grok-3".into(),
250            base_url: "https://api.x.ai/v1".into(),
251            api_key_env: Some("XAI_API_KEY".into()),
252            context_window: 131_072,
253            max_tokens: 16_384,
254            thinking: false,
255            server_tools: None,
256        },
257        ModelConfig {
258            id: "grok-3-mini".into(),
259            provider: ProviderType::OpenAi,
260            model_id: "grok-3-mini".into(),
261            base_url: "https://api.x.ai/v1".into(),
262            api_key_env: Some("XAI_API_KEY".into()),
263            context_window: 131_072,
264            max_tokens: 8_192,
265            thinking: false,
266            server_tools: None,
267        },
268        ModelConfig {
269            id: "grok-2".into(),
270            provider: ProviderType::OpenAi,
271            model_id: "grok-2-1212".into(),
272            base_url: "https://api.x.ai/v1".into(),
273            api_key_env: Some("XAI_API_KEY".into()),
274            context_window: 32_768,
275            max_tokens: 8_192,
276            thinking: false,
277            server_tools: None,
278        },
279        // xAI Grok 4.20 experimental beta
280        ModelConfig {
281            id: "grok-4.20-reasoning".into(),
282            provider: ProviderType::OpenAi,
283            model_id: "grok-4.20-experimental-beta-0304-reasoning".into(),
284            base_url: "https://api.x.ai/v1".into(),
285            api_key_env: Some("XAI_API_KEY".into()),
286            context_window: 131_072,
287            max_tokens: 16_384,
288            thinking: true,
289            server_tools: None,
290        },
291        ModelConfig {
292            id: "grok-4.20-non-reasoning".into(),
293            provider: ProviderType::OpenAi,
294            model_id: "grok-4.20-experimental-beta-0304-non-reasoning".into(),
295            base_url: "https://api.x.ai/v1".into(),
296            api_key_env: Some("XAI_API_KEY".into()),
297            context_window: 131_072,
298            max_tokens: 16_384,
299            thinking: false,
300            server_tools: None,
301        },
302        ModelConfig {
303            id: "grok-4.20-multi-agent".into(),
304            provider: ProviderType::XaiResponses,
305            model_id: "grok-4.20-multi-agent-experimental-beta-0304".into(),
306            base_url: "https://api.x.ai/v1".into(),
307            api_key_env: Some("XAI_API_KEY".into()),
308            context_window: 131_072,
309            max_tokens: 16_384,
310            thinking: false,
311            server_tools: None,
312        },
313        // Additional xAI models
314        ModelConfig {
315            id: "grok-code-fast-1".into(),
316            provider: ProviderType::OpenAi,
317            model_id: "grok-code-fast-1".into(),
318            base_url: "https://api.x.ai/v1".into(),
319            api_key_env: Some("XAI_API_KEY".into()),
320            context_window: 131_072,
321            max_tokens: 16_384,
322            thinking: false,
323            server_tools: None,
324        },
325        ModelConfig {
326            id: "grok-4-1-reasoning".into(),
327            provider: ProviderType::OpenAi,
328            model_id: "grok-4-1-reasoning".into(),
329            base_url: "https://api.x.ai/v1".into(),
330            api_key_env: Some("XAI_API_KEY".into()),
331            context_window: 131_072,
332            max_tokens: 16_384,
333            thinking: true,
334            server_tools: None,
335        },
336        ModelConfig {
337            id: "grok-4.20-beta-0309-reasoning".into(),
338            provider: ProviderType::OpenAi,
339            model_id: "grok-4.20-beta-0309-reasoning".into(),
340            base_url: "https://api.x.ai/v1".into(),
341            api_key_env: Some("XAI_API_KEY".into()),
342            context_window: 131_072,
343            max_tokens: 16_384,
344            thinking: true,
345            server_tools: None,
346        },
347        ModelConfig {
348            id: "grok-4.20-multi-agent-beta-0309".into(),
349            provider: ProviderType::XaiResponses,
350            model_id: "grok-4.20-multi-agent-beta-0309".into(),
351            base_url: "https://api.x.ai/v1".into(),
352            api_key_env: Some("XAI_API_KEY".into()),
353            context_window: 131_072,
354            max_tokens: 16_384,
355            thinking: false,
356            server_tools: None,
357        },
358
359    ]
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    #[test]
367    fn new_has_builtin_models() {
368        let registry = ModelRegistry::new();
369        // 3 Anthropic + 10 xAI/Grok
370        assert_eq!(registry.list().len(), 13);
371    }
372
373    #[test]
374    fn builtin_grok_models_use_openai_provider() {
375        let registry = ModelRegistry::new();
376        let grok = registry.get("grok-3").unwrap();
377        assert_eq!(grok.provider, ProviderType::OpenAi);
378        assert_eq!(grok.base_url, "https://api.x.ai/v1");
379        assert_eq!(grok.api_key_env.as_deref(), Some("XAI_API_KEY"));
380        assert_eq!(grok.model_id, "grok-3");
381    }
382
383    #[test]
384    fn get_builtin_model() {
385        let registry = ModelRegistry::new();
386        let m = registry.get("claude-sonnet").unwrap();
387        assert_eq!(m.model_id, "claude-sonnet-4-5-20250929");
388        assert_eq!(m.provider, ProviderType::Anthropic);
389        assert!(!m.thinking);
390    }
391
392    #[test]
393    fn get_claude_opus_has_thinking() {
394        let registry = ModelRegistry::new();
395        let m = registry.get("claude-opus").unwrap();
396        assert!(m.thinking);
397    }
398
399    #[test]
400    fn get_missing_returns_none() {
401        let registry = ModelRegistry::new();
402        assert!(registry.get("gpt-4o").is_none());
403    }
404
405    #[test]
406    fn to_model_maps_fields() {
407        let config = ModelConfig {
408            id: "test-model".into(),
409            provider: ProviderType::OpenAi,
410            model_id: "gpt-4o".into(),
411            base_url: "https://api.openai.com/v1".into(),
412            api_key_env: Some("OPENAI_API_KEY".into()),
413            context_window: 128_000,
414            max_tokens: 16_384,
415            thinking: false,
416            server_tools: None,
417        };
418        let model = ModelRegistry::to_model(&config);
419        assert_eq!(model.id, "gpt-4o");
420        assert_eq!(model.name, "test-model");
421        assert_eq!(model.provider, "openai");
422        assert_eq!(model.base_url, "https://api.openai.com/v1");
423        assert_eq!(model.context_window, 128_000);
424        assert_eq!(model.max_tokens, 16_384);
425    }
426
427    #[test]
428    fn resolve_api_key_from_env() {
429        let config = ModelConfig {
430            id: "test".into(),
431            provider: ProviderType::OpenAi,
432            model_id: "gpt-4o".into(),
433            base_url: String::new(),
434            api_key_env: Some("__RHO_TEST_KEY__".into()),
435            context_window: 128_000,
436            max_tokens: 8_192,
437            thinking: false,
438            server_tools: None,
439        };
440        std::env::set_var("__RHO_TEST_KEY__", "test-api-key-123");
441        let key = ModelRegistry::resolve_api_key(&config).unwrap();
442        assert_eq!(key, "test-api-key-123");
443        std::env::remove_var("__RHO_TEST_KEY__");
444    }
445
446    #[test]
447    fn resolve_api_key_localhost_returns_local() {
448        let config = ModelConfig {
449            id: "ollama".into(),
450            provider: ProviderType::OpenAi,
451            model_id: "llama3".into(),
452            base_url: "http://localhost:11434/v1".into(),
453            api_key_env: None,
454            context_window: 128_000,
455            max_tokens: 8_192,
456            thinking: false,
457            server_tools: None,
458        };
459        let key = ModelRegistry::resolve_api_key(&config).unwrap();
460        assert_eq!(key, "local");
461    }
462
463    #[test]
464    fn load_merges_user_toml_override() {
465        // This test only exercises the parsing logic, not file I/O
466        let toml_str = r#"
467[[model]]
468id = "claude-sonnet"
469provider = "anthropic"
470model_id = "claude-sonnet-4-6-custom"
471api_key_env = "ANTHROPIC_API_KEY"
472context_window = 200000
473max_tokens = 8192
474"#;
475        let file: ModelsFile = toml::from_str(toml_str).unwrap();
476        assert_eq!(file.models.len(), 1);
477        assert_eq!(file.models[0].model_id, "claude-sonnet-4-6-custom");
478    }
479
480    #[test]
481    fn load_parses_openai_provider() {
482        let toml_str = r#"
483[[model]]
484id = "gpt-4o"
485provider = "openai"
486model_id = "gpt-4o"
487api_key_env = "OPENAI_API_KEY"
488context_window = 128000
489max_tokens = 16384
490"#;
491        let file: ModelsFile = toml::from_str(toml_str).unwrap();
492        assert_eq!(file.models[0].provider, ProviderType::OpenAi);
493    }
494}