Skip to main content

car_inference/
router.rs

1//! Intelligent model routing — select the best model based on prompt characteristics.
2//!
3//! Routes generation requests based on prompt complexity analysis:
4//! - Simple questions → smallest model (0.6B)
5//! - Medium tasks → Qwen3-1.7B
6//! - Code tasks → Qwen3-4B
7//! - Complex reasoning → best available (8B or 30B-A3B)
8//!
9//! Classify/embed tasks always route to the smallest model.
10
11use crate::hardware::HardwareInfo;
12use crate::models::ModelRegistry;
13use serde::{Deserialize, Serialize};
14use tracing::debug;
15
16// Re-export TaskComplexity from adaptive_router (canonical definition).
17pub use crate::adaptive_router::TaskComplexity;
18
19/// The result of routing a request to a model.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct RoutingDecision {
22    pub model: String,
23    pub complexity: TaskComplexity,
24    pub reason: String,
25}
26
27/// Routes inference requests to the best available model based on prompt analysis and hardware.
28pub struct ModelRouter {
29    hw: HardwareInfo,
30}
31
32/// Model tier ordering, smallest to largest.
33const TIERS: &[&str] = &[
34    "Qwen3-0.6B",
35    "Qwen3-1.7B",
36    "Qwen3-4B",
37    "Qwen3-8B",
38    "Qwen3-30B-A3B",
39];
40
41impl ModelRouter {
42    pub fn new(hw: HardwareInfo) -> Self {
43        Self { hw }
44    }
45
46    /// Route a generation request to the best model.
47    pub fn route_generate(&self, prompt: &str, registry: &ModelRegistry) -> RoutingDecision {
48        let complexity = Self::assess_complexity(prompt);
49        let target = self.target_for_complexity(complexity);
50        let model = self.best_available(&target, registry);
51        let reason = format!("{:?} task -> {} (target: {})", complexity, model, target);
52        debug!(
53            complexity = ?complexity,
54            target = %target,
55            selected = %model,
56            "routed generation request"
57        );
58        RoutingDecision {
59            model,
60            complexity,
61            reason,
62        }
63    }
64
65    /// Route classify to the smallest available generative model.
66    pub fn route_small(&self, registry: &ModelRegistry) -> String {
67        let model = self.best_available("Qwen3-0.6B", registry);
68        debug!(selected = %model, "routed small task (classify)");
69        model
70    }
71
72    /// Route embedding to the dedicated embedding model.
73    pub fn route_embedding(&self, _registry: &ModelRegistry) -> String {
74        "Qwen3-Embedding-0.6B".to_string()
75    }
76
77    /// Assess prompt complexity based on content analysis.
78    pub fn assess_complexity(prompt: &str) -> TaskComplexity {
79        let lower = prompt.to_lowercase();
80        let word_count = prompt.split_whitespace().count();
81        let estimated_tokens = (word_count as f64 * 1.3) as usize;
82
83        // Code markers
84        let code_markers = [
85            "```",
86            "fn ",
87            "def ",
88            "class ",
89            "import ",
90            "require(",
91            "async fn",
92            "pub fn",
93            "function ",
94            "const ",
95            "let ",
96            "var ",
97            "#include",
98            "package ",
99            "impl ",
100        ];
101        let has_code = code_markers.iter().any(|m| prompt.contains(m));
102
103        // Repair/debug markers (checked case-insensitive)
104        let repair_markers = [
105            "fix", "repair", "debug", "refactor", "broken", "failing", "error", "bug",
106        ];
107        let has_repair = repair_markers.iter().any(|m| lower.contains(m));
108
109        // Complex reasoning markers (checked case-insensitive)
110        let reasoning_markers = [
111            "analyze",
112            "compare",
113            "explain why",
114            "step by step",
115            "think through",
116            "evaluate",
117            "trade-off",
118            "tradeoff",
119            "pros and cons",
120            "architecture",
121            "design",
122            "strategy",
123            "optimize",
124            "comprehensive",
125        ];
126        let has_reasoning = reasoning_markers.iter().any(|m| lower.contains(m));
127
128        // Simple question patterns (checked case-insensitive)
129        let simple_patterns = [
130            "what is",
131            "who is",
132            "when did",
133            "where is",
134            "how many",
135            "yes or no",
136            "true or false",
137            "name the",
138            "list the",
139            "define ",
140        ];
141        let is_simple_question = simple_patterns.iter().any(|p| lower.contains(p));
142
143        // Priority: code > reasoning/long > simple > medium
144        if has_code || has_repair {
145            TaskComplexity::Code
146        } else if has_reasoning || estimated_tokens > 500 {
147            TaskComplexity::Complex
148        } else if is_simple_question || estimated_tokens < 30 {
149            TaskComplexity::Simple
150        } else {
151            TaskComplexity::Medium
152        }
153    }
154
155    /// Get the target model name for a complexity level.
156    fn target_for_complexity(&self, complexity: TaskComplexity) -> String {
157        match complexity {
158            TaskComplexity::Simple => "Qwen3-0.6B".into(),
159            TaskComplexity::Medium => "Qwen3-1.7B".into(),
160            TaskComplexity::Code => "Qwen3-4B".into(),
161            TaskComplexity::Complex => self.hw.recommended_model.clone(),
162        }
163    }
164
165    /// Find the best available model at or near the target.
166    ///
167    /// Prefers the target if downloaded and fits in memory. Otherwise falls back
168    /// upward to larger models first (prefer better quality), then downward to
169    /// smaller models. If nothing is downloaded, returns the target name (which
170    /// will trigger a download on first use).
171    fn best_available(&self, target: &str, registry: &ModelRegistry) -> String {
172        let models = registry.list_models();
173        let target_idx = TIERS.iter().position(|&t| t == target).unwrap_or(0);
174
175        // Check if target is downloaded and fits
176        if let Some(m) = models.iter().find(|m| m.name == target) {
177            if m.downloaded && m.quantized_size_mb <= self.hw.max_model_mb {
178                return target.to_string();
179            }
180        }
181
182        // Try models from target upward (prefer better quality)
183        for &name in &TIERS[target_idx..] {
184            if let Some(m) = models.iter().find(|m| m.name == name) {
185                if m.downloaded && m.quantized_size_mb <= self.hw.max_model_mb {
186                    return name.to_string();
187                }
188            }
189        }
190
191        // Try models below target (fallback to smaller)
192        for &name in TIERS[..target_idx].iter().rev() {
193            if let Some(m) = models.iter().find(|m| m.name == name) {
194                if m.downloaded && m.quantized_size_mb <= self.hw.max_model_mb {
195                    return name.to_string();
196                }
197            }
198        }
199
200        // Nothing downloaded — return target anyway (will trigger download)
201        target.to_string()
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    #[test]
210    fn simple_question() {
211        assert_eq!(
212            TaskComplexity::assess("What is the capital of France?"),
213            TaskComplexity::Simple,
214        );
215    }
216
217    #[test]
218    fn code_with_function_keyword() {
219        assert_eq!(
220            TaskComplexity::assess("Write a function to sort a list"),
221            TaskComplexity::Code,
222        );
223    }
224
225    #[test]
226    fn complex_reasoning() {
227        assert_eq!(
228            TaskComplexity::assess(
229                "Analyze the trade-offs between microservices and monolithic architecture for our use case"
230            ),
231            TaskComplexity::Complex,
232        );
233    }
234
235    #[test]
236    fn medium_default() {
237        // Must be >30 estimated tokens (~24 words) to avoid Simple, and no code/reasoning markers
238        assert_eq!(
239            TaskComplexity::assess(
240                "Tell me about the history of computing and how it has evolved over the decades from early mechanical calculation devices through transistors to the modern digital systems we use today in everyday life"
241            ),
242            TaskComplexity::Medium,
243        );
244    }
245
246    #[test]
247    fn code_with_backticks() {
248        assert_eq!(
249            TaskComplexity::assess("Here is my code:\n```rust\nfn main() {}\n```\nFix it"),
250            TaskComplexity::Code,
251        );
252    }
253
254    #[test]
255    fn repair_marker() {
256        assert_eq!(
257            TaskComplexity::assess("Debug this failing test case"),
258            TaskComplexity::Code,
259        );
260    }
261}