Skip to main content

sparrow/router/
mod.rs

1pub mod learned;
2
3use std::sync::Arc;
4
5use crate::provider::{
6    Brain, BrainError, BrainRequest, ContentBlock, LatencyClass, Msg, PromptCacheConfig,
7};
8
9// ─── Routing need ───────────────────────────────────────────────────────────────
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum TaskTier {
13    Trivial,
14    Small,
15    Medium,
16    Hard,
17    Vision,
18}
19
20impl TaskTier {
21    pub fn from_str(s: &str) -> Self {
22        match s.to_lowercase().as_str() {
23            "trivial" => TaskTier::Trivial,
24            "small" => TaskTier::Small,
25            "medium" => TaskTier::Medium,
26            "hard" => TaskTier::Hard,
27            "vision" => TaskTier::Vision,
28            _ => TaskTier::Medium,
29        }
30    }
31
32    pub fn as_str(&self) -> &str {
33        match self {
34            TaskTier::Trivial => "trivial",
35            TaskTier::Small => "small",
36            TaskTier::Medium => "medium",
37            TaskTier::Hard => "hard",
38            TaskTier::Vision => "vision",
39        }
40    }
41}
42
43#[derive(Debug, Clone)]
44pub struct RoutingNeed {
45    pub tier: TaskTier,
46    pub required_tools: bool,
47    pub required_vision: bool,
48    pub prefer_local: bool,
49}
50
51#[derive(Debug, Clone)]
52pub struct BudgetState {
53    pub daily_limit_usd: f64,
54    pub daily_spent_usd: f64,
55    pub session_limit_usd: f64,
56    pub session_spent_usd: f64,
57}
58
59impl BudgetState {
60    pub fn remaining_daily(&self) -> f64 {
61        (self.daily_limit_usd - self.daily_spent_usd).max(0.0)
62    }
63
64    pub fn remaining_session(&self) -> f64 {
65        (self.session_limit_usd - self.session_spent_usd).max(0.0)
66    }
67
68    pub fn is_exhausted(&self) -> bool {
69        self.remaining_daily() <= 0.0 || self.remaining_session() <= 0.0
70    }
71}
72
73// ─── Router trait ───────────────────────────────────────────────────────────────
74
75pub trait Router: Send + Sync {
76    /// Returns an ordered fallback chain of Brains.
77    /// Primary brain first, fallbacks in order.
78    fn select(&self, need: &RoutingNeed, budget: &BudgetState) -> Vec<Arc<dyn Brain>>;
79
80    fn on_error(&self, b: &dyn Brain, e: &BrainError) -> Retry;
81
82    /// Look up a brain by its model ID across all registered providers.
83    /// Used by the WebView model-override path when the user picks a model
84    /// that isn't in the natural routing chain.
85    fn find_brain_by_id(&self, model_id: &str) -> Option<Arc<dyn Brain>>;
86}
87
88#[derive(Debug, Clone, PartialEq, Eq)]
89pub enum Retry {
90    NextInChain,
91    Abort,
92    WaitAndRetry(u64), // seconds
93}
94
95// ─── Basic Router implementation ────────────────────────────────────────────────
96
97use std::collections::HashMap;
98
99use crate::config::Config;
100
101pub struct BasicRouter {
102    /// provider_name -> list of (model_name, Brain)
103    providers: HashMap<String, Vec<Arc<dyn Brain>>>,
104    /// task tier -> preferred provider name
105    policy: HashMap<String, String>,
106    free_first: bool,
107    /// Global preferred provider override — when Some, every tier resolves to
108    /// this provider first (capability constraints still apply).
109    preferred_provider: Option<String>,
110    preferred_model: Option<String>,
111    /// \"auto\" or \"manual\" — see RoutingConfig.routing_mode.
112    routing_mode: String,
113}
114
115impl BasicRouter {
116    pub fn new(config: &Config, providers: HashMap<String, Vec<Arc<dyn Brain>>>) -> Self {
117        let mut policy = HashMap::new();
118        for (k, v) in &config.routing.policy {
119            policy.insert(k.clone(), v.clone());
120        }
121        // Defaults
122        if !policy.contains_key("trivial") {
123            policy.insert("trivial".into(), "local".into());
124        }
125        if !policy.contains_key("hard") {
126            policy.insert("hard".into(), "anthropic".into());
127        }
128
129        Self {
130            providers,
131            policy,
132            free_first: config.routing.free_first,
133            preferred_provider: config.routing.preferred_provider.clone(),
134            preferred_model: config.routing.preferred_model.clone(),
135            routing_mode: config.routing.routing_mode.clone(),
136        }
137    }
138
139    /// Score a brain for a given need: higher is better.
140    fn score(brain: &dyn Brain, need: &RoutingNeed, budget: &BudgetState) -> f64 {
141        let caps = brain.caps();
142        let mut score: f64 = 0.0;
143
144        // Capability fit
145        if need.required_tools {
146            if caps.tools {
147                score += 50.0;
148            } else {
149                score -= 250.0;
150            }
151        }
152        if need.required_vision {
153            if caps.vision {
154                score += 50.0;
155            } else {
156                score -= 300.0;
157            }
158        }
159
160        // Cost preference: prefer cheaper/free models
161        let est_cost = caps.cost_input_per_mtok + caps.cost_output_per_mtok;
162        if est_cost == 0.0 {
163            score += 100.0; // free models get a big boost
164        } else if budget.remaining_session() < est_cost * 0.1 {
165            score -= 200.0; // too expensive for remaining budget
166        } else {
167            score -= est_cost * 10.0; // penalize expensive models
168        }
169
170        // Latency / capability preference — tier-aware.
171        // Cheap tiers want speed; hard tiers want capable (bigger, slower) models.
172        match need.tier {
173            TaskTier::Trivial | TaskTier::Small => match caps.latency {
174                LatencyClass::Fast => score += 15.0,
175                LatencyClass::Medium => score += 6.0,
176                LatencyClass::Slow => score += 0.0,
177            },
178            TaskTier::Medium | TaskTier::Hard | TaskTier::Vision => match caps.latency {
179                LatencyClass::Slow => score += 18.0,
180                LatencyClass::Medium => score += 9.0,
181                LatencyClass::Fast => score += 0.0,
182            },
183        }
184
185        // Context window fit. Weighted more heavily for capable tiers so larger
186        // models (which we infer to have bigger windows) rise for hard tasks.
187        let ctx_weight = match need.tier {
188            TaskTier::Hard | TaskTier::Medium => 20_000.0,
189            _ => 10_000.0,
190        };
191        let ctx_cap = match need.tier {
192            TaskTier::Hard | TaskTier::Medium => 20.0,
193            _ => 10.0,
194        };
195        score += (caps.context_window as f64 / ctx_weight).min(ctx_cap);
196
197        score
198    }
199
200    fn resolve_provider(&self, need: &RoutingNeed) -> &str {
201        // Manual mode: always use the user's chosen provider, no fallback ever.
202        if self.routing_mode == "manual" {
203            if let Some(ref pref) = self.preferred_provider {
204                return pref.as_str();
205            }
206        }
207        // If a global preferred provider is configured, use it for all tiers.
208        // The scoring loop will still add a +25 bonus for this provider so it
209        // bubbles to the top; capability fallback still kicks in if it lacks
210        // tools/vision support.
211        if let Some(ref pref) = self.preferred_provider {
212            return pref.as_str();
213        }
214        self.policy
215            .get(need.tier.as_str())
216            .map(|s| s.as_str())
217            .unwrap_or("anthropic")
218    }
219
220    /// Classify a task using a tiny model call (only for ambiguous cases).
221    /// §3.6: "Classification: heuristic + a tiny model call only if ambiguous."
222    pub async fn classify_with_model(&self, task: &str, brain: &dyn Brain) -> TaskTier {
223        let prompt = format!(
224            "Classify this task into exactly one tier: trivial, small, medium, hard, vision.\n\nTask: {}\n\nTier:",
225            task
226        );
227
228        let req = BrainRequest {
229            system: Some("You are a task classifier. Output exactly one word: trivial, small, medium, hard, or vision.".into()),
230            messages: vec![Msg {
231                role: "user".into(),
232                content: vec![ContentBlock::Text { text: prompt }],
233            }],
234            tools: vec![],
235            max_tokens: 10,
236            temperature: 0.0,
237            stop: vec![],
238            cache: PromptCacheConfig::disabled(),
239        };
240
241        match brain.complete(req).await {
242            Ok(mut stream) => {
243                use futures::StreamExt;
244                let mut result = String::new();
245                while let Some(ev) = stream.next().await {
246                    if let crate::provider::BrainEvent::TextDelta(t) = ev {
247                        result.push_str(&t);
248                    }
249                }
250                TaskTier::from_str(result.trim())
251            }
252            Err(_) => TaskTier::Medium, // fallback
253        }
254    }
255}
256
257impl Router for BasicRouter {
258    fn select(&self, need: &RoutingNeed, budget: &BudgetState) -> Vec<Arc<dyn Brain>> {
259        // Manual mode with a specific model: use EXACTLY that model, nothing else.
260        if self.routing_mode == "manual" {
261            if let Some(ref model) = self.preferred_model {
262                for (_, brains) in &self.providers {
263                    for brain in brains {
264                        if brain.id() == *model {
265                            return vec![brain.clone()];
266                        }
267                    }
268                }
269                // Model not found — return empty, let engine report the error
270                return vec![];
271            }
272        }
273
274        if budget.is_exhausted() && !need.prefer_local {
275            // Only free/local models remain
276            if let Some(local) = self.providers.get("local") {
277                return local.clone();
278            }
279            return vec![];
280        }
281
282        let preferred_provider = self.resolve_provider(need);
283        let preferred_is_local = preferred_provider == "local" || preferred_provider == "ollama";
284        let mut scored: Vec<(f64, String, Arc<dyn Brain>)> = Vec::new();
285
286        // Score all available brains
287        for (provider_name, brains) in &self.providers {
288            if need.prefer_local && provider_name != "local" && provider_name != "ollama" {
289                continue;
290            }
291            for brain in brains {
292                let mut s = Self::score(brain.as_ref(), need, budget);
293                if provider_name == preferred_provider {
294                    s += 25.0;
295                }
296                if matches!(need.tier, TaskTier::Trivial | TaskTier::Small)
297                    && (provider_name == "local" || provider_name == "ollama")
298                {
299                    s += 30.0;
300                }
301                scored.push((s, provider_name.clone(), brain.clone()));
302            }
303        }
304
305        // Stable sort by score descending (preserves insertion order for ties,
306        // so equal-scored discovered models keep a deterministic order).
307        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
308
309        // Build a PROVIDER-DIVERSE fallback chain. A single preferred provider
310        // (e.g. opencode-zen) can have dozens of discovered models that would
311        // otherwise fill every slot — leaving no fallback if that whole provider
312        // is down (no credits / empty responses). Cap models per provider so
313        // other providers (NVIDIA, ollama…) always get fallback slots.
314        const MAX_CHAIN: usize = 6;
315        const PER_PROVIDER_CAP: usize = 3;
316        let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
317        let mut per_provider: HashMap<String, usize> = HashMap::new();
318        // (provider, brain) so we can reorder by provider below.
319        let mut result: Vec<(String, Arc<dyn Brain>)> = Vec::new();
320
321        // Pass 1: respect the per-provider cap → forces cross-provider diversity.
322        for (_, prov, brain) in &scored {
323            if result.len() >= MAX_CHAIN {
324                break;
325            }
326            let id = brain.id().to_string();
327            if seen.contains(&id) {
328                continue;
329            }
330            let count = per_provider.entry(prov.clone()).or_insert(0);
331            if *count >= PER_PROVIDER_CAP {
332                continue;
333            }
334            *count += 1;
335            seen.insert(id);
336            result.push((prov.clone(), brain.clone()));
337        }
338        // Pass 2: if too few providers to fill the chain, top up over the cap.
339        if result.len() < MAX_CHAIN {
340            for (_, prov, brain) in &scored {
341                if result.len() >= MAX_CHAIN {
342                    break;
343                }
344                let id = brain.id().to_string();
345                if seen.insert(id) {
346                    result.push((prov.clone(), brain.clone()));
347                }
348            }
349        }
350
351        // For trivial/small tasks with a local-preferred policy, put a local/free
352        // model first (works offline, $0).
353        if matches!(need.tier, TaskTier::Trivial | TaskTier::Small)
354            && (preferred_is_local || self.free_first)
355            && self.routing_mode != "manual"
356        {
357            if let Some(pos) = result.iter().position(|(prov, b)| {
358                (prov == "local" || prov == "ollama") || b.caps().cost_input_per_mtok == 0.0
359            }) {
360                let chosen = result.remove(pos);
361                result.insert(0, chosen);
362            }
363        }
364
365        result.into_iter().map(|(_, brain)| brain).collect()
366    }
367
368    fn on_error(&self, _b: &dyn Brain, e: &BrainError) -> Retry {
369        match e {
370            BrainError::RateLimit { retry_after } => {
371                if let Some(secs) = retry_after {
372                    if *secs <= 10 {
373                        Retry::WaitAndRetry(*secs)
374                    } else {
375                        Retry::NextInChain
376                    }
377                } else {
378                    Retry::NextInChain
379                }
380            }
381            BrainError::ServerError { status, .. } if *status >= 500 => Retry::NextInChain,
382            BrainError::Timeout => Retry::NextInChain,
383            BrainError::Refusal(_) => Retry::Abort,
384            _ => Retry::NextInChain,
385        }
386    }
387
388    fn find_brain_by_id(&self, model_id: &str) -> Option<Arc<dyn Brain>> {
389        for brains in self.providers.values() {
390            for b in brains {
391                if b.id() == model_id {
392                    return Some(b.clone());
393                }
394            }
395        }
396        None
397    }
398}