Skip to main content

mur_core/model/
router.rs

1//! Model Router — selects the appropriate AI model based on step type and role.
2//!
3//! Routes requests to different providers (Anthropic, Ollama, OpenRouter)
4//! based on the 5 model roles: Thinking, Coding, Task, Embedding, Auditor.
5
6use super::cli_provider::{CliProvider, CliProviderConfig};
7use super::ollama::OllamaProvider;
8use super::openrouter::OpenRouterProvider;
9use super::provider::{AnthropicProvider, ModelResponse};
10use super::{ModelConfig, ModelRole, Provider};
11use crate::types::StepType;
12use anyhow::{Context, Result};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15
16/// Configuration for all model roles.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct RouterConfig {
19    pub roles: HashMap<String, ModelConfig>,
20}
21
22impl Default for RouterConfig {
23    fn default() -> Self {
24        let mut roles = HashMap::new();
25        roles.insert(
26            "thinking".into(),
27            ModelConfig {
28                provider: Provider::Anthropic,
29                model: "claude-opus-4-20250514".into(),
30                api_key_env: Some("ANTHROPIC_API_KEY".into()),
31                max_cost_per_call: 5.0,
32                temperature: 0.3,
33                max_tokens: 8192,
34            },
35        );
36        roles.insert(
37            "coding".into(),
38            ModelConfig {
39                provider: Provider::Anthropic,
40                model: "claude-sonnet-4-20250514".into(),
41                api_key_env: Some("ANTHROPIC_API_KEY".into()),
42                max_cost_per_call: 2.0,
43                temperature: 0.2,
44                max_tokens: 8192,
45            },
46        );
47        roles.insert(
48            "task".into(),
49            ModelConfig {
50                provider: Provider::OpenRouter,
51                model: "google/gemini-2.0-flash-001".into(),
52                api_key_env: Some("OPENROUTER_API_KEY".into()),
53                max_cost_per_call: 0.5,
54                temperature: 0.3,
55                max_tokens: 4096,
56            },
57        );
58        roles.insert(
59            "embedding".into(),
60            ModelConfig {
61                provider: Provider::Ollama,
62                model: "nomic-embed-text".into(),
63                api_key_env: None,
64                max_cost_per_call: 0.0,
65                temperature: 0.0,
66                max_tokens: 0,
67            },
68        );
69        roles.insert(
70            "auditor".into(),
71            ModelConfig {
72                provider: Provider::Ollama,
73                model: "llama3.2:3b".into(),
74                api_key_env: None,
75                max_cost_per_call: 0.0,
76                temperature: 0.1,
77                max_tokens: 2048,
78            },
79        );
80        Self { roles }
81    }
82}
83
84/// Commander config file structure (subset relevant to model routing).
85#[derive(Debug, Clone, Default, Serialize, Deserialize)]
86struct CommanderFileConfig {
87    #[serde(default)]
88    anthropic_api_key: Option<String>,
89    #[serde(default)]
90    openrouter_api_key: Option<String>,
91    #[serde(default)]
92    ollama_url: Option<String>,
93    /// AI section with CLI providers and default provider.
94    #[serde(default)]
95    ai: Option<AiConfig>,
96}
97
98/// AI configuration section.
99#[derive(Debug, Clone, Default, Serialize, Deserialize)]
100struct AiConfig {
101    /// Default provider name (e.g., "claude-cli", "ollama").
102    #[serde(default)]
103    pub default_provider: Option<String>,
104    /// CLI provider configurations.
105    #[serde(default)]
106    pub cli_providers: Vec<CliProviderConfig>,
107}
108
109/// Load API keys from ~/.mur/commander/config.toml if it exists.
110fn load_config_file() -> CommanderFileConfig {
111    let config_path = if let Some(base_dirs) = directories::BaseDirs::new() {
112        base_dirs
113            .home_dir()
114            .join(".mur")
115            .join("commander")
116            .join("config.toml")
117    } else {
118        return CommanderFileConfig::default();
119    };
120
121    if !config_path.exists() {
122        return CommanderFileConfig::default();
123    }
124
125    std::fs::read_to_string(&config_path)
126        .ok()
127        .and_then(|content| toml::from_str(&content).ok())
128        .unwrap_or_default()
129}
130
131/// Multi-model router that dispatches to the right provider.
132pub struct ModelRouter {
133    config: RouterConfig,
134    anthropic: Option<AnthropicProvider>,
135    ollama: OllamaProvider,
136    openrouter: Option<OpenRouterProvider>,
137    /// CLI-based providers (claude -p, gemini, etc.).
138    cli_providers: Vec<CliProvider>,
139    /// Default provider name from config.
140    default_provider: Option<String>,
141}
142
143impl ModelRouter {
144    /// Create a router with the given config.
145    /// Loads API keys from config file (~/.mur/commander/config.toml),
146    /// falling back to environment variables.
147    pub fn new(config: RouterConfig) -> Self {
148        let file_config = load_config_file();
149
150        // Anthropic: config file key → env var
151        let anthropic_key = file_config
152            .anthropic_api_key
153            .filter(|k| !k.is_empty())
154            .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
155            .filter(|k| !k.is_empty());
156        let anthropic = anthropic_key.and_then(|k| AnthropicProvider::new(Some(k)).ok());
157
158        // OpenRouter: config file key → env var
159        let openrouter_key = file_config
160            .openrouter_api_key
161            .filter(|k| !k.is_empty())
162            .or_else(|| std::env::var("OPENROUTER_API_KEY").ok())
163            .filter(|k| !k.is_empty());
164        let openrouter = openrouter_key.and_then(|k| OpenRouterProvider::new(Some(k)).ok());
165
166        // Ollama: config file URL → default localhost
167        let ollama = OllamaProvider::new(file_config.ollama_url);
168
169        // CLI providers: from config file, or auto-detect
170        let (cli_providers, default_provider) = if let Some(ref ai) = file_config.ai {
171            let providers = ai
172                .cli_providers
173                .iter()
174                .map(|c| CliProvider::new(c.clone()))
175                .collect();
176            (providers, ai.default_provider.clone())
177        } else {
178            // Auto-detect available CLI tools
179            (CliProvider::detect_all(), None)
180        };
181
182        Self {
183            config,
184            anthropic,
185            ollama,
186            openrouter,
187            cli_providers,
188            default_provider,
189        }
190    }
191
192    /// Create a router with default config.
193    pub fn default_router() -> Self {
194        Self::new(RouterConfig::default())
195    }
196
197    /// Select the model role for a given step type.
198    pub fn role_for_step(&self, step_type: &StepType) -> ModelRole {
199        ModelRole::for_step_type(step_type)
200    }
201
202    /// Get the model config for a given role.
203    pub fn config_for_role(&self, role: &ModelRole) -> Option<&ModelConfig> {
204        let key = match role {
205            ModelRole::Thinking => "thinking",
206            ModelRole::Coding => "coding",
207            ModelRole::Task => "task",
208            ModelRole::Embedding => "embedding",
209            ModelRole::Auditor => "auditor",
210        };
211        self.config.roles.get(key)
212    }
213
214    /// Get a CLI provider by name.
215    pub fn cli_provider(&self, name: &str) -> Option<&CliProvider> {
216        self.cli_providers
217            .iter()
218            .find(|p| p.config.name == name && p.is_available())
219    }
220
221    /// List detected CLI providers.
222    pub fn detected_cli_providers(&self) -> &[CliProvider] {
223        &self.cli_providers
224    }
225
226    /// Complete using a CLI provider directly.
227    pub async fn complete_with_cli(&self, provider_name: &str, prompt: &str) -> Result<ModelResponse> {
228        let provider = self
229            .cli_provider(provider_name)
230            .context(format!("CLI provider '{}' not found or not available", provider_name))?;
231
232        let text = provider.call(prompt).await?;
233        Ok(ModelResponse {
234            content: text,
235            model: provider_name.to_string(),
236            input_tokens: 0,
237            output_tokens: 0,
238            cost: 0.0,
239        })
240    }
241
242    /// Route a completion request to the appropriate provider.
243    /// If a default CLI provider is configured and available, uses it first.
244    pub async fn complete(&self, role: &ModelRole, prompt: &str) -> Result<ModelResponse> {
245        // Try default CLI provider first if configured
246        if let Some(ref default) = self.default_provider {
247            if let Some(provider) = self.cli_provider(default) {
248                match provider.call(prompt).await {
249                    Ok(text) => {
250                        return Ok(ModelResponse {
251                            content: text,
252                            model: default.clone(),
253                            input_tokens: 0,
254                            output_tokens: 0,
255                            cost: 0.0,
256                        });
257                    }
258                    Err(e) => {
259                        tracing::warn!("CLI provider '{}' failed: {}, falling back to API", default, e);
260                    }
261                }
262            }
263        }
264
265        let config = self
266            .config_for_role(role)
267            .context(format!("No config for role {:?}", role))?;
268
269        match config.provider {
270            Provider::Anthropic => {
271                if let Some(provider) = self.anthropic.as_ref() {
272                    provider
273                        .complete(prompt, &config.model, config.temperature, config.max_tokens)
274                        .await
275                } else {
276                    // Fallback to Ollama when Anthropic key not available
277                    tracing::info!("Anthropic not available, falling back to Ollama");
278                    self.ollama
279                        .complete(prompt, "llama3.2:3b", config.temperature, config.max_tokens)
280                        .await
281                }
282            }
283            Provider::Ollama => {
284                self.ollama
285                    .complete(prompt, &config.model, config.temperature, config.max_tokens)
286                    .await
287            }
288            Provider::OpenRouter => {
289                if let Some(provider) = self.openrouter.as_ref() {
290                    provider
291                        .complete(prompt, &config.model, config.temperature, config.max_tokens)
292                        .await
293                } else {
294                    // Fallback to Ollama when OpenRouter key not available
295                    tracing::info!("OpenRouter not available, falling back to Ollama");
296                    self.ollama
297                        .complete(prompt, "llama3.2:3b", config.temperature, config.max_tokens)
298                        .await
299                }
300            }
301            Provider::OpenAI => {
302                anyhow::bail!("OpenAI provider not yet implemented")
303            }
304        }
305    }
306
307    /// Route a step to the appropriate model and get a completion.
308    pub async fn complete_for_step(
309        &self,
310        step_type: &StepType,
311        prompt: &str,
312    ) -> Result<ModelResponse> {
313        let role = self.role_for_step(step_type);
314        self.complete(&role, prompt).await
315    }
316
317    /// List all configured roles and their models.
318    pub fn list_models(&self) -> Vec<(String, &ModelConfig)> {
319        let mut models: Vec<_> = self
320            .config
321            .roles
322            .iter()
323            .map(|(k, v)| (k.clone(), v))
324            .collect();
325        models.sort_by(|a, b| a.0.cmp(&b.0));
326        models
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[test]
335    fn test_default_config() {
336        let config = RouterConfig::default();
337        assert_eq!(config.roles.len(), 5);
338        assert!(config.roles.contains_key("thinking"));
339        assert!(config.roles.contains_key("coding"));
340        assert!(config.roles.contains_key("task"));
341        assert!(config.roles.contains_key("embedding"));
342        assert!(config.roles.contains_key("auditor"));
343    }
344
345    #[test]
346    fn test_role_for_step() {
347        let router = ModelRouter::new(RouterConfig::default());
348        assert_eq!(
349            router.role_for_step(&StepType::Analyze),
350            ModelRole::Thinking
351        );
352        assert_eq!(router.role_for_step(&StepType::Code), ModelRole::Coding);
353        assert_eq!(router.role_for_step(&StepType::Search), ModelRole::Task);
354        assert_eq!(
355            router.role_for_step(&StepType::SecurityCheck),
356            ModelRole::Auditor
357        );
358    }
359
360    #[test]
361    fn test_config_for_role() {
362        let router = ModelRouter::new(RouterConfig::default());
363        let thinking = router.config_for_role(&ModelRole::Thinking).unwrap();
364        assert!(thinking.model.contains("opus"));
365
366        let coding = router.config_for_role(&ModelRole::Coding).unwrap();
367        assert!(coding.model.contains("sonnet"));
368    }
369
370    #[test]
371    fn test_list_models() {
372        let router = ModelRouter::new(RouterConfig::default());
373        let models = router.list_models();
374        assert_eq!(models.len(), 5);
375        // Should be sorted alphabetically
376        assert_eq!(models[0].0, "auditor");
377        assert_eq!(models[4].0, "thinking");
378    }
379}