Skip to main content

rho_core/
models.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
4#[serde(rename_all = "lowercase")]
5pub enum ProviderType {
6    Anthropic,
7    #[serde(alias = "openai-compatible")]
8    OpenAi,
9    #[serde(alias = "xai-responses")]
10    XaiResponses,
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ModelConfig {
15    /// User-facing ID (e.g. "gpt-4o", "claude-sonnet")
16    pub id: String,
17    pub provider: ProviderType,
18    /// Wire model ID sent to the API
19    pub model_id: String,
20    /// Empty = use provider default
21    #[serde(default)]
22    pub base_url: String,
23    /// Env var name holding the API key (e.g. "OPENAI_API_KEY")
24    pub api_key_env: Option<String>,
25    #[serde(default = "default_context_window")]
26    pub context_window: usize,
27    #[serde(default = "default_max_tokens")]
28    pub max_tokens: usize,
29    /// Whether the model supports extended thinking
30    #[serde(default)]
31    pub thinking: bool,
32    /// Provider-managed tools to inject into the request (e.g. xAI's "web_search", "x_search").
33    /// These are not executed locally — they run on the provider's servers.
34    #[serde(default)]
35    pub server_tools: Option<Vec<String>>,
36}
37
38fn default_context_window() -> usize {
39    200_000
40}
41fn default_max_tokens() -> usize {
42    8_192
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct ModelsFile {
47    #[serde(default, rename = "model")]
48    pub models: Vec<ModelConfig>,
49}
50
51pub struct ModelRegistry {
52    models: Vec<ModelConfig>,
53}
54
55impl Default for ModelRegistry {
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61impl ModelRegistry {
62    /// Create registry with built-in Anthropic defaults.
63    pub fn new() -> Self {
64        Self {
65            models: built_in_models(),
66        }
67    }
68
69    /// Load from `~/.rho/models.toml`, merging with built-ins.
70    /// User configs override built-ins by id.
71    pub fn load() -> Self {
72        let mut registry = Self::new();
73
74        if let Some(home) = dirs::home_dir() {
75            let path = home.join(".rho").join("models.toml");
76            if path.is_file() {
77                match std::fs::read_to_string(&path) {
78                    Ok(content) => match toml::from_str::<ModelsFile>(&content) {
79                        Ok(file) => {
80                            for user_model in file.models {
81                                if let Some(existing) =
82                                    registry.models.iter_mut().find(|m| m.id == user_model.id)
83                                {
84                                    *existing = user_model;
85                                } else {
86                                    registry.models.push(user_model);
87                                }
88                            }
89                        }
90                        Err(e) => {
91                            tracing::warn!("Failed to parse ~/.rho/models.toml: {}", e);
92                        }
93                    },
94                    Err(e) => {
95                        tracing::warn!("Failed to read ~/.rho/models.toml: {}", e);
96                    }
97                }
98            }
99        }
100
101        registry
102    }
103
104    pub fn get(&self, id: &str) -> Option<&ModelConfig> {
105        self.models.iter().find(|m| m.id == id)
106    }
107
108    pub fn list(&self) -> &[ModelConfig] {
109        &self.models
110    }
111
112    /// Convert a `ModelConfig` to the runtime `Model` type.
113    pub fn to_model(config: &ModelConfig) -> crate::types::Model {
114        crate::types::Model {
115            id: config.model_id.clone(),
116            name: config.id.clone(),
117            provider: match config.provider {
118                ProviderType::Anthropic => "anthropic".into(),
119                ProviderType::OpenAi => "openai".into(),
120                ProviderType::XaiResponses => "xai-responses".into(),
121            },
122            base_url: config.base_url.clone(),
123            reasoning: config.thinking,
124            context_window: config.context_window,
125            max_tokens: config.max_tokens,
126        }
127    }
128
129    /// Resolve the API key for a model config.
130    ///
131    /// Resolution order:
132    /// 1. `api_key_env` env var (if set and non-empty)
133    /// 2. `anthropic-auth` keychain / OAuth (for Anthropic provider only)
134    /// 3. `"local"` (for localhost/127.0.0.1 base URLs — e.g. Ollama)
135    /// 4. Error
136    pub fn resolve_api_key(config: &ModelConfig) -> Result<String, String> {
137        // 1. Try designated env var
138        if let Some(ref env_var) = config.api_key_env {
139            if let Ok(val) = std::env::var(env_var) {
140                if !val.is_empty() {
141                    return Ok(val);
142                }
143            }
144        }
145
146        // 2. For Anthropic, try keychain / OAuth credentials
147        if config.provider == ProviderType::Anthropic {
148            if let Ok(token) = crate::auth::get_token() {
149                return Ok(token);
150            }
151        }
152
153        // 3. For localhost endpoints (Ollama, etc.), no auth needed
154        if config.base_url.contains("localhost") || config.base_url.contains("127.0.0.1") {
155            return Ok("local".into());
156        }
157
158        Err(format!(
159            "No API key found for model '{}'. Set the {} environment variable.",
160            config.id,
161            config
162                .api_key_env
163                .as_deref()
164                .unwrap_or("appropriate API key env var")
165        ))
166    }
167}
168
169fn built_in_models() -> Vec<ModelConfig> {
170    vec![
171        // Anthropic
172        ModelConfig {
173            id: "claude-sonnet".into(),
174            provider: ProviderType::Anthropic,
175            model_id: "claude-sonnet-4-5-20250929".into(),
176            base_url: String::new(),
177            api_key_env: Some("ANTHROPIC_API_KEY".into()),
178            context_window: 200_000,
179            max_tokens: 8_192,
180            thinking: false,
181            server_tools: None,
182        },
183        ModelConfig {
184            id: "claude-opus".into(),
185            provider: ProviderType::Anthropic,
186            model_id: "claude-opus-4-6".into(),
187            base_url: String::new(),
188            api_key_env: Some("ANTHROPIC_API_KEY".into()),
189            context_window: 200_000,
190            max_tokens: 8_192,
191            thinking: true,
192            server_tools: None,
193        },
194        ModelConfig {
195            id: "claude-haiku".into(),
196            provider: ProviderType::Anthropic,
197            model_id: "claude-haiku-4-5-20251001".into(),
198            base_url: String::new(),
199            api_key_env: Some("ANTHROPIC_API_KEY".into()),
200            context_window: 200_000,
201            max_tokens: 8_192,
202            thinking: false,
203            server_tools: None,
204        },
205        // xAI (Grok) — OpenAI-compatible endpoint
206        ModelConfig {
207            id: "grok-3".into(),
208            provider: ProviderType::OpenAi,
209            model_id: "grok-3".into(),
210            base_url: "https://api.x.ai/v1".into(),
211            api_key_env: Some("XAI_API_KEY".into()),
212            context_window: 131_072,
213            max_tokens: 16_384,
214            thinking: false,
215            server_tools: None,
216        },
217        ModelConfig {
218            id: "grok-3-mini".into(),
219            provider: ProviderType::OpenAi,
220            model_id: "grok-3-mini".into(),
221            base_url: "https://api.x.ai/v1".into(),
222            api_key_env: Some("XAI_API_KEY".into()),
223            context_window: 131_072,
224            max_tokens: 8_192,
225            thinking: false,
226            server_tools: None,
227        },
228        ModelConfig {
229            id: "grok-2".into(),
230            provider: ProviderType::OpenAi,
231            model_id: "grok-2-1212".into(),
232            base_url: "https://api.x.ai/v1".into(),
233            api_key_env: Some("XAI_API_KEY".into()),
234            context_window: 32_768,
235            max_tokens: 8_192,
236            thinking: false,
237            server_tools: None,
238        },
239        // xAI Grok 4.20 experimental beta
240        ModelConfig {
241            id: "grok-4.20-reasoning".into(),
242            provider: ProviderType::OpenAi,
243            model_id: "grok-4.20-experimental-beta-0304-reasoning".into(),
244            base_url: "https://api.x.ai/v1".into(),
245            api_key_env: Some("XAI_API_KEY".into()),
246            context_window: 131_072,
247            max_tokens: 16_384,
248            thinking: true,
249            server_tools: None,
250        },
251        ModelConfig {
252            id: "grok-4.20-non-reasoning".into(),
253            provider: ProviderType::OpenAi,
254            model_id: "grok-4.20-experimental-beta-0304-non-reasoning".into(),
255            base_url: "https://api.x.ai/v1".into(),
256            api_key_env: Some("XAI_API_KEY".into()),
257            context_window: 131_072,
258            max_tokens: 16_384,
259            thinking: false,
260            server_tools: None,
261        },
262        ModelConfig {
263            id: "grok-4.20-multi-agent".into(),
264            provider: ProviderType::XaiResponses,
265            model_id: "grok-4.20-multi-agent-experimental-beta-0304".into(),
266            base_url: "https://api.x.ai/v1".into(),
267            api_key_env: Some("XAI_API_KEY".into()),
268            context_window: 131_072,
269            max_tokens: 16_384,
270            thinking: false,
271            server_tools: None,
272        },
273        // Additional xAI models
274        ModelConfig {
275            id: "grok-code-fast-1".into(),
276            provider: ProviderType::OpenAi,
277            model_id: "grok-code-fast-1".into(),
278            base_url: "https://api.x.ai/v1".into(),
279            api_key_env: Some("XAI_API_KEY".into()),
280            context_window: 131_072,
281            max_tokens: 16_384,
282            thinking: false,
283            server_tools: None,
284        },
285        ModelConfig {
286            id: "grok-4-1-reasoning".into(),
287            provider: ProviderType::OpenAi,
288            model_id: "grok-4-1-reasoning".into(),
289            base_url: "https://api.x.ai/v1".into(),
290            api_key_env: Some("XAI_API_KEY".into()),
291            context_window: 131_072,
292            max_tokens: 16_384,
293            thinking: true,
294            server_tools: None,
295        },
296        ModelConfig {
297            id: "grok-4.20-beta-0309-reasoning".into(),
298            provider: ProviderType::OpenAi,
299            model_id: "grok-4.20-beta-0309-reasoning".into(),
300            base_url: "https://api.x.ai/v1".into(),
301            api_key_env: Some("XAI_API_KEY".into()),
302            context_window: 131_072,
303            max_tokens: 16_384,
304            thinking: true,
305            server_tools: None,
306        },
307        ModelConfig {
308            id: "grok-4.20-multi-agent-beta-0309".into(),
309            provider: ProviderType::XaiResponses,
310            model_id: "grok-4.20-multi-agent-beta-0309".into(),
311            base_url: "https://api.x.ai/v1".into(),
312            api_key_env: Some("XAI_API_KEY".into()),
313            context_window: 131_072,
314            max_tokens: 16_384,
315            thinking: false,
316            server_tools: None,
317        },
318
319    ]
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn new_has_builtin_models() {
328        let registry = ModelRegistry::new();
329        // 3 Anthropic + 10 xAI/Grok
330        assert_eq!(registry.list().len(), 13);
331    }
332
333    #[test]
334    fn builtin_grok_models_use_openai_provider() {
335        let registry = ModelRegistry::new();
336        let grok = registry.get("grok-3").unwrap();
337        assert_eq!(grok.provider, ProviderType::OpenAi);
338        assert_eq!(grok.base_url, "https://api.x.ai/v1");
339        assert_eq!(grok.api_key_env.as_deref(), Some("XAI_API_KEY"));
340        assert_eq!(grok.model_id, "grok-3");
341    }
342
343    #[test]
344    fn get_builtin_model() {
345        let registry = ModelRegistry::new();
346        let m = registry.get("claude-sonnet").unwrap();
347        assert_eq!(m.model_id, "claude-sonnet-4-5-20250929");
348        assert_eq!(m.provider, ProviderType::Anthropic);
349        assert!(!m.thinking);
350    }
351
352    #[test]
353    fn get_claude_opus_has_thinking() {
354        let registry = ModelRegistry::new();
355        let m = registry.get("claude-opus").unwrap();
356        assert!(m.thinking);
357    }
358
359    #[test]
360    fn get_missing_returns_none() {
361        let registry = ModelRegistry::new();
362        assert!(registry.get("gpt-4o").is_none());
363    }
364
365    #[test]
366    fn to_model_maps_fields() {
367        let config = ModelConfig {
368            id: "test-model".into(),
369            provider: ProviderType::OpenAi,
370            model_id: "gpt-4o".into(),
371            base_url: "https://api.openai.com/v1".into(),
372            api_key_env: Some("OPENAI_API_KEY".into()),
373            context_window: 128_000,
374            max_tokens: 16_384,
375            thinking: false,
376            server_tools: None,
377        };
378        let model = ModelRegistry::to_model(&config);
379        assert_eq!(model.id, "gpt-4o");
380        assert_eq!(model.name, "test-model");
381        assert_eq!(model.provider, "openai");
382        assert_eq!(model.base_url, "https://api.openai.com/v1");
383        assert_eq!(model.context_window, 128_000);
384        assert_eq!(model.max_tokens, 16_384);
385    }
386
387    #[test]
388    fn resolve_api_key_from_env() {
389        let config = ModelConfig {
390            id: "test".into(),
391            provider: ProviderType::OpenAi,
392            model_id: "gpt-4o".into(),
393            base_url: String::new(),
394            api_key_env: Some("__RHO_TEST_KEY__".into()),
395            context_window: 128_000,
396            max_tokens: 8_192,
397            thinking: false,
398            server_tools: None,
399        };
400        std::env::set_var("__RHO_TEST_KEY__", "test-api-key-123");
401        let key = ModelRegistry::resolve_api_key(&config).unwrap();
402        assert_eq!(key, "test-api-key-123");
403        std::env::remove_var("__RHO_TEST_KEY__");
404    }
405
406    #[test]
407    fn resolve_api_key_localhost_returns_local() {
408        let config = ModelConfig {
409            id: "ollama".into(),
410            provider: ProviderType::OpenAi,
411            model_id: "llama3".into(),
412            base_url: "http://localhost:11434/v1".into(),
413            api_key_env: None,
414            context_window: 128_000,
415            max_tokens: 8_192,
416            thinking: false,
417            server_tools: None,
418        };
419        let key = ModelRegistry::resolve_api_key(&config).unwrap();
420        assert_eq!(key, "local");
421    }
422
423    #[test]
424    fn load_merges_user_toml_override() {
425        // This test only exercises the parsing logic, not file I/O
426        let toml_str = r#"
427[[model]]
428id = "claude-sonnet"
429provider = "anthropic"
430model_id = "claude-sonnet-4-6-custom"
431api_key_env = "ANTHROPIC_API_KEY"
432context_window = 200000
433max_tokens = 8192
434"#;
435        let file: ModelsFile = toml::from_str(toml_str).unwrap();
436        assert_eq!(file.models.len(), 1);
437        assert_eq!(file.models[0].model_id, "claude-sonnet-4-6-custom");
438    }
439
440    #[test]
441    fn load_parses_openai_provider() {
442        let toml_str = r#"
443[[model]]
444id = "gpt-4o"
445provider = "openai"
446model_id = "gpt-4o"
447api_key_env = "OPENAI_API_KEY"
448context_window = 128000
449max_tokens = 16384
450"#;
451        let file: ModelsFile = toml::from_str(toml_str).unwrap();
452        assert_eq!(file.models[0].provider, ProviderType::OpenAi);
453    }
454}