1use super::cli_provider::{CliProvider, CliProviderConfig};
7use super::ollama::OllamaProvider;
8use super::openrouter::OpenRouterProvider;
9use super::provider::{AnthropicProvider, ModelResponse};
10use super::{ModelConfig, ModelRole, Provider};
11use crate::types::StepType;
12use anyhow::{Context, Result};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct RouterConfig {
19 pub roles: HashMap<String, ModelConfig>,
20}
21
22impl Default for RouterConfig {
23 fn default() -> Self {
24 let mut roles = HashMap::new();
25 roles.insert(
26 "thinking".into(),
27 ModelConfig {
28 provider: Provider::Anthropic,
29 model: "claude-opus-4-20250514".into(),
30 api_key_env: Some("ANTHROPIC_API_KEY".into()),
31 max_cost_per_call: 5.0,
32 temperature: 0.3,
33 max_tokens: 8192,
34 },
35 );
36 roles.insert(
37 "coding".into(),
38 ModelConfig {
39 provider: Provider::Anthropic,
40 model: "claude-sonnet-4-20250514".into(),
41 api_key_env: Some("ANTHROPIC_API_KEY".into()),
42 max_cost_per_call: 2.0,
43 temperature: 0.2,
44 max_tokens: 8192,
45 },
46 );
47 roles.insert(
48 "task".into(),
49 ModelConfig {
50 provider: Provider::OpenRouter,
51 model: "google/gemini-2.0-flash-001".into(),
52 api_key_env: Some("OPENROUTER_API_KEY".into()),
53 max_cost_per_call: 0.5,
54 temperature: 0.3,
55 max_tokens: 4096,
56 },
57 );
58 roles.insert(
59 "embedding".into(),
60 ModelConfig {
61 provider: Provider::Ollama,
62 model: "nomic-embed-text".into(),
63 api_key_env: None,
64 max_cost_per_call: 0.0,
65 temperature: 0.0,
66 max_tokens: 0,
67 },
68 );
69 roles.insert(
70 "auditor".into(),
71 ModelConfig {
72 provider: Provider::Ollama,
73 model: "llama3.2:3b".into(),
74 api_key_env: None,
75 max_cost_per_call: 0.0,
76 temperature: 0.1,
77 max_tokens: 2048,
78 },
79 );
80 Self { roles }
81 }
82}
83
84#[derive(Debug, Clone, Default, Serialize, Deserialize)]
86struct CommanderFileConfig {
87 #[serde(default)]
88 anthropic_api_key: Option<String>,
89 #[serde(default)]
90 openrouter_api_key: Option<String>,
91 #[serde(default)]
92 ollama_url: Option<String>,
93 #[serde(default)]
95 ai: Option<AiConfig>,
96}
97
98#[derive(Debug, Clone, Default, Serialize, Deserialize)]
100struct AiConfig {
101 #[serde(default)]
103 pub default_provider: Option<String>,
104 #[serde(default)]
106 pub cli_providers: Vec<CliProviderConfig>,
107}
108
109fn load_config_file() -> CommanderFileConfig {
111 let config_path = if let Some(base_dirs) = directories::BaseDirs::new() {
112 base_dirs
113 .home_dir()
114 .join(".mur")
115 .join("commander")
116 .join("config.toml")
117 } else {
118 return CommanderFileConfig::default();
119 };
120
121 if !config_path.exists() {
122 return CommanderFileConfig::default();
123 }
124
125 std::fs::read_to_string(&config_path)
126 .ok()
127 .and_then(|content| toml::from_str(&content).ok())
128 .unwrap_or_default()
129}
130
131pub struct ModelRouter {
133 config: RouterConfig,
134 anthropic: Option<AnthropicProvider>,
135 ollama: OllamaProvider,
136 openrouter: Option<OpenRouterProvider>,
137 cli_providers: Vec<CliProvider>,
139 default_provider: Option<String>,
141}
142
143impl ModelRouter {
144 pub fn new(config: RouterConfig) -> Self {
148 let file_config = load_config_file();
149
150 let anthropic_key = file_config
152 .anthropic_api_key
153 .filter(|k| !k.is_empty())
154 .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
155 .filter(|k| !k.is_empty());
156 let anthropic = anthropic_key.and_then(|k| AnthropicProvider::new(Some(k)).ok());
157
158 let openrouter_key = file_config
160 .openrouter_api_key
161 .filter(|k| !k.is_empty())
162 .or_else(|| std::env::var("OPENROUTER_API_KEY").ok())
163 .filter(|k| !k.is_empty());
164 let openrouter = openrouter_key.and_then(|k| OpenRouterProvider::new(Some(k)).ok());
165
166 let ollama = OllamaProvider::new(file_config.ollama_url);
168
169 let (cli_providers, default_provider) = if let Some(ref ai) = file_config.ai {
171 let providers = ai
172 .cli_providers
173 .iter()
174 .map(|c| CliProvider::new(c.clone()))
175 .collect();
176 (providers, ai.default_provider.clone())
177 } else {
178 (CliProvider::detect_all(), None)
180 };
181
182 Self {
183 config,
184 anthropic,
185 ollama,
186 openrouter,
187 cli_providers,
188 default_provider,
189 }
190 }
191
192 pub fn default_router() -> Self {
194 Self::new(RouterConfig::default())
195 }
196
197 pub fn role_for_step(&self, step_type: &StepType) -> ModelRole {
199 ModelRole::for_step_type(step_type)
200 }
201
202 pub fn config_for_role(&self, role: &ModelRole) -> Option<&ModelConfig> {
204 let key = match role {
205 ModelRole::Thinking => "thinking",
206 ModelRole::Coding => "coding",
207 ModelRole::Task => "task",
208 ModelRole::Embedding => "embedding",
209 ModelRole::Auditor => "auditor",
210 };
211 self.config.roles.get(key)
212 }
213
214 pub fn cli_provider(&self, name: &str) -> Option<&CliProvider> {
216 self.cli_providers
217 .iter()
218 .find(|p| p.config.name == name && p.is_available())
219 }
220
221 pub fn detected_cli_providers(&self) -> &[CliProvider] {
223 &self.cli_providers
224 }
225
226 pub async fn complete_with_cli(&self, provider_name: &str, prompt: &str) -> Result<ModelResponse> {
228 let provider = self
229 .cli_provider(provider_name)
230 .context(format!("CLI provider '{}' not found or not available", provider_name))?;
231
232 let text = provider.call(prompt).await?;
233 Ok(ModelResponse {
234 content: text,
235 model: provider_name.to_string(),
236 input_tokens: 0,
237 output_tokens: 0,
238 cost: 0.0,
239 })
240 }
241
242 pub async fn complete(&self, role: &ModelRole, prompt: &str) -> Result<ModelResponse> {
245 if let Some(ref default) = self.default_provider {
247 if let Some(provider) = self.cli_provider(default) {
248 match provider.call(prompt).await {
249 Ok(text) => {
250 return Ok(ModelResponse {
251 content: text,
252 model: default.clone(),
253 input_tokens: 0,
254 output_tokens: 0,
255 cost: 0.0,
256 });
257 }
258 Err(e) => {
259 tracing::warn!("CLI provider '{}' failed: {}, falling back to API", default, e);
260 }
261 }
262 }
263 }
264
265 let config = self
266 .config_for_role(role)
267 .context(format!("No config for role {:?}", role))?;
268
269 match config.provider {
270 Provider::Anthropic => {
271 if let Some(provider) = self.anthropic.as_ref() {
272 provider
273 .complete(prompt, &config.model, config.temperature, config.max_tokens)
274 .await
275 } else {
276 tracing::info!("Anthropic not available, falling back to Ollama");
278 self.ollama
279 .complete(prompt, "llama3.2:3b", config.temperature, config.max_tokens)
280 .await
281 }
282 }
283 Provider::Ollama => {
284 self.ollama
285 .complete(prompt, &config.model, config.temperature, config.max_tokens)
286 .await
287 }
288 Provider::OpenRouter => {
289 if let Some(provider) = self.openrouter.as_ref() {
290 provider
291 .complete(prompt, &config.model, config.temperature, config.max_tokens)
292 .await
293 } else {
294 tracing::info!("OpenRouter not available, falling back to Ollama");
296 self.ollama
297 .complete(prompt, "llama3.2:3b", config.temperature, config.max_tokens)
298 .await
299 }
300 }
301 Provider::OpenAI => {
302 anyhow::bail!("OpenAI provider not yet implemented")
303 }
304 }
305 }
306
307 pub async fn complete_for_step(
309 &self,
310 step_type: &StepType,
311 prompt: &str,
312 ) -> Result<ModelResponse> {
313 let role = self.role_for_step(step_type);
314 self.complete(&role, prompt).await
315 }
316
317 pub fn list_models(&self) -> Vec<(String, &ModelConfig)> {
319 let mut models: Vec<_> = self
320 .config
321 .roles
322 .iter()
323 .map(|(k, v)| (k.clone(), v))
324 .collect();
325 models.sort_by(|a, b| a.0.cmp(&b.0));
326 models
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn test_default_config() {
336 let config = RouterConfig::default();
337 assert_eq!(config.roles.len(), 5);
338 assert!(config.roles.contains_key("thinking"));
339 assert!(config.roles.contains_key("coding"));
340 assert!(config.roles.contains_key("task"));
341 assert!(config.roles.contains_key("embedding"));
342 assert!(config.roles.contains_key("auditor"));
343 }
344
345 #[test]
346 fn test_role_for_step() {
347 let router = ModelRouter::new(RouterConfig::default());
348 assert_eq!(
349 router.role_for_step(&StepType::Analyze),
350 ModelRole::Thinking
351 );
352 assert_eq!(router.role_for_step(&StepType::Code), ModelRole::Coding);
353 assert_eq!(router.role_for_step(&StepType::Search), ModelRole::Task);
354 assert_eq!(
355 router.role_for_step(&StepType::SecurityCheck),
356 ModelRole::Auditor
357 );
358 }
359
360 #[test]
361 fn test_config_for_role() {
362 let router = ModelRouter::new(RouterConfig::default());
363 let thinking = router.config_for_role(&ModelRole::Thinking).unwrap();
364 assert!(thinking.model.contains("opus"));
365
366 let coding = router.config_for_role(&ModelRole::Coding).unwrap();
367 assert!(coding.model.contains("sonnet"));
368 }
369
370 #[test]
371 fn test_list_models() {
372 let router = ModelRouter::new(RouterConfig::default());
373 let models = router.list_models();
374 assert_eq!(models.len(), 5);
375 assert_eq!(models[0].0, "auditor");
377 assert_eq!(models[4].0, "thinking");
378 }
379}