use std::sync::Arc;
use crate::provider::{
Brain, BrainError, BrainRequest, ContentBlock, LatencyClass, Msg, PromptCacheConfig,
};
#[derive(Debug, Clone)]
pub enum TaskTier {
Trivial,
Small,
Medium,
Hard,
Vision,
}
impl TaskTier {
pub fn from_str(s: &str) -> Self {
match s.to_lowercase().as_str() {
"trivial" => TaskTier::Trivial,
"small" => TaskTier::Small,
"medium" => TaskTier::Medium,
"hard" => TaskTier::Hard,
"vision" => TaskTier::Vision,
_ => TaskTier::Medium,
}
}
pub fn as_str(&self) -> &str {
match self {
TaskTier::Trivial => "trivial",
TaskTier::Small => "small",
TaskTier::Medium => "medium",
TaskTier::Hard => "hard",
TaskTier::Vision => "vision",
}
}
}
#[derive(Debug, Clone)]
pub struct RoutingNeed {
pub tier: TaskTier,
pub required_tools: bool,
pub required_vision: bool,
pub prefer_local: bool,
}
#[derive(Debug, Clone)]
pub struct BudgetState {
pub daily_limit_usd: f64,
pub daily_spent_usd: f64,
pub session_limit_usd: f64,
pub session_spent_usd: f64,
}
impl BudgetState {
pub fn remaining_daily(&self) -> f64 {
(self.daily_limit_usd - self.daily_spent_usd).max(0.0)
}
pub fn remaining_session(&self) -> f64 {
(self.session_limit_usd - self.session_spent_usd).max(0.0)
}
pub fn is_exhausted(&self) -> bool {
self.remaining_daily() <= 0.0 || self.remaining_session() <= 0.0
}
}
pub trait Router: Send + Sync {
fn select(&self, need: &RoutingNeed, budget: &BudgetState) -> Vec<Arc<dyn Brain>>;
fn on_error(&self, b: &dyn Brain, e: &BrainError) -> Retry;
fn find_brain_by_id(&self, model_id: &str) -> Option<Arc<dyn Brain>>;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Retry {
NextInChain,
Abort,
WaitAndRetry(u64), }
use std::collections::HashMap;
use crate::config::Config;
pub struct BasicRouter {
providers: HashMap<String, Vec<Arc<dyn Brain>>>,
policy: HashMap<String, String>,
free_first: bool,
preferred_provider: Option<String>,
}
impl BasicRouter {
pub fn new(config: &Config, providers: HashMap<String, Vec<Arc<dyn Brain>>>) -> Self {
let mut policy = HashMap::new();
for (k, v) in &config.routing.policy {
policy.insert(k.clone(), v.clone());
}
if !policy.contains_key("trivial") {
policy.insert("trivial".into(), "local".into());
}
if !policy.contains_key("hard") {
policy.insert("hard".into(), "anthropic".into());
}
Self {
providers,
policy,
free_first: config.routing.free_first,
preferred_provider: config.routing.preferred_provider.clone(),
}
}
fn score(brain: &dyn Brain, need: &RoutingNeed, budget: &BudgetState) -> f64 {
let caps = brain.caps();
let mut score: f64 = 0.0;
if need.required_tools {
if caps.tools {
score += 50.0;
} else {
score -= 250.0;
}
}
if need.required_vision {
if caps.vision {
score += 50.0;
} else {
score -= 300.0;
}
}
let est_cost = caps.cost_input_per_mtok + caps.cost_output_per_mtok;
if est_cost == 0.0 {
score += 100.0; } else if budget.remaining_session() < est_cost * 0.1 {
score -= 200.0; } else {
score -= est_cost * 10.0; }
match need.tier {
TaskTier::Trivial | TaskTier::Small => match caps.latency {
LatencyClass::Fast => score += 15.0,
LatencyClass::Medium => score += 6.0,
LatencyClass::Slow => score += 0.0,
},
TaskTier::Medium | TaskTier::Hard | TaskTier::Vision => match caps.latency {
LatencyClass::Slow => score += 18.0,
LatencyClass::Medium => score += 9.0,
LatencyClass::Fast => score += 0.0,
},
}
let ctx_weight = match need.tier {
TaskTier::Hard | TaskTier::Medium => 20_000.0,
_ => 10_000.0,
};
let ctx_cap = match need.tier {
TaskTier::Hard | TaskTier::Medium => 20.0,
_ => 10.0,
};
score += (caps.context_window as f64 / ctx_weight).min(ctx_cap);
score
}
fn resolve_provider(&self, need: &RoutingNeed) -> &str {
if let Some(ref pref) = self.preferred_provider {
return pref.as_str();
}
self.policy
.get(need.tier.as_str())
.map(|s| s.as_str())
.unwrap_or("anthropic")
}
pub async fn classify_with_model(&self, task: &str, brain: &dyn Brain) -> TaskTier {
let prompt = format!(
"Classify this task into exactly one tier: trivial, small, medium, hard, vision.\n\nTask: {}\n\nTier:",
task
);
let req = BrainRequest {
system: Some("You are a task classifier. Output exactly one word: trivial, small, medium, hard, or vision.".into()),
messages: vec![Msg {
role: "user".into(),
content: vec![ContentBlock::Text { text: prompt }],
}],
tools: vec![],
max_tokens: 10,
temperature: 0.0,
stop: vec![],
cache: PromptCacheConfig::disabled(),
};
match brain.complete(req).await {
Ok(mut stream) => {
use futures::StreamExt;
let mut result = String::new();
while let Some(ev) = stream.next().await {
if let crate::provider::BrainEvent::TextDelta(t) = ev {
result.push_str(&t);
}
}
TaskTier::from_str(result.trim())
}
Err(_) => TaskTier::Medium, }
}
}
impl Router for BasicRouter {
fn select(&self, need: &RoutingNeed, budget: &BudgetState) -> Vec<Arc<dyn Brain>> {
if budget.is_exhausted() && !need.prefer_local {
if let Some(local) = self.providers.get("local") {
return local.clone();
}
return vec![];
}
let preferred_provider = self.resolve_provider(need);
let preferred_is_local = preferred_provider == "local" || preferred_provider == "ollama";
let mut scored: Vec<(f64, String, Arc<dyn Brain>)> = Vec::new();
for (provider_name, brains) in &self.providers {
if need.prefer_local && provider_name != "local" && provider_name != "ollama" {
continue;
}
for brain in brains {
let mut s = Self::score(brain.as_ref(), need, budget);
if provider_name == preferred_provider {
s += 25.0;
}
if matches!(need.tier, TaskTier::Trivial | TaskTier::Small)
&& (provider_name == "local" || provider_name == "ollama")
{
s += 30.0;
}
scored.push((s, provider_name.clone(), brain.clone()));
}
}
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
const MAX_CHAIN: usize = 6;
const PER_PROVIDER_CAP: usize = 3;
let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
let mut per_provider: HashMap<String, usize> = HashMap::new();
let mut result: Vec<(String, Arc<dyn Brain>)> = Vec::new();
for (_, prov, brain) in &scored {
if result.len() >= MAX_CHAIN {
break;
}
let id = brain.id().to_string();
if seen.contains(&id) {
continue;
}
let count = per_provider.entry(prov.clone()).or_insert(0);
if *count >= PER_PROVIDER_CAP {
continue;
}
*count += 1;
seen.insert(id);
result.push((prov.clone(), brain.clone()));
}
if result.len() < MAX_CHAIN {
for (_, prov, brain) in &scored {
if result.len() >= MAX_CHAIN {
break;
}
let id = brain.id().to_string();
if seen.insert(id) {
result.push((prov.clone(), brain.clone()));
}
}
}
if matches!(need.tier, TaskTier::Trivial | TaskTier::Small)
&& (preferred_is_local || self.free_first)
{
if let Some(pos) = result.iter().position(|(prov, b)| {
(prov == "local" || prov == "ollama") || b.caps().cost_input_per_mtok == 0.0
}) {
let chosen = result.remove(pos);
result.insert(0, chosen);
}
}
result.into_iter().map(|(_, brain)| brain).collect()
}
fn on_error(&self, _b: &dyn Brain, e: &BrainError) -> Retry {
match e {
BrainError::RateLimit { retry_after } => {
if let Some(secs) = retry_after {
if *secs <= 10 {
Retry::WaitAndRetry(*secs)
} else {
Retry::NextInChain
}
} else {
Retry::NextInChain
}
}
BrainError::ServerError { status, .. } if *status >= 500 => Retry::NextInChain,
BrainError::Timeout => Retry::NextInChain,
BrainError::Refusal(_) => Retry::Abort,
_ => Retry::NextInChain,
}
}
fn find_brain_by_id(&self, model_id: &str) -> Option<Arc<dyn Brain>> {
for brains in self.providers.values() {
for b in brains {
if b.id() == model_id {
return Some(b.clone());
}
}
}
None
}
}