Skip to main content

oxi_ai/
complexity_router.rs

1//! Complexity-based model routing for oxi-ai
2//!
3//! This module provides a router that classifies task complexity
4//! and selects appropriate models based on capability and cost requirements.
5
6use crate::model_db::{self, ModelEntry};
7use crate::{Complexity, Context, Message, MessageContent};
8
9/// Routes tasks to models based on estimated complexity.
10pub trait ComplexityRouter: Send + Sync {
11    /// Classify the complexity of the given context.
12    fn classify(&self, context: &Context) -> Complexity;
13
14    /// Pick the best models for a given complexity.
15    fn route(
16        &self,
17        complexity: Complexity,
18        prefer_cost_efficient: bool,
19    ) -> Vec<&'static ModelEntry>;
20}
21
22/// Default implementation of ComplexityRouter.
23///
24/// Uses keyword analysis, token counting, and system prompt hints
25/// to determine task complexity, then routes to appropriate models.
26#[derive(Debug, Clone, Default)]
27pub struct DefaultRouter {
28    _private: (),
29}
30
31impl DefaultRouter {
32    /// Create a new DefaultRouter
33    pub fn new() -> Self {
34        Self { _private: () }
35    }
36
37    /// Extract text from MessageContent
38    fn extract_content_text(&self, content: &MessageContent) -> String {
39        match content {
40            MessageContent::Text(s) => s.clone(),
41            MessageContent::Blocks(blocks) => blocks
42                .iter()
43                .filter_map(|b| b.as_text())
44                .collect::<Vec<_>>()
45                .join(" "),
46        }
47    }
48
49    /// Get the last user message text from the context
50    fn get_last_user_message_text(&self, context: &Context) -> Option<String> {
51        context.messages.iter().rev().find_map(|msg| {
52            if let Message::User(user_msg) = msg {
53                let text = self.extract_content_text(&user_msg.content);
54                if !text.is_empty() {
55                    Some(text)
56                } else {
57                    None
58                }
59            } else {
60                None
61            }
62        })
63    }
64
65    /// Count tokens in text using the high-level estimator
66    fn count_tokens(&self, text: &str) -> usize {
67        crate::high_level::tokens::estimate(text)
68    }
69
70    /// Analyze text for complexity keywords and return a base complexity score
71    /// Returns 0-4 mapping to Complexity tiers (0=Trivial, 1=Simple, 2=Moderate, etc.)
72    fn analyze_keywords(&self, text: &str) -> i32 {
73        let lower = text.to_lowercase();
74
75        // Check for complex/research keywords FIRST (more specific patterns)
76        // Complex keywords: score 3
77        let complex_keywords = [
78            "build a",
79            "build the",
80            "create a service",
81            "write a full",
82            "implement a complete",
83            "implement a full",
84            "microservice",
85            "distributed system",
86            "concurrent",
87            "parallel processing",
88            "full-stack",
89            "full stack",
90            "end-to-end",
91            "enterprise",
92            "complete application",
93            "complete system",
94        ];
95        let has_complex = complex_keywords.iter().any(|kw| lower.contains(*kw));
96
97        // Research keywords: score 4
98        let research_keywords = [
99            "analyze deeply",
100            "research",
101            "evaluate thoroughly",
102            "investigate",
103            "compare and contrast",
104            "benchmark",
105            "comprehensive analysis",
106            "thorough",
107            "in-depth",
108            "deep research",
109            "study of",
110        ];
111        let has_research = research_keywords.iter().any(|kw| lower.contains(*kw));
112
113        // Moderate keywords: score 2
114        let moderate_keywords = [
115            "architect",
116            "design a",
117            "refactor",
118            "implement",
119            "create a class",
120            "optimize",
121            "debug",
122            "review code",
123            "parse",
124            "validate",
125            "schema",
126            "api",
127            "build a",
128        ];
129        let has_moderate = moderate_keywords.iter().any(|kw| lower.contains(*kw));
130
131        // Simple keywords: score 1
132        let simple_keywords = [
133            "explain",
134            "write function",
135            "fix typo",
136            "list",
137            "describe",
138            "define",
139            "convert",
140            "calculate",
141            "simple",
142        ];
143        let has_simple = simple_keywords.iter().any(|kw| lower.contains(*kw));
144
145        // Trivial keywords: score 0
146        let trivial_keywords = [
147            "translate",
148            "summarize",
149            "spell check",
150            "format",
151            "capitalize",
152            "lowercase",
153            "uppercase",
154            "trim",
155            "count words",
156        ];
157        let has_trivial = trivial_keywords.iter().any(|kw| lower.contains(*kw));
158
159        // Return the highest matching score (research > complex > moderate > simple > trivial)
160        if has_research {
161            4
162        } else if has_complex {
163            3
164        } else if has_moderate {
165            2
166        } else if has_simple {
167            1
168        } else if has_trivial {
169            0
170        } else {
171            1 // Default to simple
172        }
173    }
174
175    /// Analyze system prompt for complexity hints
176    fn analyze_system_prompt(&self, system_prompt: Option<&str>) -> i32 {
177        let Some(prompt) = system_prompt else {
178            return 0;
179        };
180
181        let lower = prompt.to_lowercase();
182
183        // System prompts with "research" or "deep analysis" suggest higher complexity
184        if lower.contains("research")
185            || lower.contains("deep analysis")
186            || lower.contains("thorough")
187        {
188            return 2;
189        }
190
191        // "helpful assistant" without specific guidance suggests trivial/simple
192        if lower.contains("helpful assistant")
193            && !lower.contains("expert")
194            && !lower.contains("advanced")
195        {
196            return 0;
197        }
198
199        // "expert" or "senior" suggests higher complexity
200        if lower.contains("expert")
201            || lower.contains("senior developer")
202            || lower.contains("architect")
203        {
204            return 1;
205        }
206
207        0
208    }
209
210    /// Convert keyword score to Complexity enum
211    /// Score maps directly to complexity tier (0=Trivial, 1=Simple, 2=Moderate, etc.)
212    fn score_to_complexity(&self, score: i32) -> Complexity {
213        match score {
214            0 => Complexity::Trivial,
215            1 => Complexity::Simple,
216            2 => Complexity::Moderate,
217            3 => Complexity::Complex,
218            _ => Complexity::Research,
219        }
220    }
221
222    /// Get models filtered by complexity tier
223    fn get_models_for_complexity(&self, complexity: Complexity) -> Vec<&'static ModelEntry> {
224        let complexity_tier = complexity.cost_tier();
225
226        // Model mapping per complexity level
227        // We search for models by name/id pattern
228        let patterns: Vec<&str> = match complexity {
229            Complexity::Trivial => vec!["haiku", "gpt-4o-mini", "mini"],
230            Complexity::Simple => vec!["haiku", "sonnet", "gpt-4o-mini", "mini"],
231            Complexity::Moderate => vec!["sonnet", "opus", "gpt-4o", "gpt-4.1"],
232            Complexity::Complex => vec!["opus", "gemini-2.5-pro", "gpt-4.1", "claude-sonnet"],
233            Complexity::Research => vec![
234                "opus-4.5",
235                "opus-4.6",
236                "gemini-3-pro",
237                "gemini-2.5-pro",
238                "claude-opus",
239            ],
240        };
241
242        // Collect matching models from the database
243        let mut candidates: Vec<&'static ModelEntry> = Vec::new();
244
245        for pattern in &patterns {
246            let matches = model_db::search_models(pattern);
247            for model in matches {
248                // Prefer models that have been updated more recently (higher version numbers)
249                // Also filter by relevance to complexity tier
250                if self.model_suitable_for_tier(model, complexity_tier)
251                    && !candidates.contains(&model)
252                {
253                    candidates.push(model);
254                }
255            }
256        }
257
258        // Deduplicate and limit
259        candidates.truncate(20);
260        candidates
261    }
262
263    /// Check if a model is suitable for a given complexity tier
264    fn model_suitable_for_tier(&self, model: &ModelEntry, tier: u8) -> bool {
265        match tier {
266            // Trivial: fast, cheap models
267            0 => {
268                // Prefer models without reasoning (faster, cheaper)
269                !model.supports_reasoning() || model.cost_input < 0.5
270            }
271            // Simple: moderate capability
272            1 => !model.supports_reasoning() || model.cost_input < 1.5,
273            // Moderate: good capability
274            2 => {
275                // Mid-range models
276                model.cost_input < 5.0 || model.supports_reasoning()
277            }
278            // Complex: high capability
279            3 => {
280                // High-end models
281                model.supports_reasoning() || model.cost_input < 15.0
282            }
283            // Research: top tier only
284            _ => {
285                // Best models for research
286                model.supports_reasoning()
287                    || model.context_window >= 200_000
288                    || model.name.to_lowercase().contains("pro")
289                    || model.name.to_lowercase().contains("opus")
290            }
291        }
292    }
293
294    /// Sort candidates by cost efficiency
295    fn sort_by_cost(&self, candidates: &mut [&'static ModelEntry]) {
296        candidates.sort_by(|a, b| {
297            let cost_a = a.cost_input + a.cost_output;
298            let cost_b = b.cost_input + b.cost_output;
299            cost_a
300                .partial_cmp(&cost_b)
301                .unwrap_or(std::cmp::Ordering::Equal)
302        });
303    }
304
305    /// Sort candidates by capability (reasoning > context_window > cost)
306    fn sort_by_capability(&self, candidates: &mut [&'static ModelEntry]) {
307        candidates.sort_by(|a, b| {
308            // Primary: reasoning capability
309            let a_reasoning = if a.supports_reasoning() { 1 } else { 0 };
310            let b_reasoning = if b.supports_reasoning() { 1 } else { 0 };
311            if a_reasoning != b_reasoning {
312                return b_reasoning.cmp(&a_reasoning);
313            }
314
315            // Secondary: context window size
316            let a_context = a.context_window;
317            let b_context = b.context_window;
318            if a_context != b_context {
319                return b_context.cmp(&a_context);
320            }
321
322            // Tertiary: output capability
323            let a_output = a.max_tokens;
324            let b_output = b.max_tokens;
325            if a_output != b_output {
326                return b_output.cmp(&a_output);
327            }
328
329            // Quaternary: cost (prefer cheaper for same capability)
330            let cost_a = a.cost_input + a.cost_output;
331            let cost_b = b.cost_input + b.cost_output;
332            cost_a
333                .partial_cmp(&cost_b)
334                .unwrap_or(std::cmp::Ordering::Equal)
335        });
336    }
337}
338
339impl ComplexityRouter for DefaultRouter {
340    fn classify(&self, context: &Context) -> Complexity {
341        // Get last user message text
342        let last_user_text = self.get_last_user_message_text(context);
343
344        let Some(text) = last_user_text else {
345            // No user message - check system prompt
346            let prompt_score = self.analyze_system_prompt(context.system_prompt.as_deref());
347            if !context.tools.is_empty() {
348                let bumped = (prompt_score + 1).min(4);
349                return self.score_to_complexity(bumped);
350            }
351            return self.score_to_complexity(prompt_score);
352        };
353
354        // Count tokens
355        let token_count = self.count_tokens(&text);
356
357        // Analyze keywords for complexity
358        let keyword_score = self.analyze_keywords(&text);
359
360        // Determine base score based on keywords and token count
361        // For short trivial inputs, don't bump (they're genuinely simple)
362        let base_score = if token_count < 100 {
363            // Short inputs: trust keyword detection
364            // But ensure trivial keywords don't get bumped
365            keyword_score
366        } else if token_count > 2000 {
367            // Very long inputs: increase complexity
368            (keyword_score + 2).min(4)
369        } else if token_count > 500 {
370            // Medium inputs: slightly increase complexity
371            (keyword_score + 1).min(4)
372        } else {
373            keyword_score
374        };
375
376        // Analyze system prompt for hints (can increase complexity)
377        let system_score = self.analyze_system_prompt(context.system_prompt.as_deref());
378        let final_score = if system_score > base_score {
379            system_score
380        } else {
381            base_score
382        };
383
384        // If context has tools, bump complexity by 1 (capped at Research)
385        let final_score = if !context.tools.is_empty() {
386            (final_score + 1).min(4)
387        } else {
388            final_score
389        };
390
391        self.score_to_complexity(final_score)
392    }
393
394    fn route(
395        &self,
396        complexity: Complexity,
397        prefer_cost_efficient: bool,
398    ) -> Vec<&'static ModelEntry> {
399        // Get candidates for this complexity
400        let mut candidates = self.get_models_for_complexity(complexity);
401
402        // Filter to models that support the complexity tier
403        let tier = complexity.cost_tier();
404        candidates.retain(|m| self.model_suitable_for_tier(m, tier));
405
406        // Sort based on preference
407        if prefer_cost_efficient {
408            self.sort_by_cost(&mut candidates);
409        } else {
410            self.sort_by_capability(&mut candidates);
411        }
412
413        // Return top 3 candidates
414        candidates.truncate(3);
415        candidates
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422    use crate::{Message, UserMessage};
423
424    fn create_context_with_user_message(text: &str) -> Context {
425        let mut ctx = Context::new();
426        ctx.add_message(Message::User(UserMessage::new(text.to_string())));
427        ctx
428    }
429
430    #[test]
431    fn test_trivial_keywords() {
432        let router = DefaultRouter::new();
433
434        // Trivial keywords should be detected
435        let ctx = create_context_with_user_message("Please translate this to Spanish");
436        assert_eq!(router.classify(&ctx), Complexity::Trivial);
437
438        let ctx = create_context_with_user_message("Summarize this text for me");
439        assert_eq!(router.classify(&ctx), Complexity::Trivial);
440
441        // "spell check" as a phrase should be trivial
442        let ctx = create_context_with_user_message("spell check this document");
443        assert_eq!(router.classify(&ctx), Complexity::Trivial);
444    }
445
446    #[test]
447    fn test_simple_keywords() {
448        let router = DefaultRouter::new();
449
450        let ctx = create_context_with_user_message("Explain how this code works");
451        assert_eq!(router.classify(&ctx), Complexity::Simple);
452
453        let ctx = create_context_with_user_message("Write a function to reverse a string");
454        assert_eq!(router.classify(&ctx), Complexity::Simple);
455
456        let ctx = create_context_with_user_message("List all files in the directory");
457        assert_eq!(router.classify(&ctx), Complexity::Simple);
458    }
459
460    #[test]
461    fn test_moderate_keywords() {
462        let router = DefaultRouter::new();
463
464        let ctx = create_context_with_user_message("Architect a REST API service");
465        assert_eq!(router.classify(&ctx), Complexity::Moderate);
466
467        let ctx = create_context_with_user_message("Design a database schema");
468        assert_eq!(router.classify(&ctx), Complexity::Moderate);
469
470        let ctx = create_context_with_user_message("Refactor this module");
471        assert_eq!(router.classify(&ctx), Complexity::Moderate);
472    }
473
474    #[test]
475    fn test_complex_keywords() {
476        let router = DefaultRouter::new();
477
478        // Complex keywords should be detected
479        let ctx = create_context_with_user_message(
480            "Build a complete microservices architecture with distributed tracing",
481        );
482        assert!(router.classify(&ctx) >= Complexity::Complex);
483
484        let ctx = create_context_with_user_message(
485            "Implement a full-stack application with authentication and database",
486        );
487        assert!(router.classify(&ctx) >= Complexity::Complex);
488    }
489
490    #[test]
491    fn test_research_keywords() {
492        let router = DefaultRouter::new();
493
494        let ctx = create_context_with_user_message(
495            "Analyze deeply the performance characteristics of this system",
496        );
497        assert_eq!(router.classify(&ctx), Complexity::Research);
498
499        let ctx = create_context_with_user_message(
500            "Conduct a comprehensive research study on machine learning",
501        );
502        assert_eq!(router.classify(&ctx), Complexity::Research);
503    }
504
505    #[test]
506    fn test_tools_bump_complexity() {
507        let router = DefaultRouter::new();
508
509        let mut ctx = create_context_with_user_message("List files");
510        assert_eq!(router.classify(&ctx), Complexity::Simple);
511
512        // Add a tool - should bump complexity
513        ctx.add_tool(crate::Tool::new(
514            "list_files",
515            "List files",
516            serde_json::json!({}),
517        ));
518        assert_eq!(router.classify(&ctx), Complexity::Moderate);
519    }
520
521    #[test]
522    fn test_token_count_affects_complexity() {
523        let router = DefaultRouter::new();
524
525        // Test single character - very short
526        let ctx = create_context_with_user_message("a");
527        let complexity = router.classify(&ctx);
528        assert!(
529            complexity >= Complexity::Simple,
530            "Short text should be at least Simple, got {:?}",
531            complexity
532        );
533
534        // Test "explain" keyword alone
535        let ctx = create_context_with_user_message("explain this");
536        let complexity = router.classify(&ctx);
537        assert_eq!(complexity, Complexity::Simple, "'explain' should be Simple");
538
539        // Very long text should increase complexity (use enough to exceed 500 tokens)
540        // "Explain this code in detail. " is ~7 words, ~10 chars, ~13 tokens
541        // Need ~40+ repetitions to exceed 500 tokens
542        let long_text = "Explain this code in detail. ".repeat(100);
543        let ctx = create_context_with_user_message(&long_text);
544        let complexity = router.classify(&ctx);
545        assert!(
546            complexity >= Complexity::Moderate,
547            "Long text should be at least Moderate, got {:?}",
548            complexity
549        );
550    }
551
552    #[test]
553    fn test_routing_trivial() {
554        let router = DefaultRouter::new();
555
556        let models = router.route(Complexity::Trivial, true);
557        assert!(!models.is_empty());
558        assert!(models.len() <= 3);
559    }
560
561    #[test]
562    fn test_routing_research() {
563        let router = DefaultRouter::new();
564
565        let models = router.route(Complexity::Research, false);
566        assert!(!models.is_empty());
567        assert!(models.len() <= 3);
568
569        // Research models should support reasoning
570        for model in &models {
571            assert!(
572                model.supports_reasoning() || model.context_window >= 200_000,
573                "Model {} should support reasoning or have large context",
574                model.name
575            );
576        }
577    }
578
579    #[test]
580    fn test_cost_efficient_sorting() {
581        let router = DefaultRouter::new();
582
583        let models = router.route(Complexity::Moderate, true);
584
585        if models.len() > 1 {
586            // Verify cost sorting
587            for i in 1..models.len() {
588                let prev_cost = models[i - 1].cost_input + models[i - 1].cost_output;
589                let curr_cost = models[i].cost_input + models[i].cost_output;
590                assert!(
591                    prev_cost <= curr_cost,
592                    "Cost-efficient sorting failed: {:?} > {:?}",
593                    prev_cost,
594                    curr_cost
595                );
596            }
597        }
598    }
599
600    #[test]
601    fn test_capability_sorting() {
602        let router = DefaultRouter::new();
603
604        let models = router.route(Complexity::Complex, false);
605
606        if models.len() > 1 {
607            // First model should have reasoning if any do
608            let any_reasoning = models.iter().any(|m| m.supports_reasoning());
609            if any_reasoning {
610                assert!(
611                    models[0].supports_reasoning(),
612                    "First model should support reasoning when sorting by capability"
613                );
614            }
615        }
616    }
617
618    #[test]
619    fn test_system_prompt_analysis() {
620        let router = DefaultRouter::new();
621
622        let mut ctx = Context::new();
623        ctx.set_system_prompt("You are a helpful assistant.");
624        ctx.add_message(Message::User(UserMessage::new("Hello")));
625
626        // Simple system prompt should not increase complexity
627        let complexity = router.classify(&ctx);
628        assert!(complexity <= Complexity::Simple);
629
630        let mut ctx = Context::new();
631        ctx.set_system_prompt(
632            "You are an expert senior software architect conducting thorough deep analysis.",
633        );
634        ctx.add_message(Message::User(UserMessage::new("Hello")));
635
636        // Expert system prompt should increase complexity
637        let complexity = router.classify(&ctx);
638        assert!(complexity >= Complexity::Moderate);
639    }
640
641    #[test]
642    fn test_empty_context() {
643        let router = DefaultRouter::new();
644
645        let ctx = Context::new();
646        let complexity = router.classify(&ctx);
647        // Empty context defaults to Trivial
648        assert_eq!(complexity, Complexity::Trivial);
649    }
650
651    #[test]
652    fn test_default_router() {
653        let router = DefaultRouter::default();
654        let ctx = create_context_with_user_message("translate this text");
655        let complexity = router.classify(&ctx);
656        // "translate" is a trivial keyword
657        assert_eq!(complexity, Complexity::Trivial);
658    }
659
660    #[test]
661    fn test_complexity_trait_object() {
662        use std::sync::Arc;
663
664        let router: Arc<dyn ComplexityRouter> = Arc::new(DefaultRouter::new());
665        let ctx = create_context_with_user_message("refactor this code");
666        let complexity = router.classify(&ctx);
667        assert_eq!(complexity, Complexity::Moderate);
668
669        let models = router.route(complexity, true);
670        assert!(!models.is_empty());
671    }
672}