1use crate::model_db::{self, ModelEntry};
7use crate::{Complexity, Context, Message, MessageContent};
8
9pub trait ComplexityRouter: Send + Sync {
11 fn classify(&self, context: &Context) -> Complexity;
13
14 fn route(
16 &self,
17 complexity: Complexity,
18 prefer_cost_efficient: bool,
19 ) -> Vec<&'static ModelEntry>;
20}
21
22#[derive(Debug, Clone, Default)]
27pub struct DefaultRouter {
28 _private: (),
29}
30
31impl DefaultRouter {
32 pub fn new() -> Self {
34 Self { _private: () }
35 }
36
37 fn extract_content_text(&self, content: &MessageContent) -> String {
39 match content {
40 MessageContent::Text(s) => s.clone(),
41 MessageContent::Blocks(blocks) => blocks
42 .iter()
43 .filter_map(|b| b.as_text())
44 .collect::<Vec<_>>()
45 .join(" "),
46 }
47 }
48
49 fn get_last_user_message_text(&self, context: &Context) -> Option<String> {
51 context.messages.iter().rev().find_map(|msg| {
52 if let Message::User(user_msg) = msg {
53 let text = self.extract_content_text(&user_msg.content);
54 if !text.is_empty() {
55 Some(text)
56 } else {
57 None
58 }
59 } else {
60 None
61 }
62 })
63 }
64
65 fn count_tokens(&self, text: &str) -> usize {
67 crate::high_level::tokens::estimate(text)
68 }
69
70 fn analyze_keywords(&self, text: &str) -> i32 {
73 let lower = text.to_lowercase();
74
75 let complex_keywords = [
78 "build a",
79 "build the",
80 "create a service",
81 "write a full",
82 "implement a complete",
83 "implement a full",
84 "microservice",
85 "distributed system",
86 "concurrent",
87 "parallel processing",
88 "full-stack",
89 "full stack",
90 "end-to-end",
91 "enterprise",
92 "complete application",
93 "complete system",
94 ];
95 let has_complex = complex_keywords.iter().any(|kw| lower.contains(*kw));
96
97 let research_keywords = [
99 "analyze deeply",
100 "research",
101 "evaluate thoroughly",
102 "investigate",
103 "compare and contrast",
104 "benchmark",
105 "comprehensive analysis",
106 "thorough",
107 "in-depth",
108 "deep research",
109 "study of",
110 ];
111 let has_research = research_keywords.iter().any(|kw| lower.contains(*kw));
112
113 let moderate_keywords = [
115 "architect",
116 "design a",
117 "refactor",
118 "implement",
119 "create a class",
120 "optimize",
121 "debug",
122 "review code",
123 "parse",
124 "validate",
125 "schema",
126 "api",
127 "build a",
128 ];
129 let has_moderate = moderate_keywords.iter().any(|kw| lower.contains(*kw));
130
131 let simple_keywords = [
133 "explain",
134 "write function",
135 "fix typo",
136 "list",
137 "describe",
138 "define",
139 "convert",
140 "calculate",
141 "simple",
142 ];
143 let has_simple = simple_keywords.iter().any(|kw| lower.contains(*kw));
144
145 let trivial_keywords = [
147 "translate",
148 "summarize",
149 "spell check",
150 "format",
151 "capitalize",
152 "lowercase",
153 "uppercase",
154 "trim",
155 "count words",
156 ];
157 let has_trivial = trivial_keywords.iter().any(|kw| lower.contains(*kw));
158
159 if has_research {
161 4
162 } else if has_complex {
163 3
164 } else if has_moderate {
165 2
166 } else if has_simple {
167 1
168 } else if has_trivial {
169 0
170 } else {
171 1 }
173 }
174
175 fn analyze_system_prompt(&self, system_prompt: Option<&str>) -> i32 {
177 let Some(prompt) = system_prompt else {
178 return 0;
179 };
180
181 let lower = prompt.to_lowercase();
182
183 if lower.contains("research")
185 || lower.contains("deep analysis")
186 || lower.contains("thorough")
187 {
188 return 2;
189 }
190
191 if lower.contains("helpful assistant")
193 && !lower.contains("expert")
194 && !lower.contains("advanced")
195 {
196 return 0;
197 }
198
199 if lower.contains("expert")
201 || lower.contains("senior developer")
202 || lower.contains("architect")
203 {
204 return 1;
205 }
206
207 0
208 }
209
210 fn score_to_complexity(&self, score: i32) -> Complexity {
213 match score {
214 0 => Complexity::Trivial,
215 1 => Complexity::Simple,
216 2 => Complexity::Moderate,
217 3 => Complexity::Complex,
218 _ => Complexity::Research,
219 }
220 }
221
222 fn get_models_for_complexity(&self, complexity: Complexity) -> Vec<&'static ModelEntry> {
224 let complexity_tier = complexity.cost_tier();
225
226 let patterns: Vec<&str> = match complexity {
229 Complexity::Trivial => vec!["haiku", "gpt-4o-mini", "mini"],
230 Complexity::Simple => vec!["haiku", "sonnet", "gpt-4o-mini", "mini"],
231 Complexity::Moderate => vec!["sonnet", "opus", "gpt-4o", "gpt-4.1"],
232 Complexity::Complex => vec!["opus", "gemini-2.5-pro", "gpt-4.1", "claude-sonnet"],
233 Complexity::Research => vec![
234 "opus-4.5",
235 "opus-4.6",
236 "gemini-3-pro",
237 "gemini-2.5-pro",
238 "claude-opus",
239 ],
240 };
241
242 let mut candidates: Vec<&'static ModelEntry> = Vec::new();
244
245 for pattern in &patterns {
246 let matches = model_db::search_models(pattern);
247 for model in matches {
248 if self.model_suitable_for_tier(model, complexity_tier)
251 && !candidates.contains(&model)
252 {
253 candidates.push(model);
254 }
255 }
256 }
257
258 candidates.truncate(20);
260 candidates
261 }
262
263 fn model_suitable_for_tier(&self, model: &ModelEntry, tier: u8) -> bool {
265 match tier {
266 0 => {
268 !model.supports_reasoning() || model.cost_input < 0.5
270 }
271 1 => !model.supports_reasoning() || model.cost_input < 1.5,
273 2 => {
275 model.cost_input < 5.0 || model.supports_reasoning()
277 }
278 3 => {
280 model.supports_reasoning() || model.cost_input < 15.0
282 }
283 _ => {
285 model.supports_reasoning()
287 || model.context_window >= 200_000
288 || model.name.to_lowercase().contains("pro")
289 || model.name.to_lowercase().contains("opus")
290 }
291 }
292 }
293
294 fn sort_by_cost(&self, candidates: &mut [&'static ModelEntry]) {
296 candidates.sort_by(|a, b| {
297 let cost_a = a.cost_input + a.cost_output;
298 let cost_b = b.cost_input + b.cost_output;
299 cost_a
300 .partial_cmp(&cost_b)
301 .unwrap_or(std::cmp::Ordering::Equal)
302 });
303 }
304
305 fn sort_by_capability(&self, candidates: &mut [&'static ModelEntry]) {
307 candidates.sort_by(|a, b| {
308 let a_reasoning = if a.supports_reasoning() { 1 } else { 0 };
310 let b_reasoning = if b.supports_reasoning() { 1 } else { 0 };
311 if a_reasoning != b_reasoning {
312 return b_reasoning.cmp(&a_reasoning);
313 }
314
315 let a_context = a.context_window;
317 let b_context = b.context_window;
318 if a_context != b_context {
319 return b_context.cmp(&a_context);
320 }
321
322 let a_output = a.max_tokens;
324 let b_output = b.max_tokens;
325 if a_output != b_output {
326 return b_output.cmp(&a_output);
327 }
328
329 let cost_a = a.cost_input + a.cost_output;
331 let cost_b = b.cost_input + b.cost_output;
332 cost_a
333 .partial_cmp(&cost_b)
334 .unwrap_or(std::cmp::Ordering::Equal)
335 });
336 }
337}
338
339impl ComplexityRouter for DefaultRouter {
340 fn classify(&self, context: &Context) -> Complexity {
341 let last_user_text = self.get_last_user_message_text(context);
343
344 let Some(text) = last_user_text else {
345 let prompt_score = self.analyze_system_prompt(context.system_prompt.as_deref());
347 if !context.tools.is_empty() {
348 let bumped = (prompt_score + 1).min(4);
349 return self.score_to_complexity(bumped);
350 }
351 return self.score_to_complexity(prompt_score);
352 };
353
354 let token_count = self.count_tokens(&text);
356
357 let keyword_score = self.analyze_keywords(&text);
359
360 let base_score = if token_count < 100 {
363 keyword_score
366 } else if token_count > 2000 {
367 (keyword_score + 2).min(4)
369 } else if token_count > 500 {
370 (keyword_score + 1).min(4)
372 } else {
373 keyword_score
374 };
375
376 let system_score = self.analyze_system_prompt(context.system_prompt.as_deref());
378 let final_score = if system_score > base_score {
379 system_score
380 } else {
381 base_score
382 };
383
384 let final_score = if !context.tools.is_empty() {
386 (final_score + 1).min(4)
387 } else {
388 final_score
389 };
390
391 self.score_to_complexity(final_score)
392 }
393
394 fn route(
395 &self,
396 complexity: Complexity,
397 prefer_cost_efficient: bool,
398 ) -> Vec<&'static ModelEntry> {
399 let mut candidates = self.get_models_for_complexity(complexity);
401
402 let tier = complexity.cost_tier();
404 candidates.retain(|m| self.model_suitable_for_tier(m, tier));
405
406 if prefer_cost_efficient {
408 self.sort_by_cost(&mut candidates);
409 } else {
410 self.sort_by_capability(&mut candidates);
411 }
412
413 candidates.truncate(3);
415 candidates
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422 use crate::{Message, UserMessage};
423
424 fn create_context_with_user_message(text: &str) -> Context {
425 let mut ctx = Context::new();
426 ctx.add_message(Message::User(UserMessage::new(text.to_string())));
427 ctx
428 }
429
430 #[test]
431 fn test_trivial_keywords() {
432 let router = DefaultRouter::new();
433
434 let ctx = create_context_with_user_message("Please translate this to Spanish");
436 assert_eq!(router.classify(&ctx), Complexity::Trivial);
437
438 let ctx = create_context_with_user_message("Summarize this text for me");
439 assert_eq!(router.classify(&ctx), Complexity::Trivial);
440
441 let ctx = create_context_with_user_message("spell check this document");
443 assert_eq!(router.classify(&ctx), Complexity::Trivial);
444 }
445
446 #[test]
447 fn test_simple_keywords() {
448 let router = DefaultRouter::new();
449
450 let ctx = create_context_with_user_message("Explain how this code works");
451 assert_eq!(router.classify(&ctx), Complexity::Simple);
452
453 let ctx = create_context_with_user_message("Write a function to reverse a string");
454 assert_eq!(router.classify(&ctx), Complexity::Simple);
455
456 let ctx = create_context_with_user_message("List all files in the directory");
457 assert_eq!(router.classify(&ctx), Complexity::Simple);
458 }
459
460 #[test]
461 fn test_moderate_keywords() {
462 let router = DefaultRouter::new();
463
464 let ctx = create_context_with_user_message("Architect a REST API service");
465 assert_eq!(router.classify(&ctx), Complexity::Moderate);
466
467 let ctx = create_context_with_user_message("Design a database schema");
468 assert_eq!(router.classify(&ctx), Complexity::Moderate);
469
470 let ctx = create_context_with_user_message("Refactor this module");
471 assert_eq!(router.classify(&ctx), Complexity::Moderate);
472 }
473
474 #[test]
475 fn test_complex_keywords() {
476 let router = DefaultRouter::new();
477
478 let ctx = create_context_with_user_message(
480 "Build a complete microservices architecture with distributed tracing",
481 );
482 assert!(router.classify(&ctx) >= Complexity::Complex);
483
484 let ctx = create_context_with_user_message(
485 "Implement a full-stack application with authentication and database",
486 );
487 assert!(router.classify(&ctx) >= Complexity::Complex);
488 }
489
490 #[test]
491 fn test_research_keywords() {
492 let router = DefaultRouter::new();
493
494 let ctx = create_context_with_user_message(
495 "Analyze deeply the performance characteristics of this system",
496 );
497 assert_eq!(router.classify(&ctx), Complexity::Research);
498
499 let ctx = create_context_with_user_message(
500 "Conduct a comprehensive research study on machine learning",
501 );
502 assert_eq!(router.classify(&ctx), Complexity::Research);
503 }
504
505 #[test]
506 fn test_tools_bump_complexity() {
507 let router = DefaultRouter::new();
508
509 let mut ctx = create_context_with_user_message("List files");
510 assert_eq!(router.classify(&ctx), Complexity::Simple);
511
512 ctx.add_tool(crate::Tool::new(
514 "list_files",
515 "List files",
516 serde_json::json!({}),
517 ));
518 assert_eq!(router.classify(&ctx), Complexity::Moderate);
519 }
520
521 #[test]
522 fn test_token_count_affects_complexity() {
523 let router = DefaultRouter::new();
524
525 let ctx = create_context_with_user_message("a");
527 let complexity = router.classify(&ctx);
528 assert!(
529 complexity >= Complexity::Simple,
530 "Short text should be at least Simple, got {:?}",
531 complexity
532 );
533
534 let ctx = create_context_with_user_message("explain this");
536 let complexity = router.classify(&ctx);
537 assert_eq!(complexity, Complexity::Simple, "'explain' should be Simple");
538
539 let long_text = "Explain this code in detail. ".repeat(100);
543 let ctx = create_context_with_user_message(&long_text);
544 let complexity = router.classify(&ctx);
545 assert!(
546 complexity >= Complexity::Moderate,
547 "Long text should be at least Moderate, got {:?}",
548 complexity
549 );
550 }
551
552 #[test]
553 fn test_routing_trivial() {
554 let router = DefaultRouter::new();
555
556 let models = router.route(Complexity::Trivial, true);
557 assert!(!models.is_empty());
558 assert!(models.len() <= 3);
559 }
560
561 #[test]
562 fn test_routing_research() {
563 let router = DefaultRouter::new();
564
565 let models = router.route(Complexity::Research, false);
566 assert!(!models.is_empty());
567 assert!(models.len() <= 3);
568
569 for model in &models {
571 assert!(
572 model.supports_reasoning() || model.context_window >= 200_000,
573 "Model {} should support reasoning or have large context",
574 model.name
575 );
576 }
577 }
578
579 #[test]
580 fn test_cost_efficient_sorting() {
581 let router = DefaultRouter::new();
582
583 let models = router.route(Complexity::Moderate, true);
584
585 if models.len() > 1 {
586 for i in 1..models.len() {
588 let prev_cost = models[i - 1].cost_input + models[i - 1].cost_output;
589 let curr_cost = models[i].cost_input + models[i].cost_output;
590 assert!(
591 prev_cost <= curr_cost,
592 "Cost-efficient sorting failed: {:?} > {:?}",
593 prev_cost,
594 curr_cost
595 );
596 }
597 }
598 }
599
600 #[test]
601 fn test_capability_sorting() {
602 let router = DefaultRouter::new();
603
604 let models = router.route(Complexity::Complex, false);
605
606 if models.len() > 1 {
607 let any_reasoning = models.iter().any(|m| m.supports_reasoning());
609 if any_reasoning {
610 assert!(
611 models[0].supports_reasoning(),
612 "First model should support reasoning when sorting by capability"
613 );
614 }
615 }
616 }
617
618 #[test]
619 fn test_system_prompt_analysis() {
620 let router = DefaultRouter::new();
621
622 let mut ctx = Context::new();
623 ctx.set_system_prompt("You are a helpful assistant.");
624 ctx.add_message(Message::User(UserMessage::new("Hello")));
625
626 let complexity = router.classify(&ctx);
628 assert!(complexity <= Complexity::Simple);
629
630 let mut ctx = Context::new();
631 ctx.set_system_prompt(
632 "You are an expert senior software architect conducting thorough deep analysis.",
633 );
634 ctx.add_message(Message::User(UserMessage::new("Hello")));
635
636 let complexity = router.classify(&ctx);
638 assert!(complexity >= Complexity::Moderate);
639 }
640
641 #[test]
642 fn test_empty_context() {
643 let router = DefaultRouter::new();
644
645 let ctx = Context::new();
646 let complexity = router.classify(&ctx);
647 assert_eq!(complexity, Complexity::Trivial);
649 }
650
651 #[test]
652 fn test_default_router() {
653 let router = DefaultRouter::default();
654 let ctx = create_context_with_user_message("translate this text");
655 let complexity = router.classify(&ctx);
656 assert_eq!(complexity, Complexity::Trivial);
658 }
659
660 #[test]
661 fn test_complexity_trait_object() {
662 use std::sync::Arc;
663
664 let router: Arc<dyn ComplexityRouter> = Arc::new(DefaultRouter::new());
665 let ctx = create_context_with_user_message("refactor this code");
666 let complexity = router.classify(&ctx);
667 assert_eq!(complexity, Complexity::Moderate);
668
669 let models = router.route(complexity, true);
670 assert!(!models.is_empty());
671 }
672}