pub mod hybrid;
pub mod llm_based;
pub mod rule_based;
pub use hybrid::HybridRouter;
pub use llm_based::{LlmBasedRouter, LlmRouter};
pub use rule_based::RuleBasedRouter;
use crate::error::AppResult;
use serde::{Deserialize, Deserializer, Serialize, de};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
pub enum TargetModel {
Fast,
Balanced,
Deep,
}
impl<'de> Deserialize<'de> for TargetModel {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
match s.as_str() {
"fast" => Ok(TargetModel::Fast),
"balanced" => Ok(TargetModel::Balanced),
"deep" => Ok(TargetModel::Deep),
_ => {
let suggestion = if s.to_lowercase() == "fast" {
"Did you mean 'fast' (lowercase)?"
} else if s.to_lowercase() == "balanced" || s == "balance" {
"Did you mean 'balanced' (lowercase, with 'd')?"
} else if s.to_lowercase() == "deep" {
"Did you mean 'deep' (lowercase)?"
} else {
"Valid options: 'fast', 'balanced', or 'deep'"
};
Err(de::Error::custom(format!(
"Invalid router_tier '{}'. Must be 'fast', 'balanced', or 'deep' (lowercase only). \n\
{}\n\
\n\
Common mistakes:\n\
- Capitalization: 'FAST' or 'Fast' should be 'fast'\n\
- Typos: 'balance' should be 'balanced'\n\
- Invalid values: Only 'fast', 'balanced', and 'deep' are supported\n\
\n\
See config.toml documentation for tier selection guidance.",
s, suggestion
)))
}
}
}
}
impl Default for TargetModel {
fn default() -> Self {
TargetModel::Balanced
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum RoutingStrategy {
Rule,
Llm,
}
impl RoutingStrategy {
pub fn as_str(&self) -> &'static str {
match self {
Self::Rule => "rule",
Self::Llm => "llm",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct RoutingDecision {
target: TargetModel,
strategy: RoutingStrategy,
#[serde(skip_serializing_if = "Vec::is_empty")]
warnings: Vec<String>,
}
impl RoutingDecision {
pub fn new(target: TargetModel, strategy: RoutingStrategy) -> Self {
Self {
target,
strategy,
warnings: Vec::new(),
}
}
pub fn target(&self) -> TargetModel {
self.target
}
pub fn strategy(&self) -> RoutingStrategy {
self.strategy
}
pub fn warnings(&self) -> &[String] {
&self.warnings
}
pub fn with_warning(mut self, warning: String) -> Self {
self.warnings.push(warning);
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum Importance {
Low,
#[default]
Normal,
High,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum TaskType {
CasualChat,
Code,
CreativeWriting,
DeepAnalysis,
DocumentSummary,
#[default]
QuestionAnswer,
}
#[derive(Debug, Clone, Copy)]
pub struct RouteMetadata {
pub token_estimate: usize,
pub importance: Importance,
pub task_type: TaskType,
}
impl RouteMetadata {
pub fn new(token_estimate: usize) -> Self {
Self {
token_estimate,
importance: Importance::default(),
task_type: TaskType::default(),
}
}
pub fn with_importance(mut self, importance: Importance) -> Self {
self.importance = importance;
self
}
pub fn with_task_type(mut self, task_type: TaskType) -> Self {
self.task_type = task_type;
self
}
pub fn estimate_tokens(prompt: &str) -> usize {
prompt.chars().count() / 4
}
}
pub enum Router {
Rule(RuleBasedRouter),
Llm(LlmBasedRouter),
Hybrid(HybridRouter),
}
impl Router {
pub async fn route(
&self,
user_prompt: &str,
meta: &RouteMetadata,
selector: &crate::models::ModelSelector,
) -> AppResult<RoutingDecision> {
match self {
Router::Rule(r) => {
match r.route(user_prompt, meta, selector).await? {
Some(decision) => Ok(decision),
None => {
let default_target = selector.default_tier().ok_or_else(|| {
crate::error::AppError::Config(
"No routing rule matched and no endpoints configured for default fallback"
.to_string(),
)
})?;
let exclusion_set = crate::models::ExclusionSet::new();
if selector
.select(default_target, &exclusion_set)
.await
.is_none()
{
return Err(crate::error::AppError::RoutingFailed(format!(
"No rule matched and default tier {:?} has no healthy endpoints available",
default_target
)));
}
tracing::info!(
default_tier = ?default_target,
token_estimate = meta.token_estimate,
importance = ?meta.importance,
task_type = ?meta.task_type,
"No rule matched, using default tier (rule-only mode)"
);
Ok(RoutingDecision::new(default_target, RoutingStrategy::Rule))
}
}
}
Router::Llm(r) => r.route(user_prompt, meta).await,
Router::Hybrid(r) => r.route(user_prompt, meta).await,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_target_model_enum_values() {
let fast = TargetModel::Fast;
let balanced = TargetModel::Balanced;
let deep = TargetModel::Deep;
assert_eq!(fast, TargetModel::Fast);
assert_eq!(balanced, TargetModel::Balanced);
assert_eq!(deep, TargetModel::Deep);
}
#[test]
fn test_importance_default() {
assert_eq!(Importance::default(), Importance::Normal);
}
#[test]
fn test_task_type_default() {
assert_eq!(TaskType::default(), TaskType::QuestionAnswer);
}
#[test]
fn test_route_metadata_new() {
let meta = RouteMetadata::new(100);
assert_eq!(meta.token_estimate, 100);
assert_eq!(meta.importance, Importance::Normal);
assert_eq!(meta.task_type, TaskType::QuestionAnswer);
}
#[test]
fn test_route_metadata_builder() {
let meta = RouteMetadata::new(200)
.with_importance(Importance::High)
.with_task_type(TaskType::Code);
assert_eq!(meta.token_estimate, 200);
assert_eq!(meta.importance, Importance::High);
assert_eq!(meta.task_type, TaskType::Code);
}
#[test]
fn test_estimate_tokens() {
let prompt = "Hello, world!";
let estimate = RouteMetadata::estimate_tokens(prompt);
assert_eq!(estimate, 3);
let long_prompt = "a".repeat(1000);
let long_estimate = RouteMetadata::estimate_tokens(&long_prompt);
assert_eq!(long_estimate, 250); }
#[test]
fn test_importance_serde() {
assert_eq!(
serde_json::from_str::<Importance>(r#""low""#).unwrap(),
Importance::Low
);
assert_eq!(
serde_json::from_str::<Importance>(r#""normal""#).unwrap(),
Importance::Normal
);
assert_eq!(
serde_json::from_str::<Importance>(r#""high""#).unwrap(),
Importance::High
);
}
#[test]
fn test_task_type_serde() {
assert_eq!(
serde_json::from_str::<TaskType>(r#""casual_chat""#).unwrap(),
TaskType::CasualChat
);
assert_eq!(
serde_json::from_str::<TaskType>(r#""code""#).unwrap(),
TaskType::Code
);
assert_eq!(
serde_json::from_str::<TaskType>(r#""creative_writing""#).unwrap(),
TaskType::CreativeWriting
);
}
#[test]
fn test_routing_strategy_as_str() {
assert_eq!(RoutingStrategy::Rule.as_str(), "rule");
assert_eq!(RoutingStrategy::Llm.as_str(), "llm");
}
#[test]
fn test_routing_strategy_serde() {
assert_eq!(
serde_json::from_str::<RoutingStrategy>(r#""rule""#).unwrap(),
RoutingStrategy::Rule
);
assert_eq!(
serde_json::from_str::<RoutingStrategy>(r#""llm""#).unwrap(),
RoutingStrategy::Llm
);
assert_eq!(
serde_json::to_string(&RoutingStrategy::Rule).unwrap(),
r#""rule""#
);
assert_eq!(
serde_json::to_string(&RoutingStrategy::Llm).unwrap(),
r#""llm""#
);
}
#[test]
fn test_routing_decision_new() {
let decision = RoutingDecision::new(TargetModel::Fast, RoutingStrategy::Rule);
assert_eq!(decision.target(), TargetModel::Fast);
assert_eq!(decision.strategy(), RoutingStrategy::Rule);
}
#[test]
fn test_routing_decision_accessors() {
let decision = RoutingDecision::new(TargetModel::Balanced, RoutingStrategy::Llm);
assert_eq!(
decision.target(),
TargetModel::Balanced,
"target() should return the target model"
);
assert_eq!(
decision.strategy(),
RoutingStrategy::Llm,
"strategy() should return the routing strategy"
);
}
#[test]
fn test_routing_decision_accessors_all_variants() {
let test_cases = vec![
(TargetModel::Fast, RoutingStrategy::Rule),
(TargetModel::Balanced, RoutingStrategy::Rule),
(TargetModel::Deep, RoutingStrategy::Rule),
(TargetModel::Fast, RoutingStrategy::Llm),
(TargetModel::Balanced, RoutingStrategy::Llm),
(TargetModel::Deep, RoutingStrategy::Llm),
];
for (target, strategy) in test_cases {
let decision = RoutingDecision::new(target, strategy);
assert_eq!(decision.target(), target);
assert_eq!(decision.strategy(), strategy);
}
}
#[test]
fn test_routing_decision_equality() {
let decision1 = RoutingDecision::new(TargetModel::Balanced, RoutingStrategy::Llm);
let decision2 = RoutingDecision::new(TargetModel::Balanced, RoutingStrategy::Llm);
let decision3 = RoutingDecision::new(TargetModel::Fast, RoutingStrategy::Rule);
assert_eq!(decision1, decision2);
assert_ne!(decision1, decision3);
}
}