Skip to main content

sparrow/router/
mod.rs

1pub mod learned;
2
3use std::sync::Arc;
4
5use crate::provider::{
6    Brain, BrainError, BrainRequest, ContentBlock, LatencyClass, Msg, PromptCacheConfig,
7};
8
9// ─── Routing need ───────────────────────────────────────────────────────────────
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum TaskTier {
13    Trivial,
14    Small,
15    Medium,
16    Hard,
17    Vision,
18}
19
20impl TaskTier {
21    pub fn from_str(s: &str) -> Self {
22        match s.to_lowercase().as_str() {
23            "trivial" => TaskTier::Trivial,
24            "small" => TaskTier::Small,
25            "medium" => TaskTier::Medium,
26            "hard" => TaskTier::Hard,
27            "vision" => TaskTier::Vision,
28            _ => TaskTier::Medium,
29        }
30    }
31
32    pub fn as_str(&self) -> &str {
33        match self {
34            TaskTier::Trivial => "trivial",
35            TaskTier::Small => "small",
36            TaskTier::Medium => "medium",
37            TaskTier::Hard => "hard",
38            TaskTier::Vision => "vision",
39        }
40    }
41}
42
43#[derive(Debug, Clone)]
44pub struct RoutingNeed {
45    pub tier: TaskTier,
46    pub required_tools: bool,
47    pub required_vision: bool,
48    pub prefer_local: bool,
49}
50
51#[derive(Debug, Clone)]
52pub struct BudgetState {
53    pub daily_limit_usd: f64,
54    pub daily_spent_usd: f64,
55    pub session_limit_usd: f64,
56    pub session_spent_usd: f64,
57}
58
59impl BudgetState {
60    pub fn remaining_daily(&self) -> f64 {
61        (self.daily_limit_usd - self.daily_spent_usd).max(0.0)
62    }
63
64    pub fn remaining_session(&self) -> f64 {
65        (self.session_limit_usd - self.session_spent_usd).max(0.0)
66    }
67
68    pub fn is_exhausted(&self) -> bool {
69        self.remaining_daily() <= 0.0 || self.remaining_session() <= 0.0
70    }
71}
72
73// ─── Router trait ───────────────────────────────────────────────────────────────
74
75pub trait Router: Send + Sync {
76    /// Returns an ordered fallback chain of Brains.
77    /// Primary brain first, fallbacks in order.
78    fn select(&self, need: &RoutingNeed, budget: &BudgetState) -> Vec<Arc<dyn Brain>>;
79
80    fn on_error(&self, b: &dyn Brain, e: &BrainError) -> Retry;
81
82    /// Look up a brain by its model ID across all registered providers.
83    /// Used by the WebView model-override path when the user picks a model
84    /// that isn't in the natural routing chain.
85    fn find_brain_by_id(&self, model_id: &str) -> Option<Arc<dyn Brain>>;
86}
87
88#[derive(Debug, Clone, PartialEq, Eq)]
89pub enum Retry {
90    NextInChain,
91    Abort,
92    WaitAndRetry(u64), // seconds
93}
94
95// ─── Basic Router implementation ────────────────────────────────────────────────
96
97use std::collections::HashMap;
98
99use crate::config::Config;
100
101pub struct BasicRouter {
102    /// provider_name -> list of (model_name, Brain)
103    providers: HashMap<String, Vec<Arc<dyn Brain>>>,
104    /// task tier -> preferred provider name
105    policy: HashMap<String, String>,
106    free_first: bool,
107    /// Global preferred provider override — when Some, every tier resolves to
108    /// this provider first (capability constraints still apply).
109    preferred_provider: Option<String>,
110}
111
112impl BasicRouter {
113    pub fn new(config: &Config, providers: HashMap<String, Vec<Arc<dyn Brain>>>) -> Self {
114        let mut policy = HashMap::new();
115        for (k, v) in &config.routing.policy {
116            policy.insert(k.clone(), v.clone());
117        }
118        // Defaults
119        if !policy.contains_key("trivial") {
120            policy.insert("trivial".into(), "local".into());
121        }
122        if !policy.contains_key("hard") {
123            policy.insert("hard".into(), "anthropic".into());
124        }
125
126        Self {
127            providers,
128            policy,
129            free_first: config.routing.free_first,
130            preferred_provider: config.routing.preferred_provider.clone(),
131        }
132    }
133
134    /// Score a brain for a given need: higher is better.
135    fn score(brain: &dyn Brain, need: &RoutingNeed, budget: &BudgetState) -> f64 {
136        let caps = brain.caps();
137        let mut score: f64 = 0.0;
138
139        // Capability fit
140        if need.required_tools {
141            if caps.tools {
142                score += 50.0;
143            } else {
144                score -= 250.0;
145            }
146        }
147        if need.required_vision {
148            if caps.vision {
149                score += 50.0;
150            } else {
151                score -= 300.0;
152            }
153        }
154
155        // Cost preference: prefer cheaper/free models
156        let est_cost = caps.cost_input_per_mtok + caps.cost_output_per_mtok;
157        if est_cost == 0.0 {
158            score += 100.0; // free models get a big boost
159        } else if budget.remaining_session() < est_cost * 0.1 {
160            score -= 200.0; // too expensive for remaining budget
161        } else {
162            score -= est_cost * 10.0; // penalize expensive models
163        }
164
165        // Latency / capability preference — tier-aware.
166        // Cheap tiers want speed; hard tiers want capable (bigger, slower) models.
167        match need.tier {
168            TaskTier::Trivial | TaskTier::Small => match caps.latency {
169                LatencyClass::Fast => score += 15.0,
170                LatencyClass::Medium => score += 6.0,
171                LatencyClass::Slow => score += 0.0,
172            },
173            TaskTier::Medium | TaskTier::Hard | TaskTier::Vision => match caps.latency {
174                LatencyClass::Slow => score += 18.0,
175                LatencyClass::Medium => score += 9.0,
176                LatencyClass::Fast => score += 0.0,
177            },
178        }
179
180        // Context window fit. Weighted more heavily for capable tiers so larger
181        // models (which we infer to have bigger windows) rise for hard tasks.
182        let ctx_weight = match need.tier {
183            TaskTier::Hard | TaskTier::Medium => 20_000.0,
184            _ => 10_000.0,
185        };
186        let ctx_cap = match need.tier {
187            TaskTier::Hard | TaskTier::Medium => 20.0,
188            _ => 10.0,
189        };
190        score += (caps.context_window as f64 / ctx_weight).min(ctx_cap);
191
192        score
193    }
194
195    fn resolve_provider(&self, need: &RoutingNeed) -> &str {
196        // If a global preferred provider is configured, use it for all tiers.
197        // The scoring loop will still add a +25 bonus for this provider so it
198        // bubbles to the top; capability fallback still kicks in if it lacks
199        // tools/vision support.
200        if let Some(ref pref) = self.preferred_provider {
201            return pref.as_str();
202        }
203        self.policy
204            .get(need.tier.as_str())
205            .map(|s| s.as_str())
206            .unwrap_or("anthropic")
207    }
208
209    /// Classify a task using a tiny model call (only for ambiguous cases).
210    /// §3.6: "Classification: heuristic + a tiny model call only if ambiguous."
211    pub async fn classify_with_model(&self, task: &str, brain: &dyn Brain) -> TaskTier {
212        let prompt = format!(
213            "Classify this task into exactly one tier: trivial, small, medium, hard, vision.\n\nTask: {}\n\nTier:",
214            task
215        );
216
217        let req = BrainRequest {
218            system: Some("You are a task classifier. Output exactly one word: trivial, small, medium, hard, or vision.".into()),
219            messages: vec![Msg {
220                role: "user".into(),
221                content: vec![ContentBlock::Text { text: prompt }],
222            }],
223            tools: vec![],
224            max_tokens: 10,
225            temperature: 0.0,
226            stop: vec![],
227            cache: PromptCacheConfig::disabled(),
228        };
229
230        match brain.complete(req).await {
231            Ok(mut stream) => {
232                use futures::StreamExt;
233                let mut result = String::new();
234                while let Some(ev) = stream.next().await {
235                    if let crate::provider::BrainEvent::TextDelta(t) = ev {
236                        result.push_str(&t);
237                    }
238                }
239                TaskTier::from_str(result.trim())
240            }
241            Err(_) => TaskTier::Medium, // fallback
242        }
243    }
244}
245
246impl Router for BasicRouter {
247    fn select(&self, need: &RoutingNeed, budget: &BudgetState) -> Vec<Arc<dyn Brain>> {
248        if budget.is_exhausted() && !need.prefer_local {
249            // Only free/local models remain
250            if let Some(local) = self.providers.get("local") {
251                return local.clone();
252            }
253            return vec![];
254        }
255
256        let preferred_provider = self.resolve_provider(need);
257        let preferred_is_local = preferred_provider == "local" || preferred_provider == "ollama";
258        let mut scored: Vec<(f64, String, Arc<dyn Brain>)> = Vec::new();
259
260        // Score all available brains
261        for (provider_name, brains) in &self.providers {
262            if need.prefer_local && provider_name != "local" && provider_name != "ollama" {
263                continue;
264            }
265            for brain in brains {
266                let mut s = Self::score(brain.as_ref(), need, budget);
267                if provider_name == preferred_provider {
268                    s += 25.0;
269                }
270                if matches!(need.tier, TaskTier::Trivial | TaskTier::Small)
271                    && (provider_name == "local" || provider_name == "ollama")
272                {
273                    s += 30.0;
274                }
275                scored.push((s, provider_name.clone(), brain.clone()));
276            }
277        }
278
279        // Stable sort by score descending (preserves insertion order for ties,
280        // so equal-scored discovered models keep a deterministic order).
281        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
282
283        // Build a PROVIDER-DIVERSE fallback chain. A single preferred provider
284        // (e.g. opencode-zen) can have dozens of discovered models that would
285        // otherwise fill every slot — leaving no fallback if that whole provider
286        // is down (no credits / empty responses). Cap models per provider so
287        // other providers (NVIDIA, ollama…) always get fallback slots.
288        const MAX_CHAIN: usize = 6;
289        const PER_PROVIDER_CAP: usize = 3;
290        let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
291        let mut per_provider: HashMap<String, usize> = HashMap::new();
292        // (provider, brain) so we can reorder by provider below.
293        let mut result: Vec<(String, Arc<dyn Brain>)> = Vec::new();
294
295        // Pass 1: respect the per-provider cap → forces cross-provider diversity.
296        for (_, prov, brain) in &scored {
297            if result.len() >= MAX_CHAIN {
298                break;
299            }
300            let id = brain.id().to_string();
301            if seen.contains(&id) {
302                continue;
303            }
304            let count = per_provider.entry(prov.clone()).or_insert(0);
305            if *count >= PER_PROVIDER_CAP {
306                continue;
307            }
308            *count += 1;
309            seen.insert(id);
310            result.push((prov.clone(), brain.clone()));
311        }
312        // Pass 2: if too few providers to fill the chain, top up over the cap.
313        if result.len() < MAX_CHAIN {
314            for (_, prov, brain) in &scored {
315                if result.len() >= MAX_CHAIN {
316                    break;
317                }
318                let id = brain.id().to_string();
319                if seen.insert(id) {
320                    result.push((prov.clone(), brain.clone()));
321                }
322            }
323        }
324
325        // For trivial/small tasks with a local-preferred policy, put a local/free
326        // model first (works offline, $0).
327        if matches!(need.tier, TaskTier::Trivial | TaskTier::Small)
328            && (preferred_is_local || self.free_first)
329        {
330            if let Some(pos) = result.iter().position(|(prov, b)| {
331                (prov == "local" || prov == "ollama") || b.caps().cost_input_per_mtok == 0.0
332            }) {
333                let chosen = result.remove(pos);
334                result.insert(0, chosen);
335            }
336        }
337
338        result.into_iter().map(|(_, brain)| brain).collect()
339    }
340
341    fn on_error(&self, _b: &dyn Brain, e: &BrainError) -> Retry {
342        match e {
343            BrainError::RateLimit { retry_after } => {
344                if let Some(secs) = retry_after {
345                    if *secs <= 10 {
346                        Retry::WaitAndRetry(*secs)
347                    } else {
348                        Retry::NextInChain
349                    }
350                } else {
351                    Retry::NextInChain
352                }
353            }
354            BrainError::ServerError { status, .. } if *status >= 500 => Retry::NextInChain,
355            BrainError::Timeout => Retry::NextInChain,
356            BrainError::Refusal(_) => Retry::Abort,
357            _ => Retry::NextInChain,
358        }
359    }
360
361    fn find_brain_by_id(&self, model_id: &str) -> Option<Arc<dyn Brain>> {
362        for brains in self.providers.values() {
363            for b in brains {
364                if b.id() == model_id {
365                    return Some(b.clone());
366                }
367            }
368        }
369        None
370    }
371}