use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::{Duration, Instant};
use super::claude_integration::ClaudeModel;
use super::{AgentType, ClaudeFlowAgent, ClaudeFlowTask};
use crate::error::Result;
#[inline]
fn contains_ci(haystack: &[u8], needle: &[u8]) -> bool {
if needle.is_empty() {
return true;
}
if haystack.len() < needle.len() {
return false;
}
let first_lower = needle[0];
let first_upper = first_lower.to_ascii_uppercase();
for i in 0..=(haystack.len() - needle.len()) {
let c = haystack[i];
if c == first_lower || c == first_upper {
let mut matches = true;
for (j, &n) in needle.iter().enumerate().skip(1) {
let h = haystack[i + j];
if h != n && h != n.to_ascii_uppercase() {
matches = false;
break;
}
}
if matches {
return true;
}
}
}
false
}
#[derive(Debug, Clone, Default)]
pub struct ComplexityFactors {
pub token_estimate: usize,
pub reasoning_depth: f32,
pub domain_expertise: f32,
pub code_complexity: f32,
pub planning_complexity: f32,
pub security_sensitivity: f32,
pub performance_criticality: f32,
}
static DEFAULT_WEIGHTS: std::sync::LazyLock<ComplexityWeights> =
std::sync::LazyLock::new(ComplexityWeights::default);
impl ComplexityFactors {
#[inline]
pub fn weighted_score(&self) -> f32 {
let weights = &*DEFAULT_WEIGHTS;
let token_factor = match self.token_estimate {
0..=500 => 0.2,
501..=1000 => 0.4,
1001..=2000 => 0.6,
2001..=5000 => 0.8,
_ => 1.0,
};
let factors = [
self.reasoning_depth,
self.domain_expertise,
self.code_complexity,
self.planning_complexity,
self.security_sensitivity,
self.performance_criticality,
];
let weighted = (token_factor * weights.token_weight)
+ (self.reasoning_depth * weights.reasoning_weight)
+ (self.domain_expertise * weights.domain_weight)
+ (self.code_complexity * weights.code_weight)
+ (self.planning_complexity * weights.planning_weight)
+ (self.security_sensitivity * weights.security_weight)
+ (self.performance_criticality * weights.performance_weight);
let total_weight = weights.token_weight
+ weights.reasoning_weight
+ weights.domain_weight
+ weights.code_weight
+ weights.planning_weight
+ weights.security_weight
+ weights.performance_weight;
let avg = if total_weight > 0.0 {
weighted / total_weight
} else {
0.0
};
let mut sorted = factors;
sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let peak = (sorted[0] + sorted[1]) * 0.5;
(avg * 0.5 + peak * 0.5).clamp(0.0, 1.0)
}
}
#[derive(Debug, Clone)]
pub struct ComplexityWeights {
pub token_weight: f32,
pub reasoning_weight: f32,
pub domain_weight: f32,
pub code_weight: f32,
pub planning_weight: f32,
pub security_weight: f32,
pub performance_weight: f32,
}
impl Default for ComplexityWeights {
fn default() -> Self {
Self {
token_weight: 0.10,
reasoning_weight: 0.30,
domain_weight: 0.20,
code_weight: 0.10,
planning_weight: 0.10,
security_weight: 0.10,
performance_weight: 0.10,
}
}
}
#[derive(Debug, Clone)]
pub struct ComplexityScore {
pub overall: f32,
pub factors: ComplexityFactors,
pub recommended_tier: u8,
pub confidence: f32,
pub reasoning: String,
}
impl ComplexityScore {
#[inline]
pub fn recommended_model(&self) -> ClaudeModel {
match self.recommended_tier {
1 => ClaudeModel::Haiku,
2 => ClaudeModel::Sonnet,
_ => ClaudeModel::Opus,
}
}
#[inline]
pub fn is_simple(&self) -> bool {
self.overall < 0.35 && self.factors.token_estimate < 500
}
#[inline]
pub fn requires_opus(&self) -> bool {
self.overall > 0.7
|| self.factors.token_estimate > 2000
|| self.factors.security_sensitivity > 0.8
|| self.factors.reasoning_depth > 0.8
}
}
const HIGH_COMPLEXITY_PATTERNS: &[&str] = &[
"architecture",
"design pattern",
"distributed",
"concurrent",
"security audit",
"vulnerability",
"performance optimization",
"scalability",
"migration",
"refactor entire",
"redesign",
"multi-agent",
"complex algorithm",
"machine learning",
"cryptography",
];
const MODERATE_COMPLEXITY_PATTERNS: &[&str] = &[
"implement",
"create feature",
"add functionality",
"write tests",
"integration test",
"api endpoint",
"database query",
"refactor",
"debugging",
"error handling",
"validation",
];
const SIMPLE_PATTERNS: &[&str] = &[
"fix typo",
"rename",
"add comment",
"format",
"simple change",
"quick fix",
"update config",
"minor change",
"small update",
"add import",
"remove unused",
];
pub struct TaskComplexityAnalyzer {
pattern_weights: HashMap<String, f32>,
task_type_complexity: HashMap<String, f32>,
accuracy_history: Vec<AccuracyRecord>,
analysis_count: u64,
}
#[derive(Debug, Clone)]
struct AccuracyRecord {
predicted: f32,
actual: Option<f32>,
model: ClaudeModel,
timestamp: Instant,
}
impl TaskComplexityAnalyzer {
pub fn new() -> Self {
Self {
pattern_weights: Self::build_pattern_weights(),
task_type_complexity: Self::build_task_type_complexity(),
accuracy_history: Vec::new(),
analysis_count: 0,
}
}
fn build_pattern_weights() -> HashMap<String, f32> {
let mut weights = HashMap::new();
for pattern in HIGH_COMPLEXITY_PATTERNS {
weights.insert(pattern.to_string(), 0.9);
}
for pattern in MODERATE_COMPLEXITY_PATTERNS {
weights.insert(pattern.to_string(), 0.5);
}
for pattern in SIMPLE_PATTERNS {
weights.insert(pattern.to_string(), 0.2);
}
weights
}
fn build_task_type_complexity() -> HashMap<String, f32> {
let mut map = HashMap::new();
map.insert("CodeGeneration".to_string(), 0.5);
map.insert("CodeReview".to_string(), 0.6);
map.insert("Testing".to_string(), 0.4);
map.insert("Research".to_string(), 0.5);
map.insert("Documentation".to_string(), 0.3);
map.insert("Debugging".to_string(), 0.5);
map.insert("Refactoring".to_string(), 0.6);
map.insert("Security".to_string(), 0.8);
map.insert("Performance".to_string(), 0.7);
map.insert("Architecture".to_string(), 0.9);
map
}
pub fn analyze(&mut self, task: &str) -> ComplexityScore {
self.analysis_count += 1;
let lower_task = task.to_lowercase();
let token_estimate = self.estimate_tokens(task);
let reasoning_depth = self.analyze_reasoning_depth(&lower_task);
let domain_expertise = self.analyze_domain_expertise(&lower_task);
let code_complexity = self.analyze_code_complexity(&lower_task);
let planning_complexity = self.analyze_planning(&lower_task);
let security_sensitivity = self.analyze_security(&lower_task);
let performance_criticality = self.analyze_performance(&lower_task);
let factors = ComplexityFactors {
token_estimate,
reasoning_depth,
domain_expertise,
code_complexity,
planning_complexity,
security_sensitivity,
performance_criticality,
};
let overall = factors.weighted_score();
let recommended_tier = if overall < 0.35 && token_estimate < 500 {
1 } else if overall < 0.7 && token_estimate < 2000 {
2 } else {
3 };
let confidence = self.calculate_confidence(&lower_task);
let reasoning = self.generate_reasoning(&factors, recommended_tier);
ComplexityScore {
overall,
factors,
recommended_tier,
confidence,
reasoning,
}
}
#[inline]
fn estimate_tokens(&self, task: &str) -> usize {
let base_tokens = task.len() / 4;
let task_bytes = task.as_bytes();
let multiplier = if contains_ci(task_bytes, b"entire")
|| contains_ci(task_bytes, b"all")
|| contains_ci(task_bytes, b"comprehensive")
{
3.0
} else if contains_ci(task_bytes, b"full") || contains_ci(task_bytes, b"complete") {
2.5
} else if contains_ci(task_bytes, b"implement") || contains_ci(task_bytes, b"create") {
2.0
} else if contains_ci(task_bytes, b"fix") || contains_ci(task_bytes, b"update") {
1.2
} else {
1.5
};
let factor = if contains_ci(task_bytes, b"architecture")
|| contains_ci(task_bytes, b"design")
{
3.0
} else if contains_ci(task_bytes, b"test") {
1.5
} else if contains_ci(task_bytes, b"comment") || contains_ci(task_bytes, b"documentation") {
1.2
} else {
1.0
};
((base_tokens as f32 * multiplier * factor) as usize).max(100)
}
#[inline]
fn analyze_reasoning_depth(&self, task: &str) -> f32 {
let mut depth: f32 = 0.3;
if task.contains("why") || task.contains("explain") || task.contains("analyze") {
depth += 0.2;
}
if task.contains("trade-off") || task.contains("compare") || task.contains("evaluate") {
depth += 0.2;
}
if task.contains("design") || task.contains("architect") || task.contains("pattern") {
depth += 0.3;
}
if task.contains("debug") || task.contains("investigate") || task.contains("root cause") {
depth += 0.2;
}
if task.contains("distributed") || task.contains("concurrent") || task.contains("parallel")
{
depth += 0.3;
}
depth.min(1.0_f32)
}
#[inline]
fn analyze_domain_expertise(&self, task: &str) -> f32 {
let mut expertise: f32 = 0.2;
if task.contains("database") || task.contains("sql") || task.contains("query") {
expertise += 0.2;
}
if task.contains("network")
|| task.contains("protocol")
|| task.contains("http")
|| task.contains("rest")
|| task.contains("api")
|| task.contains("endpoint")
{
expertise += 0.2;
}
if task.contains("security") || task.contains("crypto") || task.contains("auth") {
expertise += 0.3;
}
if task.contains("ml") || task.contains("machine learning") || task.contains("model") {
expertise += 0.3;
}
if task.contains("system") || task.contains("kernel") || task.contains("low-level") {
expertise += 0.3;
}
expertise.min(1.0_f32)
}
#[inline]
fn analyze_code_complexity(&self, task: &str) -> f32 {
let mut complexity: f32 = 0.3;
if task.contains("algorithm") || task.contains("data structure") {
complexity += 0.3;
}
if task.contains("recursive") || task.contains("dynamic programming") {
complexity += 0.3;
}
if task.contains("async") || task.contains("concurrent") || task.contains("thread") {
complexity += 0.2;
}
if task.contains("generic") || task.contains("trait") || task.contains("interface") {
complexity += 0.1;
}
if task.contains("validation")
|| task.contains("validate")
|| task.contains("registration")
|| task.contains("error handling")
{
complexity += 0.2;
}
if task.contains("simple")
|| task.contains("basic")
|| task.contains("minor")
|| task.contains("typo")
{
complexity -= 0.2;
}
complexity.clamp(0.0_f32, 1.0_f32)
}
#[inline]
fn analyze_planning(&self, task: &str) -> f32 {
let mut planning: f32 = 0.2;
if task.contains("then") || task.contains("after") || task.contains("first") {
planning += 0.2;
}
if task.contains("workflow") || task.contains("pipeline") || task.contains("process") {
planning += 0.3;
}
if task.contains("migrate") || task.contains("upgrade") || task.contains("transition") {
planning += 0.3;
}
if task.contains("coordinate") || task.contains("orchestrate") {
planning += 0.2;
}
planning.min(1.0_f32)
}
#[inline]
fn analyze_security(&self, task: &str) -> f32 {
let mut sensitivity: f32 = 0.1;
if task.contains("security") || task.contains("secure") || task.contains("auth") {
sensitivity += 0.3;
}
if task.contains("vulnerability") || task.contains("cve") || task.contains("exploit") {
sensitivity += 0.4;
}
if task.contains("encrypt") || task.contains("decrypt") || task.contains("crypto") {
sensitivity += 0.3;
}
if task.contains("password") || task.contains("secret") || task.contains("key") {
sensitivity += 0.2;
}
if task.contains("injection") || task.contains("xss") || task.contains("csrf") {
sensitivity += 0.3;
}
sensitivity.min(1.0_f32)
}
#[inline]
fn analyze_performance(&self, task: &str) -> f32 {
let mut criticality: f32 = 0.1;
if task.contains("performance") || task.contains("optimize") || task.contains("speed") {
criticality += 0.3;
}
if task.contains("benchmark") || task.contains("profile") || task.contains("latency") {
criticality += 0.2;
}
if task.contains("memory") || task.contains("cache") || task.contains("efficient") {
criticality += 0.2;
}
if task.contains("scale") || task.contains("throughput") || task.contains("concurrent") {
criticality += 0.2;
}
criticality.min(1.0_f32)
}
fn calculate_confidence(&self, task: &str) -> f32 {
let mut matches = 0;
let total_patterns = self.pattern_weights.len();
for pattern in self.pattern_weights.keys() {
if task.contains(pattern) {
matches += 1;
}
}
let pattern_confidence = if matches > 0 {
0.5 + (matches as f32 / total_patterns as f32) * 0.4
} else {
0.4
};
let length_factor = if task.len() > 100 {
1.0
} else if task.len() > 50 {
0.9
} else {
0.7
};
(pattern_confidence * length_factor).min(0.95)
}
fn generate_reasoning(&self, factors: &ComplexityFactors, tier: u8) -> String {
let model = match tier {
1 => "Haiku",
2 => "Sonnet",
_ => "Opus",
};
let mut reasons = Vec::new();
if factors.token_estimate < 500 {
reasons.push(format!("low token estimate (~{})", factors.token_estimate));
} else if factors.token_estimate > 2000 {
reasons.push(format!("high token estimate (~{})", factors.token_estimate));
}
if factors.reasoning_depth > 0.7 {
reasons.push("deep reasoning required".to_string());
}
if factors.security_sensitivity > 0.7 {
reasons.push("security-sensitive task".to_string());
}
if factors.code_complexity > 0.7 {
reasons.push("complex code patterns".to_string());
}
if reasons.is_empty() {
reasons.push("balanced complexity factors".to_string());
}
format!("Recommended {} due to: {}", model, reasons.join(", "))
}
pub fn record_feedback(&mut self, predicted: f32, actual: f32, model: ClaudeModel) {
self.accuracy_history.push(AccuracyRecord {
predicted,
actual: Some(actual),
model,
timestamp: Instant::now(),
});
if self.accuracy_history.len() > 1000 {
self.accuracy_history.remove(0);
}
}
pub fn calibration_bias(&self) -> f32 {
let with_feedback: Vec<_> = self
.accuracy_history
.iter()
.filter(|r| r.actual.is_some())
.collect();
if with_feedback.is_empty() {
return 0.0;
}
let diffs: Vec<f32> = with_feedback
.iter()
.map(|r| r.predicted - r.actual.unwrap())
.filter(|v| v.is_finite())
.collect();
if diffs.is_empty() {
return 0.0;
}
let sum: f32 = diffs.iter().sum();
sum / diffs.len() as f32
}
pub fn accuracy_stats(&self) -> AnalyzerStats {
let with_feedback: Vec<_> = self
.accuracy_history
.iter()
.filter(|r| r.actual.is_some())
.collect();
if with_feedback.is_empty() {
return AnalyzerStats::default();
}
let total_error: f32 = with_feedback
.iter()
.map(|r| (r.predicted - r.actual.unwrap()).abs())
.sum();
let avg_error = total_error / with_feedback.len() as f32;
AnalyzerStats {
total_analyses: self.analysis_count,
feedback_count: with_feedback.len(),
average_error: avg_error,
accuracy: 1.0 - avg_error,
}
}
}
impl Default for TaskComplexityAnalyzer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct AnalyzerStats {
pub total_analyses: u64,
pub feedback_count: usize,
pub average_error: f32,
pub accuracy: f32,
}
#[derive(Debug, Clone)]
pub struct SelectionCriteria {
pub prefer_cost: bool,
pub prefer_latency: bool,
pub min_quality: f32,
pub max_cost: Option<f64>,
pub max_latency: Option<u64>,
}
impl Default for SelectionCriteria {
fn default() -> Self {
Self {
prefer_cost: false,
prefer_latency: false,
min_quality: 0.6,
max_cost: None,
max_latency: None,
}
}
}
#[derive(Debug, Clone)]
pub struct ModelRoutingDecision {
pub model: ClaudeModel,
pub complexity_score: ComplexityScore,
pub estimated_cost: f64,
pub estimated_latency: u64,
pub confidence: f32,
pub reasoning: String,
pub alternatives: Vec<(ClaudeModel, String)>,
}
pub struct ModelSelector {
analyzer: TaskComplexityAnalyzer,
criteria: SelectionCriteria,
selection_history: Vec<SelectionRecord>,
total_selections: u64,
}
#[derive(Debug, Clone)]
struct SelectionRecord {
model: ClaudeModel,
complexity: f32,
success: Option<bool>,
timestamp: Instant,
}
impl ModelSelector {
pub fn new(criteria: SelectionCriteria) -> Self {
Self {
analyzer: TaskComplexityAnalyzer::new(),
criteria,
selection_history: Vec::new(),
total_selections: 0,
}
}
pub fn select_model(&mut self, task: &str) -> ModelRoutingDecision {
self.total_selections += 1;
let complexity_score = self.analyzer.analyze(task);
let base_model = complexity_score.recommended_model();
let model = self.apply_criteria(&complexity_score, base_model);
let estimated_tokens = complexity_score.factors.token_estimate;
let estimated_cost = self.estimate_cost(model, estimated_tokens);
let estimated_latency = self.estimate_latency(model, estimated_tokens);
let alternatives = self.generate_alternatives(model, &complexity_score);
self.selection_history.push(SelectionRecord {
model,
complexity: complexity_score.overall,
success: None,
timestamp: Instant::now(),
});
if self.selection_history.len() > 1000 {
self.selection_history.remove(0);
}
ModelRoutingDecision {
model,
complexity_score: complexity_score.clone(),
estimated_cost,
estimated_latency,
confidence: complexity_score.confidence,
reasoning: complexity_score.reasoning.clone(),
alternatives,
}
}
fn apply_criteria(&self, score: &ComplexityScore, base_model: ClaudeModel) -> ClaudeModel {
let mut model = base_model;
if let Some(max_cost) = self.criteria.max_cost {
let estimated_cost = self.estimate_cost(model, score.factors.token_estimate);
if estimated_cost > max_cost {
model = match model {
ClaudeModel::Opus => ClaudeModel::Sonnet,
ClaudeModel::Sonnet => ClaudeModel::Haiku,
ClaudeModel::Haiku => ClaudeModel::Haiku,
};
}
}
if let Some(max_latency) = self.criteria.max_latency {
let estimated_latency = self.estimate_latency(model, score.factors.token_estimate);
if estimated_latency > max_latency {
model = match model {
ClaudeModel::Opus => ClaudeModel::Sonnet,
ClaudeModel::Sonnet => ClaudeModel::Haiku,
ClaudeModel::Haiku => ClaudeModel::Haiku,
};
}
}
if self.criteria.prefer_cost && score.overall < 0.5 {
model = match model {
ClaudeModel::Opus => ClaudeModel::Sonnet,
ClaudeModel::Sonnet if score.is_simple() => ClaudeModel::Haiku,
_ => model,
};
}
if self.criteria.prefer_latency && score.overall < 0.6 {
model = match model {
ClaudeModel::Opus => ClaudeModel::Sonnet,
ClaudeModel::Sonnet if score.is_simple() => ClaudeModel::Haiku,
_ => model,
};
}
if score.requires_opus() && model != ClaudeModel::Opus {
model = ClaudeModel::Opus;
}
model
}
#[inline]
fn estimate_cost(&self, model: ClaudeModel, token_estimate: usize) -> f64 {
let input_tokens = token_estimate as f64;
let output_tokens = input_tokens * 1.5;
let input_cost = (input_tokens * model.input_cost_per_1k()) / 1000.0;
let output_cost = (output_tokens * model.output_cost_per_1k()) / 1000.0;
input_cost + output_cost
}
#[inline]
fn estimate_latency(&self, model: ClaudeModel, token_estimate: usize) -> u64 {
let base_ttft = model.typical_ttft_ms();
let tokens_per_second = match model {
ClaudeModel::Haiku => 200.0,
ClaudeModel::Sonnet => 100.0,
ClaudeModel::Opus => 50.0,
};
let generation_time = (token_estimate as f64 / tokens_per_second * 1000.0) as u64;
base_ttft + generation_time
}
fn generate_alternatives(
&self,
selected: ClaudeModel,
score: &ComplexityScore,
) -> Vec<(ClaudeModel, String)> {
let mut alternatives = Vec::new();
match selected {
ClaudeModel::Haiku => {
alternatives.push((
ClaudeModel::Sonnet,
"For better quality if needed".to_string(),
));
}
ClaudeModel::Sonnet => {
if score.is_simple() {
alternatives.push((
ClaudeModel::Haiku,
"For cost savings on simple task".to_string(),
));
}
if score.factors.reasoning_depth > 0.5 {
alternatives.push((
ClaudeModel::Opus,
"For deeper reasoning if quality insufficient".to_string(),
));
}
}
ClaudeModel::Opus => {
if !score.requires_opus() {
alternatives.push((
ClaudeModel::Sonnet,
"May suffice for cost savings".to_string(),
));
}
}
}
alternatives
}
pub fn record_outcome(&mut self, success: bool) {
if let Some(record) = self.selection_history.last_mut() {
record.success = Some(success);
}
}
pub fn stats(&self) -> SelectorStats {
let with_outcome: Vec<_> = self
.selection_history
.iter()
.filter(|r| r.success.is_some())
.collect();
let success_count = with_outcome
.iter()
.filter(|r| r.success == Some(true))
.count();
let success_rate = if !with_outcome.is_empty() {
success_count as f32 / with_outcome.len() as f32
} else {
0.0
};
let mut by_model: HashMap<ClaudeModel, usize> = HashMap::new();
for record in &self.selection_history {
*by_model.entry(record.model).or_insert(0) += 1;
}
SelectorStats {
total_selections: self.total_selections,
feedback_count: with_outcome.len(),
success_rate,
selections_by_model: by_model,
analyzer_stats: self.analyzer.accuracy_stats(),
}
}
pub fn set_criteria(&mut self, criteria: SelectionCriteria) {
self.criteria = criteria;
}
pub fn criteria(&self) -> &SelectionCriteria {
&self.criteria
}
}
impl Default for ModelSelector {
fn default() -> Self {
Self::new(SelectionCriteria::default())
}
}
#[derive(Debug, Clone)]
pub struct SelectorStats {
pub total_selections: u64,
pub feedback_count: usize,
pub success_rate: f32,
pub selections_by_model: HashMap<ClaudeModel, usize>,
pub analyzer_stats: AnalyzerStats,
}
pub struct ModelRouter {
selector: ModelSelector,
agent_overrides: HashMap<AgentType, ClaudeModel>,
task_overrides: HashMap<ClaudeFlowTask, ClaudeModel>,
}
impl ModelRouter {
pub fn new() -> Self {
Self {
selector: ModelSelector::default(),
agent_overrides: Self::default_agent_overrides(),
task_overrides: Self::default_task_overrides(),
}
}
pub fn with_criteria(criteria: SelectionCriteria) -> Self {
Self {
selector: ModelSelector::new(criteria),
agent_overrides: Self::default_agent_overrides(),
task_overrides: Self::default_task_overrides(),
}
}
fn default_agent_overrides() -> HashMap<AgentType, ClaudeModel> {
let mut map = HashMap::new();
map.insert(AgentType::Security, ClaudeModel::Opus);
map.insert(AgentType::Reviewer, ClaudeModel::Sonnet);
map
}
fn default_task_overrides() -> HashMap<ClaudeFlowTask, ClaudeModel> {
let mut map = HashMap::new();
map.insert(ClaudeFlowTask::Architecture, ClaudeModel::Opus);
map.insert(ClaudeFlowTask::Security, ClaudeModel::Opus);
map.insert(ClaudeFlowTask::Documentation, ClaudeModel::Haiku);
map
}
pub fn route(
&mut self,
task: &str,
agent_type: Option<AgentType>,
task_type: Option<ClaudeFlowTask>,
) -> ModelRoutingDecision {
if let Some(agent) = agent_type {
if let Some(&model) = self.agent_overrides.get(&agent) {
let mut decision = self.selector.select_model(task);
decision.model = model;
decision.reasoning =
format!("Agent type {:?} override: {}", agent, decision.reasoning);
return decision;
}
}
if let Some(task_t) = task_type {
if let Some(&model) = self.task_overrides.get(&task_t) {
let mut decision = self.selector.select_model(task);
decision.model = model;
decision.reasoning =
format!("Task type {:?} override: {}", task_t, decision.reasoning);
return decision;
}
}
self.selector.select_model(task)
}
pub fn set_agent_override(&mut self, agent: AgentType, model: ClaudeModel) {
self.agent_overrides.insert(agent, model);
}
pub fn remove_agent_override(&mut self, agent: AgentType) {
self.agent_overrides.remove(&agent);
}
pub fn set_task_override(&mut self, task: ClaudeFlowTask, model: ClaudeModel) {
self.task_overrides.insert(task, model);
}
pub fn remove_task_override(&mut self, task: ClaudeFlowTask) {
self.task_overrides.remove(&task);
}
pub fn record_outcome(&mut self, success: bool) {
self.selector.record_outcome(success);
}
pub fn stats(&self) -> SelectorStats {
self.selector.stats()
}
pub fn set_criteria(&mut self, criteria: SelectionCriteria) {
self.selector.set_criteria(criteria);
}
}
impl Default for ModelRouter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_complexity_analyzer_simple_task() {
let mut analyzer = TaskComplexityAnalyzer::new();
let score = analyzer.analyze("fix typo in readme");
assert!(score.overall < 0.5);
assert!(score.is_simple());
assert_eq!(score.recommended_tier, 1); }
#[test]
fn test_complexity_analyzer_complex_task() {
let mut analyzer = TaskComplexityAnalyzer::new();
let score = analyzer.analyze(
"Design and implement a distributed authentication system with OAuth2, JWT tokens, \
and comprehensive security audit for vulnerabilities",
);
assert!(score.overall > 0.7);
assert!(score.requires_opus());
assert_eq!(score.recommended_tier, 3); }
#[test]
fn test_complexity_analyzer_moderate_task() {
let mut analyzer = TaskComplexityAnalyzer::new();
let score = analyzer
.analyze("Implement a REST API endpoint for user registration with input validation");
assert!(score.overall >= 0.35);
assert!(score.overall < 0.7);
assert_eq!(score.recommended_tier, 2); }
#[test]
fn test_model_selector() {
let mut selector = ModelSelector::default();
let decision = selector.select_model("rename variable x to count");
assert_eq!(decision.model, ClaudeModel::Haiku);
let decision = selector.select_model(
"Design microservices architecture with distributed tracing and security audit",
);
assert_eq!(decision.model, ClaudeModel::Opus);
}
#[test]
fn test_model_selector_cost_preference() {
let criteria = SelectionCriteria {
prefer_cost: true,
..Default::default()
};
let mut selector = ModelSelector::new(criteria);
let decision = selector.select_model("write a simple unit test");
assert_eq!(decision.model, ClaudeModel::Haiku);
}
#[test]
fn test_model_router_overrides() {
let mut router = ModelRouter::new();
let decision = router.route("fix a bug", Some(AgentType::Security), None);
assert_eq!(decision.model, ClaudeModel::Opus);
let decision = router.route("update config", None, Some(ClaudeFlowTask::Architecture));
assert_eq!(decision.model, ClaudeModel::Opus);
}
#[test]
fn test_complexity_factors_weighted_score() {
let factors = ComplexityFactors {
token_estimate: 2500,
reasoning_depth: 0.8,
domain_expertise: 0.5,
code_complexity: 0.6,
planning_complexity: 0.7,
security_sensitivity: 0.9,
performance_criticality: 0.3,
};
let score = factors.weighted_score();
assert!(score > 0.5); assert!(score <= 1.0);
}
#[test]
fn test_cost_estimation() {
let selector = ModelSelector::default();
let haiku_cost = selector.estimate_cost(ClaudeModel::Haiku, 1000);
let sonnet_cost = selector.estimate_cost(ClaudeModel::Sonnet, 1000);
let opus_cost = selector.estimate_cost(ClaudeModel::Opus, 1000);
assert!(haiku_cost < sonnet_cost);
assert!(sonnet_cost < opus_cost);
}
#[test]
fn test_latency_estimation() {
let selector = ModelSelector::default();
let haiku_latency = selector.estimate_latency(ClaudeModel::Haiku, 500);
let sonnet_latency = selector.estimate_latency(ClaudeModel::Sonnet, 500);
let opus_latency = selector.estimate_latency(ClaudeModel::Opus, 500);
assert!(haiku_latency < sonnet_latency);
assert!(sonnet_latency < opus_latency);
}
}