use super::{Importance, RouteMetadata, RoutingDecision, RoutingStrategy, TargetModel, TaskType};
use crate::error::AppResult;
use crate::models::ModelSelector;
#[derive(Debug, Clone, Default)]
pub struct RuleBasedRouter;
impl RuleBasedRouter {
pub fn new() -> Self {
Self
}
pub async fn route(
&self,
_user_prompt: &str,
meta: &RouteMetadata,
_selector: &ModelSelector,
) -> AppResult<Option<RoutingDecision>> {
if let Some(target) = self.evaluate_rules(meta) {
return Ok(Some(RoutingDecision::new(target, RoutingStrategy::Rule)));
}
tracing::debug!(
token_estimate = meta.token_estimate,
importance = ?meta.importance,
task_type = ?meta.task_type,
"No rule matched, returning None for caller to handle"
);
Ok(None)
}
fn evaluate_rules(&self, meta: &RouteMetadata) -> Option<TargetModel> {
use Importance::*;
use TaskType::*;
if matches!(meta.task_type, CasualChat)
&& meta.token_estimate < 256
&& !matches!(meta.importance, High)
{
return Some(TargetModel::Fast);
}
if (matches!(meta.importance, High) && !matches!(meta.task_type, CasualChat))
|| matches!(meta.task_type, DeepAnalysis | CreativeWriting)
{
return Some(TargetModel::Deep);
}
if matches!(meta.task_type, Code) {
return if meta.token_estimate > 1024 {
Some(TargetModel::Deep)
} else {
Some(TargetModel::Balanced)
};
}
if meta.token_estimate >= 200
&& meta.token_estimate < 2048
&& matches!(meta.task_type, QuestionAnswer | DocumentSummary)
{
return Some(TargetModel::Balanced);
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_metrics() -> Arc<crate::metrics::Metrics> {
Arc::new(crate::metrics::Metrics::new().expect("should create metrics"))
}
use crate::config::Config;
use std::sync::Arc;
fn test_config() -> Arc<Config> {
let toml = r#"
[server]
host = "127.0.0.1"
port = 3000
request_timeout_seconds = 30
[[models.fast]]
name = "test-fast"
base_url = "http://localhost:1234/v1"
max_tokens = 2048
temperature = 0.7
weight = 1.0
priority = 1
[[models.balanced]]
name = "test-balanced"
base_url = "http://localhost:1235/v1"
max_tokens = 4096
temperature = 0.7
weight = 1.0
priority = 1
[[models.deep]]
name = "test-deep"
base_url = "http://localhost:1236/v1"
max_tokens = 8192
temperature = 0.7
weight = 1.0
priority = 1
[routing]
strategy = "rule"
default_importance = "normal"
router_tier = "balanced"
[observability]
log_level = "info"
"#;
Arc::new(toml::from_str(toml).expect("should parse config"))
}
#[tokio::test]
async fn test_router_creates() {
let router = RuleBasedRouter::new();
let config = test_config();
let selector = ModelSelector::new(config, test_metrics());
let result = router
.route("test", &RouteMetadata::new(100), &selector)
.await;
assert!(result.is_ok());
assert!(
result.unwrap().is_none(),
"No rule match should return None"
);
}
#[test]
fn test_casual_chat_small_tokens_routes_to_fast() {
let router = RuleBasedRouter::new();
let meta = RouteMetadata::new(100)
.with_task_type(TaskType::CasualChat)
.with_importance(Importance::Normal);
let target = router.evaluate_rules(&meta);
assert_eq!(target, Some(TargetModel::Fast));
}
#[test]
fn test_casual_chat_high_importance_no_match() {
let router = RuleBasedRouter::new();
let meta = RouteMetadata::new(100)
.with_task_type(TaskType::CasualChat)
.with_importance(Importance::High);
let target = router.evaluate_rules(&meta);
assert_eq!(target, None); }
#[test]
fn test_casual_chat_large_tokens_no_match() {
let router = RuleBasedRouter::new();
let meta = RouteMetadata::new(300)
.with_task_type(TaskType::CasualChat)
.with_importance(Importance::Normal);
let target = router.evaluate_rules(&meta);
assert_eq!(target, None);
}
#[test]
fn test_document_summary_routes_to_balanced() {
let router = RuleBasedRouter::new();
let meta = RouteMetadata::new(1500)
.with_task_type(TaskType::DocumentSummary)
.with_importance(Importance::Normal);
let target = router.evaluate_rules(&meta);
assert_eq!(target, Some(TargetModel::Balanced));
}
#[test]
fn test_question_answer_routes_to_balanced() {
let router = RuleBasedRouter::new();
let meta = RouteMetadata::new(500)
.with_task_type(TaskType::QuestionAnswer)
.with_importance(Importance::Low);
let target = router.evaluate_rules(&meta);
assert_eq!(target, Some(TargetModel::Balanced));
}
#[test]
fn test_medium_task_exceeds_token_limit_no_match() {
let router = RuleBasedRouter::new();
let meta = RouteMetadata::new(3000)
.with_task_type(TaskType::QuestionAnswer)
.with_importance(Importance::Normal);
let target = router.evaluate_rules(&meta);
assert_eq!(target, None);
}
#[test]
fn test_high_importance_routes_to_deep() {
let router = RuleBasedRouter::new();
let meta = RouteMetadata::new(500)
.with_task_type(TaskType::QuestionAnswer)
.with_importance(Importance::High);
let target = router.evaluate_rules(&meta);
assert_eq!(target, Some(TargetModel::Deep));
}
#[test]
fn test_deep_analysis_routes_to_deep() {
let router = RuleBasedRouter::new();
let meta = RouteMetadata::new(1000)
.with_task_type(TaskType::DeepAnalysis)
.with_importance(Importance::Normal);
let target = router.evaluate_rules(&meta);
assert_eq!(target, Some(TargetModel::Deep));
}
#[test]
fn test_creative_writing_routes_to_deep() {
let router = RuleBasedRouter::new();
let meta = RouteMetadata::new(800)
.with_task_type(TaskType::CreativeWriting)
.with_importance(Importance::Low);
let target = router.evaluate_rules(&meta);
assert_eq!(target, Some(TargetModel::Deep));
}
#[test]
fn test_code_small_tokens_routes_to_balanced() {
let router = RuleBasedRouter::new();
let meta = RouteMetadata::new(500)
.with_task_type(TaskType::Code)
.with_importance(Importance::Normal);
let target = router.evaluate_rules(&meta);
assert_eq!(target, Some(TargetModel::Balanced));
}
#[test]
fn test_code_large_tokens_routes_to_deep() {
let router = RuleBasedRouter::new();
let meta = RouteMetadata::new(2000)
.with_task_type(TaskType::Code)
.with_importance(Importance::Normal);
let target = router.evaluate_rules(&meta);
assert_eq!(target, Some(TargetModel::Deep));
}
#[test]
fn test_router_always_returns_valid_result() {
let router = RuleBasedRouter::new();
let test_cases = vec![
(0, Importance::Low, TaskType::CasualChat),
(100, Importance::Normal, TaskType::Code),
(1000, Importance::High, TaskType::DeepAnalysis),
(500, Importance::Low, TaskType::QuestionAnswer),
(2500, Importance::Normal, TaskType::DocumentSummary),
];
for (tokens, importance, task_type) in test_cases {
let meta = RouteMetadata::new(tokens)
.with_importance(importance)
.with_task_type(task_type);
let result = router.evaluate_rules(&meta);
if let Some(model) = result {
assert!(matches!(
model,
TargetModel::Fast | TargetModel::Balanced | TargetModel::Deep
));
}
}
}
#[test]
fn test_boundary_255_tokens_casual_chat_routes_to_fast() {
let router = RuleBasedRouter::new();
let meta = RouteMetadata::new(255)
.with_task_type(TaskType::CasualChat)
.with_importance(Importance::Normal);
let target = router.evaluate_rules(&meta);
assert_eq!(
target,
Some(TargetModel::Fast),
"255 tokens should match Rule 1 (< 256)"
);
}
#[test]
fn test_boundary_256_tokens_casual_chat_no_match() {
let router = RuleBasedRouter::new();
let meta = RouteMetadata::new(256)
.with_task_type(TaskType::CasualChat)
.with_importance(Importance::Normal);
let target = router.evaluate_rules(&meta);
assert_eq!(
target, None,
"256 tokens should NOT match Rule 1 (requires < 256)"
);
}
#[test]
fn test_boundary_1024_tokens_code_routes_to_balanced() {
let router = RuleBasedRouter::new();
let meta = RouteMetadata::new(1024)
.with_task_type(TaskType::Code)
.with_importance(Importance::Normal);
let target = router.evaluate_rules(&meta);
assert_eq!(
target,
Some(TargetModel::Balanced),
"1024 tokens should match Code → Balanced (not > 1024)"
);
}
#[test]
fn test_boundary_1025_tokens_code_routes_to_deep() {
let router = RuleBasedRouter::new();
let meta = RouteMetadata::new(1025)
.with_task_type(TaskType::Code)
.with_importance(Importance::Normal);
let target = router.evaluate_rules(&meta);
assert_eq!(
target,
Some(TargetModel::Deep),
"1025 tokens should match Code → Deep (> 1024)"
);
}
#[test]
fn test_boundary_199_tokens_question_answer_no_match() {
let router = RuleBasedRouter::new();
let meta = RouteMetadata::new(199)
.with_task_type(TaskType::QuestionAnswer)
.with_importance(Importance::Normal);
let target = router.evaluate_rules(&meta);
assert_eq!(
target, None,
"199 tokens should NOT match Rule 4 (requires >= 200)"
);
}
#[test]
fn test_boundary_200_tokens_question_answer_routes_to_balanced() {
let router = RuleBasedRouter::new();
let meta = RouteMetadata::new(200)
.with_task_type(TaskType::QuestionAnswer)
.with_importance(Importance::Normal);
let target = router.evaluate_rules(&meta);
assert_eq!(
target,
Some(TargetModel::Balanced),
"200 tokens should match Rule 4 (>= 200)"
);
}
#[test]
fn test_boundary_2047_tokens_question_answer_routes_to_balanced() {
let router = RuleBasedRouter::new();
let meta = RouteMetadata::new(2047)
.with_task_type(TaskType::QuestionAnswer)
.with_importance(Importance::Normal);
let target = router.evaluate_rules(&meta);
assert_eq!(
target,
Some(TargetModel::Balanced),
"2047 tokens should match Rule 4 (< 2048)"
);
}
#[test]
fn test_boundary_2048_tokens_question_answer_no_match() {
let router = RuleBasedRouter::new();
let meta = RouteMetadata::new(2048)
.with_task_type(TaskType::QuestionAnswer)
.with_importance(Importance::Normal);
let target = router.evaluate_rules(&meta);
assert_eq!(
target, None,
"2048 tokens should NOT match Rule 4 (requires < 2048)"
);
}
#[tokio::test]
async fn test_route_returns_none_when_no_rule_matches() {
let router = RuleBasedRouter::new();
let config = test_config();
let selector = ModelSelector::new(config, test_metrics());
let meta = RouteMetadata::new(100)
.with_task_type(TaskType::CasualChat)
.with_importance(Importance::High);
let result = router.route("test prompt", &meta, &selector).await;
assert!(
result.is_ok(),
"route() should succeed even when no rule matches"
);
assert!(
result.unwrap().is_none(),
"route() should return None when no rule matches (not default tier)"
);
}
#[tokio::test]
async fn test_route_returns_some_when_rule_matches() {
let router = RuleBasedRouter::new();
let config = test_config();
let selector = ModelSelector::new(config, test_metrics());
let meta = RouteMetadata::new(100)
.with_task_type(TaskType::CasualChat)
.with_importance(Importance::Normal);
let result = router.route("test prompt", &meta, &selector).await;
assert!(result.is_ok(), "route() should succeed when rule matches");
let decision = result.unwrap();
assert!(
decision.is_some(),
"route() should return Some when rule matches"
);
let decision = decision.unwrap();
assert_eq!(decision.target(), TargetModel::Fast);
assert_eq!(decision.strategy(), RoutingStrategy::Rule);
}
}