1use crate::hardware::HardwareInfo;
12use crate::models::ModelRegistry;
13use serde::{Deserialize, Serialize};
14use tracing::debug;
15
16pub use crate::adaptive_router::TaskComplexity;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct RoutingDecision {
22 pub model: String,
23 pub complexity: TaskComplexity,
24 pub reason: String,
25}
26
27pub struct ModelRouter {
29 hw: HardwareInfo,
30}
31
32const TIERS: &[&str] = &[
34 "Qwen3-0.6B",
35 "Qwen3-1.7B",
36 "Qwen3-4B",
37 "Qwen3-8B",
38 "Qwen3-30B-A3B",
39];
40
41impl ModelRouter {
42 pub fn new(hw: HardwareInfo) -> Self {
43 Self { hw }
44 }
45
46 pub fn route_generate(&self, prompt: &str, registry: &ModelRegistry) -> RoutingDecision {
48 let complexity = Self::assess_complexity(prompt);
49 let target = self.target_for_complexity(complexity);
50 let model = self.best_available(&target, registry);
51 let reason = format!(
52 "{:?} task -> {} (target: {})",
53 complexity, model, target
54 );
55 debug!(
56 complexity = ?complexity,
57 target = %target,
58 selected = %model,
59 "routed generation request"
60 );
61 RoutingDecision {
62 model,
63 complexity,
64 reason,
65 }
66 }
67
68 pub fn route_small(&self, registry: &ModelRegistry) -> String {
70 let model = self.best_available("Qwen3-0.6B", registry);
71 debug!(selected = %model, "routed small task (classify)");
72 model
73 }
74
75 pub fn route_embedding(&self, _registry: &ModelRegistry) -> String {
77 "Qwen3-Embedding-0.6B".to_string()
78 }
79
80 pub fn assess_complexity(prompt: &str) -> TaskComplexity {
82 let lower = prompt.to_lowercase();
83 let word_count = prompt.split_whitespace().count();
84 let estimated_tokens = (word_count as f64 * 1.3) as usize;
85
86 let code_markers = [
88 "```", "fn ", "def ", "class ", "import ", "require(",
89 "async fn", "pub fn", "function ", "const ", "let ", "var ",
90 "#include", "package ", "impl ",
91 ];
92 let has_code = code_markers.iter().any(|m| prompt.contains(m));
93
94 let repair_markers = [
96 "fix", "repair", "debug", "refactor", "broken",
97 "failing", "error", "bug",
98 ];
99 let has_repair = repair_markers.iter().any(|m| lower.contains(m));
100
101 let reasoning_markers = [
103 "analyze", "compare", "explain why", "step by step",
104 "think through", "evaluate", "trade-off", "tradeoff",
105 "pros and cons", "architecture", "design", "strategy",
106 "optimize", "comprehensive",
107 ];
108 let has_reasoning = reasoning_markers.iter().any(|m| lower.contains(m));
109
110 let simple_patterns = [
112 "what is", "who is", "when did", "where is",
113 "how many", "yes or no", "true or false", "name the",
114 "list the", "define ",
115 ];
116 let is_simple_question = simple_patterns.iter().any(|p| lower.contains(p));
117
118 if has_code || has_repair {
120 TaskComplexity::Code
121 } else if has_reasoning || estimated_tokens > 500 {
122 TaskComplexity::Complex
123 } else if is_simple_question || estimated_tokens < 30 {
124 TaskComplexity::Simple
125 } else {
126 TaskComplexity::Medium
127 }
128 }
129
130 fn target_for_complexity(&self, complexity: TaskComplexity) -> String {
132 match complexity {
133 TaskComplexity::Simple => "Qwen3-0.6B".into(),
134 TaskComplexity::Medium => "Qwen3-1.7B".into(),
135 TaskComplexity::Code => "Qwen3-4B".into(),
136 TaskComplexity::Complex => self.hw.recommended_model.clone(),
137 }
138 }
139
140 fn best_available(&self, target: &str, registry: &ModelRegistry) -> String {
147 let models = registry.list_models();
148 let target_idx = TIERS.iter().position(|&t| t == target).unwrap_or(0);
149
150 if let Some(m) = models.iter().find(|m| m.name == target) {
152 if m.downloaded && m.quantized_size_mb <= self.hw.max_model_mb {
153 return target.to_string();
154 }
155 }
156
157 for &name in &TIERS[target_idx..] {
159 if let Some(m) = models.iter().find(|m| m.name == name) {
160 if m.downloaded && m.quantized_size_mb <= self.hw.max_model_mb {
161 return name.to_string();
162 }
163 }
164 }
165
166 for &name in TIERS[..target_idx].iter().rev() {
168 if let Some(m) = models.iter().find(|m| m.name == name) {
169 if m.downloaded && m.quantized_size_mb <= self.hw.max_model_mb {
170 return name.to_string();
171 }
172 }
173 }
174
175 target.to_string()
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[test]
185 fn simple_question() {
186 assert_eq!(
187 TaskComplexity::assess("What is the capital of France?"),
188 TaskComplexity::Simple,
189 );
190 }
191
192 #[test]
193 fn code_with_function_keyword() {
194 assert_eq!(
195 TaskComplexity::assess("Write a function to sort a list"),
196 TaskComplexity::Code,
197 );
198 }
199
200 #[test]
201 fn complex_reasoning() {
202 assert_eq!(
203 TaskComplexity::assess(
204 "Analyze the trade-offs between microservices and monolithic architecture for our use case"
205 ),
206 TaskComplexity::Complex,
207 );
208 }
209
210 #[test]
211 fn medium_default() {
212 assert_eq!(
214 TaskComplexity::assess(
215 "Tell me about the history of computing and how it has evolved over the decades from early mechanical calculation devices through transistors to the modern digital systems we use today in everyday life"
216 ),
217 TaskComplexity::Medium,
218 );
219 }
220
221 #[test]
222 fn code_with_backticks() {
223 assert_eq!(
224 TaskComplexity::assess("Here is my code:\n```rust\nfn main() {}\n```\nFix it"),
225 TaskComplexity::Code,
226 );
227 }
228
229 #[test]
230 fn repair_marker() {
231 assert_eq!(
232 TaskComplexity::assess("Debug this failing test case"),
233 TaskComplexity::Code,
234 );
235 }
236}