1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
4#[serde(rename_all = "lowercase")]
5pub enum ProviderType {
6 Anthropic,
7 #[serde(alias = "openai-compatible")]
8 OpenAi,
9}
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ModelConfig {
13 pub id: String,
15 pub provider: ProviderType,
16 pub model_id: String,
18 #[serde(default)]
20 pub base_url: String,
21 pub api_key_env: Option<String>,
23 #[serde(default = "default_context_window")]
24 pub context_window: usize,
25 #[serde(default = "default_max_tokens")]
26 pub max_tokens: usize,
27 #[serde(default)]
29 pub thinking: bool,
30 #[serde(default)]
33 pub server_tools: Option<Vec<String>>,
34}
35
36fn default_context_window() -> usize {
37 200_000
38}
39fn default_max_tokens() -> usize {
40 8_192
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ModelsFile {
45 #[serde(default, rename = "model")]
46 pub models: Vec<ModelConfig>,
47}
48
49pub struct ModelRegistry {
50 models: Vec<ModelConfig>,
51}
52
53impl Default for ModelRegistry {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl ModelRegistry {
60 pub fn new() -> Self {
62 Self {
63 models: built_in_models(),
64 }
65 }
66
67 pub fn load() -> Self {
70 let mut registry = Self::new();
71
72 if let Some(home) = dirs::home_dir() {
73 let path = home.join(".rho").join("models.toml");
74 if path.is_file() {
75 match std::fs::read_to_string(&path) {
76 Ok(content) => match toml::from_str::<ModelsFile>(&content) {
77 Ok(file) => {
78 for user_model in file.models {
79 if let Some(existing) =
80 registry.models.iter_mut().find(|m| m.id == user_model.id)
81 {
82 *existing = user_model;
83 } else {
84 registry.models.push(user_model);
85 }
86 }
87 }
88 Err(e) => {
89 tracing::warn!("Failed to parse ~/.rho/models.toml: {}", e);
90 }
91 },
92 Err(e) => {
93 tracing::warn!("Failed to read ~/.rho/models.toml: {}", e);
94 }
95 }
96 }
97 }
98
99 registry
100 }
101
102 pub fn get(&self, id: &str) -> Option<&ModelConfig> {
103 self.models.iter().find(|m| m.id == id)
104 }
105
106 pub fn list(&self) -> &[ModelConfig] {
107 &self.models
108 }
109
110 pub fn to_model(config: &ModelConfig) -> crate::types::Model {
112 crate::types::Model {
113 id: config.model_id.clone(),
114 name: config.id.clone(),
115 provider: match config.provider {
116 ProviderType::Anthropic => "anthropic".into(),
117 ProviderType::OpenAi => "openai".into(),
118 },
119 base_url: config.base_url.clone(),
120 reasoning: config.thinking,
121 context_window: config.context_window,
122 max_tokens: config.max_tokens,
123 }
124 }
125
126 pub fn resolve_api_key(config: &ModelConfig) -> Result<String, String> {
134 if let Some(ref env_var) = config.api_key_env {
136 if let Ok(val) = std::env::var(env_var) {
137 if !val.is_empty() {
138 return Ok(val);
139 }
140 }
141 }
142
143 if config.provider == ProviderType::Anthropic {
145 if let Ok(token) = crate::auth::get_token() {
146 return Ok(token);
147 }
148 }
149
150 if config.base_url.contains("localhost") || config.base_url.contains("127.0.0.1") {
152 return Ok("local".into());
153 }
154
155 Err(format!(
156 "No API key found for model '{}'. Set the {} environment variable.",
157 config.id,
158 config
159 .api_key_env
160 .as_deref()
161 .unwrap_or("appropriate API key env var")
162 ))
163 }
164}
165
166fn built_in_models() -> Vec<ModelConfig> {
167 vec![
168 ModelConfig {
170 id: "claude-sonnet".into(),
171 provider: ProviderType::Anthropic,
172 model_id: "claude-sonnet-4-5-20250929".into(),
173 base_url: String::new(),
174 api_key_env: Some("ANTHROPIC_API_KEY".into()),
175 context_window: 200_000,
176 max_tokens: 8_192,
177 thinking: false,
178 server_tools: None,
179 },
180 ModelConfig {
181 id: "claude-opus".into(),
182 provider: ProviderType::Anthropic,
183 model_id: "claude-opus-4-6".into(),
184 base_url: String::new(),
185 api_key_env: Some("ANTHROPIC_API_KEY".into()),
186 context_window: 200_000,
187 max_tokens: 8_192,
188 thinking: true,
189 server_tools: None,
190 },
191 ModelConfig {
192 id: "claude-haiku".into(),
193 provider: ProviderType::Anthropic,
194 model_id: "claude-haiku-4-5-20251001".into(),
195 base_url: String::new(),
196 api_key_env: Some("ANTHROPIC_API_KEY".into()),
197 context_window: 200_000,
198 max_tokens: 8_192,
199 thinking: false,
200 server_tools: None,
201 },
202 ModelConfig {
204 id: "grok-3".into(),
205 provider: ProviderType::OpenAi,
206 model_id: "grok-3".into(),
207 base_url: "https://api.x.ai/v1".into(),
208 api_key_env: Some("XAI_API_KEY".into()),
209 context_window: 131_072,
210 max_tokens: 16_384,
211 thinking: false,
212 server_tools: None,
213 },
214 ModelConfig {
215 id: "grok-3-mini".into(),
216 provider: ProviderType::OpenAi,
217 model_id: "grok-3-mini".into(),
218 base_url: "https://api.x.ai/v1".into(),
219 api_key_env: Some("XAI_API_KEY".into()),
220 context_window: 131_072,
221 max_tokens: 8_192,
222 thinking: false,
223 server_tools: None,
224 },
225 ModelConfig {
226 id: "grok-2".into(),
227 provider: ProviderType::OpenAi,
228 model_id: "grok-2-1212".into(),
229 base_url: "https://api.x.ai/v1".into(),
230 api_key_env: Some("XAI_API_KEY".into()),
231 context_window: 32_768,
232 max_tokens: 8_192,
233 thinking: false,
234 server_tools: None,
235 },
236 ModelConfig {
238 id: "grok-4.20-reasoning".into(),
239 provider: ProviderType::OpenAi,
240 model_id: "grok-4.20-experimental-beta-0304-reasoning".into(),
241 base_url: "https://api.x.ai/v1".into(),
242 api_key_env: Some("XAI_API_KEY".into()),
243 context_window: 131_072,
244 max_tokens: 16_384,
245 thinking: true,
246 server_tools: None,
247 },
248 ModelConfig {
249 id: "grok-4.20-non-reasoning".into(),
250 provider: ProviderType::OpenAi,
251 model_id: "grok-4.20-experimental-beta-0304-non-reasoning".into(),
252 base_url: "https://api.x.ai/v1".into(),
253 api_key_env: Some("XAI_API_KEY".into()),
254 context_window: 131_072,
255 max_tokens: 16_384,
256 thinking: false,
257 server_tools: None,
258 },
259 ModelConfig {
260 id: "grok-4.20-multi-agent".into(),
261 provider: ProviderType::OpenAi,
262 model_id: "grok-4.20-multi-agent-experimental-beta-0304".into(),
263 base_url: "https://api.x.ai/v1".into(),
264 api_key_env: Some("XAI_API_KEY".into()),
265 context_window: 131_072,
266 max_tokens: 16_384,
267 thinking: false,
268 server_tools: None,
269 },
270 ModelConfig {
272 id: "grok-code-fast-1".into(),
273 provider: ProviderType::OpenAi,
274 model_id: "grok-code-fast-1".into(),
275 base_url: "https://api.x.ai/v1".into(),
276 api_key_env: Some("XAI_API_KEY".into()),
277 context_window: 131_072,
278 max_tokens: 16_384,
279 thinking: false,
280 server_tools: None,
281 },
282 ModelConfig {
283 id: "grok-4-1-reasoning".into(),
284 provider: ProviderType::OpenAi,
285 model_id: "grok-4-1-reasoning".into(),
286 base_url: "https://api.x.ai/v1".into(),
287 api_key_env: Some("XAI_API_KEY".into()),
288 context_window: 131_072,
289 max_tokens: 16_384,
290 thinking: true,
291 server_tools: None,
292 },
293 ModelConfig {
294 id: "grok-4.20-beta-0309-reasoning".into(),
295 provider: ProviderType::OpenAi,
296 model_id: "grok-4.20-beta-0309-reasoning".into(),
297 base_url: "https://api.x.ai/v1".into(),
298 api_key_env: Some("XAI_API_KEY".into()),
299 context_window: 131_072,
300 max_tokens: 16_384,
301 thinking: true,
302 server_tools: None,
303 },
304 ModelConfig {
305 id: "grok-4.20-multi-agent-beta-0309".into(),
306 provider: ProviderType::OpenAi,
307 model_id: "grok-4.20-multi-agent-beta-0309".into(),
308 base_url: "https://api.x.ai/v1".into(),
309 api_key_env: Some("XAI_API_KEY".into()),
310 context_window: 131_072,
311 max_tokens: 16_384,
312 thinking: false,
313 server_tools: None,
314 },
315
316 ]
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn new_has_builtin_models() {
325 let registry = ModelRegistry::new();
326 assert_eq!(registry.list().len(), 13);
328 }
329
330 #[test]
331 fn builtin_grok_models_use_openai_provider() {
332 let registry = ModelRegistry::new();
333 let grok = registry.get("grok-3").unwrap();
334 assert_eq!(grok.provider, ProviderType::OpenAi);
335 assert_eq!(grok.base_url, "https://api.x.ai/v1");
336 assert_eq!(grok.api_key_env.as_deref(), Some("XAI_API_KEY"));
337 assert_eq!(grok.model_id, "grok-3");
338 }
339
340 #[test]
341 fn get_builtin_model() {
342 let registry = ModelRegistry::new();
343 let m = registry.get("claude-sonnet").unwrap();
344 assert_eq!(m.model_id, "claude-sonnet-4-5-20250929");
345 assert_eq!(m.provider, ProviderType::Anthropic);
346 assert!(!m.thinking);
347 }
348
349 #[test]
350 fn get_claude_opus_has_thinking() {
351 let registry = ModelRegistry::new();
352 let m = registry.get("claude-opus").unwrap();
353 assert!(m.thinking);
354 }
355
356 #[test]
357 fn get_missing_returns_none() {
358 let registry = ModelRegistry::new();
359 assert!(registry.get("gpt-4o").is_none());
360 }
361
362 #[test]
363 fn to_model_maps_fields() {
364 let config = ModelConfig {
365 id: "test-model".into(),
366 provider: ProviderType::OpenAi,
367 model_id: "gpt-4o".into(),
368 base_url: "https://api.openai.com/v1".into(),
369 api_key_env: Some("OPENAI_API_KEY".into()),
370 context_window: 128_000,
371 max_tokens: 16_384,
372 thinking: false,
373 server_tools: None,
374 };
375 let model = ModelRegistry::to_model(&config);
376 assert_eq!(model.id, "gpt-4o");
377 assert_eq!(model.name, "test-model");
378 assert_eq!(model.provider, "openai");
379 assert_eq!(model.base_url, "https://api.openai.com/v1");
380 assert_eq!(model.context_window, 128_000);
381 assert_eq!(model.max_tokens, 16_384);
382 }
383
384 #[test]
385 fn resolve_api_key_from_env() {
386 let config = ModelConfig {
387 id: "test".into(),
388 provider: ProviderType::OpenAi,
389 model_id: "gpt-4o".into(),
390 base_url: String::new(),
391 api_key_env: Some("__RHO_TEST_KEY__".into()),
392 context_window: 128_000,
393 max_tokens: 8_192,
394 thinking: false,
395 server_tools: None,
396 };
397 std::env::set_var("__RHO_TEST_KEY__", "test-api-key-123");
398 let key = ModelRegistry::resolve_api_key(&config).unwrap();
399 assert_eq!(key, "test-api-key-123");
400 std::env::remove_var("__RHO_TEST_KEY__");
401 }
402
403 #[test]
404 fn resolve_api_key_localhost_returns_local() {
405 let config = ModelConfig {
406 id: "ollama".into(),
407 provider: ProviderType::OpenAi,
408 model_id: "llama3".into(),
409 base_url: "http://localhost:11434/v1".into(),
410 api_key_env: None,
411 context_window: 128_000,
412 max_tokens: 8_192,
413 thinking: false,
414 server_tools: None,
415 };
416 let key = ModelRegistry::resolve_api_key(&config).unwrap();
417 assert_eq!(key, "local");
418 }
419
420 #[test]
421 fn load_merges_user_toml_override() {
422 let toml_str = r#"
424[[model]]
425id = "claude-sonnet"
426provider = "anthropic"
427model_id = "claude-sonnet-4-6-custom"
428api_key_env = "ANTHROPIC_API_KEY"
429context_window = 200000
430max_tokens = 8192
431"#;
432 let file: ModelsFile = toml::from_str(toml_str).unwrap();
433 assert_eq!(file.models.len(), 1);
434 assert_eq!(file.models[0].model_id, "claude-sonnet-4-6-custom");
435 }
436
437 #[test]
438 fn load_parses_openai_provider() {
439 let toml_str = r#"
440[[model]]
441id = "gpt-4o"
442provider = "openai"
443model_id = "gpt-4o"
444api_key_env = "OPENAI_API_KEY"
445context_window = 128000
446max_tokens = 16384
447"#;
448 let file: ModelsFile = toml::from_str(toml_str).unwrap();
449 assert_eq!(file.models[0].provider, ProviderType::OpenAi);
450 }
451}