Skip to main content

car_inference/
router.rs

1//! Intelligent model routing — select the best model based on prompt characteristics.
2//!
3//! Routes generation requests based on prompt complexity analysis:
4//! - Simple questions → smallest model (0.6B)
5//! - Medium tasks → Qwen3-1.7B
6//! - Code tasks → Qwen3-4B
7//! - Complex reasoning → best available (8B or 30B-A3B)
8//!
9//! Classify/embed tasks always route to the smallest model.
10
11use crate::hardware::HardwareInfo;
12use crate::models::ModelRegistry;
13use serde::{Deserialize, Serialize};
14use tracing::debug;
15
16// Re-export TaskComplexity from adaptive_router (canonical definition).
17pub use crate::adaptive_router::TaskComplexity;
18
19/// The result of routing a request to a model.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct RoutingDecision {
22    pub model: String,
23    pub complexity: TaskComplexity,
24    pub reason: String,
25}
26
27/// Routes inference requests to the best available model based on prompt analysis and hardware.
28pub struct ModelRouter {
29    hw: HardwareInfo,
30}
31
32/// Model tier ordering, smallest to largest.
33const TIERS: &[&str] = &[
34    "Qwen3-0.6B",
35    "Qwen3-1.7B",
36    "Qwen3-4B",
37    "Qwen3-8B",
38    "Qwen3-30B-A3B",
39];
40
41impl ModelRouter {
42    pub fn new(hw: HardwareInfo) -> Self {
43        Self { hw }
44    }
45
46    /// Route a generation request to the best model.
47    pub fn route_generate(&self, prompt: &str, registry: &ModelRegistry) -> RoutingDecision {
48        let complexity = Self::assess_complexity(prompt);
49        let target = self.target_for_complexity(complexity);
50        let model = self.best_available(&target, registry);
51        let reason = format!(
52            "{:?} task -> {} (target: {})",
53            complexity, model, target
54        );
55        debug!(
56            complexity = ?complexity,
57            target = %target,
58            selected = %model,
59            "routed generation request"
60        );
61        RoutingDecision {
62            model,
63            complexity,
64            reason,
65        }
66    }
67
68    /// Route classify to the smallest available generative model.
69    pub fn route_small(&self, registry: &ModelRegistry) -> String {
70        let model = self.best_available("Qwen3-0.6B", registry);
71        debug!(selected = %model, "routed small task (classify)");
72        model
73    }
74
75    /// Route embedding to the dedicated embedding model.
76    pub fn route_embedding(&self, _registry: &ModelRegistry) -> String {
77        "Qwen3-Embedding-0.6B".to_string()
78    }
79
80    /// Assess prompt complexity based on content analysis.
81    pub fn assess_complexity(prompt: &str) -> TaskComplexity {
82        let lower = prompt.to_lowercase();
83        let word_count = prompt.split_whitespace().count();
84        let estimated_tokens = (word_count as f64 * 1.3) as usize;
85
86        // Code markers
87        let code_markers = [
88            "```", "fn ", "def ", "class ", "import ", "require(",
89            "async fn", "pub fn", "function ", "const ", "let ", "var ",
90            "#include", "package ", "impl ",
91        ];
92        let has_code = code_markers.iter().any(|m| prompt.contains(m));
93
94        // Repair/debug markers (checked case-insensitive)
95        let repair_markers = [
96            "fix", "repair", "debug", "refactor", "broken",
97            "failing", "error", "bug",
98        ];
99        let has_repair = repair_markers.iter().any(|m| lower.contains(m));
100
101        // Complex reasoning markers (checked case-insensitive)
102        let reasoning_markers = [
103            "analyze", "compare", "explain why", "step by step",
104            "think through", "evaluate", "trade-off", "tradeoff",
105            "pros and cons", "architecture", "design", "strategy",
106            "optimize", "comprehensive",
107        ];
108        let has_reasoning = reasoning_markers.iter().any(|m| lower.contains(m));
109
110        // Simple question patterns (checked case-insensitive)
111        let simple_patterns = [
112            "what is", "who is", "when did", "where is",
113            "how many", "yes or no", "true or false", "name the",
114            "list the", "define ",
115        ];
116        let is_simple_question = simple_patterns.iter().any(|p| lower.contains(p));
117
118        // Priority: code > reasoning/long > simple > medium
119        if has_code || has_repair {
120            TaskComplexity::Code
121        } else if has_reasoning || estimated_tokens > 500 {
122            TaskComplexity::Complex
123        } else if is_simple_question || estimated_tokens < 30 {
124            TaskComplexity::Simple
125        } else {
126            TaskComplexity::Medium
127        }
128    }
129
130    /// Get the target model name for a complexity level.
131    fn target_for_complexity(&self, complexity: TaskComplexity) -> String {
132        match complexity {
133            TaskComplexity::Simple => "Qwen3-0.6B".into(),
134            TaskComplexity::Medium => "Qwen3-1.7B".into(),
135            TaskComplexity::Code => "Qwen3-4B".into(),
136            TaskComplexity::Complex => self.hw.recommended_model.clone(),
137        }
138    }
139
140    /// Find the best available model at or near the target.
141    ///
142    /// Prefers the target if downloaded and fits in memory. Otherwise falls back
143    /// upward to larger models first (prefer better quality), then downward to
144    /// smaller models. If nothing is downloaded, returns the target name (which
145    /// will trigger a download on first use).
146    fn best_available(&self, target: &str, registry: &ModelRegistry) -> String {
147        let models = registry.list_models();
148        let target_idx = TIERS.iter().position(|&t| t == target).unwrap_or(0);
149
150        // Check if target is downloaded and fits
151        if let Some(m) = models.iter().find(|m| m.name == target) {
152            if m.downloaded && m.quantized_size_mb <= self.hw.max_model_mb {
153                return target.to_string();
154            }
155        }
156
157        // Try models from target upward (prefer better quality)
158        for &name in &TIERS[target_idx..] {
159            if let Some(m) = models.iter().find(|m| m.name == name) {
160                if m.downloaded && m.quantized_size_mb <= self.hw.max_model_mb {
161                    return name.to_string();
162                }
163            }
164        }
165
166        // Try models below target (fallback to smaller)
167        for &name in TIERS[..target_idx].iter().rev() {
168            if let Some(m) = models.iter().find(|m| m.name == name) {
169                if m.downloaded && m.quantized_size_mb <= self.hw.max_model_mb {
170                    return name.to_string();
171                }
172            }
173        }
174
175        // Nothing downloaded — return target anyway (will trigger download)
176        target.to_string()
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn simple_question() {
186        assert_eq!(
187            TaskComplexity::assess("What is the capital of France?"),
188            TaskComplexity::Simple,
189        );
190    }
191
192    #[test]
193    fn code_with_function_keyword() {
194        assert_eq!(
195            TaskComplexity::assess("Write a function to sort a list"),
196            TaskComplexity::Code,
197        );
198    }
199
200    #[test]
201    fn complex_reasoning() {
202        assert_eq!(
203            TaskComplexity::assess(
204                "Analyze the trade-offs between microservices and monolithic architecture for our use case"
205            ),
206            TaskComplexity::Complex,
207        );
208    }
209
210    #[test]
211    fn medium_default() {
212        // Must be >30 estimated tokens (~24 words) to avoid Simple, and no code/reasoning markers
213        assert_eq!(
214            TaskComplexity::assess(
215                "Tell me about the history of computing and how it has evolved over the decades from early mechanical calculation devices through transistors to the modern digital systems we use today in everyday life"
216            ),
217            TaskComplexity::Medium,
218        );
219    }
220
221    #[test]
222    fn code_with_backticks() {
223        assert_eq!(
224            TaskComplexity::assess("Here is my code:\n```rust\nfn main() {}\n```\nFix it"),
225            TaskComplexity::Code,
226        );
227    }
228
229    #[test]
230    fn repair_marker() {
231        assert_eq!(
232            TaskComplexity::assess("Debug this failing test case"),
233            TaskComplexity::Code,
234        );
235    }
236}