Skip to main content

rho_core/
models.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
4#[serde(rename_all = "lowercase")]
5pub enum ProviderType {
6    Anthropic,
7    #[serde(alias = "openai-compatible")]
8    OpenAi,
9}
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ModelConfig {
13    /// User-facing ID (e.g. "gpt-4o", "claude-sonnet")
14    pub id: String,
15    pub provider: ProviderType,
16    /// Wire model ID sent to the API
17    pub model_id: String,
18    /// Empty = use provider default
19    #[serde(default)]
20    pub base_url: String,
21    /// Env var name holding the API key (e.g. "OPENAI_API_KEY")
22    pub api_key_env: Option<String>,
23    #[serde(default = "default_context_window")]
24    pub context_window: usize,
25    #[serde(default = "default_max_tokens")]
26    pub max_tokens: usize,
27    /// Whether the model supports extended thinking
28    #[serde(default)]
29    pub thinking: bool,
30    /// Provider-managed tools to inject into the request (e.g. xAI's "web_search", "x_search").
31    /// These are not executed locally — they run on the provider's servers.
32    #[serde(default)]
33    pub server_tools: Option<Vec<String>>,
34}
35
36fn default_context_window() -> usize {
37    200_000
38}
39fn default_max_tokens() -> usize {
40    8_192
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ModelsFile {
45    #[serde(default, rename = "model")]
46    pub models: Vec<ModelConfig>,
47}
48
49pub struct ModelRegistry {
50    models: Vec<ModelConfig>,
51}
52
53impl Default for ModelRegistry {
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59impl ModelRegistry {
60    /// Create registry with built-in Anthropic defaults.
61    pub fn new() -> Self {
62        Self {
63            models: built_in_models(),
64        }
65    }
66
67    /// Load from `~/.rho/models.toml`, merging with built-ins.
68    /// User configs override built-ins by id.
69    pub fn load() -> Self {
70        let mut registry = Self::new();
71
72        if let Some(home) = dirs::home_dir() {
73            let path = home.join(".rho").join("models.toml");
74            if path.is_file() {
75                match std::fs::read_to_string(&path) {
76                    Ok(content) => match toml::from_str::<ModelsFile>(&content) {
77                        Ok(file) => {
78                            for user_model in file.models {
79                                if let Some(existing) =
80                                    registry.models.iter_mut().find(|m| m.id == user_model.id)
81                                {
82                                    *existing = user_model;
83                                } else {
84                                    registry.models.push(user_model);
85                                }
86                            }
87                        }
88                        Err(e) => {
89                            tracing::warn!("Failed to parse ~/.rho/models.toml: {}", e);
90                        }
91                    },
92                    Err(e) => {
93                        tracing::warn!("Failed to read ~/.rho/models.toml: {}", e);
94                    }
95                }
96            }
97        }
98
99        registry
100    }
101
102    pub fn get(&self, id: &str) -> Option<&ModelConfig> {
103        self.models.iter().find(|m| m.id == id)
104    }
105
106    pub fn list(&self) -> &[ModelConfig] {
107        &self.models
108    }
109
110    /// Convert a `ModelConfig` to the runtime `Model` type.
111    pub fn to_model(config: &ModelConfig) -> crate::types::Model {
112        crate::types::Model {
113            id: config.model_id.clone(),
114            name: config.id.clone(),
115            provider: match config.provider {
116                ProviderType::Anthropic => "anthropic".into(),
117                ProviderType::OpenAi => "openai".into(),
118            },
119            base_url: config.base_url.clone(),
120            reasoning: config.thinking,
121            context_window: config.context_window,
122            max_tokens: config.max_tokens,
123        }
124    }
125
126    /// Resolve the API key for a model config.
127    ///
128    /// Resolution order:
129    /// 1. `api_key_env` env var (if set and non-empty)
130    /// 2. `anthropic-auth` keychain / OAuth (for Anthropic provider only)
131    /// 3. `"local"` (for localhost/127.0.0.1 base URLs — e.g. Ollama)
132    /// 4. Error
133    pub fn resolve_api_key(config: &ModelConfig) -> Result<String, String> {
134        // 1. Try designated env var
135        if let Some(ref env_var) = config.api_key_env {
136            if let Ok(val) = std::env::var(env_var) {
137                if !val.is_empty() {
138                    return Ok(val);
139                }
140            }
141        }
142
143        // 2. For Anthropic, try keychain / OAuth credentials
144        if config.provider == ProviderType::Anthropic {
145            if let Ok(token) = crate::auth::get_token() {
146                return Ok(token);
147            }
148        }
149
150        // 3. For localhost endpoints (Ollama, etc.), no auth needed
151        if config.base_url.contains("localhost") || config.base_url.contains("127.0.0.1") {
152            return Ok("local".into());
153        }
154
155        Err(format!(
156            "No API key found for model '{}'. Set the {} environment variable.",
157            config.id,
158            config
159                .api_key_env
160                .as_deref()
161                .unwrap_or("appropriate API key env var")
162        ))
163    }
164}
165
166fn built_in_models() -> Vec<ModelConfig> {
167    vec![
168        // Anthropic
169        ModelConfig {
170            id: "claude-sonnet".into(),
171            provider: ProviderType::Anthropic,
172            model_id: "claude-sonnet-4-5-20250929".into(),
173            base_url: String::new(),
174            api_key_env: Some("ANTHROPIC_API_KEY".into()),
175            context_window: 200_000,
176            max_tokens: 8_192,
177            thinking: false,
178            server_tools: None,
179        },
180        ModelConfig {
181            id: "claude-opus".into(),
182            provider: ProviderType::Anthropic,
183            model_id: "claude-opus-4-6".into(),
184            base_url: String::new(),
185            api_key_env: Some("ANTHROPIC_API_KEY".into()),
186            context_window: 200_000,
187            max_tokens: 8_192,
188            thinking: true,
189            server_tools: None,
190        },
191        ModelConfig {
192            id: "claude-haiku".into(),
193            provider: ProviderType::Anthropic,
194            model_id: "claude-haiku-4-5-20251001".into(),
195            base_url: String::new(),
196            api_key_env: Some("ANTHROPIC_API_KEY".into()),
197            context_window: 200_000,
198            max_tokens: 8_192,
199            thinking: false,
200            server_tools: None,
201        },
202        // xAI (Grok) — OpenAI-compatible endpoint
203        ModelConfig {
204            id: "grok-3".into(),
205            provider: ProviderType::OpenAi,
206            model_id: "grok-3".into(),
207            base_url: "https://api.x.ai/v1".into(),
208            api_key_env: Some("XAI_API_KEY".into()),
209            context_window: 131_072,
210            max_tokens: 16_384,
211            thinking: false,
212            server_tools: None,
213        },
214        ModelConfig {
215            id: "grok-3-mini".into(),
216            provider: ProviderType::OpenAi,
217            model_id: "grok-3-mini".into(),
218            base_url: "https://api.x.ai/v1".into(),
219            api_key_env: Some("XAI_API_KEY".into()),
220            context_window: 131_072,
221            max_tokens: 8_192,
222            thinking: false,
223            server_tools: None,
224        },
225        ModelConfig {
226            id: "grok-2".into(),
227            provider: ProviderType::OpenAi,
228            model_id: "grok-2-1212".into(),
229            base_url: "https://api.x.ai/v1".into(),
230            api_key_env: Some("XAI_API_KEY".into()),
231            context_window: 32_768,
232            max_tokens: 8_192,
233            thinking: false,
234            server_tools: None,
235        },
236        // xAI Grok 4.20 experimental beta
237        ModelConfig {
238            id: "grok-4.20-reasoning".into(),
239            provider: ProviderType::OpenAi,
240            model_id: "grok-4.20-experimental-beta-0304-reasoning".into(),
241            base_url: "https://api.x.ai/v1".into(),
242            api_key_env: Some("XAI_API_KEY".into()),
243            context_window: 131_072,
244            max_tokens: 16_384,
245            thinking: true,
246            server_tools: None,
247        },
248        ModelConfig {
249            id: "grok-4.20-non-reasoning".into(),
250            provider: ProviderType::OpenAi,
251            model_id: "grok-4.20-experimental-beta-0304-non-reasoning".into(),
252            base_url: "https://api.x.ai/v1".into(),
253            api_key_env: Some("XAI_API_KEY".into()),
254            context_window: 131_072,
255            max_tokens: 16_384,
256            thinking: false,
257            server_tools: None,
258        },
259        ModelConfig {
260            id: "grok-4.20-multi-agent".into(),
261            provider: ProviderType::OpenAi,
262            model_id: "grok-4.20-multi-agent-experimental-beta-0304".into(),
263            base_url: "https://api.x.ai/v1".into(),
264            api_key_env: Some("XAI_API_KEY".into()),
265            context_window: 131_072,
266            max_tokens: 16_384,
267            thinking: false,
268            server_tools: None,
269        },
270        // Additional xAI models
271        ModelConfig {
272            id: "grok-code-fast-1".into(),
273            provider: ProviderType::OpenAi,
274            model_id: "grok-code-fast-1".into(),
275            base_url: "https://api.x.ai/v1".into(),
276            api_key_env: Some("XAI_API_KEY".into()),
277            context_window: 131_072,
278            max_tokens: 16_384,
279            thinking: false,
280            server_tools: None,
281        },
282        ModelConfig {
283            id: "grok-4-1-reasoning".into(),
284            provider: ProviderType::OpenAi,
285            model_id: "grok-4-1-reasoning".into(),
286            base_url: "https://api.x.ai/v1".into(),
287            api_key_env: Some("XAI_API_KEY".into()),
288            context_window: 131_072,
289            max_tokens: 16_384,
290            thinking: true,
291            server_tools: None,
292        },
293        ModelConfig {
294            id: "grok-4.20-beta-0309-reasoning".into(),
295            provider: ProviderType::OpenAi,
296            model_id: "grok-4.20-beta-0309-reasoning".into(),
297            base_url: "https://api.x.ai/v1".into(),
298            api_key_env: Some("XAI_API_KEY".into()),
299            context_window: 131_072,
300            max_tokens: 16_384,
301            thinking: true,
302            server_tools: None,
303        },
304        ModelConfig {
305            id: "grok-4.20-multi-agent-beta-0309".into(),
306            provider: ProviderType::OpenAi,
307            model_id: "grok-4.20-multi-agent-beta-0309".into(),
308            base_url: "https://api.x.ai/v1".into(),
309            api_key_env: Some("XAI_API_KEY".into()),
310            context_window: 131_072,
311            max_tokens: 16_384,
312            thinking: false,
313            server_tools: None,
314        },
315
316    ]
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn new_has_builtin_models() {
325        let registry = ModelRegistry::new();
326        // 3 Anthropic + 10 xAI/Grok
327        assert_eq!(registry.list().len(), 13);
328    }
329
330    #[test]
331    fn builtin_grok_models_use_openai_provider() {
332        let registry = ModelRegistry::new();
333        let grok = registry.get("grok-3").unwrap();
334        assert_eq!(grok.provider, ProviderType::OpenAi);
335        assert_eq!(grok.base_url, "https://api.x.ai/v1");
336        assert_eq!(grok.api_key_env.as_deref(), Some("XAI_API_KEY"));
337        assert_eq!(grok.model_id, "grok-3");
338    }
339
340    #[test]
341    fn get_builtin_model() {
342        let registry = ModelRegistry::new();
343        let m = registry.get("claude-sonnet").unwrap();
344        assert_eq!(m.model_id, "claude-sonnet-4-5-20250929");
345        assert_eq!(m.provider, ProviderType::Anthropic);
346        assert!(!m.thinking);
347    }
348
349    #[test]
350    fn get_claude_opus_has_thinking() {
351        let registry = ModelRegistry::new();
352        let m = registry.get("claude-opus").unwrap();
353        assert!(m.thinking);
354    }
355
356    #[test]
357    fn get_missing_returns_none() {
358        let registry = ModelRegistry::new();
359        assert!(registry.get("gpt-4o").is_none());
360    }
361
362    #[test]
363    fn to_model_maps_fields() {
364        let config = ModelConfig {
365            id: "test-model".into(),
366            provider: ProviderType::OpenAi,
367            model_id: "gpt-4o".into(),
368            base_url: "https://api.openai.com/v1".into(),
369            api_key_env: Some("OPENAI_API_KEY".into()),
370            context_window: 128_000,
371            max_tokens: 16_384,
372            thinking: false,
373            server_tools: None,
374        };
375        let model = ModelRegistry::to_model(&config);
376        assert_eq!(model.id, "gpt-4o");
377        assert_eq!(model.name, "test-model");
378        assert_eq!(model.provider, "openai");
379        assert_eq!(model.base_url, "https://api.openai.com/v1");
380        assert_eq!(model.context_window, 128_000);
381        assert_eq!(model.max_tokens, 16_384);
382    }
383
384    #[test]
385    fn resolve_api_key_from_env() {
386        let config = ModelConfig {
387            id: "test".into(),
388            provider: ProviderType::OpenAi,
389            model_id: "gpt-4o".into(),
390            base_url: String::new(),
391            api_key_env: Some("__RHO_TEST_KEY__".into()),
392            context_window: 128_000,
393            max_tokens: 8_192,
394            thinking: false,
395            server_tools: None,
396        };
397        std::env::set_var("__RHO_TEST_KEY__", "test-api-key-123");
398        let key = ModelRegistry::resolve_api_key(&config).unwrap();
399        assert_eq!(key, "test-api-key-123");
400        std::env::remove_var("__RHO_TEST_KEY__");
401    }
402
403    #[test]
404    fn resolve_api_key_localhost_returns_local() {
405        let config = ModelConfig {
406            id: "ollama".into(),
407            provider: ProviderType::OpenAi,
408            model_id: "llama3".into(),
409            base_url: "http://localhost:11434/v1".into(),
410            api_key_env: None,
411            context_window: 128_000,
412            max_tokens: 8_192,
413            thinking: false,
414            server_tools: None,
415        };
416        let key = ModelRegistry::resolve_api_key(&config).unwrap();
417        assert_eq!(key, "local");
418    }
419
420    #[test]
421    fn load_merges_user_toml_override() {
422        // This test only exercises the parsing logic, not file I/O
423        let toml_str = r#"
424[[model]]
425id = "claude-sonnet"
426provider = "anthropic"
427model_id = "claude-sonnet-4-6-custom"
428api_key_env = "ANTHROPIC_API_KEY"
429context_window = 200000
430max_tokens = 8192
431"#;
432        let file: ModelsFile = toml::from_str(toml_str).unwrap();
433        assert_eq!(file.models.len(), 1);
434        assert_eq!(file.models[0].model_id, "claude-sonnet-4-6-custom");
435    }
436
437    #[test]
438    fn load_parses_openai_provider() {
439        let toml_str = r#"
440[[model]]
441id = "gpt-4o"
442provider = "openai"
443model_id = "gpt-4o"
444api_key_env = "OPENAI_API_KEY"
445context_window = 128000
446max_tokens = 16384
447"#;
448        let file: ModelsFile = toml::from_str(toml_str).unwrap();
449        assert_eq!(file.models[0].provider, ProviderType::OpenAi);
450    }
451}