1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
4#[serde(rename_all = "lowercase")]
5pub enum ProviderType {
6 Anthropic,
7 #[serde(alias = "openai-compatible")]
8 OpenAi,
9 #[serde(alias = "xai-responses")]
10 XaiResponses,
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ModelConfig {
15 pub id: String,
17 pub provider: ProviderType,
18 pub model_id: String,
20 #[serde(default)]
22 pub base_url: String,
23 pub api_key_env: Option<String>,
25 #[serde(default = "default_context_window")]
26 pub context_window: usize,
27 #[serde(default = "default_max_tokens")]
28 pub max_tokens: usize,
29 #[serde(default)]
31 pub thinking: bool,
32 #[serde(default)]
35 pub server_tools: Option<Vec<String>>,
36}
37
38fn default_context_window() -> usize {
39 200_000
40}
41fn default_max_tokens() -> usize {
42 8_192
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct ModelsFile {
47 #[serde(default, rename = "model")]
48 pub models: Vec<ModelConfig>,
49}
50
51pub struct ModelRegistry {
52 models: Vec<ModelConfig>,
53}
54
55impl Default for ModelRegistry {
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61impl ModelRegistry {
62 pub fn new() -> Self {
64 Self {
65 models: built_in_models(),
66 }
67 }
68
69 pub fn load() -> Self {
72 let mut registry = Self::new();
73
74 if let Some(home) = dirs::home_dir() {
75 let path = home.join(".rho").join("models.toml");
76 if path.is_file() {
77 match std::fs::read_to_string(&path) {
78 Ok(content) => match toml::from_str::<ModelsFile>(&content) {
79 Ok(file) => {
80 for user_model in file.models {
81 if let Some(existing) =
82 registry.models.iter_mut().find(|m| m.id == user_model.id)
83 {
84 *existing = user_model;
85 } else {
86 registry.models.push(user_model);
87 }
88 }
89 }
90 Err(e) => {
91 tracing::warn!("Failed to parse ~/.rho/models.toml: {}", e);
92 }
93 },
94 Err(e) => {
95 tracing::warn!("Failed to read ~/.rho/models.toml: {}", e);
96 }
97 }
98 }
99 }
100
101 registry
102 }
103
104 pub fn get(&self, id: &str) -> Option<&ModelConfig> {
105 self.models.iter().find(|m| m.id == id)
106 }
107
108 pub fn list(&self) -> &[ModelConfig] {
109 &self.models
110 }
111
112 pub fn to_model(config: &ModelConfig) -> crate::types::Model {
114 crate::types::Model {
115 id: config.model_id.clone(),
116 name: config.id.clone(),
117 provider: match config.provider {
118 ProviderType::Anthropic => "anthropic".into(),
119 ProviderType::OpenAi => "openai".into(),
120 ProviderType::XaiResponses => "xai-responses".into(),
121 },
122 base_url: config.base_url.clone(),
123 reasoning: config.thinking,
124 context_window: config.context_window,
125 max_tokens: config.max_tokens,
126 }
127 }
128
129 pub fn resolve_api_key(config: &ModelConfig) -> Result<String, String> {
137 if let Some(ref env_var) = config.api_key_env {
139 if let Ok(val) = std::env::var(env_var) {
140 if !val.is_empty() {
141 return Ok(val);
142 }
143 }
144 }
145
146 if config.provider == ProviderType::Anthropic {
148 if let Ok(token) = crate::auth::get_token() {
149 return Ok(token);
150 }
151 }
152
153 if config.base_url.contains("localhost") || config.base_url.contains("127.0.0.1") {
155 return Ok("local".into());
156 }
157
158 Err(format!(
159 "No API key found for model '{}'. Set the {} environment variable.",
160 config.id,
161 config
162 .api_key_env
163 .as_deref()
164 .unwrap_or("appropriate API key env var")
165 ))
166 }
167}
168
169fn built_in_models() -> Vec<ModelConfig> {
170 vec![
171 ModelConfig {
173 id: "claude-sonnet".into(),
174 provider: ProviderType::Anthropic,
175 model_id: "claude-sonnet-4-5-20250929".into(),
176 base_url: String::new(),
177 api_key_env: Some("ANTHROPIC_API_KEY".into()),
178 context_window: 200_000,
179 max_tokens: 8_192,
180 thinking: false,
181 server_tools: None,
182 },
183 ModelConfig {
184 id: "claude-opus".into(),
185 provider: ProviderType::Anthropic,
186 model_id: "claude-opus-4-6".into(),
187 base_url: String::new(),
188 api_key_env: Some("ANTHROPIC_API_KEY".into()),
189 context_window: 200_000,
190 max_tokens: 8_192,
191 thinking: true,
192 server_tools: None,
193 },
194 ModelConfig {
195 id: "claude-haiku".into(),
196 provider: ProviderType::Anthropic,
197 model_id: "claude-haiku-4-5-20251001".into(),
198 base_url: String::new(),
199 api_key_env: Some("ANTHROPIC_API_KEY".into()),
200 context_window: 200_000,
201 max_tokens: 8_192,
202 thinking: false,
203 server_tools: None,
204 },
205 ModelConfig {
207 id: "grok-3".into(),
208 provider: ProviderType::OpenAi,
209 model_id: "grok-3".into(),
210 base_url: "https://api.x.ai/v1".into(),
211 api_key_env: Some("XAI_API_KEY".into()),
212 context_window: 131_072,
213 max_tokens: 16_384,
214 thinking: false,
215 server_tools: None,
216 },
217 ModelConfig {
218 id: "grok-3-mini".into(),
219 provider: ProviderType::OpenAi,
220 model_id: "grok-3-mini".into(),
221 base_url: "https://api.x.ai/v1".into(),
222 api_key_env: Some("XAI_API_KEY".into()),
223 context_window: 131_072,
224 max_tokens: 8_192,
225 thinking: false,
226 server_tools: None,
227 },
228 ModelConfig {
229 id: "grok-2".into(),
230 provider: ProviderType::OpenAi,
231 model_id: "grok-2-1212".into(),
232 base_url: "https://api.x.ai/v1".into(),
233 api_key_env: Some("XAI_API_KEY".into()),
234 context_window: 32_768,
235 max_tokens: 8_192,
236 thinking: false,
237 server_tools: None,
238 },
239 ModelConfig {
241 id: "grok-4.20-reasoning".into(),
242 provider: ProviderType::OpenAi,
243 model_id: "grok-4.20-experimental-beta-0304-reasoning".into(),
244 base_url: "https://api.x.ai/v1".into(),
245 api_key_env: Some("XAI_API_KEY".into()),
246 context_window: 131_072,
247 max_tokens: 16_384,
248 thinking: true,
249 server_tools: None,
250 },
251 ModelConfig {
252 id: "grok-4.20-non-reasoning".into(),
253 provider: ProviderType::OpenAi,
254 model_id: "grok-4.20-experimental-beta-0304-non-reasoning".into(),
255 base_url: "https://api.x.ai/v1".into(),
256 api_key_env: Some("XAI_API_KEY".into()),
257 context_window: 131_072,
258 max_tokens: 16_384,
259 thinking: false,
260 server_tools: None,
261 },
262 ModelConfig {
263 id: "grok-4.20-multi-agent".into(),
264 provider: ProviderType::XaiResponses,
265 model_id: "grok-4.20-multi-agent-experimental-beta-0304".into(),
266 base_url: "https://api.x.ai/v1".into(),
267 api_key_env: Some("XAI_API_KEY".into()),
268 context_window: 131_072,
269 max_tokens: 16_384,
270 thinking: false,
271 server_tools: None,
272 },
273 ModelConfig {
275 id: "grok-code-fast-1".into(),
276 provider: ProviderType::OpenAi,
277 model_id: "grok-code-fast-1".into(),
278 base_url: "https://api.x.ai/v1".into(),
279 api_key_env: Some("XAI_API_KEY".into()),
280 context_window: 131_072,
281 max_tokens: 16_384,
282 thinking: false,
283 server_tools: None,
284 },
285 ModelConfig {
286 id: "grok-4-1-reasoning".into(),
287 provider: ProviderType::OpenAi,
288 model_id: "grok-4-1-reasoning".into(),
289 base_url: "https://api.x.ai/v1".into(),
290 api_key_env: Some("XAI_API_KEY".into()),
291 context_window: 131_072,
292 max_tokens: 16_384,
293 thinking: true,
294 server_tools: None,
295 },
296 ModelConfig {
297 id: "grok-4.20-beta-0309-reasoning".into(),
298 provider: ProviderType::OpenAi,
299 model_id: "grok-4.20-beta-0309-reasoning".into(),
300 base_url: "https://api.x.ai/v1".into(),
301 api_key_env: Some("XAI_API_KEY".into()),
302 context_window: 131_072,
303 max_tokens: 16_384,
304 thinking: true,
305 server_tools: None,
306 },
307 ModelConfig {
308 id: "grok-4.20-multi-agent-beta-0309".into(),
309 provider: ProviderType::XaiResponses,
310 model_id: "grok-4.20-multi-agent-beta-0309".into(),
311 base_url: "https://api.x.ai/v1".into(),
312 api_key_env: Some("XAI_API_KEY".into()),
313 context_window: 131_072,
314 max_tokens: 16_384,
315 thinking: false,
316 server_tools: None,
317 },
318
319 ]
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[test]
327 fn new_has_builtin_models() {
328 let registry = ModelRegistry::new();
329 assert_eq!(registry.list().len(), 13);
331 }
332
333 #[test]
334 fn builtin_grok_models_use_openai_provider() {
335 let registry = ModelRegistry::new();
336 let grok = registry.get("grok-3").unwrap();
337 assert_eq!(grok.provider, ProviderType::OpenAi);
338 assert_eq!(grok.base_url, "https://api.x.ai/v1");
339 assert_eq!(grok.api_key_env.as_deref(), Some("XAI_API_KEY"));
340 assert_eq!(grok.model_id, "grok-3");
341 }
342
343 #[test]
344 fn get_builtin_model() {
345 let registry = ModelRegistry::new();
346 let m = registry.get("claude-sonnet").unwrap();
347 assert_eq!(m.model_id, "claude-sonnet-4-5-20250929");
348 assert_eq!(m.provider, ProviderType::Anthropic);
349 assert!(!m.thinking);
350 }
351
352 #[test]
353 fn get_claude_opus_has_thinking() {
354 let registry = ModelRegistry::new();
355 let m = registry.get("claude-opus").unwrap();
356 assert!(m.thinking);
357 }
358
359 #[test]
360 fn get_missing_returns_none() {
361 let registry = ModelRegistry::new();
362 assert!(registry.get("gpt-4o").is_none());
363 }
364
365 #[test]
366 fn to_model_maps_fields() {
367 let config = ModelConfig {
368 id: "test-model".into(),
369 provider: ProviderType::OpenAi,
370 model_id: "gpt-4o".into(),
371 base_url: "https://api.openai.com/v1".into(),
372 api_key_env: Some("OPENAI_API_KEY".into()),
373 context_window: 128_000,
374 max_tokens: 16_384,
375 thinking: false,
376 server_tools: None,
377 };
378 let model = ModelRegistry::to_model(&config);
379 assert_eq!(model.id, "gpt-4o");
380 assert_eq!(model.name, "test-model");
381 assert_eq!(model.provider, "openai");
382 assert_eq!(model.base_url, "https://api.openai.com/v1");
383 assert_eq!(model.context_window, 128_000);
384 assert_eq!(model.max_tokens, 16_384);
385 }
386
387 #[test]
388 fn resolve_api_key_from_env() {
389 let config = ModelConfig {
390 id: "test".into(),
391 provider: ProviderType::OpenAi,
392 model_id: "gpt-4o".into(),
393 base_url: String::new(),
394 api_key_env: Some("__RHO_TEST_KEY__".into()),
395 context_window: 128_000,
396 max_tokens: 8_192,
397 thinking: false,
398 server_tools: None,
399 };
400 std::env::set_var("__RHO_TEST_KEY__", "test-api-key-123");
401 let key = ModelRegistry::resolve_api_key(&config).unwrap();
402 assert_eq!(key, "test-api-key-123");
403 std::env::remove_var("__RHO_TEST_KEY__");
404 }
405
406 #[test]
407 fn resolve_api_key_localhost_returns_local() {
408 let config = ModelConfig {
409 id: "ollama".into(),
410 provider: ProviderType::OpenAi,
411 model_id: "llama3".into(),
412 base_url: "http://localhost:11434/v1".into(),
413 api_key_env: None,
414 context_window: 128_000,
415 max_tokens: 8_192,
416 thinking: false,
417 server_tools: None,
418 };
419 let key = ModelRegistry::resolve_api_key(&config).unwrap();
420 assert_eq!(key, "local");
421 }
422
423 #[test]
424 fn load_merges_user_toml_override() {
425 let toml_str = r#"
427[[model]]
428id = "claude-sonnet"
429provider = "anthropic"
430model_id = "claude-sonnet-4-6-custom"
431api_key_env = "ANTHROPIC_API_KEY"
432context_window = 200000
433max_tokens = 8192
434"#;
435 let file: ModelsFile = toml::from_str(toml_str).unwrap();
436 assert_eq!(file.models.len(), 1);
437 assert_eq!(file.models[0].model_id, "claude-sonnet-4-6-custom");
438 }
439
440 #[test]
441 fn load_parses_openai_provider() {
442 let toml_str = r#"
443[[model]]
444id = "gpt-4o"
445provider = "openai"
446model_id = "gpt-4o"
447api_key_env = "OPENAI_API_KEY"
448context_window = 128000
449max_tokens = 16384
450"#;
451 let file: ModelsFile = toml::from_str(toml_str).unwrap();
452 assert_eq!(file.models[0].provider, ProviderType::OpenAi);
453 }
454}