use super::{CostTier, ModelPolicy};
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct ModelRouter {
policy: ModelPolicy,
large_model_threshold: usize,
medium_model_threshold: usize,
enabled: bool,
}
impl Default for ModelRouter {
fn default() -> Self {
Self::new()
}
}
impl ModelRouter {
pub fn new() -> Self {
Self {
policy: ModelPolicy::default(),
large_model_threshold: 4000,
medium_model_threshold: 1000,
enabled: true,
}
}
pub fn with_policy(policy: ModelPolicy) -> Self {
Self {
policy,
..Default::default()
}
}
pub fn with_thresholds(mut self, medium: usize, large: usize) -> Self {
self.medium_model_threshold = medium;
self.large_model_threshold = large;
self
}
pub fn policy(&self) -> &ModelPolicy {
&self.policy
}
pub fn enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
pub fn route(&self, task: &TaskAnalysis) -> RoutingDecision {
if !self.enabled {
return RoutingDecision {
model: self.policy.medium.clone(),
tier: CostTier::Medium,
reason: "Smart routing disabled".to_string(),
};
}
let tier = self.analyze_complexity(task);
let tier = self.apply_constraints(tier, task);
let model = self.policy.model_for_tier(tier).to_string();
let reason = self.explain_routing(task, tier);
RoutingDecision {
model,
tier,
reason,
}
}
fn analyze_complexity(&self, task: &TaskAnalysis) -> CostTier {
let mut score = 0;
if task.estimated_tokens > self.large_model_threshold {
score += 3;
} else if task.estimated_tokens > self.medium_model_threshold {
score += 2;
} else {
score += 1;
}
if task.requires_reasoning {
score += 2;
}
if task.requires_code_generation {
score += 2;
}
if task.requires_structured_output {
score += 1;
}
if task.multi_step {
score += 1;
}
match score {
0..=2 => CostTier::Low,
3..=5 => CostTier::Medium,
_ => CostTier::High,
}
}
fn apply_constraints(&self, tier: CostTier, task: &TaskAnalysis) -> CostTier {
let tier = match (tier, self.policy.max_cost_tier) {
(CostTier::High, CostTier::Low) => CostTier::Low,
(CostTier::High, CostTier::Medium) => CostTier::Medium,
(CostTier::Medium, CostTier::Low) => CostTier::Low,
_ => tier,
};
if let Some(max_latency) = self.policy.max_latency_ms {
let estimated_latency = self.estimate_latency(tier, task);
if estimated_latency > max_latency {
return match tier {
CostTier::High => CostTier::Medium,
CostTier::Medium => CostTier::Low,
CostTier::Low => CostTier::Low,
};
}
}
if tier == CostTier::High && !self.policy.allow_large {
return CostTier::Medium;
}
tier
}
fn estimate_latency(&self, tier: CostTier, task: &TaskAnalysis) -> u64 {
let base_latency = match tier {
CostTier::Low => 500,
CostTier::Medium => 1500,
CostTier::High => 3000,
};
let token_latency = (task.estimated_tokens as u64 / 100) * 50;
base_latency + token_latency
}
fn explain_routing(&self, task: &TaskAnalysis, tier: CostTier) -> String {
let mut reasons = Vec::new();
if task.estimated_tokens > self.large_model_threshold {
reasons.push("high token count");
}
if task.requires_reasoning {
reasons.push("requires reasoning");
}
if task.requires_code_generation {
reasons.push("requires code generation");
}
if reasons.is_empty() {
reasons.push("standard task");
}
format!("{:?} tier selected: {}", tier, reasons.join(", "))
}
pub fn route_simple(&self, prompt: &str) -> RoutingDecision {
let task = TaskAnalysis::from_prompt(prompt);
self.route(&task)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ModelRequirements {
pub vision: bool,
pub audio: bool,
pub video: bool,
pub pdf: bool,
pub min_context_tokens: u32,
pub max_input_cost_per_m: f32,
pub min_arena_score: f32,
}
impl ModelRequirements {
pub fn with_vision(mut self) -> Self {
self.vision = true;
self
}
pub fn with_min_context(mut self, tokens: u32) -> Self {
self.min_context_tokens = tokens;
self
}
pub fn with_max_cost(mut self, cost: f32) -> Self {
self.max_input_cost_per_m = cost;
self
}
pub fn with_min_arena(mut self, score: f32) -> Self {
self.min_arena_score = score;
self
}
}
#[derive(Debug, Clone)]
pub struct ScoredModel {
pub name: String,
pub score: f32,
pub arena_rank: Option<f32>,
pub input_cost_per_m: Option<f32>,
pub supports_vision: bool,
pub max_input_tokens: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum SelectionStrategy {
#[default]
BestQuality,
CheapestFirst,
LargestContext,
ValueOptimal,
}
#[derive(Debug, Clone)]
pub struct ModelSelector {
models: Vec<(String, Option<f32>)>,
strategy: SelectionStrategy,
}
impl ModelSelector {
pub fn new(models: &[&str]) -> Self {
Self {
models: models.iter().map(|m| (m.to_lowercase(), None)).collect(),
strategy: SelectionStrategy::default(),
}
}
pub fn from_owned(models: Vec<String>) -> Self {
Self {
models: models
.into_iter()
.map(|m| (m.to_lowercase(), None))
.collect(),
strategy: SelectionStrategy::default(),
}
}
pub fn set_strategy(&mut self, strategy: SelectionStrategy) {
self.strategy = strategy;
}
pub fn set_priority(&mut self, model: &str, priority: f32) {
let lower = model.to_lowercase();
for (name, prio) in &mut self.models {
if *name == lower {
*prio = Some(priority);
return;
}
}
self.models.push((lower, Some(priority)));
}
pub fn add_model(&mut self, model: &str) {
let lower = model.to_lowercase();
if !self.models.iter().any(|(n, _)| *n == lower) {
self.models.push((lower, None));
}
}
pub fn select(&self, reqs: &ModelRequirements) -> Option<ScoredModel> {
if self.models.len() <= 2 {
return self
.models
.iter()
.filter_map(|(name, custom_prio)| self.score_model(name, *custom_prio, reqs))
.next();
}
self.ranked(reqs).into_iter().next()
}
pub fn ranked(&self, reqs: &ModelRequirements) -> Vec<ScoredModel> {
let mut candidates: Vec<ScoredModel> = self
.models
.iter()
.filter_map(|(name, custom_prio)| self.score_model(name, *custom_prio, reqs))
.collect();
if candidates.len() > 2 {
candidates.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
candidates
}
pub fn select_multi(&self, requirements: &[ModelRequirements]) -> Vec<Option<ScoredModel>> {
let mut used: Vec<bool> = vec![false; self.models.len()];
let mut results = Vec::with_capacity(requirements.len());
for reqs in requirements {
let mut best: Option<(ScoredModel, usize)> = None;
for (idx, (name, custom_prio)) in self.models.iter().enumerate() {
if used[idx] {
continue;
}
if let Some(scored) = self.score_model(name, *custom_prio, reqs) {
let dominated = match &best {
Some((current, _)) => scored.score > current.score,
None => true,
};
if dominated {
best = Some((scored, idx));
}
}
}
if let Some((model, idx)) = best {
used[idx] = true;
results.push(Some(model));
} else {
let fallback = self.select(reqs);
results.push(fallback);
}
}
results
}
fn score_model(
&self,
name: &str,
custom_prio: Option<f32>,
reqs: &ModelRequirements,
) -> Option<ScoredModel> {
let profile = llm_models_spider::model_profile(name);
let has_vision = profile
.as_ref()
.map(|p| p.capabilities.vision)
.unwrap_or_else(|| llm_models_spider::supports_vision(name));
let has_audio = profile
.as_ref()
.map(|p| p.capabilities.audio)
.unwrap_or(false);
let has_video = profile
.as_ref()
.map(|p| p.capabilities.video)
.unwrap_or_else(|| llm_models_spider::supports_video(name));
let has_pdf = profile
.as_ref()
.map(|p| p.capabilities.file)
.unwrap_or_else(|| llm_models_spider::supports_pdf(name));
if reqs.vision && !has_vision {
return None;
}
if reqs.audio && !has_audio {
return None;
}
if reqs.video && !has_video {
return None;
}
if reqs.pdf && !has_pdf {
return None;
}
let max_input = profile.as_ref().map(|p| p.max_input_tokens).unwrap_or(0);
if reqs.min_context_tokens > 0 && max_input < reqs.min_context_tokens {
return None;
}
let arena = profile.as_ref().and_then(|p| p.ranks.overall);
let input_cost = profile
.as_ref()
.and_then(|p| p.pricing.input_cost_per_m_tokens);
if reqs.max_input_cost_per_m > 0.0 {
if let Some(cost) = input_cost {
if cost > reqs.max_input_cost_per_m {
return None;
}
}
}
if reqs.min_arena_score > 0.0 {
match arena {
Some(score) if score >= reqs.min_arena_score => {}
Some(_) => return None,
None => {} }
}
let score = if let Some(prio) = custom_prio {
prio
} else {
self.auto_score(arena, input_cost, max_input)
};
Some(ScoredModel {
name: name.to_string(),
score,
arena_rank: arena,
input_cost_per_m: input_cost,
supports_vision: has_vision,
max_input_tokens: max_input,
})
}
fn auto_score(&self, arena: Option<f32>, cost: Option<f32>, context: u32) -> f32 {
match self.strategy {
SelectionStrategy::BestQuality => arena.unwrap_or(50.0),
SelectionStrategy::CheapestFirst => {
match cost {
Some(c) if c > 0.0 => 1000.0 / c,
_ => 50.0, }
}
SelectionStrategy::LargestContext => context as f32 / 1000.0,
SelectionStrategy::ValueOptimal => {
let quality = arena.unwrap_or(50.0);
let cost_factor = match cost {
Some(c) if c > 0.0 => 100.0 / c,
_ => 1.0,
};
quality * cost_factor.sqrt()
}
}
}
}
pub fn auto_policy(available_models: &[&str]) -> ModelPolicy {
if available_models.is_empty() {
return ModelPolicy::default();
}
if available_models.len() == 1 {
let m = available_models[0].to_string();
return ModelPolicy {
small: m.clone(),
medium: m.clone(),
large: m,
allow_large: true,
max_latency_ms: None,
max_cost_tier: CostTier::High,
};
}
if available_models.len() == 2 {
let a = available_models[0].to_string();
let b = available_models[1].to_string();
return ModelPolicy {
large: a.clone(),
medium: a,
small: b,
allow_large: true,
max_latency_ms: None,
max_cost_tier: CostTier::High,
};
}
let mut models: Vec<(&str, f32, f32)> = available_models
.iter()
.map(|&name| {
let profile = llm_models_spider::model_profile(name);
let arena = profile
.as_ref()
.and_then(|p| p.ranks.overall)
.unwrap_or(50.0);
let cost = profile
.as_ref()
.and_then(|p| p.pricing.input_cost_per_m_tokens)
.unwrap_or(5.0);
(name, arena, cost)
})
.collect();
models.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
if models.is_empty() {
return ModelPolicy::default();
}
let large = models[0].0.to_string();
let small = models
.last()
.map(|m| m.0.to_string())
.unwrap_or_else(|| large.clone());
let medium = if models.len() >= 3 {
models[models.len() / 2].0.to_string()
} else {
large.clone()
};
ModelPolicy {
small,
medium,
large,
allow_large: true,
max_latency_ms: None,
max_cost_tier: CostTier::High,
}
}
#[derive(Debug, Clone, Default)]
pub struct TaskAnalysis {
pub estimated_tokens: usize,
pub requires_reasoning: bool,
pub requires_code_generation: bool,
pub requires_structured_output: bool,
pub multi_step: bool,
pub max_latency: Option<Duration>,
pub category: TaskCategory,
pub requires_vision: bool,
pub requires_audio: bool,
}
impl TaskAnalysis {
pub fn from_prompt(prompt: &str) -> Self {
let estimated_tokens = estimate_tokens(prompt);
let lower = prompt.to_lowercase();
Self {
estimated_tokens,
requires_reasoning: lower.contains("analyze")
|| lower.contains("compare")
|| lower.contains("explain")
|| lower.contains("why"),
requires_code_generation: lower.contains("code")
|| lower.contains("implement")
|| lower.contains("function")
|| lower.contains("script"),
requires_structured_output: lower.contains("json")
|| lower.contains("extract")
|| lower.contains("list"),
multi_step: lower.contains("then")
|| lower.contains("step")
|| lower.contains("first")
|| lower.contains("next"),
max_latency: None,
category: TaskCategory::General,
requires_vision: lower.contains("screenshot")
|| lower.contains("image")
|| lower.contains("picture")
|| lower.contains("visual"),
requires_audio: lower.contains("audio")
|| lower.contains("voice")
|| lower.contains("speech"),
}
}
pub fn extraction(html_length: usize) -> Self {
Self {
estimated_tokens: html_length / 4 + 200, requires_reasoning: false,
requires_code_generation: false,
requires_structured_output: true,
multi_step: false,
max_latency: None,
category: TaskCategory::Extraction,
requires_vision: false,
requires_audio: false,
}
}
pub fn action(instruction: &str) -> Self {
let mut analysis = Self::from_prompt(instruction);
analysis.category = TaskCategory::Action;
analysis.requires_structured_output = true;
analysis
}
pub fn with_max_latency(mut self, latency: Duration) -> Self {
self.max_latency = Some(latency);
self
}
pub fn to_requirements(&self) -> ModelRequirements {
ModelRequirements {
vision: self.requires_vision,
audio: self.requires_audio,
..Default::default()
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum TaskCategory {
#[default]
General,
Extraction,
Action,
Code,
Analysis,
Classification,
}
#[derive(Debug, Clone)]
pub struct RoutingDecision {
pub model: String,
pub tier: CostTier,
pub reason: String,
}
impl RoutingDecision {
pub fn is_fast(&self) -> bool {
self.tier == CostTier::Low
}
pub fn is_powerful(&self) -> bool {
self.tier == CostTier::High
}
}
pub fn estimate_tokens(text: &str) -> usize {
text.len() / 4 + 1
}
pub fn estimate_message_tokens(messages: &[crate::Message]) -> usize {
messages
.iter()
.map(|m| estimate_tokens(m.content.as_text()) + 4) .sum()
}
pub fn classify_round_complexity(
user_prompt: &str,
html_len: usize,
round_idx: usize,
stagnated: bool,
) -> TaskAnalysis {
let mut analysis = TaskAnalysis::from_prompt(user_prompt);
analysis.estimated_tokens = user_prompt.len() / 4 + 1;
if round_idx == 0 {
analysis.requires_reasoning = true;
}
if stagnated {
analysis.requires_reasoning = true;
analysis.multi_step = true;
}
if html_len > 50_000 {
analysis.requires_reasoning = true;
}
analysis.requires_structured_output = true;
analysis
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_router_simple() {
let router = ModelRouter::new();
let decision = router.route_simple("Extract the title from this page");
assert!(!decision.model.is_empty());
}
#[test]
fn test_model_router_complex() {
let router = ModelRouter::new();
let task = TaskAnalysis {
estimated_tokens: 5000,
requires_reasoning: true,
requires_code_generation: true,
..Default::default()
};
let decision = router.route(&task);
assert_eq!(decision.tier, CostTier::High);
}
#[test]
fn test_model_router_constrained() {
let policy = ModelPolicy {
max_cost_tier: CostTier::Medium,
..Default::default()
};
let router = ModelRouter::with_policy(policy);
let task = TaskAnalysis {
estimated_tokens: 5000,
requires_reasoning: true,
..Default::default()
};
let decision = router.route(&task);
assert!(decision.tier != CostTier::High);
}
#[test]
fn test_task_analysis_from_prompt() {
let analysis = TaskAnalysis::from_prompt(
"Analyze the code and explain why it's slow, then implement a fix",
);
assert!(analysis.requires_reasoning);
assert!(analysis.requires_code_generation);
assert!(analysis.multi_step);
}
#[test]
fn test_task_analysis_vision_detection() {
let analysis = TaskAnalysis::from_prompt("Look at this screenshot and describe it");
assert!(analysis.requires_vision);
let analysis = TaskAnalysis::from_prompt("Summarize this text");
assert!(!analysis.requires_vision);
}
#[test]
fn test_estimate_tokens() {
assert_eq!(estimate_tokens("hello world"), 3); assert_eq!(estimate_tokens(""), 1);
}
#[test]
fn test_model_selector_basic() {
let selector = ModelSelector::new(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
let reqs = ModelRequirements::default();
let result = selector.select(&reqs);
assert!(result.is_some());
}
#[test]
fn test_model_selector_vision_filter() {
let selector = ModelSelector::new(&["gpt-4o", "gpt-3.5-turbo"]);
let reqs = ModelRequirements::default().with_vision();
let ranked = selector.ranked(&reqs);
assert!(!ranked.is_empty());
for m in &ranked {
assert!(
m.supports_vision,
"non-vision model {} passed filter",
m.name
);
}
}
#[test]
fn test_model_selector_custom_priority() {
let mut selector = ModelSelector::new(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
selector.set_priority("gpt-4o-mini", 999.0);
let reqs = ModelRequirements::default();
let best = selector.select(&reqs).unwrap();
assert_eq!(best.name, "gpt-4o-mini");
assert_eq!(best.score, 999.0);
}
#[test]
fn test_model_selector_cheapest_strategy() {
let mut selector = ModelSelector::new(&["gpt-4o", "gpt-4o-mini"]);
selector.set_strategy(SelectionStrategy::CheapestFirst);
let reqs = ModelRequirements::default();
let ranked = selector.ranked(&reqs);
assert!(ranked.len() >= 1);
}
#[test]
fn test_model_selector_multi_dispatch() {
let selector = ModelSelector::new(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
let requirements = vec![
ModelRequirements::default().with_vision(),
ModelRequirements::default(),
];
let results = selector.select_multi(&requirements);
assert_eq!(results.len(), 2);
assert!(results[0].is_some());
assert!(results[0].as_ref().unwrap().supports_vision);
assert!(results[1].is_some());
}
#[test]
fn test_model_selector_add_model() {
let mut selector = ModelSelector::new(&["gpt-4o"]);
selector.add_model("gpt-4o-mini");
assert_eq!(selector.models.len(), 2);
selector.add_model("gpt-4o");
assert_eq!(selector.models.len(), 2);
}
#[test]
fn test_auto_policy_single_model() {
let policy = auto_policy(&["gpt-4o"]);
assert_eq!(policy.small, "gpt-4o");
assert_eq!(policy.medium, "gpt-4o");
assert_eq!(policy.large, "gpt-4o");
}
#[test]
fn test_auto_policy_multiple_models() {
let policy = auto_policy(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
assert!(!policy.small.is_empty());
assert!(!policy.medium.is_empty());
assert!(!policy.large.is_empty());
}
#[test]
fn test_auto_policy_empty() {
let policy = auto_policy(&[]);
assert_eq!(policy.small, "gpt-4o-mini");
assert_eq!(policy.medium, "gpt-4o");
}
#[test]
fn test_model_requirements_builder() {
let reqs = ModelRequirements::default()
.with_vision()
.with_min_context(100_000)
.with_max_cost(10.0)
.with_min_arena(60.0);
assert!(reqs.vision);
assert_eq!(reqs.min_context_tokens, 100_000);
assert_eq!(reqs.max_input_cost_per_m, 10.0);
assert_eq!(reqs.min_arena_score, 60.0);
}
#[test]
fn test_task_to_requirements() {
let task = TaskAnalysis::from_prompt("Look at this screenshot and extract data");
let reqs = task.to_requirements();
assert!(reqs.vision);
}
#[test]
fn test_llm_data_vision_models_detected() {
for model in &[
"gpt-4o",
"gpt-4o-mini",
"claude-sonnet-4-5-20250514",
"gemini-2.0-flash",
"qwen2-vl-72b-instruct",
"llama-3.2-11b-vision-instruct",
] {
assert!(
llm_models_spider::supports_vision(model),
"{model} should support vision"
);
}
for model in &["gpt-3.5-turbo", "deepseek-chat"] {
assert!(
!llm_models_spider::supports_vision(model),
"{model} should NOT support vision"
);
}
}
#[test]
fn test_llm_data_model_profiles_exist() {
let must_have = [
"gpt-4o",
"gpt-4o-mini",
"gpt-3.5-turbo",
"claude-3-5-sonnet-20241022",
"gemini-2.0-flash",
"deepseek-chat",
];
for name in &must_have {
let profile = llm_models_spider::model_profile(name);
assert!(
profile.is_some(),
"model_profile({name}) should return Some"
);
let p = profile.unwrap();
assert!(
p.max_input_tokens > 0,
"{name} should have max_input_tokens > 0, got {}",
p.max_input_tokens
);
}
}
#[test]
fn test_llm_data_arena_scores_present() {
for name in &["claude-3.5-sonnet", "chatgpt-4o-latest", "claude-opus-4"] {
let profile = llm_models_spider::model_profile(name);
assert!(profile.is_some(), "{name} should have a profile");
let p = profile.unwrap();
assert!(
p.ranks.overall.is_some(),
"{name} should have an arena score"
);
assert!(
p.ranks.overall.unwrap() > 0.0,
"{name} arena score should be > 0"
);
}
}
#[test]
fn test_llm_data_pricing_ordering() {
let cheap = llm_models_spider::model_profile("gpt-4o-mini");
let expensive = llm_models_spider::model_profile("claude-opus-4-20250514");
assert!(cheap.is_some() && expensive.is_some());
let cheap_cost = cheap.unwrap().pricing.input_cost_per_m_tokens.unwrap();
let expensive_cost = expensive.unwrap().pricing.input_cost_per_m_tokens.unwrap();
assert!(
cheap_cost < expensive_cost,
"gpt-4o-mini (${cheap_cost}) should be cheaper than claude-opus-4 (${expensive_cost})"
);
}
#[test]
fn test_llm_data_context_window_ordering() {
let large_ctx = llm_models_spider::model_profile("gemini-2.5-pro-preview-05-06");
let small_ctx = llm_models_spider::model_profile("gpt-3.5-turbo");
assert!(large_ctx.is_some() && small_ctx.is_some());
let large_tokens = large_ctx.unwrap().max_input_tokens;
let small_tokens = small_ctx.unwrap().max_input_tokens;
assert!(
large_tokens > small_tokens,
"gemini-2.5-pro ({large_tokens}) should have more context than gpt-3.5-turbo ({small_tokens})"
);
}
#[test]
fn test_selector_realistic_pool_best_quality() {
let selector = ModelSelector::new(&[
"gpt-4o",
"gpt-4o-mini",
"gpt-3.5-turbo",
"claude-3-5-sonnet-20241022",
"gemini-2.0-flash",
"deepseek-chat",
]);
let reqs = ModelRequirements::default();
let ranked = selector.ranked(&reqs);
assert!(!ranked.is_empty());
let top = &ranked[0];
for other in &ranked[1..] {
assert!(
top.score >= other.score,
"top model {} (score {}) should beat {} (score {})",
top.name,
top.score,
other.name,
other.score
);
}
}
#[test]
fn test_selector_realistic_pool_cheapest() {
let mut selector = ModelSelector::new(&[
"gpt-4o",
"gpt-4o-mini",
"gpt-3.5-turbo",
"claude-3-5-sonnet-20241022",
]);
selector.set_strategy(SelectionStrategy::CheapestFirst);
let reqs = ModelRequirements::default();
let ranked = selector.ranked(&reqs);
assert!(ranked.len() >= 2);
let top = &ranked[0];
let bottom = ranked.last().unwrap();
if let (Some(top_cost), Some(bottom_cost)) = (top.input_cost_per_m, bottom.input_cost_per_m)
{
assert!(
top_cost <= bottom_cost,
"cheapest ({}, ${top_cost}) should rank above expensive ({}, ${bottom_cost})",
top.name,
bottom.name
);
}
}
#[test]
fn test_selector_vision_filter_rejects_text_only() {
let selector = ModelSelector::new(&["gpt-3.5-turbo", "deepseek-chat"]);
let reqs = ModelRequirements::default().with_vision();
let result = selector.select(&reqs);
assert!(
result.is_none(),
"text-only pool should return None for vision requirement"
);
}
#[test]
fn test_selector_unknown_models_graceful() {
let selector = ModelSelector::new(&["my-custom-model", "local-llama"]);
let reqs = ModelRequirements::default();
let result = selector.select(&reqs);
assert!(
result.is_some(),
"unknown models should still return Some with default score"
);
let scored = result.unwrap();
assert_eq!(scored.score, 50.0, "unknown model gets default score 50.0");
}
#[test]
fn test_selector_single_model_all_strategies() {
for strategy in &[
SelectionStrategy::BestQuality,
SelectionStrategy::CheapestFirst,
SelectionStrategy::LargestContext,
SelectionStrategy::ValueOptimal,
] {
let mut selector = ModelSelector::new(&["gpt-4o"]);
selector.set_strategy(*strategy);
let reqs = ModelRequirements::default();
let result = selector.select(&reqs);
assert!(
result.is_some(),
"single model should be returned for {strategy:?}"
);
assert_eq!(result.unwrap().name, "gpt-4o");
}
}
#[test]
fn test_selector_deterministic_ordering() {
let selector = ModelSelector::new(&[
"gpt-4o",
"gpt-4o-mini",
"claude-3-5-sonnet-20241022",
"gemini-2.0-flash",
]);
let reqs = ModelRequirements::default();
let first_run: Vec<String> = selector
.ranked(&reqs)
.iter()
.map(|m| m.name.clone())
.collect();
let second_run: Vec<String> = selector
.ranked(&reqs)
.iter()
.map(|m| m.name.clone())
.collect();
assert_eq!(
first_run, second_run,
"repeated calls must produce identical ordering"
);
}
#[test]
fn test_selector_cost_filter_strict() {
let selector = ModelSelector::new(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
let reqs = ModelRequirements::default().with_max_cost(1.0);
let ranked = selector.ranked(&reqs);
for m in &ranked {
if let Some(cost) = m.input_cost_per_m {
assert!(
cost <= 1.0,
"{} has cost ${cost} which exceeds max 1.0",
m.name
);
}
}
}
#[test]
fn test_selector_min_context_filter() {
let selector = ModelSelector::new(&["gpt-4o", "gpt-3.5-turbo", "gemini-2.0-flash"]);
let reqs = ModelRequirements::default().with_min_context(500_000);
let ranked = selector.ranked(&reqs);
for m in &ranked {
assert!(
m.max_input_tokens >= 500_000,
"{} has {} tokens, below 500k minimum",
m.name,
m.max_input_tokens
);
}
}
#[test]
fn test_selector_value_optimal_balances() {
let mut selector = ModelSelector::new(&[
"gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo", ]);
selector.set_strategy(SelectionStrategy::ValueOptimal);
let reqs = ModelRequirements::default();
let ranked = selector.ranked(&reqs);
assert!(!ranked.is_empty());
let top = &ranked[0];
assert!(top.score > 0.0, "ValueOptimal score should be positive");
}
#[test]
fn test_select_multi_no_reuse() {
let selector = ModelSelector::new(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
let requirements = vec![
ModelRequirements::default(),
ModelRequirements::default(),
ModelRequirements::default(),
];
let results = selector.select_multi(&requirements);
assert_eq!(results.len(), 3);
let names: Vec<&str> = results
.iter()
.filter_map(|r| r.as_ref().map(|m| m.name.as_str()))
.collect();
assert_eq!(names.len(), 3, "all 3 requests should get a model");
let mut deduped = names.clone();
deduped.sort();
deduped.dedup();
assert_eq!(
deduped.len(),
3,
"no model should be reused when pool is large enough"
);
}
#[test]
fn test_select_multi_exhaustion_fallback() {
let selector = ModelSelector::new(&["gpt-4o"]);
let requirements = vec![
ModelRequirements::default(),
ModelRequirements::default(),
ModelRequirements::default(),
];
let results = selector.select_multi(&requirements);
assert_eq!(results.len(), 3);
assert!(results[0].is_some());
assert!(
results[1].is_some(),
"fallback should reuse the single model"
);
assert!(
results[2].is_some(),
"fallback should reuse the single model"
);
assert_eq!(results[0].as_ref().unwrap().name, "gpt-4o");
assert_eq!(results[1].as_ref().unwrap().name, "gpt-4o");
assert_eq!(results[2].as_ref().unwrap().name, "gpt-4o");
}
#[test]
fn test_selector_priority_override_beats_arena() {
let mut selector = ModelSelector::new(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
selector.set_priority("gpt-3.5-turbo", 999.0);
let reqs = ModelRequirements::default();
let best = selector.select(&reqs).unwrap();
assert_eq!(
best.name, "gpt-3.5-turbo",
"priority override should beat natural arena score"
);
assert_eq!(best.score, 999.0);
}
#[test]
fn test_auto_policy_realistic_tiering() {
let policy = auto_policy(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
assert_ne!(
policy.large, policy.small,
"large and small should be different models"
);
assert_eq!(policy.model_for_tier(CostTier::High), policy.large);
assert_eq!(policy.model_for_tier(CostTier::Low), policy.small);
assert_eq!(policy.model_for_tier(CostTier::Medium), policy.medium);
}
#[test]
fn test_auto_policy_2_models() {
let policy = auto_policy(&["gpt-4o", "gpt-4o-mini"]);
assert_eq!(policy.large, "gpt-4o");
assert_eq!(policy.medium, "gpt-4o");
assert_eq!(policy.small, "gpt-4o-mini");
assert_eq!(
policy.medium, policy.large,
"2-model policy should have medium == large"
);
assert_ne!(policy.large, policy.small, "large and small should differ");
}
#[test]
fn test_auto_policy_unknown_models() {
let policy = auto_policy(&["my-custom-llm", "local-model-7b", "test-endpoint"]);
assert!(!policy.small.is_empty());
assert!(!policy.medium.is_empty());
assert!(!policy.large.is_empty());
assert!(policy.allow_large);
}
#[test]
fn test_auto_policy_to_router_e2e() {
let policy = auto_policy(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
let router = ModelRouter::with_policy(policy.clone());
let simple_task = TaskAnalysis {
estimated_tokens: 100,
..Default::default()
};
let decision = router.route(&simple_task);
assert_eq!(decision.tier, CostTier::Low);
assert_eq!(decision.model, policy.small);
let hard_task = TaskAnalysis {
estimated_tokens: 5000,
requires_reasoning: true,
requires_code_generation: true,
..Default::default()
};
let decision = router.route(&hard_task);
assert_eq!(decision.tier, CostTier::High);
assert_eq!(decision.model, policy.large);
let medium_task = TaskAnalysis {
estimated_tokens: 2000,
requires_structured_output: true,
multi_step: true,
..Default::default()
};
let decision = router.route(&medium_task);
assert_eq!(decision.tier, CostTier::Medium);
assert_eq!(decision.model, policy.medium);
}
#[test]
fn test_auto_policy_to_router_e2e_single_model() {
let policy = auto_policy(&["gpt-4o"]);
assert_eq!(policy.small, "gpt-4o");
assert_eq!(policy.medium, "gpt-4o");
assert_eq!(policy.large, "gpt-4o");
let router = ModelRouter::with_policy(policy);
let simple = TaskAnalysis {
estimated_tokens: 50,
..Default::default()
};
let medium = TaskAnalysis {
estimated_tokens: 2000,
requires_structured_output: true,
..Default::default()
};
let hard = TaskAnalysis {
estimated_tokens: 5000,
requires_reasoning: true,
requires_code_generation: true,
..Default::default()
};
for (label, task) in [("simple", &simple), ("medium", &medium), ("hard", &hard)] {
let decision = router.route(task);
assert_eq!(
decision.model, "gpt-4o",
"{label} task should still route to the only model"
);
}
}
#[test]
fn test_selector_single_model_vision_mismatch() {
let selector = ModelSelector::new(&["gpt-3.5-turbo"]);
let reqs = ModelRequirements::default().with_vision();
assert!(
selector.select(&reqs).is_none(),
"single text-only model should not satisfy vision requirement"
);
let selector = ModelSelector::new(&["gpt-4o"]);
let result = selector.select(&reqs);
assert!(
result.is_some(),
"single vision model should satisfy vision"
);
assert_eq!(result.unwrap().name, "gpt-4o");
}
#[test]
fn test_selector_single_model_with_cost_filter() {
let selector = ModelSelector::new(&["gpt-4o"]);
let reqs = ModelRequirements::default().with_max_cost(0.01);
assert!(
selector.select(&reqs).is_none(),
"single expensive model should be filtered by strict cost limit"
);
let reqs = ModelRequirements::default().with_max_cost(100.0);
let result = selector.select(&reqs);
assert!(result.is_some());
assert_eq!(result.unwrap().name, "gpt-4o");
}
#[test]
fn test_selector_single_unknown_model_e2e() {
let policy = auto_policy(&["my-local-llama"]);
assert_eq!(policy.small, "my-local-llama");
assert_eq!(policy.medium, "my-local-llama");
assert_eq!(policy.large, "my-local-llama");
let router = ModelRouter::with_policy(policy);
let decision = router.route_simple("do something complex and analyze the code");
assert_eq!(
decision.model, "my-local-llama",
"unknown single model should still be routed to"
);
let selector = ModelSelector::new(&["my-local-llama"]);
let result = selector.select(&ModelRequirements::default());
assert!(result.is_some());
let scored = result.unwrap();
assert_eq!(scored.name, "my-local-llama");
assert_eq!(scored.score, 50.0, "unknown model gets default score");
assert_eq!(
scored.max_input_tokens, 0,
"unknown model has no context data"
);
assert!(
scored.arena_rank.is_none(),
"unknown model has no arena data"
);
}
#[test]
fn test_router_latency_constraint_downgrade() {
let policy = ModelPolicy {
max_latency_ms: Some(1000),
..Default::default()
};
let router = ModelRouter::with_policy(policy);
let task = TaskAnalysis {
estimated_tokens: 5000,
requires_reasoning: true,
requires_code_generation: true,
..Default::default()
};
let decision = router.route(&task);
assert_ne!(
decision.tier,
CostTier::High,
"latency constraint should prevent High tier"
);
}
#[test]
fn test_router_allow_large_false() {
let policy = ModelPolicy {
allow_large: false,
..Default::default()
};
let router = ModelRouter::with_policy(policy);
let task = TaskAnalysis {
estimated_tokens: 5000,
requires_reasoning: true,
requires_code_generation: true,
..Default::default()
};
let decision = router.route(&task);
assert_ne!(
decision.tier,
CostTier::High,
"allow_large=false should cap at Medium"
);
}
#[test]
fn test_router_threshold_customization() {
let router = ModelRouter::new().with_thresholds(100, 200);
let task = TaskAnalysis {
estimated_tokens: 300,
requires_reasoning: true,
requires_code_generation: true,
..Default::default()
};
let decision = router.route(&task);
assert_eq!(
decision.tier,
CostTier::High,
"lowered thresholds should promote to High tier sooner"
);
let default_router = ModelRouter::new();
let decision = default_router.route(&task);
assert_eq!(
decision.tier,
CostTier::Medium,
"default thresholds should keep this at Medium"
);
}
#[test]
fn test_selector_empty_pool() {
let selector = ModelSelector::new(&[]);
let reqs = ModelRequirements::default();
let result = selector.select(&reqs);
assert!(result.is_none(), "empty pool should return None");
}
#[test]
fn test_selector_duplicate_models() {
let selector = ModelSelector::new(&["gpt-4o", "gpt-4o", "gpt-4o"]);
let requirements = vec![
ModelRequirements::default(),
ModelRequirements::default(),
ModelRequirements::default(),
];
let results = selector.select_multi(&requirements);
assert_eq!(results.len(), 3, "should not hang on duplicates");
for (i, r) in results.iter().enumerate() {
assert!(r.is_some(), "request {i} should get a model");
}
}
#[test]
fn test_task_analysis_edge_cases() {
let analysis = TaskAnalysis::from_prompt("");
assert_eq!(analysis.estimated_tokens, 1);
assert!(!analysis.requires_reasoning);
let analysis = TaskAnalysis::from_prompt(
"analyze compare explain why code implement function script json extract list then step first next screenshot image",
);
assert!(analysis.requires_reasoning);
assert!(analysis.requires_code_generation);
assert!(analysis.requires_structured_output);
assert!(analysis.multi_step);
assert!(analysis.requires_vision);
let analysis = TaskAnalysis::from_prompt("δ½ ε₯½δΈη π ζ₯ζ¬θͺγγΉγ");
assert!(!analysis.requires_reasoning);
assert!(!analysis.requires_code_generation);
assert!(analysis.estimated_tokens > 0);
}
#[test]
fn test_auto_policy_large_pool() {
let models: Vec<&str> = vec![
"gpt-4o",
"gpt-4o-mini",
"gpt-3.5-turbo",
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-20241022",
"gemini-2.0-flash",
"deepseek-chat",
"unknown-model-1",
"unknown-model-2",
"unknown-model-3",
"unknown-model-4",
"unknown-model-5",
"unknown-model-6",
"unknown-model-7",
"unknown-model-8",
"unknown-model-9",
"unknown-model-10",
"unknown-model-11",
"unknown-model-12",
"unknown-model-13",
];
let policy = auto_policy(&models);
assert!(!policy.small.is_empty());
assert!(!policy.medium.is_empty());
assert!(!policy.large.is_empty());
assert!(policy.allow_large);
assert_eq!(policy.max_cost_tier, CostTier::High);
}
#[test]
fn test_classify_round_complexity_round_0() {
let analysis = classify_round_complexity("click button", 1000, 0, false);
assert!(
analysis.requires_reasoning,
"round 0 always requires reasoning"
);
assert!(analysis.requires_structured_output);
}
#[test]
fn test_classify_round_complexity_stagnated() {
let analysis = classify_round_complexity("click button", 1000, 5, true);
assert!(
analysis.requires_reasoning,
"stagnated rounds need reasoning"
);
assert!(analysis.multi_step, "stagnated rounds are multi-step");
}
#[test]
fn test_classify_round_complexity_large_html() {
let analysis = classify_round_complexity("click button", 60_000, 3, false);
assert!(analysis.requires_reasoning, "large HTML needs reasoning");
}
#[test]
fn test_classify_round_complexity_simple() {
let analysis = classify_round_complexity("click button", 1000, 3, false);
assert!(!analysis.requires_reasoning);
assert!(!analysis.multi_step);
assert!(analysis.requires_structured_output);
}
#[test]
fn test_policy_getter() {
let policy = ModelPolicy {
small: "small-model".to_string(),
medium: "medium-model".to_string(),
large: "large-model".to_string(),
allow_large: true,
max_latency_ms: None,
max_cost_tier: CostTier::High,
};
let router = ModelRouter::with_policy(policy.clone());
assert_eq!(router.policy().small, "small-model");
assert_eq!(router.policy().large, "large-model");
}
}