use crate::hardware::HardwareInfo;
use crate::models::ModelRegistry;
use serde::{Deserialize, Serialize};
use tracing::debug;
pub use crate::adaptive_router::TaskComplexity;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingDecision {
pub model: String,
pub complexity: TaskComplexity,
pub reason: String,
}
pub struct ModelRouter {
hw: HardwareInfo,
}
const TIERS: &[&str] = &[
"Qwen3-0.6B",
"Qwen3-1.7B",
"Qwen3-4B",
"Qwen3-8B",
"Qwen3-30B-A3B",
];
impl ModelRouter {
pub fn new(hw: HardwareInfo) -> Self {
Self { hw }
}
pub fn route_generate(&self, prompt: &str, registry: &ModelRegistry) -> RoutingDecision {
let complexity = Self::assess_complexity(prompt);
let target = self.target_for_complexity(complexity);
let model = self.best_available(&target, registry);
let reason = format!("{:?} task -> {} (target: {})", complexity, model, target);
debug!(
complexity = ?complexity,
target = %target,
selected = %model,
"routed generation request"
);
RoutingDecision {
model,
complexity,
reason,
}
}
pub fn route_small(&self, registry: &ModelRegistry) -> String {
let model = self.best_available("Qwen3-0.6B", registry);
debug!(selected = %model, "routed small task (classify)");
model
}
pub fn route_embedding(&self, _registry: &ModelRegistry) -> String {
"Qwen3-Embedding-0.6B".to_string()
}
pub fn assess_complexity(prompt: &str) -> TaskComplexity {
let lower = prompt.to_lowercase();
let word_count = prompt.split_whitespace().count();
let estimated_tokens = (word_count as f64 * 1.3) as usize;
let code_markers = [
"```",
"fn ",
"def ",
"class ",
"import ",
"require(",
"async fn",
"pub fn",
"function ",
"const ",
"let ",
"var ",
"#include",
"package ",
"impl ",
];
let has_code = code_markers.iter().any(|m| prompt.contains(m));
let repair_markers = [
"fix", "repair", "debug", "refactor", "broken", "failing", "error", "bug",
];
let has_repair = repair_markers.iter().any(|m| lower.contains(m));
let reasoning_markers = [
"analyze",
"compare",
"explain why",
"step by step",
"think through",
"evaluate",
"trade-off",
"tradeoff",
"pros and cons",
"architecture",
"design",
"strategy",
"optimize",
"comprehensive",
];
let has_reasoning = reasoning_markers.iter().any(|m| lower.contains(m));
let simple_patterns = [
"what is",
"who is",
"when did",
"where is",
"how many",
"yes or no",
"true or false",
"name the",
"list the",
"define ",
];
let is_simple_question = simple_patterns.iter().any(|p| lower.contains(p));
if has_code || has_repair {
TaskComplexity::Code
} else if has_reasoning || estimated_tokens > 500 {
TaskComplexity::Complex
} else if is_simple_question || estimated_tokens < 30 {
TaskComplexity::Simple
} else {
TaskComplexity::Medium
}
}
fn target_for_complexity(&self, complexity: TaskComplexity) -> String {
match complexity {
TaskComplexity::Simple => "Qwen3-0.6B".into(),
TaskComplexity::Medium => "Qwen3-1.7B".into(),
TaskComplexity::Code => "Qwen3-4B".into(),
TaskComplexity::Complex => self.hw.recommended_model.clone(),
}
}
fn best_available(&self, target: &str, registry: &ModelRegistry) -> String {
let models = registry.list_models();
let target_idx = TIERS.iter().position(|&t| t == target).unwrap_or(0);
if let Some(m) = models.iter().find(|m| m.name == target) {
if m.downloaded && m.quantized_size_mb <= self.hw.max_model_mb {
return target.to_string();
}
}
for &name in &TIERS[target_idx..] {
if let Some(m) = models.iter().find(|m| m.name == name) {
if m.downloaded && m.quantized_size_mb <= self.hw.max_model_mb {
return name.to_string();
}
}
}
for &name in TIERS[..target_idx].iter().rev() {
if let Some(m) = models.iter().find(|m| m.name == name) {
if m.downloaded && m.quantized_size_mb <= self.hw.max_model_mb {
return name.to_string();
}
}
}
target.to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn simple_question() {
assert_eq!(
TaskComplexity::assess("What is the capital of France?"),
TaskComplexity::Simple,
);
}
#[test]
fn code_with_function_keyword() {
assert_eq!(
TaskComplexity::assess("Write a function to sort a list"),
TaskComplexity::Code,
);
}
#[test]
fn complex_reasoning() {
assert_eq!(
TaskComplexity::assess(
"Analyze the trade-offs between microservices and monolithic architecture for our use case"
),
TaskComplexity::Complex,
);
}
#[test]
fn medium_default() {
assert_eq!(
TaskComplexity::assess(
"Tell me about the history of computing and how it has evolved over the decades from early mechanical calculation devices through transistors to the modern digital systems we use today in everyday life"
),
TaskComplexity::Medium,
);
}
#[test]
fn code_with_backticks() {
assert_eq!(
TaskComplexity::assess("Here is my code:\n```rust\nfn main() {}\n```\nFix it"),
TaskComplexity::Code,
);
}
#[test]
fn repair_marker() {
assert_eq!(
TaskComplexity::assess("Debug this failing test case"),
TaskComplexity::Code,
);
}
}