1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
4#[serde(rename_all = "lowercase")]
5pub enum ProviderType {
6 Anthropic,
7 #[serde(alias = "openai-compatible")]
8 OpenAi,
9 #[serde(alias = "xai-responses")]
10 XaiResponses,
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ModelConfig {
15 pub id: String,
17 pub provider: ProviderType,
18 pub model_id: String,
20 #[serde(default)]
22 pub base_url: String,
23 pub api_key_env: Option<String>,
25 #[serde(default = "default_context_window")]
26 pub context_window: usize,
27 #[serde(default = "default_max_tokens")]
28 pub max_tokens: usize,
29 #[serde(default)]
31 pub thinking: bool,
32 #[serde(default)]
35 pub server_tools: Option<Vec<String>>,
36}
37
38fn default_context_window() -> usize {
39 200_000
40}
41fn default_max_tokens() -> usize {
42 8_192
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct ModelsFile {
47 #[serde(default, rename = "model")]
48 pub models: Vec<ModelConfig>,
49}
50
51pub struct ModelRegistry {
52 models: Vec<ModelConfig>,
53}
54
55impl Default for ModelRegistry {
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61impl ModelRegistry {
62 pub fn new() -> Self {
64 Self {
65 models: built_in_models(),
66 }
67 }
68
69 pub fn load() -> Self {
72 let mut registry = Self::new();
73
74 if let Some(home) = dirs::home_dir() {
75 let path = home.join(".rho").join("models.toml");
76 if path.is_file() {
77 match std::fs::read_to_string(&path) {
78 Ok(content) => match toml::from_str::<ModelsFile>(&content) {
79 Ok(file) => {
80 for user_model in file.models {
81 if let Some(existing) =
82 registry.models.iter_mut().find(|m| m.id == user_model.id)
83 {
84 *existing = user_model;
85 } else {
86 registry.models.push(user_model);
87 }
88 }
89 }
90 Err(e) => {
91 tracing::warn!("Failed to parse ~/.rho/models.toml: {}", e);
92 }
93 },
94 Err(e) => {
95 tracing::warn!("Failed to read ~/.rho/models.toml: {}", e);
96 }
97 }
98 }
99 }
100
101 load_zen_models(&mut registry.models);
102
103 registry
104 }
105
106 pub fn get(&self, id: &str) -> Option<&ModelConfig> {
107 self.models.iter().find(|m| m.id == id)
108 }
109
110 pub fn list(&self) -> &[ModelConfig] {
111 &self.models
112 }
113
114 pub fn to_model(config: &ModelConfig) -> crate::types::Model {
116 crate::types::Model {
117 id: config.model_id.clone(),
118 name: config.id.clone(),
119 provider: match config.provider {
120 ProviderType::Anthropic => "anthropic".into(),
121 ProviderType::OpenAi => "openai".into(),
122 ProviderType::XaiResponses => "xai-responses".into(),
123 },
124 base_url: config.base_url.clone(),
125 reasoning: config.thinking,
126 context_window: config.context_window,
127 max_tokens: config.max_tokens,
128 }
129 }
130
131 pub fn resolve_api_key(config: &ModelConfig) -> Result<String, String> {
139 if let Some(ref env_var) = config.api_key_env {
141 if let Ok(val) = std::env::var(env_var) {
142 if !val.is_empty() {
143 return Ok(val);
144 }
145 }
146 }
147
148 if config.provider == ProviderType::Anthropic {
150 if let Ok(token) = crate::auth::get_token() {
151 return Ok(token);
152 }
153 }
154
155 if config.base_url.contains("localhost") || config.base_url.contains("127.0.0.1") {
157 return Ok("local".into());
158 }
159
160 Err(format!(
161 "No API key found for model '{}'. Set the {} environment variable.",
162 config.id,
163 config
164 .api_key_env
165 .as_deref()
166 .unwrap_or("appropriate API key env var")
167 ))
168 }
169}
170
171fn load_zen_models(models: &mut Vec<ModelConfig>) {
173 if std::env::var("OPENCODE_ZEN_API_KEY")
174 .ok()
175 .filter(|v| !v.is_empty())
176 .is_none()
177 {
178 return;
179 }
180
181 let zen_ids = crate::zen::fetch_zen_models();
182 for model_id in zen_ids {
183 let registry_id = format!("zen-{}", model_id);
184 if models.iter().any(|m| m.id == registry_id) {
186 continue;
187 }
188
189 let (provider, base_url) = if model_id.contains("claude") {
190 (ProviderType::Anthropic, "https://opencode.ai/zen".to_string())
191 } else {
192 (ProviderType::OpenAi, "https://opencode.ai/zen/v1".to_string())
193 };
194
195 models.push(ModelConfig {
196 id: registry_id,
197 provider,
198 model_id: model_id.clone(),
199 base_url,
200 api_key_env: Some("OPENCODE_ZEN_API_KEY".into()),
201 context_window: 200_000,
202 max_tokens: 16_384,
203 thinking: model_id.contains("opus"),
204 server_tools: None,
205 });
206 }
207}
208
209fn built_in_models() -> Vec<ModelConfig> {
210 vec![
211 ModelConfig {
213 id: "claude-sonnet".into(),
214 provider: ProviderType::Anthropic,
215 model_id: "claude-sonnet-4-5-20250929".into(),
216 base_url: String::new(),
217 api_key_env: Some("ANTHROPIC_API_KEY".into()),
218 context_window: 200_000,
219 max_tokens: 8_192,
220 thinking: false,
221 server_tools: None,
222 },
223 ModelConfig {
224 id: "claude-opus".into(),
225 provider: ProviderType::Anthropic,
226 model_id: "claude-opus-4-6".into(),
227 base_url: String::new(),
228 api_key_env: Some("ANTHROPIC_API_KEY".into()),
229 context_window: 200_000,
230 max_tokens: 8_192,
231 thinking: true,
232 server_tools: None,
233 },
234 ModelConfig {
235 id: "claude-haiku".into(),
236 provider: ProviderType::Anthropic,
237 model_id: "claude-haiku-4-5-20251001".into(),
238 base_url: String::new(),
239 api_key_env: Some("ANTHROPIC_API_KEY".into()),
240 context_window: 200_000,
241 max_tokens: 8_192,
242 thinking: false,
243 server_tools: None,
244 },
245 ModelConfig {
247 id: "grok-3".into(),
248 provider: ProviderType::OpenAi,
249 model_id: "grok-3".into(),
250 base_url: "https://api.x.ai/v1".into(),
251 api_key_env: Some("XAI_API_KEY".into()),
252 context_window: 131_072,
253 max_tokens: 16_384,
254 thinking: false,
255 server_tools: None,
256 },
257 ModelConfig {
258 id: "grok-3-mini".into(),
259 provider: ProviderType::OpenAi,
260 model_id: "grok-3-mini".into(),
261 base_url: "https://api.x.ai/v1".into(),
262 api_key_env: Some("XAI_API_KEY".into()),
263 context_window: 131_072,
264 max_tokens: 8_192,
265 thinking: false,
266 server_tools: None,
267 },
268 ModelConfig {
269 id: "grok-2".into(),
270 provider: ProviderType::OpenAi,
271 model_id: "grok-2-1212".into(),
272 base_url: "https://api.x.ai/v1".into(),
273 api_key_env: Some("XAI_API_KEY".into()),
274 context_window: 32_768,
275 max_tokens: 8_192,
276 thinking: false,
277 server_tools: None,
278 },
279 ModelConfig {
281 id: "grok-4.20-reasoning".into(),
282 provider: ProviderType::OpenAi,
283 model_id: "grok-4.20-experimental-beta-0304-reasoning".into(),
284 base_url: "https://api.x.ai/v1".into(),
285 api_key_env: Some("XAI_API_KEY".into()),
286 context_window: 131_072,
287 max_tokens: 16_384,
288 thinking: true,
289 server_tools: None,
290 },
291 ModelConfig {
292 id: "grok-4.20-non-reasoning".into(),
293 provider: ProviderType::OpenAi,
294 model_id: "grok-4.20-experimental-beta-0304-non-reasoning".into(),
295 base_url: "https://api.x.ai/v1".into(),
296 api_key_env: Some("XAI_API_KEY".into()),
297 context_window: 131_072,
298 max_tokens: 16_384,
299 thinking: false,
300 server_tools: None,
301 },
302 ModelConfig {
303 id: "grok-4.20-multi-agent".into(),
304 provider: ProviderType::XaiResponses,
305 model_id: "grok-4.20-multi-agent-experimental-beta-0304".into(),
306 base_url: "https://api.x.ai/v1".into(),
307 api_key_env: Some("XAI_API_KEY".into()),
308 context_window: 131_072,
309 max_tokens: 16_384,
310 thinking: false,
311 server_tools: None,
312 },
313 ModelConfig {
315 id: "grok-code-fast-1".into(),
316 provider: ProviderType::OpenAi,
317 model_id: "grok-code-fast-1".into(),
318 base_url: "https://api.x.ai/v1".into(),
319 api_key_env: Some("XAI_API_KEY".into()),
320 context_window: 131_072,
321 max_tokens: 16_384,
322 thinking: false,
323 server_tools: None,
324 },
325 ModelConfig {
326 id: "grok-4-1-reasoning".into(),
327 provider: ProviderType::OpenAi,
328 model_id: "grok-4-1-reasoning".into(),
329 base_url: "https://api.x.ai/v1".into(),
330 api_key_env: Some("XAI_API_KEY".into()),
331 context_window: 131_072,
332 max_tokens: 16_384,
333 thinking: true,
334 server_tools: None,
335 },
336 ModelConfig {
337 id: "grok-4.20-beta-0309-reasoning".into(),
338 provider: ProviderType::OpenAi,
339 model_id: "grok-4.20-beta-0309-reasoning".into(),
340 base_url: "https://api.x.ai/v1".into(),
341 api_key_env: Some("XAI_API_KEY".into()),
342 context_window: 131_072,
343 max_tokens: 16_384,
344 thinking: true,
345 server_tools: None,
346 },
347 ModelConfig {
348 id: "grok-4.20-multi-agent-beta-0309".into(),
349 provider: ProviderType::XaiResponses,
350 model_id: "grok-4.20-multi-agent-beta-0309".into(),
351 base_url: "https://api.x.ai/v1".into(),
352 api_key_env: Some("XAI_API_KEY".into()),
353 context_window: 131_072,
354 max_tokens: 16_384,
355 thinking: false,
356 server_tools: None,
357 },
358
359 ]
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365
366 #[test]
367 fn new_has_builtin_models() {
368 let registry = ModelRegistry::new();
369 assert_eq!(registry.list().len(), 13);
371 }
372
373 #[test]
374 fn builtin_grok_models_use_openai_provider() {
375 let registry = ModelRegistry::new();
376 let grok = registry.get("grok-3").unwrap();
377 assert_eq!(grok.provider, ProviderType::OpenAi);
378 assert_eq!(grok.base_url, "https://api.x.ai/v1");
379 assert_eq!(grok.api_key_env.as_deref(), Some("XAI_API_KEY"));
380 assert_eq!(grok.model_id, "grok-3");
381 }
382
383 #[test]
384 fn get_builtin_model() {
385 let registry = ModelRegistry::new();
386 let m = registry.get("claude-sonnet").unwrap();
387 assert_eq!(m.model_id, "claude-sonnet-4-5-20250929");
388 assert_eq!(m.provider, ProviderType::Anthropic);
389 assert!(!m.thinking);
390 }
391
392 #[test]
393 fn get_claude_opus_has_thinking() {
394 let registry = ModelRegistry::new();
395 let m = registry.get("claude-opus").unwrap();
396 assert!(m.thinking);
397 }
398
399 #[test]
400 fn get_missing_returns_none() {
401 let registry = ModelRegistry::new();
402 assert!(registry.get("gpt-4o").is_none());
403 }
404
405 #[test]
406 fn to_model_maps_fields() {
407 let config = ModelConfig {
408 id: "test-model".into(),
409 provider: ProviderType::OpenAi,
410 model_id: "gpt-4o".into(),
411 base_url: "https://api.openai.com/v1".into(),
412 api_key_env: Some("OPENAI_API_KEY".into()),
413 context_window: 128_000,
414 max_tokens: 16_384,
415 thinking: false,
416 server_tools: None,
417 };
418 let model = ModelRegistry::to_model(&config);
419 assert_eq!(model.id, "gpt-4o");
420 assert_eq!(model.name, "test-model");
421 assert_eq!(model.provider, "openai");
422 assert_eq!(model.base_url, "https://api.openai.com/v1");
423 assert_eq!(model.context_window, 128_000);
424 assert_eq!(model.max_tokens, 16_384);
425 }
426
427 #[test]
428 fn resolve_api_key_from_env() {
429 let config = ModelConfig {
430 id: "test".into(),
431 provider: ProviderType::OpenAi,
432 model_id: "gpt-4o".into(),
433 base_url: String::new(),
434 api_key_env: Some("__RHO_TEST_KEY__".into()),
435 context_window: 128_000,
436 max_tokens: 8_192,
437 thinking: false,
438 server_tools: None,
439 };
440 std::env::set_var("__RHO_TEST_KEY__", "test-api-key-123");
441 let key = ModelRegistry::resolve_api_key(&config).unwrap();
442 assert_eq!(key, "test-api-key-123");
443 std::env::remove_var("__RHO_TEST_KEY__");
444 }
445
446 #[test]
447 fn resolve_api_key_localhost_returns_local() {
448 let config = ModelConfig {
449 id: "ollama".into(),
450 provider: ProviderType::OpenAi,
451 model_id: "llama3".into(),
452 base_url: "http://localhost:11434/v1".into(),
453 api_key_env: None,
454 context_window: 128_000,
455 max_tokens: 8_192,
456 thinking: false,
457 server_tools: None,
458 };
459 let key = ModelRegistry::resolve_api_key(&config).unwrap();
460 assert_eq!(key, "local");
461 }
462
463 #[test]
464 fn load_merges_user_toml_override() {
465 let toml_str = r#"
467[[model]]
468id = "claude-sonnet"
469provider = "anthropic"
470model_id = "claude-sonnet-4-6-custom"
471api_key_env = "ANTHROPIC_API_KEY"
472context_window = 200000
473max_tokens = 8192
474"#;
475 let file: ModelsFile = toml::from_str(toml_str).unwrap();
476 assert_eq!(file.models.len(), 1);
477 assert_eq!(file.models[0].model_id, "claude-sonnet-4-6-custom");
478 }
479
480 #[test]
481 fn load_parses_openai_provider() {
482 let toml_str = r#"
483[[model]]
484id = "gpt-4o"
485provider = "openai"
486model_id = "gpt-4o"
487api_key_env = "OPENAI_API_KEY"
488context_window = 128000
489max_tokens = 16384
490"#;
491 let file: ModelsFile = toml::from_str(toml_str).unwrap();
492 assert_eq!(file.models[0].provider, ProviderType::OpenAi);
493 }
494}