1use std::sync::Arc;
2
3use crate::provider::{
4 Brain, BrainError, BrainRequest, ContentBlock, LatencyClass, Msg, PromptCacheConfig,
5};
6
7#[derive(Debug, Clone)]
10pub enum TaskTier {
11 Trivial,
12 Small,
13 Medium,
14 Hard,
15 Vision,
16}
17
18impl TaskTier {
19 pub fn from_str(s: &str) -> Self {
20 match s.to_lowercase().as_str() {
21 "trivial" => TaskTier::Trivial,
22 "small" => TaskTier::Small,
23 "medium" => TaskTier::Medium,
24 "hard" => TaskTier::Hard,
25 "vision" => TaskTier::Vision,
26 _ => TaskTier::Medium,
27 }
28 }
29
30 pub fn as_str(&self) -> &str {
31 match self {
32 TaskTier::Trivial => "trivial",
33 TaskTier::Small => "small",
34 TaskTier::Medium => "medium",
35 TaskTier::Hard => "hard",
36 TaskTier::Vision => "vision",
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
42pub struct RoutingNeed {
43 pub tier: TaskTier,
44 pub required_tools: bool,
45 pub required_vision: bool,
46 pub prefer_local: bool,
47}
48
49#[derive(Debug, Clone)]
50pub struct BudgetState {
51 pub daily_limit_usd: f64,
52 pub daily_spent_usd: f64,
53 pub session_limit_usd: f64,
54 pub session_spent_usd: f64,
55}
56
57impl BudgetState {
58 pub fn remaining_daily(&self) -> f64 {
59 (self.daily_limit_usd - self.daily_spent_usd).max(0.0)
60 }
61
62 pub fn remaining_session(&self) -> f64 {
63 (self.session_limit_usd - self.session_spent_usd).max(0.0)
64 }
65
66 pub fn is_exhausted(&self) -> bool {
67 self.remaining_daily() <= 0.0 || self.remaining_session() <= 0.0
68 }
69}
70
71pub trait Router: Send + Sync {
74 fn select(&self, need: &RoutingNeed, budget: &BudgetState) -> Vec<Arc<dyn Brain>>;
77
78 fn on_error(&self, b: &dyn Brain, e: &BrainError) -> Retry;
79
80 fn find_brain_by_id(&self, model_id: &str) -> Option<Arc<dyn Brain>>;
84}
85
86#[derive(Debug, Clone, PartialEq, Eq)]
87pub enum Retry {
88 NextInChain,
89 Abort,
90 WaitAndRetry(u64), }
92
93use std::collections::HashMap;
96
97use crate::config::Config;
98
99pub struct BasicRouter {
100 providers: HashMap<String, Vec<Arc<dyn Brain>>>,
102 policy: HashMap<String, String>,
104 free_first: bool,
105 preferred_provider: Option<String>,
108}
109
110impl BasicRouter {
111 pub fn new(config: &Config, providers: HashMap<String, Vec<Arc<dyn Brain>>>) -> Self {
112 let mut policy = HashMap::new();
113 for (k, v) in &config.routing.policy {
114 policy.insert(k.clone(), v.clone());
115 }
116 if !policy.contains_key("trivial") {
118 policy.insert("trivial".into(), "local".into());
119 }
120 if !policy.contains_key("hard") {
121 policy.insert("hard".into(), "anthropic".into());
122 }
123
124 Self {
125 providers,
126 policy,
127 free_first: config.routing.free_first,
128 preferred_provider: config.routing.preferred_provider.clone(),
129 }
130 }
131
132 fn score(brain: &dyn Brain, need: &RoutingNeed, budget: &BudgetState) -> f64 {
134 let caps = brain.caps();
135 let mut score: f64 = 0.0;
136
137 if need.required_tools {
139 if caps.tools {
140 score += 50.0;
141 } else {
142 score -= 250.0;
143 }
144 }
145 if need.required_vision {
146 if caps.vision {
147 score += 50.0;
148 } else {
149 score -= 300.0;
150 }
151 }
152
153 let est_cost = caps.cost_input_per_mtok + caps.cost_output_per_mtok;
155 if est_cost == 0.0 {
156 score += 100.0; } else if budget.remaining_session() < est_cost * 0.1 {
158 score -= 200.0; } else {
160 score -= est_cost * 10.0; }
162
163 match need.tier {
166 TaskTier::Trivial | TaskTier::Small => match caps.latency {
167 LatencyClass::Fast => score += 15.0,
168 LatencyClass::Medium => score += 6.0,
169 LatencyClass::Slow => score += 0.0,
170 },
171 TaskTier::Medium | TaskTier::Hard | TaskTier::Vision => match caps.latency {
172 LatencyClass::Slow => score += 18.0,
173 LatencyClass::Medium => score += 9.0,
174 LatencyClass::Fast => score += 0.0,
175 },
176 }
177
178 let ctx_weight = match need.tier {
181 TaskTier::Hard | TaskTier::Medium => 20_000.0,
182 _ => 10_000.0,
183 };
184 let ctx_cap = match need.tier {
185 TaskTier::Hard | TaskTier::Medium => 20.0,
186 _ => 10.0,
187 };
188 score += (caps.context_window as f64 / ctx_weight).min(ctx_cap);
189
190 score
191 }
192
193 fn resolve_provider(&self, need: &RoutingNeed) -> &str {
194 if let Some(ref pref) = self.preferred_provider {
199 return pref.as_str();
200 }
201 self.policy
202 .get(need.tier.as_str())
203 .map(|s| s.as_str())
204 .unwrap_or("anthropic")
205 }
206
207 pub async fn classify_with_model(&self, task: &str, brain: &dyn Brain) -> TaskTier {
210 let prompt = format!(
211 "Classify this task into exactly one tier: trivial, small, medium, hard, vision.\n\nTask: {}\n\nTier:",
212 task
213 );
214
215 let req = BrainRequest {
216 system: Some("You are a task classifier. Output exactly one word: trivial, small, medium, hard, or vision.".into()),
217 messages: vec![Msg {
218 role: "user".into(),
219 content: vec![ContentBlock::Text { text: prompt }],
220 }],
221 tools: vec![],
222 max_tokens: 10,
223 temperature: 0.0,
224 stop: vec![],
225 cache: PromptCacheConfig::disabled(),
226 };
227
228 match brain.complete(req).await {
229 Ok(mut stream) => {
230 use futures::StreamExt;
231 let mut result = String::new();
232 while let Some(ev) = stream.next().await {
233 if let crate::provider::BrainEvent::TextDelta(t) = ev {
234 result.push_str(&t);
235 }
236 }
237 TaskTier::from_str(result.trim())
238 }
239 Err(_) => TaskTier::Medium, }
241 }
242}
243
244impl Router for BasicRouter {
245 fn select(&self, need: &RoutingNeed, budget: &BudgetState) -> Vec<Arc<dyn Brain>> {
246 if budget.is_exhausted() && !need.prefer_local {
247 if let Some(local) = self.providers.get("local") {
249 return local.clone();
250 }
251 return vec![];
252 }
253
254 let preferred_provider = self.resolve_provider(need);
255 let preferred_is_local = preferred_provider == "local" || preferred_provider == "ollama";
256 let mut scored: Vec<(f64, String, Arc<dyn Brain>)> = Vec::new();
257
258 for (provider_name, brains) in &self.providers {
260 if need.prefer_local && provider_name != "local" && provider_name != "ollama" {
261 continue;
262 }
263 for brain in brains {
264 let mut s = Self::score(brain.as_ref(), need, budget);
265 if provider_name == preferred_provider {
266 s += 25.0;
267 }
268 if matches!(need.tier, TaskTier::Trivial | TaskTier::Small)
269 && (provider_name == "local" || provider_name == "ollama")
270 {
271 s += 30.0;
272 }
273 scored.push((s, provider_name.clone(), brain.clone()));
274 }
275 }
276
277 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
280
281 const MAX_CHAIN: usize = 6;
287 const PER_PROVIDER_CAP: usize = 3;
288 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
289 let mut per_provider: HashMap<String, usize> = HashMap::new();
290 let mut result: Vec<(String, Arc<dyn Brain>)> = Vec::new();
292
293 for (_, prov, brain) in &scored {
295 if result.len() >= MAX_CHAIN {
296 break;
297 }
298 let id = brain.id().to_string();
299 if seen.contains(&id) {
300 continue;
301 }
302 let count = per_provider.entry(prov.clone()).or_insert(0);
303 if *count >= PER_PROVIDER_CAP {
304 continue;
305 }
306 *count += 1;
307 seen.insert(id);
308 result.push((prov.clone(), brain.clone()));
309 }
310 if result.len() < MAX_CHAIN {
312 for (_, prov, brain) in &scored {
313 if result.len() >= MAX_CHAIN {
314 break;
315 }
316 let id = brain.id().to_string();
317 if seen.insert(id) {
318 result.push((prov.clone(), brain.clone()));
319 }
320 }
321 }
322
323 if matches!(need.tier, TaskTier::Trivial | TaskTier::Small)
326 && (preferred_is_local || self.free_first)
327 {
328 if let Some(pos) = result.iter().position(|(prov, b)| {
329 (prov == "local" || prov == "ollama") || b.caps().cost_input_per_mtok == 0.0
330 }) {
331 let chosen = result.remove(pos);
332 result.insert(0, chosen);
333 }
334 }
335
336 result.into_iter().map(|(_, brain)| brain).collect()
337 }
338
339 fn on_error(&self, _b: &dyn Brain, e: &BrainError) -> Retry {
340 match e {
341 BrainError::RateLimit { retry_after } => {
342 if let Some(secs) = retry_after {
343 if *secs <= 10 {
344 Retry::WaitAndRetry(*secs)
345 } else {
346 Retry::NextInChain
347 }
348 } else {
349 Retry::NextInChain
350 }
351 }
352 BrainError::ServerError { status, .. } if *status >= 500 => Retry::NextInChain,
353 BrainError::Timeout => Retry::NextInChain,
354 BrainError::Refusal(_) => Retry::Abort,
355 _ => Retry::NextInChain,
356 }
357 }
358
359 fn find_brain_by_id(&self, model_id: &str) -> Option<Arc<dyn Brain>> {
360 for brains in self.providers.values() {
361 for b in brains {
362 if b.id() == model_id {
363 return Some(b.clone());
364 }
365 }
366 }
367 None
368 }
369}