use ironclad_core::config::RoutingConfig;
use ironclad_core::{IroncladError, Result};
pub trait RouterBackend: Send + Sync + std::fmt::Debug {
fn classify_complexity(&self, features: &[f32]) -> f64;
}
#[derive(Debug, Default)]
pub struct HeuristicBackend;
impl RouterBackend for HeuristicBackend {
fn classify_complexity(&self, features: &[f32]) -> f64 {
heuristic_classify_complexity(features)
}
}
#[derive(Debug)]
pub struct ModelRouter {
primary: String,
fallbacks: Vec<String>,
current_index: usize,
config: RoutingConfig,
backend: Box<dyn RouterBackend>,
model_override: Option<String>,
}
impl ModelRouter {
pub fn new(
primary: String,
fallbacks: Vec<String>,
config: RoutingConfig,
backend: Box<dyn RouterBackend>,
) -> Self {
Self {
primary,
fallbacks,
current_index: 0,
config,
backend,
model_override: None,
}
}
pub fn classify_complexity(&self, features: &[f32]) -> f64 {
self.backend.classify_complexity(features)
}
pub fn select_model(&self) -> &str {
if let Some(ref ovr) = self.model_override {
return ovr;
}
match self.config.mode.as_str() {
"primary" => &self.primary,
"round-robin" => {
let all_count = 1 + self.fallbacks.len();
let idx = self.current_index % all_count;
if idx == 0 {
&self.primary
} else {
&self.fallbacks[idx - 1]
}
}
_ => {
if self.current_index == 0 {
&self.primary
} else {
self.fallbacks
.get(self.current_index - 1)
.unwrap_or(&self.primary)
}
}
}
}
pub fn set_override(&mut self, model: String) {
self.model_override = Some(model);
}
pub fn clear_override(&mut self) {
self.model_override = None;
}
pub fn get_override(&self) -> Option<&str> {
self.model_override.as_deref()
}
pub fn primary(&self) -> &str {
&self.primary
}
pub fn fallbacks(&self) -> &[String] {
&self.fallbacks
}
pub fn advance_fallback(&mut self) -> Result<&str> {
if self.current_index >= self.fallbacks.len() {
return Err(IroncladError::Llm("all fallback models exhausted".into()));
}
self.current_index += 1;
Ok(self.select_model())
}
pub fn reset(&mut self) {
self.current_index = 0;
}
pub fn current_index(&self) -> usize {
self.current_index
}
pub fn config(&self) -> &RoutingConfig {
&self.config
}
pub fn sync_runtime(&mut self, primary: String, fallbacks: Vec<String>, config: RoutingConfig) {
self.primary = primary;
self.fallbacks = fallbacks;
self.config = config;
self.current_index = 0;
}
}
pub fn extract_features(
message: &str,
tool_call_count: usize,
conversation_depth: usize,
) -> Vec<f32> {
vec![
message.len() as f32,
tool_call_count as f32,
conversation_depth as f32,
]
}
pub fn classify_complexity(features: &[f32]) -> f64 {
HeuristicBackend.classify_complexity(features)
}
fn heuristic_classify_complexity(features: &[f32]) -> f64 {
if features.len() < 3 {
return 0.0;
}
let msg_component = (features[0] as f64 / 1000.0) * 0.3;
let tool_component = (features[1] as f64 / 5.0) * 0.3;
let depth_component = (features[2] as f64 / 10.0) * 0.4;
(msg_component + tool_component + depth_component).clamp(0.0, 1.0)
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> RoutingConfig {
RoutingConfig::default()
}
#[test]
fn select_primary_model() {
let router = ModelRouter::new(
"ollama/qwen3:8b".into(),
vec!["openai/gpt-4o".into()],
test_config(),
Box::new(HeuristicBackend),
);
assert_eq!(router.select_model(), "ollama/qwen3:8b");
assert_eq!(router.current_index(), 0);
}
#[test]
fn advance_through_fallbacks() {
let mut router = ModelRouter::new(
"primary".into(),
vec!["fallback1".into(), "fallback2".into()],
test_config(),
Box::new(HeuristicBackend),
);
let first = router.advance_fallback().unwrap();
assert_eq!(first, "fallback1");
assert_eq!(router.current_index(), 1);
let second = router.advance_fallback().unwrap();
assert_eq!(second, "fallback2");
assert_eq!(router.current_index(), 2);
}
#[test]
fn model_override_takes_precedence() {
let mut router = ModelRouter::new(
"ollama/qwen3:8b".into(),
vec!["openai/gpt-4o".into()],
test_config(),
Box::new(HeuristicBackend),
);
router.set_override("anthropic/claude-sonnet".into());
assert_eq!(router.select_model(), "anthropic/claude-sonnet");
assert_eq!(router.get_override(), Some("anthropic/claude-sonnet"));
router.clear_override();
assert_eq!(router.select_model(), "ollama/qwen3:8b");
}
#[test]
fn complexity_classification() {
let features = extract_features("hello world", 0, 0);
let score = classify_complexity(&features);
assert!((0.0..=1.0).contains(&score));
let long_msg = "x".repeat(2000);
let heavy = extract_features(&long_msg, 5, 10);
let heavy_score = classify_complexity(&heavy);
assert!(heavy_score > 0.5);
}
}