Skip to main content

sparrow/router/
mod.rs

1use std::sync::Arc;
2
3use crate::provider::{
4    Brain, BrainError, BrainRequest, ContentBlock, LatencyClass, Msg, PromptCacheConfig,
5};
6
7// ─── Routing need ───────────────────────────────────────────────────────────────
8
9#[derive(Debug, Clone)]
10pub enum TaskTier {
11    Trivial,
12    Small,
13    Medium,
14    Hard,
15    Vision,
16}
17
18impl TaskTier {
19    pub fn from_str(s: &str) -> Self {
20        match s.to_lowercase().as_str() {
21            "trivial" => TaskTier::Trivial,
22            "small" => TaskTier::Small,
23            "medium" => TaskTier::Medium,
24            "hard" => TaskTier::Hard,
25            "vision" => TaskTier::Vision,
26            _ => TaskTier::Medium,
27        }
28    }
29
30    pub fn as_str(&self) -> &str {
31        match self {
32            TaskTier::Trivial => "trivial",
33            TaskTier::Small => "small",
34            TaskTier::Medium => "medium",
35            TaskTier::Hard => "hard",
36            TaskTier::Vision => "vision",
37        }
38    }
39}
40
41#[derive(Debug, Clone)]
42pub struct RoutingNeed {
43    pub tier: TaskTier,
44    pub required_tools: bool,
45    pub required_vision: bool,
46    pub prefer_local: bool,
47}
48
49#[derive(Debug, Clone)]
50pub struct BudgetState {
51    pub daily_limit_usd: f64,
52    pub daily_spent_usd: f64,
53    pub session_limit_usd: f64,
54    pub session_spent_usd: f64,
55}
56
57impl BudgetState {
58    pub fn remaining_daily(&self) -> f64 {
59        (self.daily_limit_usd - self.daily_spent_usd).max(0.0)
60    }
61
62    pub fn remaining_session(&self) -> f64 {
63        (self.session_limit_usd - self.session_spent_usd).max(0.0)
64    }
65
66    pub fn is_exhausted(&self) -> bool {
67        self.remaining_daily() <= 0.0 || self.remaining_session() <= 0.0
68    }
69}
70
71// ─── Router trait ───────────────────────────────────────────────────────────────
72
73pub trait Router: Send + Sync {
74    /// Returns an ordered fallback chain of Brains.
75    /// Primary brain first, fallbacks in order.
76    fn select(&self, need: &RoutingNeed, budget: &BudgetState) -> Vec<Arc<dyn Brain>>;
77
78    fn on_error(&self, b: &dyn Brain, e: &BrainError) -> Retry;
79
80    /// Look up a brain by its model ID across all registered providers.
81    /// Used by the WebView model-override path when the user picks a model
82    /// that isn't in the natural routing chain.
83    fn find_brain_by_id(&self, model_id: &str) -> Option<Arc<dyn Brain>>;
84}
85
86#[derive(Debug, Clone, PartialEq, Eq)]
87pub enum Retry {
88    NextInChain,
89    Abort,
90    WaitAndRetry(u64), // seconds
91}
92
93// ─── Basic Router implementation ────────────────────────────────────────────────
94
95use std::collections::HashMap;
96
97use crate::config::Config;
98
99pub struct BasicRouter {
100    /// provider_name -> list of (model_name, Brain)
101    providers: HashMap<String, Vec<Arc<dyn Brain>>>,
102    /// task tier -> preferred provider name
103    policy: HashMap<String, String>,
104    free_first: bool,
105    /// Global preferred provider override — when Some, every tier resolves to
106    /// this provider first (capability constraints still apply).
107    preferred_provider: Option<String>,
108}
109
110impl BasicRouter {
111    pub fn new(config: &Config, providers: HashMap<String, Vec<Arc<dyn Brain>>>) -> Self {
112        let mut policy = HashMap::new();
113        for (k, v) in &config.routing.policy {
114            policy.insert(k.clone(), v.clone());
115        }
116        // Defaults
117        if !policy.contains_key("trivial") {
118            policy.insert("trivial".into(), "local".into());
119        }
120        if !policy.contains_key("hard") {
121            policy.insert("hard".into(), "anthropic".into());
122        }
123
124        Self {
125            providers,
126            policy,
127            free_first: config.routing.free_first,
128            preferred_provider: config.routing.preferred_provider.clone(),
129        }
130    }
131
132    /// Score a brain for a given need: higher is better.
133    fn score(brain: &dyn Brain, need: &RoutingNeed, budget: &BudgetState) -> f64 {
134        let caps = brain.caps();
135        let mut score: f64 = 0.0;
136
137        // Capability fit
138        if need.required_tools {
139            if caps.tools {
140                score += 50.0;
141            } else {
142                score -= 250.0;
143            }
144        }
145        if need.required_vision {
146            if caps.vision {
147                score += 50.0;
148            } else {
149                score -= 300.0;
150            }
151        }
152
153        // Cost preference: prefer cheaper/free models
154        let est_cost = caps.cost_input_per_mtok + caps.cost_output_per_mtok;
155        if est_cost == 0.0 {
156            score += 100.0; // free models get a big boost
157        } else if budget.remaining_session() < est_cost * 0.1 {
158            score -= 200.0; // too expensive for remaining budget
159        } else {
160            score -= est_cost * 10.0; // penalize expensive models
161        }
162
163        // Latency / capability preference — tier-aware.
164        // Cheap tiers want speed; hard tiers want capable (bigger, slower) models.
165        match need.tier {
166            TaskTier::Trivial | TaskTier::Small => match caps.latency {
167                LatencyClass::Fast => score += 15.0,
168                LatencyClass::Medium => score += 6.0,
169                LatencyClass::Slow => score += 0.0,
170            },
171            TaskTier::Medium | TaskTier::Hard | TaskTier::Vision => match caps.latency {
172                LatencyClass::Slow => score += 18.0,
173                LatencyClass::Medium => score += 9.0,
174                LatencyClass::Fast => score += 0.0,
175            },
176        }
177
178        // Context window fit. Weighted more heavily for capable tiers so larger
179        // models (which we infer to have bigger windows) rise for hard tasks.
180        let ctx_weight = match need.tier {
181            TaskTier::Hard | TaskTier::Medium => 20_000.0,
182            _ => 10_000.0,
183        };
184        let ctx_cap = match need.tier {
185            TaskTier::Hard | TaskTier::Medium => 20.0,
186            _ => 10.0,
187        };
188        score += (caps.context_window as f64 / ctx_weight).min(ctx_cap);
189
190        score
191    }
192
193    fn resolve_provider(&self, need: &RoutingNeed) -> &str {
194        // If a global preferred provider is configured, use it for all tiers.
195        // The scoring loop will still add a +25 bonus for this provider so it
196        // bubbles to the top; capability fallback still kicks in if it lacks
197        // tools/vision support.
198        if let Some(ref pref) = self.preferred_provider {
199            return pref.as_str();
200        }
201        self.policy
202            .get(need.tier.as_str())
203            .map(|s| s.as_str())
204            .unwrap_or("anthropic")
205    }
206
207    /// Classify a task using a tiny model call (only for ambiguous cases).
208    /// §3.6: "Classification: heuristic + a tiny model call only if ambiguous."
209    pub async fn classify_with_model(&self, task: &str, brain: &dyn Brain) -> TaskTier {
210        let prompt = format!(
211            "Classify this task into exactly one tier: trivial, small, medium, hard, vision.\n\nTask: {}\n\nTier:",
212            task
213        );
214
215        let req = BrainRequest {
216            system: Some("You are a task classifier. Output exactly one word: trivial, small, medium, hard, or vision.".into()),
217            messages: vec![Msg {
218                role: "user".into(),
219                content: vec![ContentBlock::Text { text: prompt }],
220            }],
221            tools: vec![],
222            max_tokens: 10,
223            temperature: 0.0,
224            stop: vec![],
225            cache: PromptCacheConfig::disabled(),
226        };
227
228        match brain.complete(req).await {
229            Ok(mut stream) => {
230                use futures::StreamExt;
231                let mut result = String::new();
232                while let Some(ev) = stream.next().await {
233                    if let crate::provider::BrainEvent::TextDelta(t) = ev {
234                        result.push_str(&t);
235                    }
236                }
237                TaskTier::from_str(result.trim())
238            }
239            Err(_) => TaskTier::Medium, // fallback
240        }
241    }
242}
243
244impl Router for BasicRouter {
245    fn select(&self, need: &RoutingNeed, budget: &BudgetState) -> Vec<Arc<dyn Brain>> {
246        if budget.is_exhausted() && !need.prefer_local {
247            // Only free/local models remain
248            if let Some(local) = self.providers.get("local") {
249                return local.clone();
250            }
251            return vec![];
252        }
253
254        let preferred_provider = self.resolve_provider(need);
255        let preferred_is_local = preferred_provider == "local" || preferred_provider == "ollama";
256        let mut scored: Vec<(f64, String, Arc<dyn Brain>)> = Vec::new();
257
258        // Score all available brains
259        for (provider_name, brains) in &self.providers {
260            if need.prefer_local && provider_name != "local" && provider_name != "ollama" {
261                continue;
262            }
263            for brain in brains {
264                let mut s = Self::score(brain.as_ref(), need, budget);
265                if provider_name == preferred_provider {
266                    s += 25.0;
267                }
268                if matches!(need.tier, TaskTier::Trivial | TaskTier::Small)
269                    && (provider_name == "local" || provider_name == "ollama")
270                {
271                    s += 30.0;
272                }
273                scored.push((s, provider_name.clone(), brain.clone()));
274            }
275        }
276
277        // Stable sort by score descending (preserves insertion order for ties,
278        // so equal-scored discovered models keep a deterministic order).
279        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
280
281        // Build a PROVIDER-DIVERSE fallback chain. A single preferred provider
282        // (e.g. opencode-zen) can have dozens of discovered models that would
283        // otherwise fill every slot — leaving no fallback if that whole provider
284        // is down (no credits / empty responses). Cap models per provider so
285        // other providers (NVIDIA, ollama…) always get fallback slots.
286        const MAX_CHAIN: usize = 6;
287        const PER_PROVIDER_CAP: usize = 3;
288        let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
289        let mut per_provider: HashMap<String, usize> = HashMap::new();
290        // (provider, brain) so we can reorder by provider below.
291        let mut result: Vec<(String, Arc<dyn Brain>)> = Vec::new();
292
293        // Pass 1: respect the per-provider cap → forces cross-provider diversity.
294        for (_, prov, brain) in &scored {
295            if result.len() >= MAX_CHAIN {
296                break;
297            }
298            let id = brain.id().to_string();
299            if seen.contains(&id) {
300                continue;
301            }
302            let count = per_provider.entry(prov.clone()).or_insert(0);
303            if *count >= PER_PROVIDER_CAP {
304                continue;
305            }
306            *count += 1;
307            seen.insert(id);
308            result.push((prov.clone(), brain.clone()));
309        }
310        // Pass 2: if too few providers to fill the chain, top up over the cap.
311        if result.len() < MAX_CHAIN {
312            for (_, prov, brain) in &scored {
313                if result.len() >= MAX_CHAIN {
314                    break;
315                }
316                let id = brain.id().to_string();
317                if seen.insert(id) {
318                    result.push((prov.clone(), brain.clone()));
319                }
320            }
321        }
322
323        // For trivial/small tasks with a local-preferred policy, put a local/free
324        // model first (works offline, $0).
325        if matches!(need.tier, TaskTier::Trivial | TaskTier::Small)
326            && (preferred_is_local || self.free_first)
327        {
328            if let Some(pos) = result.iter().position(|(prov, b)| {
329                (prov == "local" || prov == "ollama") || b.caps().cost_input_per_mtok == 0.0
330            }) {
331                let chosen = result.remove(pos);
332                result.insert(0, chosen);
333            }
334        }
335
336        result.into_iter().map(|(_, brain)| brain).collect()
337    }
338
339    fn on_error(&self, _b: &dyn Brain, e: &BrainError) -> Retry {
340        match e {
341            BrainError::RateLimit { retry_after } => {
342                if let Some(secs) = retry_after {
343                    if *secs <= 10 {
344                        Retry::WaitAndRetry(*secs)
345                    } else {
346                        Retry::NextInChain
347                    }
348                } else {
349                    Retry::NextInChain
350                }
351            }
352            BrainError::ServerError { status, .. } if *status >= 500 => Retry::NextInChain,
353            BrainError::Timeout => Retry::NextInChain,
354            BrainError::Refusal(_) => Retry::Abort,
355            _ => Retry::NextInChain,
356        }
357    }
358
359    fn find_brain_by_id(&self, model_id: &str) -> Option<Arc<dyn Brain>> {
360        for brains in self.providers.values() {
361            for b in brains {
362                if b.id() == model_id {
363                    return Some(b.clone());
364                }
365            }
366        }
367        None
368    }
369}