1use crate::hardware::HardwareInfo;
12use crate::models::ModelRegistry;
13use serde::{Deserialize, Serialize};
14use tracing::debug;
15
16pub use crate::adaptive_router::TaskComplexity;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct RoutingDecision {
22 pub model: String,
23 pub complexity: TaskComplexity,
24 pub reason: String,
25}
26
27pub struct ModelRouter {
29 hw: HardwareInfo,
30}
31
32const TIERS: &[&str] = &[
34 "Qwen3-0.6B",
35 "Qwen3-1.7B",
36 "Qwen3-4B",
37 "Qwen3-8B",
38 "Qwen3-30B-A3B",
39];
40
41impl ModelRouter {
42 pub fn new(hw: HardwareInfo) -> Self {
43 Self { hw }
44 }
45
46 pub fn route_generate(&self, prompt: &str, registry: &ModelRegistry) -> RoutingDecision {
48 let complexity = Self::assess_complexity(prompt);
49 let target = self.target_for_complexity(complexity);
50 let model = self.best_available(&target, registry);
51 let reason = format!("{:?} task -> {} (target: {})", complexity, model, target);
52 debug!(
53 complexity = ?complexity,
54 target = %target,
55 selected = %model,
56 "routed generation request"
57 );
58 RoutingDecision {
59 model,
60 complexity,
61 reason,
62 }
63 }
64
65 pub fn route_small(&self, registry: &ModelRegistry) -> String {
67 let model = self.best_available("Qwen3-0.6B", registry);
68 debug!(selected = %model, "routed small task (classify)");
69 model
70 }
71
72 pub fn route_embedding(&self, _registry: &ModelRegistry) -> String {
74 "Qwen3-Embedding-0.6B".to_string()
75 }
76
77 pub fn assess_complexity(prompt: &str) -> TaskComplexity {
79 let lower = prompt.to_lowercase();
80 let word_count = prompt.split_whitespace().count();
81 let estimated_tokens = (word_count as f64 * 1.3) as usize;
82
83 let code_markers = [
85 "```",
86 "fn ",
87 "def ",
88 "class ",
89 "import ",
90 "require(",
91 "async fn",
92 "pub fn",
93 "function ",
94 "const ",
95 "let ",
96 "var ",
97 "#include",
98 "package ",
99 "impl ",
100 ];
101 let has_code = code_markers.iter().any(|m| prompt.contains(m));
102
103 let repair_markers = [
105 "fix", "repair", "debug", "refactor", "broken", "failing", "error", "bug",
106 ];
107 let has_repair = repair_markers.iter().any(|m| lower.contains(m));
108
109 let reasoning_markers = [
111 "analyze",
112 "compare",
113 "explain why",
114 "step by step",
115 "think through",
116 "evaluate",
117 "trade-off",
118 "tradeoff",
119 "pros and cons",
120 "architecture",
121 "design",
122 "strategy",
123 "optimize",
124 "comprehensive",
125 ];
126 let has_reasoning = reasoning_markers.iter().any(|m| lower.contains(m));
127
128 let simple_patterns = [
130 "what is",
131 "who is",
132 "when did",
133 "where is",
134 "how many",
135 "yes or no",
136 "true or false",
137 "name the",
138 "list the",
139 "define ",
140 ];
141 let is_simple_question = simple_patterns.iter().any(|p| lower.contains(p));
142
143 if has_code || has_repair {
145 TaskComplexity::Code
146 } else if has_reasoning || estimated_tokens > 500 {
147 TaskComplexity::Complex
148 } else if is_simple_question || estimated_tokens < 30 {
149 TaskComplexity::Simple
150 } else {
151 TaskComplexity::Medium
152 }
153 }
154
155 fn target_for_complexity(&self, complexity: TaskComplexity) -> String {
157 match complexity {
158 TaskComplexity::Simple => "Qwen3-0.6B".into(),
159 TaskComplexity::Medium => "Qwen3-1.7B".into(),
160 TaskComplexity::Code => "Qwen3-4B".into(),
161 TaskComplexity::Complex => self.hw.recommended_model.clone(),
162 }
163 }
164
165 fn best_available(&self, target: &str, registry: &ModelRegistry) -> String {
172 let models = registry.list_models();
173 let target_idx = TIERS.iter().position(|&t| t == target).unwrap_or(0);
174
175 if let Some(m) = models.iter().find(|m| m.name == target) {
177 if m.downloaded && m.quantized_size_mb <= self.hw.max_model_mb {
178 return target.to_string();
179 }
180 }
181
182 for &name in &TIERS[target_idx..] {
184 if let Some(m) = models.iter().find(|m| m.name == name) {
185 if m.downloaded && m.quantized_size_mb <= self.hw.max_model_mb {
186 return name.to_string();
187 }
188 }
189 }
190
191 for &name in TIERS[..target_idx].iter().rev() {
193 if let Some(m) = models.iter().find(|m| m.name == name) {
194 if m.downloaded && m.quantized_size_mb <= self.hw.max_model_mb {
195 return name.to_string();
196 }
197 }
198 }
199
200 target.to_string()
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 #[test]
210 fn simple_question() {
211 assert_eq!(
212 TaskComplexity::assess("What is the capital of France?"),
213 TaskComplexity::Simple,
214 );
215 }
216
217 #[test]
218 fn code_with_function_keyword() {
219 assert_eq!(
220 TaskComplexity::assess("Write a function to sort a list"),
221 TaskComplexity::Code,
222 );
223 }
224
225 #[test]
226 fn complex_reasoning() {
227 assert_eq!(
228 TaskComplexity::assess(
229 "Analyze the trade-offs between microservices and monolithic architecture for our use case"
230 ),
231 TaskComplexity::Complex,
232 );
233 }
234
235 #[test]
236 fn medium_default() {
237 assert_eq!(
239 TaskComplexity::assess(
240 "Tell me about the history of computing and how it has evolved over the decades from early mechanical calculation devices through transistors to the modern digital systems we use today in everyday life"
241 ),
242 TaskComplexity::Medium,
243 );
244 }
245
246 #[test]
247 fn code_with_backticks() {
248 assert_eq!(
249 TaskComplexity::assess("Here is my code:\n```rust\nfn main() {}\n```\nFix it"),
250 TaskComplexity::Code,
251 );
252 }
253
254 #[test]
255 fn repair_marker() {
256 assert_eq!(
257 TaskComplexity::assess("Debug this failing test case"),
258 TaskComplexity::Code,
259 );
260 }
261}