Skip to main content

car_inference/
adaptive_router.rs

1//! Adaptive model routing — three-phase routing with learned performance profiles.
2//!
3//! Phase 1: **Filter** — hard constraints (capability, availability, memory, cost).
4//! Phase 2: **Score** — blend quality, latency, and cost using observed profiles
5//!          or schema defaults on cold start.
6//! Phase 3: **Select** — Thompson Sampling (Beta distribution per model) for
7//!          natural exploration-exploitation balance. Models with fewer observations
8//!          have wider distributions, giving them chances to prove themselves.
9//!
10//! Replaces the hardcoded `ModelRouter` from `router.rs`.
11
12use rand::Rng;
13use serde::{Deserialize, Serialize};
14
15use std::sync::{Arc, Mutex};
16
17use crate::hardware::HardwareInfo;
18use crate::outcome::{InferenceTask, OutcomeTracker};
19use crate::registry::UnifiedRegistry;
20use crate::routing_ext::CircuitBreakerRegistry;
21use crate::schema::{ModelCapability, ModelSchema};
22use crate::tasks::RoutingWorkload;
23
24/// Prompt complexity assessment (migrated from router.rs).
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case")]
27pub enum TaskComplexity {
28    Simple,
29    Medium,
30    Code,
31    Complex,
32}
33
34impl TaskComplexity {
35    /// Assess complexity of a prompt string.
36    ///
37    /// Uses tree-sitter AST parsing (when the `ast` feature is enabled) for
38    /// accurate code detection: if a code block parses successfully as any
39    /// supported language, it's definitively code. Falls back to keyword
40    /// heuristics for prompts that mention code without containing code blocks.
41    pub fn assess(prompt: &str) -> Self {
42        let lower = prompt.to_lowercase();
43        let word_count = prompt.split_whitespace().count();
44        let estimated_tokens = (word_count as f64 * 1.3) as usize;
45
46        let has_code = Self::detect_code(prompt);
47
48        let repair_markers = [
49            "fix", "repair", "debug", "refactor", "broken", "failing", "error", "bug",
50        ];
51        let has_repair = repair_markers.iter().any(|m| lower.contains(m));
52
53        let reasoning_markers = [
54            "analyze",
55            "compare",
56            "explain why",
57            "step by step",
58            "think through",
59            "evaluate",
60            "trade-off",
61            "tradeoff",
62            "pros and cons",
63            "architecture",
64            "design",
65            "strategy",
66            "optimize",
67            "comprehensive",
68        ];
69        let has_reasoning = reasoning_markers.iter().any(|m| lower.contains(m));
70
71        let simple_patterns = [
72            "what is",
73            "who is",
74            "when did",
75            "where is",
76            "how many",
77            "yes or no",
78            "true or false",
79            "name the",
80            "list the",
81            "define ",
82        ];
83        let is_simple = simple_patterns.iter().any(|p| lower.contains(p));
84
85        if has_code || has_repair {
86            TaskComplexity::Code
87        } else if has_reasoning || estimated_tokens > 500 {
88            TaskComplexity::Complex
89        } else if is_simple || estimated_tokens < 30 {
90            TaskComplexity::Simple
91        } else {
92            TaskComplexity::Medium
93        }
94    }
95
96    /// Detect whether a prompt contains code.
97    ///
98    /// With `ast` feature: extracts code blocks (``` delimited), attempts to
99    /// parse each with tree-sitter. If any parses into symbols, it's real code.
100    /// Without `ast` feature: falls back to keyword heuristics.
101    fn detect_code(prompt: &str) -> bool {
102        // First try AST-based detection on code blocks
103        #[cfg(feature = "ast")]
104        {
105            if let Some(is_code) = Self::detect_code_ast(prompt) {
106                return is_code;
107            }
108        }
109
110        // Fallback: keyword heuristics
111        let code_markers = [
112            "```",
113            "fn ",
114            "def ",
115            "class ",
116            "import ",
117            "require(",
118            "async fn",
119            "pub fn",
120            "function ",
121            "const ",
122            "let ",
123            "var ",
124            "#include",
125            "package ",
126            "impl ",
127        ];
128        code_markers.iter().any(|m| prompt.contains(m))
129    }
130
131    /// AST-based code detection: parse code blocks with tree-sitter.
132    /// Returns Some(true) if code found, Some(false) if blocks exist but
133    /// don't parse, None if no code blocks found (fall through to heuristics).
134    #[cfg(feature = "ast")]
135    fn detect_code_ast(prompt: &str) -> Option<bool> {
136        // Extract code blocks between ``` markers
137        let mut blocks = Vec::new();
138        let mut rest = prompt;
139        while let Some(start) = rest.find("```") {
140            let after_fence = &rest[start + 3..];
141            // Skip optional language tag on the opening fence
142            let code_start = after_fence.find('\n').map(|i| i + 1).unwrap_or(0);
143            if let Some(end) = after_fence[code_start..].find("```") {
144                blocks.push(&after_fence[code_start..code_start + end]);
145                rest = &after_fence[code_start + end + 3..];
146            } else {
147                break;
148            }
149        }
150
151        if blocks.is_empty() {
152            return None; // No code blocks — let heuristics decide
153        }
154
155        // Try to parse each block with tree-sitter
156        let languages = [
157            car_ast::Language::Rust,
158            car_ast::Language::Python,
159            car_ast::Language::TypeScript,
160            car_ast::Language::JavaScript,
161            car_ast::Language::Go,
162        ];
163
164        for block in &blocks {
165            let trimmed = block.trim();
166            if trimmed.is_empty() {
167                continue;
168            }
169
170            for lang in &languages {
171                if let Some(parsed) = car_ast::parse(trimmed, *lang) {
172                    // If it parsed into any symbols, it's definitely code
173                    if !parsed.symbols.is_empty() {
174                        return Some(true);
175                    }
176                }
177            }
178        }
179
180        // Had code blocks but none parsed into symbols — could be
181        // pseudocode, output, or unsupported language
182        Some(false)
183    }
184
185    /// Map complexity to required capabilities.
186    pub fn required_capabilities(&self) -> Vec<ModelCapability> {
187        match self {
188            TaskComplexity::Simple => vec![ModelCapability::Generate],
189            TaskComplexity::Medium => vec![ModelCapability::Generate],
190            TaskComplexity::Code => vec![ModelCapability::Code],
191            TaskComplexity::Complex => vec![ModelCapability::Reasoning],
192        }
193    }
194
195    /// Map complexity to the InferenceTask type.
196    pub fn inference_task(&self) -> InferenceTask {
197        match self {
198            TaskComplexity::Simple | TaskComplexity::Medium => InferenceTask::Generate,
199            TaskComplexity::Code => InferenceTask::Code,
200            TaskComplexity::Complex => InferenceTask::Reasoning,
201        }
202    }
203}
204
205/// Configuration for routing behavior.
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct RoutingConfig {
208    /// Minimum observations before trusting a model's profile over schema defaults.
209    pub min_observations: u64,
210    /// Scoring weights (must sum to 1.0).
211    pub quality_weight: f64,
212    pub latency_weight: f64,
213    pub cost_weight: f64,
214    /// Hard constraint: maximum latency budget in ms.
215    pub max_latency_ms: Option<u64>,
216    /// Hard constraint: maximum cost per call in USD.
217    pub max_cost_usd: Option<f64>,
218    /// Prefer local models over remote (all else being equal).
219    pub prefer_local: bool,
220    /// Thompson Sampling prior strength. Higher = more weight on the Phase 2 score
221    /// as a prior, lower = more influenced by observed outcomes.
222    /// Equivalent to the number of "virtual" observations from the prior.
223    pub prior_strength: f64,
224    /// Prefer trusted remote models for quality-critical tasks until local models
225    /// have enough task-specific evidence to be promoted.
226    pub quality_first_cold_start: bool,
227    /// Minimum task-specific observations required before a local model can
228    /// compete with trusted remote models during cold start.
229    pub bootstrap_min_task_observations: u64,
230    /// Minimum task-specific EMA quality required before a local model can
231    /// displace trusted remote models during cold start.
232    pub bootstrap_quality_floor: f64,
233}
234
235impl Default for RoutingConfig {
236    fn default() -> Self {
237        Self {
238            min_observations: 2,
239            quality_weight: 0.45,
240            latency_weight: 0.4,
241            cost_weight: 0.15,
242            max_latency_ms: None,
243            max_cost_usd: None,
244            prefer_local: true,
245            prior_strength: 2.0,
246            quality_first_cold_start: true,
247            bootstrap_min_task_observations: 8,
248            bootstrap_quality_floor: 0.8,
249        }
250    }
251}
252
253/// How a model was selected.
254#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
255#[serde(rename_all = "snake_case")]
256pub enum RoutingStrategy {
257    /// Using declared schema capabilities (no observed data).
258    SchemaBased,
259    /// Using observed performance profiles (exploitation).
260    ProfileBased,
261    /// Deliberately trying an under-tested model (exploration).
262    Exploration,
263    /// User explicitly specified the model.
264    Explicit,
265}
266
267/// The result of adaptive routing.
268#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct AdaptiveRoutingDecision {
270    /// Selected model id.
271    pub model_id: String,
272    /// Selected model name (display).
273    pub model_name: String,
274    /// Task type.
275    pub task: InferenceTask,
276    /// Assessed complexity.
277    pub complexity: TaskComplexity,
278    /// Human-readable reason.
279    pub reason: String,
280    /// How the model was selected.
281    pub strategy: RoutingStrategy,
282    /// Predicted quality (0.0-1.0).
283    pub predicted_quality: f64,
284    /// Fallback chain (ordered list of alternative model ids).
285    pub fallbacks: Vec<String>,
286    /// Context window of the selected model (tokens). 0 = unknown.
287    pub context_length: usize,
288    /// Whether the prompt needs compaction to fit the selected model's context window.
289    pub needs_compaction: bool,
290}
291
292/// Adaptive router with three-phase model selection.
293pub struct AdaptiveRouter {
294    hw: HardwareInfo,
295    config: RoutingConfig,
296    /// Circuit breaker registry — blocks models after consecutive failures (#25).
297    pub circuit_breakers: Arc<Mutex<CircuitBreakerRegistry>>,
298}
299
300/// All inputs to a routing decision packed into one struct so adding
301/// a new combinator (per-tenant routing, streaming awareness, …)
302/// only adds a field rather than another `route_*` sibling method.
303/// See [`AdaptiveRouter::route_with`].
304///
305/// Use [`RouteRequest::new`] for the common defaults, then mutate
306/// just the fields the caller cares about:
307///
308/// ```ignore
309/// let decision = router.route_with(RouteRequest {
310///     has_tools: true,
311///     intent: Some(&hint),
312///     ..RouteRequest::new(prompt, &registry, &tracker)
313/// });
314/// ```
315pub struct RouteRequest<'a> {
316    pub prompt: &'a str,
317    pub registry: &'a UnifiedRegistry,
318    pub tracker: &'a OutcomeTracker,
319    /// Estimated prompt-side token count for context-window-aware
320    /// scoring. `0` skips the compaction-headroom check.
321    pub estimated_total_tokens: usize,
322    pub has_tools: bool,
323    pub has_vision: bool,
324    pub workload: RoutingWorkload,
325    /// Caller-supplied intent hint. `prefer_local: true` overrides
326    /// `workload` to [`RoutingWorkload::LocalPreferred`].
327    pub intent: Option<&'a crate::intent::IntentHint>,
328}
329
330impl<'a> RouteRequest<'a> {
331    /// Build a request with the same defaults the bare
332    /// [`AdaptiveRouter::route`] uses: interactive workload, no tools,
333    /// no vision, no context-aware sizing, no intent.
334    pub fn new(
335        prompt: &'a str,
336        registry: &'a UnifiedRegistry,
337        tracker: &'a OutcomeTracker,
338    ) -> Self {
339        Self {
340            prompt,
341            registry,
342            tracker,
343            estimated_total_tokens: 0,
344            has_tools: false,
345            has_vision: false,
346            workload: RoutingWorkload::Interactive,
347            intent: None,
348        }
349    }
350}
351
352impl AdaptiveRouter {
353    pub fn new(hw: HardwareInfo, config: RoutingConfig) -> Self {
354        let circuit_breakers = Arc::new(Mutex::new(
355            CircuitBreakerRegistry::new(3, 300), // 3 failures, 5 min cooldown
356        ));
357        Self {
358            hw,
359            config,
360            circuit_breakers,
361        }
362    }
363
364    pub fn with_default_config(hw: HardwareInfo) -> Self {
365        Self::new(hw, RoutingConfig::default())
366    }
367
368    pub fn config(&self) -> &RoutingConfig {
369        &self.config
370    }
371
372    pub fn set_config(&mut self, config: RoutingConfig) {
373        self.config = config;
374    }
375
376    /// Canonical entry point. The seven `route_*` sibling methods
377    /// each build a `RouteRequest` with their fixed defaults and
378    /// call this — adding a new combinator (per-tenant routing,
379    /// streaming awareness, …) only adds a field here, not another
380    /// public method. Closes #108.
381    pub fn route_with(&self, req: RouteRequest<'_>) -> AdaptiveRoutingDecision {
382        // Intent-hint precedence over the caller-supplied workload:
383        // `prefer_fast` wins outright (voice fast track is the most
384        // latency-sensitive path we have); `prefer_local` is the
385        // long-standing override; absent either, the caller's
386        // workload stands.
387        let workload = match req.intent {
388            Some(h) if h.prefer_fast => RoutingWorkload::Fastest,
389            Some(h) if h.prefer_local => RoutingWorkload::LocalPreferred,
390            _ => req.workload,
391        };
392        self.route_inner_with_intent(
393            req.prompt,
394            req.registry,
395            req.tracker,
396            req.has_tools,
397            req.has_vision,
398            req.estimated_total_tokens,
399            workload,
400            req.intent,
401        )
402    }
403
404    /// Route a generation request to the best model.
405    /// If `has_tools` is true, requires ToolUse capability (#13).
406    pub fn route(
407        &self,
408        prompt: &str,
409        registry: &UnifiedRegistry,
410        tracker: &OutcomeTracker,
411    ) -> AdaptiveRoutingDecision {
412        self.route_with(RouteRequest::new(prompt, registry, tracker))
413    }
414
415    /// Route an "editor" request — cheap mechanical work (context compaction,
416    /// edit materialization, title generation, replanning). Uses
417    /// [`RoutingWorkload::Background`] so cost/quality weights favour cheaper
418    /// models, and lets hosting layers express the Aider-style architect/editor
419    /// split without plumbing a raw model-name override through every layer.
420    pub fn route_editor(
421        &self,
422        prompt: &str,
423        registry: &UnifiedRegistry,
424        tracker: &OutcomeTracker,
425    ) -> AdaptiveRoutingDecision {
426        self.route_with(RouteRequest {
427            workload: RoutingWorkload::Background,
428            ..RouteRequest::new(prompt, registry, tracker)
429        })
430    }
431
432    /// Route with tool_use requirement — filters to models that support structured tool calls.
433    pub fn route_with_tools(
434        &self,
435        prompt: &str,
436        registry: &UnifiedRegistry,
437        tracker: &OutcomeTracker,
438    ) -> AdaptiveRoutingDecision {
439        self.route_with(RouteRequest {
440            has_tools: true,
441            ..RouteRequest::new(prompt, registry, tracker)
442        })
443    }
444
445    /// Route with image input requirement — filters to models that support vision.
446    pub fn route_with_vision(
447        &self,
448        prompt: &str,
449        registry: &UnifiedRegistry,
450        tracker: &OutcomeTracker,
451        has_tools: bool,
452    ) -> AdaptiveRoutingDecision {
453        self.route_with(RouteRequest {
454            has_tools,
455            has_vision: true,
456            ..RouteRequest::new(prompt, registry, tracker)
457        })
458    }
459
460    /// Route with caller-supplied intent — see [`crate::IntentHint`].
461    /// The hint can override the auto-detected `InferenceTask`, add hard
462    /// `require` capability filters on top of the prompt-derived ones,
463    /// and bias the score profile toward local models when
464    /// `prefer_local` is set.
465    pub fn route_with_intent<'a>(
466        &self,
467        prompt: &'a str,
468        registry: &'a UnifiedRegistry,
469        tracker: &'a OutcomeTracker,
470        intent: &'a crate::intent::IntentHint,
471    ) -> AdaptiveRoutingDecision {
472        self.route_with(RouteRequest {
473            intent: Some(intent),
474            ..RouteRequest::new(prompt, registry, tracker)
475        })
476    }
477
478    /// Route with context awareness — estimates prompt tokens and prefers models
479    /// whose context window can fit the full prompt without compaction.
480    pub fn route_context_aware(
481        &self,
482        prompt: &str,
483        estimated_total_tokens: usize,
484        registry: &UnifiedRegistry,
485        tracker: &OutcomeTracker,
486        has_tools: bool,
487        has_vision: bool,
488        workload: RoutingWorkload,
489    ) -> AdaptiveRoutingDecision {
490        self.route_with(RouteRequest {
491            estimated_total_tokens,
492            has_tools,
493            has_vision,
494            workload,
495            ..RouteRequest::new(prompt, registry, tracker)
496        })
497    }
498
499    /// Context-aware routing with caller-supplied intent. Same context
500    /// math as [`Self::route_context_aware`]; the intent layers on
501    /// top — task override, additional `require` filters,
502    /// `prefer_local` workload override.
503    pub fn route_context_aware_with_intent<'a>(
504        &self,
505        prompt: &'a str,
506        estimated_total_tokens: usize,
507        registry: &'a UnifiedRegistry,
508        tracker: &'a OutcomeTracker,
509        has_tools: bool,
510        has_vision: bool,
511        workload: RoutingWorkload,
512        intent: &'a crate::intent::IntentHint,
513    ) -> AdaptiveRoutingDecision {
514        self.route_with(RouteRequest {
515            estimated_total_tokens,
516            has_tools,
517            has_vision,
518            workload,
519            intent: Some(intent),
520            ..RouteRequest::new(prompt, registry, tracker)
521        })
522    }
523
524    fn route_inner_with_intent(
525        &self,
526        prompt: &str,
527        registry: &UnifiedRegistry,
528        tracker: &OutcomeTracker,
529        has_tools: bool,
530        has_vision: bool,
531        estimated_total_tokens: usize,
532        workload: RoutingWorkload,
533        intent: Option<&crate::intent::IntentHint>,
534    ) -> AdaptiveRoutingDecision {
535        let complexity = TaskComplexity::assess(prompt);
536        // Caller intent overrides the prompt-derived task when supplied.
537        let task = intent
538            .and_then(|h| h.task)
539            .map(task_hint_to_inference_task)
540            .unwrap_or_else(|| complexity.inference_task());
541        let mut required_caps = complexity.required_capabilities();
542        if let Some(hint) = intent {
543            for cap in &hint.require {
544                if !required_caps.contains(cap) {
545                    required_caps.push(*cap);
546                }
547            }
548        }
549        if has_vision {
550            required_caps.push(ModelCapability::Vision);
551        }
552        if has_tools {
553            required_caps.push(ModelCapability::ToolUse);
554            // Detect multi-step prompts that need multiple tool calls in one response.
555            // Patterns: numbered lists ("1) ... 2) ..."), multiple instructions, explicit multi-edit.
556            if Self::needs_multi_tool_call(prompt) {
557                required_caps.push(ModelCapability::MultiToolCall);
558            }
559        }
560
561        // Phase 1: Filter candidates
562        let mut candidates = self.filter_candidates(&required_caps, registry, tracker);
563
564        // Fallback: if requiring MultiToolCall eliminates all candidates, drop the
565        // requirement and let the best ToolUse model handle it with multiple round-trips.
566        if candidates.is_empty() && required_caps.contains(&ModelCapability::MultiToolCall) {
567            required_caps.retain(|c| *c != ModelCapability::MultiToolCall);
568            candidates = self.filter_candidates(&required_caps, registry, tracker);
569        }
570
571        if candidates.is_empty() {
572            // Nothing available — return the schema-based default
573            return self.cold_start_decision(complexity, task, registry, has_vision);
574        }
575
576        candidates = self.apply_quality_first_bootstrap_policy(
577            candidates, task, tracker, has_vision, has_tools, workload,
578        );
579
580        // Context-aware filtering: if we know the prompt size, prefer models that fit.
581        // Phase 1b: separate candidates into "fits" and "needs compaction" groups.
582        let (fits, needs_compaction_candidates) = if estimated_total_tokens > 0 {
583            let mut fits = Vec::new();
584            let mut tight = Vec::new();
585            for m in &candidates {
586                if m.context_length == 0 || m.context_length >= estimated_total_tokens {
587                    fits.push(m.clone());
588                } else {
589                    tight.push(m.clone());
590                }
591            }
592            (fits, tight)
593        } else {
594            (candidates.clone(), Vec::new())
595        };
596
597        // Prefer models that fit; fall back to compaction-required models if none fit
598        let (scoring_candidates, compaction_needed) = if !fits.is_empty() {
599            (fits, false)
600        } else if !needs_compaction_candidates.is_empty() {
601            tracing::info!(
602                prompt_tokens = estimated_total_tokens,
603                candidates = needs_compaction_candidates.len(),
604                "no model fits full prompt — compaction will be needed"
605            );
606            (needs_compaction_candidates.clone(), true)
607        } else {
608            (candidates.clone(), false)
609        };
610
611        // Phase 2: Score candidates (with context headroom bonus)
612        let scored = self.score_candidates_context_aware(
613            &scoring_candidates,
614            task,
615            tracker,
616            estimated_total_tokens,
617            workload,
618        );
619
620        // Phase 3: Thompson Sampling selection
621        let (selected_id, strategy) = self.select_with_thompson_sampling(&scored, tracker);
622
623        // Build fallback chain: prefer models that fit, then compaction candidates
624        let mut fallbacks: Vec<String> = scored
625            .iter()
626            .filter(|(id, _)| *id != selected_id)
627            .map(|(id, _)| id.clone())
628            .collect();
629        // Add compaction candidates to the end of the fallback chain
630        if !compaction_needed {
631            for m in &needs_compaction_candidates {
632                if m.id != selected_id && !fallbacks.contains(&m.id) {
633                    fallbacks.push(m.id.clone());
634                }
635            }
636        }
637
638        let predicted_quality = scored
639            .iter()
640            .find(|(id, _)| *id == selected_id)
641            .map(|(_, score)| *score)
642            .unwrap_or(0.5);
643
644        let selected_schema = registry
645            .get(&selected_id)
646            .or_else(|| registry.find_by_name(&selected_id));
647        let model_name = selected_schema
648            .map(|m| m.name.clone())
649            .unwrap_or_else(|| selected_id.clone());
650        let context_length = selected_schema.map(|m| m.context_length).unwrap_or(0);
651
652        let needs_compact = compaction_needed
653            || (estimated_total_tokens > 0
654                && context_length > 0
655                && estimated_total_tokens > context_length);
656
657        let compaction_note = if needs_compact {
658            format!(
659                " [compaction needed: {}→{}tok]",
660                estimated_total_tokens, context_length
661            )
662        } else {
663            String::new()
664        };
665
666        let reason = format!(
667            "{:?} task → {} via {:?} (quality: {:.2}, {} candidates){}",
668            complexity,
669            model_name,
670            strategy,
671            predicted_quality,
672            scoring_candidates.len(),
673            compaction_note,
674        );
675
676        AdaptiveRoutingDecision {
677            model_id: selected_id,
678            model_name,
679            task,
680            complexity,
681            reason,
682            strategy,
683            predicted_quality,
684            fallbacks,
685            context_length,
686            needs_compaction: needs_compact,
687        }
688    }
689
690    /// Route to the best embedding model.
691    pub fn route_embedding(&self, registry: &UnifiedRegistry) -> String {
692        let embed_models = registry.query_by_capability(ModelCapability::Embed);
693        embed_models
694            .first()
695            .map(|m| m.name.clone())
696            .unwrap_or_else(|| "Qwen3-Embedding-0.6B".to_string())
697    }
698
699    /// Route to the smallest available model (for classification).
700    pub fn route_small(&self, registry: &UnifiedRegistry) -> String {
701        let gen_models = registry.query_by_capability(ModelCapability::Generate);
702        // Pick smallest by size
703        gen_models
704            .iter()
705            .filter(|m| m.is_local())
706            .min_by_key(|m| m.size_mb())
707            .map(|m| m.name.clone())
708            .unwrap_or_else(|| "Qwen3-0.6B".to_string())
709    }
710
711    // --- Internal phases ---
712
713    // --- Scoring constants ---
714
715    /// Latency ceiling: requests taking longer than this score 0.0.
716    const LATENCY_CEILING_MS: f64 = 10000.0;
717    /// TPS ceiling: models faster than this score 1.0.
718    const _TPS_CEILING: f64 = 150.0;
719    /// MoE throughput penalty for Candle: naive expert routing runs at ~10% of declared TPS.
720    const MOE_TPS_MULTIPLIER: f64 = 0.10;
721    /// MoE throughput multiplier for MLX: fused Metal kernels run at ~50% of declared TPS.
722    const MLX_MOE_TPS_MULTIPLIER: f64 = 0.50;
723    /// Cost ceiling: models costing more than this per 1K output tokens score 0.0.
724    const COST_CEILING_PER_1K: f64 = 0.1;
725    /// Local preference bonus added to the weighted score (before normalization).
726    const LOCAL_BONUS: f64 = 0.15;
727    /// True on platforms where local inference has a GPU/NPU backend
728    /// (Apple Silicon with MLX). On Intel Macs and on hosts built with
729    /// `--cfg=car_skip_mlx`, local inference falls through to candle/
730    /// CPU which is far slower than cloud for non-trivial models. The
731    /// router suppresses LOCAL_BONUS when this is false so cloud
732    /// models rank fairly instead of losing to a CPU-bound 4B that
733    /// will take 30s to first-token.
734    #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
735    const HAS_GPU_BACKEND: bool = true;
736    #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
737    const HAS_GPU_BACKEND: bool = false;
738    /// Extra bonus for MLX models on Apple Silicon (stacks with LOCAL_BONUS).
739    /// MLX gets fused Metal kernels and better memory layout vs Candle on Mac.
740    #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
741    const MLX_BONUS: f64 = 0.10;
742    /// Extra bonus for system-owned models — no model file on disk, no
743    /// download to provision, framework-managed memory. Currently only
744    /// `apple/foundation:default` qualifies. Tag-driven so future
745    /// system-LLM integrations (e.g. Android AICore) inherit the
746    /// scoring without a code change. Stacks with LOCAL_BONUS but
747    /// **excludes** MLX_BONUS (system models aren't MLX); the net
748    /// effect is FoundationModels ranks roughly even with a
749    /// well-warmed MLX 4B for short fast-turn tasks instead of
750    /// strictly losing because its catalog `tokens_per_second` /
751    /// `size_mb` are null.
752    const SYSTEM_LLM_BONUS: f64 = 0.12;
753
754    // --- Internal phases ---
755
756    /// Phase 1: Filter by hard constraints.
757    fn filter_candidates(
758        &self,
759        required_caps: &[ModelCapability],
760        registry: &UnifiedRegistry,
761        tracker: &OutcomeTracker,
762    ) -> Vec<ModelSchema> {
763        registry
764            .list()
765            .into_iter()
766            .filter(|m| {
767                // Must have all required capabilities
768                if !required_caps.iter().all(|c| m.has_capability(*c)) {
769                    return false;
770                }
771                // Must be available (downloaded for local, API key set for remote)
772                if !m.available {
773                    return false;
774                }
775                // Local models must fit in memory (strict: >= excludes models at the limit)
776                if m.is_local() && m.size_mb() >= self.hw.max_model_mb {
777                    return false;
778                }
779                // Hard latency constraint
780                if let Some(max) = self.config.max_latency_ms {
781                    if let Some(p50) = m.performance.latency_p50_ms {
782                        if p50 > max {
783                            return false;
784                        }
785                    }
786                }
787                // Hard cost constraint
788                if let Some(max) = self.config.max_cost_usd {
789                    if m.cost_per_1k_output() > max {
790                        return false;
791                    }
792                }
793                // Hard exclusion: prefer_local=false excludes all local models (#12)
794                if !self.config.prefer_local && m.is_local() {
795                    return false;
796                }
797                // Hard exclusion: rate-limited models excluded for this session (#13)
798                if tracker.is_excluded(&m.id) {
799                    return false;
800                }
801                // Circuit breaker: skip models with consecutive failures (#25)
802                if let Ok(mut cb) = self.circuit_breakers.lock() {
803                    if !cb.allow_request(&m.id) {
804                        tracing::debug!(model = %m.id, "skipped by circuit breaker");
805                        return false;
806                    }
807                }
808                true
809            })
810            .cloned()
811            .collect()
812    }
813
814    fn apply_quality_first_bootstrap_policy(
815        &self,
816        candidates: Vec<ModelSchema>,
817        task: InferenceTask,
818        tracker: &OutcomeTracker,
819        has_vision: bool,
820        has_tools: bool,
821        workload: RoutingWorkload,
822    ) -> Vec<ModelSchema> {
823        if !self.config.quality_first_cold_start
824            || !workload.is_latency_sensitive()
825            || !self.is_quality_critical_bootstrap_task(task, has_vision, has_tools)
826        {
827            return candidates;
828        }
829
830        let trusted_remote: Vec<ModelSchema> = candidates
831            .iter()
832            .filter(|model| self.is_trusted_quality_remote(model))
833            .cloned()
834            .collect();
835
836        if trusted_remote.is_empty() {
837            return candidates;
838        }
839
840        let proven_local: Vec<ModelSchema> = candidates
841            .iter()
842            .filter(|model| {
843                model.is_local() && self.is_local_model_proven_for_task(model, task, tracker)
844            })
845            .cloned()
846            .collect();
847
848        if !proven_local.is_empty() {
849            return proven_local;
850        }
851
852        trusted_remote
853    }
854
855    /// Phase 2: Score candidates with context awareness.
856    /// Applies a headroom bonus to models with more context window for the prompt.
857    /// When estimated_total_tokens is 0, no context bonus/penalty is applied.
858    fn score_candidates_context_aware(
859        &self,
860        candidates: &[ModelSchema],
861        task: InferenceTask,
862        tracker: &OutcomeTracker,
863        estimated_total_tokens: usize,
864        workload: RoutingWorkload,
865    ) -> Vec<(String, f64)> {
866        let mut scored: Vec<(String, f64)> = candidates
867            .iter()
868            .map(|m| {
869                let base_score = self.score_model(m, task, tracker, workload);
870                // Context headroom bonus: prefer models with more room to spare.
871                // Max bonus: 0.10 (at 4x headroom or more). No bonus if unknown.
872                let headroom_bonus = if estimated_total_tokens > 0 && m.context_length > 0 {
873                    let ratio = m.context_length as f64 / estimated_total_tokens as f64;
874                    if ratio >= 1.0 {
875                        (ratio.min(4.0) - 1.0) / 3.0 * 0.10 // 0.0 at exact fit, 0.10 at 4x
876                    } else {
877                        -0.15 // Penalty for models that require compaction
878                    }
879                } else {
880                    0.0
881                };
882                (m.id.clone(), base_score + headroom_bonus)
883            })
884            .collect();
885
886        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
887        scored
888    }
889
890    /// Score a single model. All sub-scores are in [0.0, 1.0].
891    /// Final score = weighted sum + local_bonus, so range is [0.0, ~1.15].
892    fn score_model(
893        &self,
894        model: &ModelSchema,
895        task: InferenceTask,
896        tracker: &OutcomeTracker,
897        workload: RoutingWorkload,
898    ) -> f64 {
899        let profile = tracker.profile(&model.id);
900        let schema_quality = self.schema_quality_estimate(model);
901        let schema_latency = self.schema_latency_estimate(model);
902        let (quality_weight, latency_weight, cost_weight) = workload.weights();
903
904        // Quality: blend schema estimate with observed data based on observation count.
905        // Both cold start and warm start use consistent blending.
906        let quality = match profile {
907            Some(p) if p.total_calls >= self.config.min_observations => p
908                .task_stats(task)
909                .map(|ts| ts.ema_quality)
910                .unwrap_or(p.ema_quality),
911            Some(p) if p.total_calls == 0 => p
912                .task_stats(task)
913                .map(|ts| ts.ema_quality)
914                .unwrap_or(p.ema_quality),
915            Some(p) if p.total_calls > 0 => {
916                let w = p.total_calls as f64 / self.config.min_observations as f64;
917                schema_quality * (1.0 - w) + p.ema_quality * w
918            }
919            _ => schema_quality,
920        };
921
922        // Latency: same blending as quality — don't trust a single observation more
923        // than schema estimates. This prevents routing oscillation on first few calls.
924        let latency = match profile {
925            Some(p) if p.total_calls >= self.config.min_observations => {
926                let avg = p
927                    .task_stats(task)
928                    .filter(|ts| ts.calls > 0 || ts.avg_latency_ms > 0.0)
929                    .map(|ts| ts.avg_latency_ms)
930                    .unwrap_or_else(|| p.avg_latency_ms());
931                self.latency_ms_to_score(avg)
932            }
933            Some(p) if p.total_calls == 0 => p
934                .task_stats(task)
935                .filter(|ts| ts.avg_latency_ms > 0.0)
936                .map(|ts| self.latency_ms_to_score(ts.avg_latency_ms))
937                .unwrap_or(schema_latency),
938            Some(p) if p.total_calls > 0 => {
939                let observed = self.latency_ms_to_score(
940                    p.task_stats(task)
941                        .filter(|ts| ts.calls > 0 || ts.avg_latency_ms > 0.0)
942                        .map(|ts| ts.avg_latency_ms)
943                        .unwrap_or_else(|| p.avg_latency_ms()),
944                );
945                let w = p.total_calls as f64 / self.config.min_observations as f64;
946                schema_latency * (1.0 - w) + observed * w
947            }
948            _ => schema_latency,
949        };
950
951        // Cost score (lower is better → invert)
952        let cost = if model.is_local() {
953            1.0
954        } else {
955            (1.0 - (model.cost_per_1k_output() / Self::COST_CEILING_PER_1K)).clamp(0.0, 1.0)
956        };
957
958        // Suppress LOCAL_BONUS on hosts without a GPU/NPU backend.
959        // Without it, an Intel Mac or a `car_skip_mlx` build would
960        // pick a CPU-bound local 4B over a comparable cloud model
961        // every time, then surprise the user with 30s+ first-token
962        // latency. Cloud loses on cost/privacy in normal scoring;
963        // dropping the local bonus lets it win on the latency that
964        // actually matters when the GPU isn't there. Tracing emits
965        // when this fires so the silent degradation becomes visible.
966        let local_bonus = if self.config.prefer_local && model.is_local() && Self::HAS_GPU_BACKEND {
967            Self::LOCAL_BONUS
968        } else {
969            if self.config.prefer_local && model.is_local() && !Self::HAS_GPU_BACKEND {
970                tracing::debug!(
971                    model = %model.id,
972                    "LOCAL_BONUS suppressed: no GPU backend on this host (Intel Mac or car_skip_mlx); cloud models will rank higher"
973                );
974            }
975            0.0
976        };
977        let workload_local_bonus = if model.is_local() {
978            workload.local_bonus()
979        } else {
980            0.0
981        };
982
983        // On Apple Silicon, prefer MLX models over Candle equivalents
984        #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
985        let mlx_bonus = if model.is_mlx() { Self::MLX_BONUS } else { 0.0 };
986        #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
987        let mlx_bonus = 0.0;
988
989        // vLLM-MLX bonus: continuous batching gives better multi-agent throughput
990        let vllm_mlx_bonus = if model.is_vllm_mlx() {
991            Self::LOCAL_BONUS + 0.05
992        } else {
993            0.0
994        };
995
996        // System-LLM bonus: catalog tag-driven so it generalizes beyond
997        // FoundationModels. Models tagged `low_latency` AND `private`
998        // are zero-cost system-owned LLMs (apple/foundation:default
999        // today; AICore-on-Android etc. in the future). They don't
1000        // appear in `is_mlx()` but deserve to compete with MLX 4B on
1001        // routing — the catalog's tags carry that intent and the
1002        // router now honors it.
1003        let system_llm_bonus = if model.tags.iter().any(|t| t == "low_latency")
1004            && model.tags.iter().any(|t| t == "private")
1005        {
1006            Self::SYSTEM_LLM_BONUS
1007        } else {
1008            0.0
1009        };
1010
1011        quality_weight * quality
1012            + latency_weight * latency
1013            + cost_weight * cost
1014            + local_bonus
1015            + workload_local_bonus
1016            + mlx_bonus
1017            + vllm_mlx_bonus
1018            + system_llm_bonus
1019    }
1020
1021    /// Convert latency in ms to a [0, 1] score. Used by both schema and observed paths
1022    /// so the scales are consistent (fixes Linus review issue #2).
1023    fn latency_ms_to_score(&self, ms: f64) -> f64 {
1024        (1.0 - (ms / Self::LATENCY_CEILING_MS)).clamp(0.0, 1.0)
1025    }
1026
1027    /// Convert TPS to estimated latency in ms (for a typical 200-token response).
1028    fn tps_to_latency_ms(tps: f64) -> f64 {
1029        if tps <= 0.0 {
1030            return Self::LATENCY_CEILING_MS;
1031        }
1032        // Assume ~200 tokens per response as baseline
1033        (200.0 / tps) * 1000.0
1034    }
1035
1036    /// Detect whether a prompt likely needs multiple tool calls in a single response.
1037    /// Looks for numbered lists, multiple explicit instructions, multi-edit patterns.
1038    fn needs_multi_tool_call(prompt: &str) -> bool {
1039        let lower = prompt.to_lowercase();
1040
1041        // Numbered list patterns: "1) ... 2) ..." or "1. ... 2. ..."
1042        let has_numbered_list = {
1043            let mut count = 0u32;
1044            for i in 1..=5u32 {
1045                if lower.contains(&format!("{}) ", i)) || lower.contains(&format!("{}. ", i)) {
1046                    count += 1;
1047                }
1048            }
1049            count >= 2
1050        };
1051
1052        // Explicit multi-action keywords
1053        let multi_keywords = [
1054            "multiple edits",
1055            "several changes",
1056            "three changes",
1057            "two changes",
1058            "all of the following",
1059            "each of these",
1060            "do both",
1061            "do all",
1062            "and also",
1063            "additionally",
1064            "as well as",
1065            "then also",
1066        ];
1067        let has_multi_keywords = multi_keywords.iter().any(|kw| lower.contains(kw));
1068
1069        // Bullet point lists with action verbs
1070        let bullet_actions = lower.matches("- add ").count()
1071            + lower.matches("- update ").count()
1072            + lower.matches("- change ").count()
1073            + lower.matches("- remove ").count()
1074            + lower.matches("- fix ").count()
1075            + lower.matches("- edit ").count()
1076            + lower.matches("- implement ").count()
1077            + lower.matches("- create ").count();
1078        let has_bullet_list = bullet_actions >= 2;
1079
1080        has_numbered_list || has_multi_keywords || has_bullet_list
1081    }
1082
1083    /// Schema-based quality estimate (cold start).
1084    ///
1085    /// Diminishing returns on model size — the jump from 4B to 8B matters,
1086    /// but 8B to 30B is marginal. Remote models get a conservative estimate.
1087    fn schema_quality_estimate(&self, model: &ModelSchema) -> f64 {
1088        match model.size_mb() {
1089            0 => 0.5,             // remote: unknown, conservative
1090            s if s < 1000 => 0.4, // 0.6B
1091            s if s < 2000 => 0.5, // 1.7B
1092            s if s < 3000 => 0.6, // 4B
1093            s if s < 6000 => 0.7, // 8B
1094            _ => 0.75,            // 30B+: diminishing returns
1095        }
1096    }
1097
1098    /// Schema-based latency estimate (cold start).
1099    ///
1100    /// Converts declared TPS/p50 to the same ms-based score used by observed data,
1101    /// so there's no discontinuity when the first observation arrives.
1102    fn schema_latency_estimate(&self, model: &ModelSchema) -> f64 {
1103        let is_moe = model.tags.contains(&"moe".to_string());
1104
1105        if model.is_local() {
1106            if let Some(tps) = model.performance.tokens_per_second {
1107                let effective_tps = if is_moe {
1108                    let multiplier = if model.is_mlx() {
1109                        Self::MLX_MOE_TPS_MULTIPLIER
1110                    } else {
1111                        Self::MOE_TPS_MULTIPLIER
1112                    };
1113                    tps * multiplier
1114                } else {
1115                    tps
1116                };
1117                let estimated_ms = Self::tps_to_latency_ms(effective_tps);
1118                return self.latency_ms_to_score(estimated_ms);
1119            }
1120            return 0.5; // local, no declared TPS
1121        }
1122
1123        // Remote: use declared p50 latency
1124        if let Some(p50) = model.performance.latency_p50_ms {
1125            return self.latency_ms_to_score(p50 as f64);
1126        }
1127        0.3 // remote, no declared latency
1128    }
1129
1130    fn is_quality_critical_bootstrap_task(
1131        &self,
1132        task: InferenceTask,
1133        has_vision: bool,
1134        has_tools: bool,
1135    ) -> bool {
1136        has_vision
1137            || has_tools
1138            || matches!(
1139                task,
1140                InferenceTask::Generate | InferenceTask::Code | InferenceTask::Reasoning
1141            )
1142    }
1143
1144    fn is_trusted_quality_remote(&self, model: &ModelSchema) -> bool {
1145        model.is_remote()
1146            && matches!(model.provider.as_str(), "openai" | "anthropic" | "google")
1147            && !model.has_capability(ModelCapability::SpeechToText)
1148            && !model.has_capability(ModelCapability::TextToSpeech)
1149    }
1150
1151    fn is_local_model_proven_for_task(
1152        &self,
1153        model: &ModelSchema,
1154        task: InferenceTask,
1155        tracker: &OutcomeTracker,
1156    ) -> bool {
1157        let Some(profile) = tracker.profile(&model.id) else {
1158            return false;
1159        };
1160        if let Some(task_stats) = profile.task_stats(task) {
1161            if task_stats.calls >= self.config.bootstrap_min_task_observations
1162                && task_stats.ema_quality >= self.config.bootstrap_quality_floor
1163            {
1164                return true;
1165            }
1166        }
1167
1168        profile.total_calls >= self.config.bootstrap_min_task_observations
1169            && profile.ema_quality >= self.config.bootstrap_quality_floor
1170    }
1171
1172    /// Phase 3: Thompson Sampling selection.
1173    ///
1174    /// Each model gets a Beta(alpha, beta) distribution where:
1175    /// - alpha = prior_successes + observed_successes
1176    /// - beta = prior_failures + observed_failures
1177    ///
1178    /// The Phase 2 score serves as the prior mean, scaled by `prior_strength`.
1179    /// Models with few observations have wide distributions (natural exploration).
1180    /// Models with many observations have tight distributions (exploitation).
1181    fn select_with_thompson_sampling(
1182        &self,
1183        scored: &[(String, f64)],
1184        tracker: &OutcomeTracker,
1185    ) -> (String, RoutingStrategy) {
1186        if scored.is_empty() {
1187            return (String::new(), RoutingStrategy::SchemaBased);
1188        }
1189
1190        let mut rng = rand::rng();
1191        let mut best_sample = f64::NEG_INFINITY;
1192        let mut best_id = scored[0].0.clone();
1193        let mut best_strategy = RoutingStrategy::SchemaBased;
1194
1195        for (id, phase2_score) in scored {
1196            let profile = tracker.profile(id);
1197            let prior = self.config.prior_strength;
1198
1199            // Convert Phase 2 score (0.0-1.15) to a prior mean in [0, 1]
1200            let prior_mean = phase2_score.clamp(0.0, 1.0);
1201
1202            // Prior pseudo-counts from the Phase 2 score
1203            let prior_alpha = prior * prior_mean;
1204            let prior_beta = prior * (1.0 - prior_mean);
1205
1206            // Observed counts
1207            let (obs_alpha, obs_beta) = match profile {
1208                Some(p) => (p.success_count as f64, p.fail_count as f64),
1209                None => (0.0, 0.0),
1210            };
1211
1212            // Posterior Beta parameters
1213            let alpha = (prior_alpha + obs_alpha).max(0.01);
1214            let beta = (prior_beta + obs_beta).max(0.01);
1215
1216            // Sample from Beta(alpha, beta) using the Jöhnk algorithm
1217            let sample = sample_beta(&mut rng, alpha, beta);
1218
1219            if sample > best_sample {
1220                best_sample = sample;
1221                best_id = id.clone();
1222                best_strategy = match profile {
1223                    Some(p) if p.total_calls >= self.config.min_observations => {
1224                        RoutingStrategy::ProfileBased
1225                    }
1226                    Some(p) if p.total_calls > 0 => {
1227                        // Under-tested but has some data — exploration
1228                        RoutingStrategy::Exploration
1229                    }
1230                    _ => RoutingStrategy::SchemaBased,
1231                };
1232            }
1233        }
1234
1235        (best_id, best_strategy)
1236    }
1237
1238    /// Fallback decision when no candidates pass filtering.
1239    fn cold_start_decision(
1240        &self,
1241        complexity: TaskComplexity,
1242        task: InferenceTask,
1243        registry: &UnifiedRegistry,
1244        has_vision: bool,
1245    ) -> AdaptiveRoutingDecision {
1246        if has_vision {
1247            if let Some(model) = registry
1248                .query_by_capability(ModelCapability::Vision)
1249                .into_iter()
1250                .find(|model| model.available && self.is_trusted_quality_remote(model))
1251                .or_else(|| {
1252                    registry
1253                        .query_by_capability(ModelCapability::Vision)
1254                        .first()
1255                        .copied()
1256                })
1257            {
1258                return AdaptiveRoutingDecision {
1259                    model_id: model.id.clone(),
1260                    model_name: model.name.clone(),
1261                    task,
1262                    complexity,
1263                    reason: format!(
1264                        "{:?} task → {} (cold start, vision fallback)",
1265                        complexity, model.name
1266                    ),
1267                    strategy: RoutingStrategy::SchemaBased,
1268                    predicted_quality: 0.5,
1269                    fallbacks: vec![],
1270                    context_length: model.context_length,
1271                    needs_compaction: false,
1272                };
1273            }
1274        }
1275
1276        if self.config.quality_first_cold_start {
1277            let required_caps = complexity.required_capabilities();
1278            if let Some(model) = registry
1279                .list()
1280                .into_iter()
1281                .filter(|model| {
1282                    model.available
1283                        && required_caps.iter().all(|cap| model.has_capability(*cap))
1284                        && self.is_trusted_quality_remote(model)
1285                })
1286                .max_by(|a, b| {
1287                    self.schema_quality_estimate(a)
1288                        .partial_cmp(&self.schema_quality_estimate(b))
1289                        .unwrap_or(std::cmp::Ordering::Equal)
1290                })
1291            {
1292                return AdaptiveRoutingDecision {
1293                    model_id: model.id.clone(),
1294                    model_name: model.name.clone(),
1295                    task,
1296                    complexity,
1297                    reason: format!(
1298                        "{:?} task → {} (quality-first cold start)",
1299                        complexity, model.name
1300                    ),
1301                    strategy: RoutingStrategy::SchemaBased,
1302                    predicted_quality: self.schema_quality_estimate(model),
1303                    fallbacks: vec![],
1304                    context_length: model.context_length,
1305                    needs_compaction: false,
1306                };
1307            }
1308        }
1309
1310        // Fall back to the old complexity-based defaults
1311        let model_name = match complexity {
1312            TaskComplexity::Simple => "Qwen3-0.6B",
1313            TaskComplexity::Medium => "Qwen3-1.7B",
1314            TaskComplexity::Code => "Qwen3-4B",
1315            TaskComplexity::Complex => &self.hw.recommended_model,
1316        };
1317
1318        let model_id = registry
1319            .find_by_name(model_name)
1320            .map(|m| m.id.clone())
1321            .unwrap_or_else(|| model_name.to_string());
1322
1323        let context_length = registry
1324            .find_by_name(model_name)
1325            .map(|m| m.context_length)
1326            .unwrap_or(0);
1327
1328        AdaptiveRoutingDecision {
1329            model_id,
1330            model_name: model_name.to_string(),
1331            task,
1332            complexity,
1333            reason: format!(
1334                "{:?} task → {} (cold start, no candidates)",
1335                complexity, model_name
1336            ),
1337            strategy: RoutingStrategy::SchemaBased,
1338            predicted_quality: 0.5,
1339            fallbacks: vec![],
1340            context_length,
1341            needs_compaction: false,
1342        }
1343    }
1344}
1345
1346/// Map a caller-supplied [`crate::TaskHint`] to the engine's
1347/// [`InferenceTask`] enum. The intent surface uses the higher-level
1348/// hint vocabulary; the router operates on InferenceTask. Every
1349/// TaskHint variant maps to a distinct InferenceTask — variants that
1350/// would have silently collapsed to `Generate` were cut from the MVP.
1351fn task_hint_to_inference_task(hint: crate::intent::TaskHint) -> InferenceTask {
1352    use crate::intent::TaskHint;
1353    match hint {
1354        TaskHint::Chat => InferenceTask::Generate,
1355        TaskHint::Classify => InferenceTask::Classify,
1356        TaskHint::Reasoning => InferenceTask::Reasoning,
1357        TaskHint::Code => InferenceTask::Code,
1358    }
1359}
1360
1361/// Sample from a Beta(alpha, beta) distribution.
1362///
1363/// Uses the Gamma distribution method: if X ~ Gamma(alpha, 1) and Y ~ Gamma(beta, 1),
1364/// then X / (X + Y) ~ Beta(alpha, beta).
1365///
1366/// For Gamma sampling, uses Marsaglia and Tsang's method for alpha >= 1,
1367/// and Ahrens-Dieter for alpha < 1.
1368fn sample_beta(rng: &mut impl Rng, alpha: f64, beta: f64) -> f64 {
1369    let x = sample_gamma(rng, alpha);
1370    let y = sample_gamma(rng, beta);
1371    if x + y == 0.0 {
1372        0.5 // degenerate case
1373    } else {
1374        x / (x + y)
1375    }
1376}
1377
1378/// Sample from Gamma(shape, 1) using Marsaglia-Tsang for shape >= 1,
1379/// with Ahrens-Dieter boost for shape < 1.
1380fn sample_gamma(rng: &mut impl Rng, shape: f64) -> f64 {
1381    if shape < 1.0 {
1382        // Ahrens-Dieter: Gamma(a) = Gamma(a+1) * U^(1/a)
1383        let u: f64 = rng.random();
1384        return sample_gamma(rng, shape + 1.0) * u.powf(1.0 / shape);
1385    }
1386
1387    // Marsaglia-Tsang method for shape >= 1
1388    let d = shape - 1.0 / 3.0;
1389    let c = 1.0 / (9.0 * d).sqrt();
1390
1391    loop {
1392        let x: f64 = loop {
1393            let n = sample_standard_normal(rng);
1394            if 1.0 + c * n > 0.0 {
1395                break n;
1396            }
1397        };
1398
1399        let v = (1.0 + c * x).powi(3);
1400        let u: f64 = rng.random();
1401
1402        if u < 1.0 - 0.0331 * x.powi(4) {
1403            return d * v;
1404        }
1405        if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
1406            return d * v;
1407        }
1408    }
1409}
1410
1411/// Sample from standard normal N(0,1) using Box-Muller transform.
1412fn sample_standard_normal(rng: &mut impl Rng) -> f64 {
1413    let u1: f64 = rng.random();
1414    let u2: f64 = rng.random();
1415    (-2.0 * u1.max(1e-300).ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
1416}
1417
1418#[cfg(test)]
1419mod tests {
1420    use super::*;
1421    use crate::outcome::InferredOutcome;
1422
1423    fn test_hw() -> HardwareInfo {
1424        HardwareInfo {
1425            os: "macos".into(),
1426            arch: "aarch64".into(),
1427            cpu_cores: 10,
1428            total_ram_mb: 32768,
1429            gpu_backend: crate::hardware::GpuBackend::Metal,
1430            gpu_memory_mb: Some(28672),
1431            gpu_devices: Vec::new(),
1432            recommended_model: "Qwen3-8B".into(),
1433            recommended_context: 8192,
1434            max_model_mb: 18000, // headroom above 30B-A3B's 17000MB
1435        }
1436    }
1437
1438    fn test_registry() -> UnifiedRegistry {
1439        let tmp = std::path::PathBuf::from("/tmp/car-test-adaptive-router");
1440        unsafe {
1441            std::env::set_var("OPENAI_API_KEY", "test-openai-key");
1442        }
1443        // Create fake model dirs so the registry marks them as available
1444        for name in &[
1445            "Qwen3-0.6B",
1446            "Qwen3-1.7B",
1447            "Qwen3-4B",
1448            "Qwen3-8B",
1449            "Qwen3-Embedding-0.6B",
1450        ] {
1451            let dir = tmp.join(name);
1452            let _ = std::fs::create_dir_all(&dir);
1453            let _ = std::fs::write(dir.join("model.gguf"), b"fake");
1454            let _ = std::fs::write(dir.join("tokenizer.json"), b"{}");
1455        }
1456        let mut reg = UnifiedRegistry::new(tmp);
1457        reg.register(ModelSchema {
1458            id: "openai/gpt-5.4-mini:latest".into(),
1459            name: "gpt-5.4-mini".into(),
1460            provider: "openai".into(),
1461            family: "gpt-5.4".into(),
1462            version: "latest".into(),
1463            capabilities: vec![
1464                ModelCapability::Generate,
1465                ModelCapability::Code,
1466                ModelCapability::Reasoning,
1467                ModelCapability::ToolUse,
1468                ModelCapability::MultiToolCall,
1469                ModelCapability::Vision,
1470            ],
1471            context_length: 128_000,
1472            param_count: "api".into(),
1473            quantization: None,
1474            performance: Default::default(),
1475            cost: Default::default(),
1476            source: crate::schema::ModelSource::RemoteApi {
1477                endpoint: "https://api.openai.com/v1".into(),
1478                api_key_env: "OPENAI_API_KEY".into(),
1479                api_key_envs: vec![],
1480                api_version: None,
1481                protocol: crate::schema::ApiProtocol::OpenAiCompat,
1482            },
1483            tags: vec!["trusted-remote".into()],
1484            supported_params: vec![],
1485            public_benchmarks: vec![],
1486            available: true,
1487        });
1488        reg
1489    }
1490
1491    #[test]
1492    fn routes_simple_to_trusted_remote_during_cold_start() {
1493        let router = AdaptiveRouter::new(
1494            test_hw(),
1495            RoutingConfig {
1496                prior_strength: 100.0, // strong prior = exploit Phase 2 scores (deterministic-ish)
1497                ..Default::default()
1498            },
1499        );
1500        let reg = test_registry();
1501        let tracker = OutcomeTracker::new();
1502
1503        let decision = router.route("What is 2+2?", &reg, &tracker);
1504        assert_eq!(decision.complexity, TaskComplexity::Simple);
1505        assert_eq!(decision.strategy, RoutingStrategy::SchemaBased);
1506        // On cold start, quality-critical tasks should stay on trusted remote models.
1507        let schema = reg
1508            .find_by_name(&decision.model_name)
1509            .expect("selected model should exist in registry");
1510        assert!(
1511            !schema.is_local(),
1512            "simple task should route to trusted remote model during cold start"
1513        );
1514        assert!(matches!(
1515            schema.provider.as_str(),
1516            "openai" | "anthropic" | "google"
1517        ));
1518    }
1519
1520    #[test]
1521    fn routes_code_to_code_capable_remote_during_cold_start() {
1522        let router = AdaptiveRouter::new(
1523            test_hw(),
1524            RoutingConfig {
1525                prior_strength: 100.0, // strong prior = exploit Phase 2 scores (deterministic-ish)
1526                ..Default::default()
1527            },
1528        );
1529        let reg = test_registry();
1530        let tracker = OutcomeTracker::new();
1531
1532        let decision = router.route(
1533            "Fix this function:\n```rust\nfn main() {}\n```",
1534            &reg,
1535            &tracker,
1536        );
1537        assert_eq!(decision.complexity, TaskComplexity::Code);
1538        assert_eq!(decision.task, InferenceTask::Code);
1539        // Must select a code-capable local model (not 0.6B which lacks Code)
1540        let schema = reg
1541            .find_by_name(&decision.model_name)
1542            .expect("model should exist");
1543        assert!(
1544            schema.has_capability(ModelCapability::Code),
1545            "selected model must support Code"
1546        );
1547        assert!(!schema.is_local(), "should route to trusted remote model");
1548    }
1549
1550    #[test]
1551    fn routes_images_to_vision_capable_model() {
1552        let router = AdaptiveRouter::new(
1553            test_hw(),
1554            RoutingConfig {
1555                prior_strength: 100.0,
1556                ..Default::default()
1557            },
1558        );
1559        let mut reg = test_registry();
1560        let tracker = OutcomeTracker::new();
1561
1562        reg.register(ModelSchema {
1563            id: "mlx-vlm/qwen3-vl-2b:bf16".into(),
1564            name: "Qwen3-VL-2B-mlx-vlm".into(),
1565            provider: "qwen".into(),
1566            family: "qwen3-vl".into(),
1567            version: "bf16".into(),
1568            capabilities: vec![
1569                ModelCapability::Generate,
1570                ModelCapability::Vision,
1571                ModelCapability::Grounding,
1572            ],
1573            context_length: 262_144,
1574            param_count: "2B".into(),
1575            quantization: None,
1576            performance: Default::default(),
1577            cost: Default::default(),
1578            source: crate::schema::ModelSource::Mlx {
1579                hf_repo: "Qwen/Qwen3-VL-2B-Instruct".into(),
1580                hf_weight_file: None,
1581            },
1582            tags: vec!["vision".into(), "mlx-vlm-cli".into()],
1583            supported_params: vec![],
1584            public_benchmarks: vec![],
1585            available: true,
1586        });
1587
1588        let decision = router.route_with_vision("What is in this image?", &reg, &tracker, false);
1589        let schema = reg
1590            .find_by_name(&decision.model_name)
1591            .expect("model should exist");
1592        assert!(
1593            schema.has_capability(ModelCapability::Vision),
1594            "selected model must support Vision"
1595        );
1596    }
1597
1598    #[test]
1599    fn profile_based_routing_favors_proven_model() {
1600        let router = AdaptiveRouter::new(
1601            test_hw(),
1602            RoutingConfig {
1603                prior_strength: 0.5, // weak prior, let observed data dominate
1604                min_observations: 3,
1605                ..Default::default()
1606            },
1607        );
1608        let reg = test_registry();
1609        let mut tracker = OutcomeTracker::new();
1610
1611        // Build a strong profile for Qwen3-8B on code tasks (fast + high quality)
1612        let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
1613        for _ in 0..20 {
1614            let trace = tracker.record_start(qwen_8b_id, InferenceTask::Code, "test");
1615            tracker.record_complete(&trace, 500, 100, 50);
1616            tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1617        }
1618
1619        // Thompson Sampling is stochastic — run multiple times, 8B should win majority
1620        let mut wins = 0;
1621        for _ in 0..20 {
1622            let decision = router.route("Fix this bug in the parser", &reg, &tracker);
1623            assert_eq!(decision.complexity, TaskComplexity::Code);
1624            if decision.model_id == qwen_8b_id {
1625                wins += 1;
1626            }
1627        }
1628        assert!(
1629            wins >= 12,
1630            "proven model won only {wins}/20 times (expected >= 12)"
1631        );
1632    }
1633
1634    #[test]
1635    fn proven_local_model_can_displace_bootstrap_remote() {
1636        let router = AdaptiveRouter::new(
1637            test_hw(),
1638            RoutingConfig {
1639                prior_strength: 100.0,
1640                bootstrap_min_task_observations: 6,
1641                bootstrap_quality_floor: 0.8,
1642                ..Default::default()
1643            },
1644        );
1645        let reg = test_registry();
1646        let mut tracker = OutcomeTracker::new();
1647
1648        let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
1649        for _ in 0..12 {
1650            let trace = tracker.record_start(qwen_8b_id, InferenceTask::Generate, "test");
1651            tracker.record_complete(&trace, 300, 50, 20);
1652            tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1653        }
1654
1655        let mut local_wins = 0;
1656        for _ in 0..20 {
1657            let decision = router.route("Summarize this design decision.", &reg, &tracker);
1658            let schema = reg
1659                .get(&decision.model_id)
1660                .expect("selected model should exist");
1661            if schema.is_local() {
1662                local_wins += 1;
1663            }
1664        }
1665
1666        assert!(
1667            local_wins >= 12,
1668            "proven local model won only {local_wins}/20 times (expected >= 12)"
1669        );
1670    }
1671
1672    #[test]
1673    fn benchmark_prior_informs_background_routing() {
1674        let router = AdaptiveRouter::new(
1675            test_hw(),
1676            RoutingConfig {
1677                prior_strength: 100.0,
1678                ..Default::default()
1679            },
1680        );
1681        let reg = test_registry();
1682        let mut tracker = OutcomeTracker::new();
1683        let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1684        profile.ema_quality = 0.95;
1685        tracker.import_profiles(vec![profile]);
1686
1687        let decision = router.route_context_aware(
1688            "Write a Python fibonacci function.",
1689            128,
1690            &reg,
1691            &tracker,
1692            false,
1693            false,
1694            RoutingWorkload::Background,
1695        );
1696
1697        let schema = reg
1698            .get(&decision.model_id)
1699            .expect("selected model should exist");
1700        assert!(
1701            schema.is_local(),
1702            "background routing should allow strong local benchmark priors to win"
1703        );
1704    }
1705
1706    #[test]
1707    fn task_specific_benchmark_prior_informs_cold_start_routing() {
1708        let router = AdaptiveRouter::new(
1709            test_hw(),
1710            RoutingConfig {
1711                prior_strength: 100.0,
1712                ..Default::default()
1713            },
1714        );
1715        let reg = test_registry();
1716        let mut tracker = OutcomeTracker::new();
1717        let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1718        profile.task_stats.insert(
1719            crate::outcome::InferenceTask::Code.to_string(),
1720            crate::outcome::TaskStats {
1721                ema_quality: 0.95,
1722                ..Default::default()
1723            },
1724        );
1725        tracker.import_profiles(vec![profile]);
1726
1727        let decision = router.route_context_aware(
1728            "Write a Python fibonacci function.",
1729            128,
1730            &reg,
1731            &tracker,
1732            false,
1733            false,
1734            RoutingWorkload::Background,
1735        );
1736
1737        let schema = reg
1738            .get(&decision.model_id)
1739            .expect("selected model should exist");
1740        assert!(
1741            schema.is_local(),
1742            "background routing should use task-specific cold-start priors for local code models"
1743        );
1744    }
1745
1746    #[test]
1747    fn task_specific_latency_prior_affects_cold_start_score() {
1748        let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1749        let reg = test_registry();
1750        let model = reg
1751            .get("qwen/qwen3-8b:q4_k_m")
1752            .expect("local test model should exist");
1753
1754        let mut fast_tracker = OutcomeTracker::new();
1755        let mut fast_profile = crate::outcome::ModelProfile::new(model.id.clone());
1756        fast_profile.task_stats.insert(
1757            crate::outcome::InferenceTask::Generate.to_string(),
1758            crate::outcome::TaskStats {
1759                ema_quality: 0.95,
1760                avg_latency_ms: 1200.0,
1761                ..Default::default()
1762            },
1763        );
1764        fast_tracker.import_profiles(vec![fast_profile]);
1765
1766        let mut slow_tracker = OutcomeTracker::new();
1767        let mut slow_profile = crate::outcome::ModelProfile::new(model.id.clone());
1768        slow_profile.task_stats.insert(
1769            crate::outcome::InferenceTask::Generate.to_string(),
1770            crate::outcome::TaskStats {
1771                ema_quality: 0.95,
1772                avg_latency_ms: 120_000.0,
1773                ..Default::default()
1774            },
1775        );
1776        slow_tracker.import_profiles(vec![slow_profile]);
1777
1778        let fast_score = router.score_model(
1779            model,
1780            InferenceTask::Generate,
1781            &fast_tracker,
1782            RoutingWorkload::Interactive,
1783        );
1784        let slow_score = router.score_model(
1785            model,
1786            InferenceTask::Generate,
1787            &slow_tracker,
1788            RoutingWorkload::Interactive,
1789        );
1790
1791        assert!(
1792            fast_score > slow_score,
1793            "faster task latency prior should improve cold-start score ({fast_score} <= {slow_score})"
1794        );
1795    }
1796
1797    #[test]
1798    fn interactive_workload_keeps_remote_bootstrap_bias() {
1799        let router = AdaptiveRouter::new(
1800            test_hw(),
1801            RoutingConfig {
1802                prior_strength: 100.0,
1803                ..Default::default()
1804            },
1805        );
1806        let reg = test_registry();
1807        let mut tracker = OutcomeTracker::new();
1808        let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1809        profile.ema_quality = 0.95;
1810        tracker.import_profiles(vec![profile]);
1811
1812        let decision = router.route_context_aware(
1813            "Write a Python fibonacci function.",
1814            128,
1815            &reg,
1816            &tracker,
1817            false,
1818            false,
1819            RoutingWorkload::Interactive,
1820        );
1821
1822        let schema = reg
1823            .get(&decision.model_id)
1824            .expect("selected model should exist");
1825        assert!(
1826            !schema.is_local(),
1827            "interactive routing should still prefer trusted remote models during cold start"
1828        );
1829    }
1830
1831    #[test]
1832    fn fallback_chain_has_alternatives() {
1833        let router = AdaptiveRouter::new(
1834            test_hw(),
1835            RoutingConfig {
1836                prior_strength: 100.0, // strong prior = exploit Phase 2 scores (deterministic-ish)
1837                ..Default::default()
1838            },
1839        );
1840        let reg = test_registry();
1841        let tracker = OutcomeTracker::new();
1842
1843        let decision = router.route("Analyze the architecture trade-offs", &reg, &tracker);
1844        assert!(!decision.fallbacks.is_empty());
1845        // Primary should not appear in fallbacks
1846        assert!(!decision.fallbacks.contains(&decision.model_id));
1847    }
1848
1849    #[test]
1850    fn latency_scoring_is_consistent() {
1851        // Verify that schema and observed latency produce comparable scores
1852        let router = AdaptiveRouter::with_default_config(test_hw());
1853
1854        // A model at 25 TPS → ~200/25*1000 = 8000ms → score = 1 - 8000/10000 = 0.2
1855        let schema_score = router.latency_ms_to_score(AdaptiveRouter::tps_to_latency_ms(25.0));
1856        // Same model observed at 8000ms → same formula
1857        let observed_score = router.latency_ms_to_score(8000.0);
1858        assert!(
1859            (schema_score - observed_score).abs() < 0.01,
1860            "schema ({schema_score}) and observed ({observed_score}) should match"
1861        );
1862    }
1863
1864    #[test]
1865    fn complexity_assessment() {
1866        assert_eq!(
1867            TaskComplexity::assess("What is the capital of France?"),
1868            TaskComplexity::Simple
1869        );
1870        assert_eq!(
1871            TaskComplexity::assess("Fix this broken test"),
1872            TaskComplexity::Code
1873        );
1874        assert_eq!(
1875            TaskComplexity::assess("Analyze the trade-offs between A and B"),
1876            TaskComplexity::Complex
1877        );
1878    }
1879
1880    #[test]
1881    fn beta_sampling_produces_valid_values() {
1882        let mut rng = rand::rng();
1883        // Sample 100 times from Beta(2, 5) — should be in [0, 1]
1884        for _ in 0..100 {
1885            let s = sample_beta(&mut rng, 2.0, 5.0);
1886            assert!(s >= 0.0 && s <= 1.0, "sample {s} out of [0,1] range");
1887        }
1888        // Beta(1, 1) = Uniform(0, 1) — mean should be ~0.5
1889        let samples: Vec<f64> = (0..1000).map(|_| sample_beta(&mut rng, 1.0, 1.0)).collect();
1890        let mean = samples.iter().sum::<f64>() / samples.len() as f64;
1891        assert!(
1892            (mean - 0.5).abs() < 0.05,
1893            "Beta(1,1) mean {mean} should be ~0.5"
1894        );
1895    }
1896
1897    #[test]
1898    fn thompson_sampling_converges_to_best() {
1899        // A model with strong observed success should win most of the time
1900        let router = AdaptiveRouter::new(
1901            test_hw(),
1902            RoutingConfig {
1903                prior_strength: 1.0, // weak prior, let observations dominate
1904                ..Default::default()
1905            },
1906        );
1907        let reg = test_registry();
1908        let mut tracker = OutcomeTracker::new();
1909
1910        // Give Qwen3-4B 20 successes (strong signal)
1911        let qwen_4b_id = "qwen/qwen3-4b:q4_k_m";
1912        for _ in 0..20 {
1913            let trace = tracker.record_start(qwen_4b_id, InferenceTask::Code, "test");
1914            tracker.record_complete(&trace, 500, 100, 50);
1915            tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1916        }
1917
1918        // Route 20 code tasks — strong model should win the majority
1919        let mut wins = 0;
1920        for _ in 0..20 {
1921            let decision = router.route("Fix this parser bug", &reg, &tracker);
1922            if decision.model_id == qwen_4b_id {
1923                wins += 1;
1924            }
1925        }
1926        assert!(
1927            wins >= 14,
1928            "strong model won only {wins}/20 times (expected >= 14)"
1929        );
1930    }
1931
1932    // ----- Intent surface (parslee-ai/car-releases#18) -----
1933
1934    #[test]
1935    fn intent_require_filters_out_models_lacking_capability() {
1936        // Asking for vision when no candidate has it should fall to
1937        // the cold-start decision rather than scoring incompatible
1938        // candidates. The fixture registry has no vision-capable
1939        // local models.
1940        let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1941        let reg = test_registry();
1942        let tracker = OutcomeTracker::new();
1943
1944        let intent = crate::intent::IntentHint {
1945            require: vec![ModelCapability::Vision],
1946            ..Default::default()
1947        };
1948        let decision = router.route_with_intent("hello", &reg, &tracker, &intent);
1949
1950        // When require filters out every candidate, the router falls
1951        // to the schema-based cold-start path. Asserting the strategy
1952        // (not just non-empty model_id) catches a regression where
1953        // future code might silently include filtered candidates.
1954        assert_eq!(
1955            decision.strategy,
1956            RoutingStrategy::SchemaBased,
1957            "require=[vision] with no vision-capable candidates must drop to schema cold-start"
1958        );
1959    }
1960
1961    #[test]
1962    fn intent_default_does_not_override_task_or_caps() {
1963        // Thompson sampling makes per-call model selection
1964        // non-deterministic, so we can't compare model_ids directly.
1965        // What we can assert deterministically: a default IntentHint
1966        // must not change the task selection or the capability
1967        // requirements — those are functions of the prompt only when
1968        // no hint is supplied.
1969        let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1970        let reg = test_registry();
1971        let tracker = OutcomeTracker::new();
1972
1973        let baseline = router.route("write a haiku", &reg, &tracker);
1974        let with_default = router.route_with_intent(
1975            "write a haiku",
1976            &reg,
1977            &tracker,
1978            &crate::intent::IntentHint::default(),
1979        );
1980
1981        assert_eq!(
1982            baseline.task, with_default.task,
1983            "default IntentHint must not change the prompt-derived task"
1984        );
1985        assert_eq!(
1986            baseline.complexity, with_default.complexity,
1987            "default IntentHint must not change the prompt-derived complexity"
1988        );
1989    }
1990
1991    #[test]
1992    fn intent_task_hint_overrides_prompt_complexity() {
1993        // A short prompt that complexity assessment would route as
1994        // Generate should land on Reasoning when the intent says so.
1995        let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1996        let reg = test_registry();
1997        let tracker = OutcomeTracker::new();
1998
1999        let hint = crate::intent::IntentHint {
2000            task: Some(crate::intent::TaskHint::Reasoning),
2001            ..Default::default()
2002        };
2003        let decision = router.route_with_intent("hi", &reg, &tracker, &hint);
2004
2005        assert_eq!(
2006            decision.task,
2007            InferenceTask::Reasoning,
2008            "TaskHint::Reasoning should override the prompt-derived task"
2009        );
2010    }
2011}