Skip to main content

car_inference/
adaptive_router.rs

1//! Adaptive model routing — three-phase routing with learned performance profiles.
2//!
3//! Phase 1: **Filter** — hard constraints (capability, availability, memory, cost).
4//! Phase 2: **Score** — blend quality, latency, and cost using observed profiles
5//!          or schema defaults on cold start.
6//! Phase 3: **Select** — Thompson Sampling (Beta distribution per model) for
7//!          natural exploration-exploitation balance. Models with fewer observations
8//!          have wider distributions, giving them chances to prove themselves.
9//!
10//! Replaces the hardcoded `ModelRouter` from `router.rs`.
11
12use rand::Rng;
13use serde::{Deserialize, Serialize};
14
15use std::sync::{Arc, Mutex};
16
17use crate::hardware::HardwareInfo;
18use crate::outcome::{InferenceTask, OutcomeTracker};
19use crate::registry::UnifiedRegistry;
20use crate::routing_ext::CircuitBreakerRegistry;
21use crate::schema::{ModelCapability, ModelSchema};
22use crate::tasks::RoutingWorkload;
23
24/// Prompt complexity assessment (migrated from router.rs).
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case")]
27pub enum TaskComplexity {
28    Simple,
29    Medium,
30    Code,
31    Complex,
32}
33
34impl TaskComplexity {
35    /// Assess complexity of a prompt string.
36    ///
37    /// Uses tree-sitter AST parsing (when the `ast` feature is enabled) for
38    /// accurate code detection: if a code block parses successfully as any
39    /// supported language, it's definitively code. Falls back to keyword
40    /// heuristics for prompts that mention code without containing code blocks.
41    pub fn assess(prompt: &str) -> Self {
42        let lower = prompt.to_lowercase();
43        let word_count = prompt.split_whitespace().count();
44        let estimated_tokens = (word_count as f64 * 1.3) as usize;
45
46        let has_code = Self::detect_code(prompt);
47
48        let repair_markers = [
49            "fix", "repair", "debug", "refactor", "broken", "failing", "error", "bug",
50        ];
51        let has_repair = repair_markers.iter().any(|m| lower.contains(m));
52
53        let reasoning_markers = [
54            "analyze",
55            "compare",
56            "explain why",
57            "step by step",
58            "think through",
59            "evaluate",
60            "trade-off",
61            "tradeoff",
62            "pros and cons",
63            "architecture",
64            "design",
65            "strategy",
66            "optimize",
67            "comprehensive",
68        ];
69        let has_reasoning = reasoning_markers.iter().any(|m| lower.contains(m));
70
71        let simple_patterns = [
72            "what is",
73            "who is",
74            "when did",
75            "where is",
76            "how many",
77            "yes or no",
78            "true or false",
79            "name the",
80            "list the",
81            "define ",
82        ];
83        let is_simple = simple_patterns.iter().any(|p| lower.contains(p));
84
85        if has_code || has_repair {
86            TaskComplexity::Code
87        } else if has_reasoning || estimated_tokens > 500 {
88            TaskComplexity::Complex
89        } else if is_simple || estimated_tokens < 30 {
90            TaskComplexity::Simple
91        } else {
92            TaskComplexity::Medium
93        }
94    }
95
96    /// Detect whether a prompt contains code.
97    ///
98    /// With `ast` feature: extracts code blocks (``` delimited), attempts to
99    /// parse each with tree-sitter. If any parses into symbols, it's real code.
100    /// Without `ast` feature: falls back to keyword heuristics.
101    fn detect_code(prompt: &str) -> bool {
102        // First try AST-based detection on code blocks
103        #[cfg(feature = "ast")]
104        {
105            if let Some(is_code) = Self::detect_code_ast(prompt) {
106                return is_code;
107            }
108        }
109
110        // Fallback: keyword heuristics
111        let code_markers = [
112            "```",
113            "fn ",
114            "def ",
115            "class ",
116            "import ",
117            "require(",
118            "async fn",
119            "pub fn",
120            "function ",
121            "const ",
122            "let ",
123            "var ",
124            "#include",
125            "package ",
126            "impl ",
127        ];
128        code_markers.iter().any(|m| prompt.contains(m))
129    }
130
131    /// AST-based code detection: parse code blocks with tree-sitter.
132    /// Returns Some(true) if code found, Some(false) if blocks exist but
133    /// don't parse, None if no code blocks found (fall through to heuristics).
134    #[cfg(feature = "ast")]
135    fn detect_code_ast(prompt: &str) -> Option<bool> {
136        // Extract code blocks between ``` markers
137        let mut blocks = Vec::new();
138        let mut rest = prompt;
139        while let Some(start) = rest.find("```") {
140            let after_fence = &rest[start + 3..];
141            // Skip optional language tag on the opening fence
142            let code_start = after_fence.find('\n').map(|i| i + 1).unwrap_or(0);
143            if let Some(end) = after_fence[code_start..].find("```") {
144                blocks.push(&after_fence[code_start..code_start + end]);
145                rest = &after_fence[code_start + end + 3..];
146            } else {
147                break;
148            }
149        }
150
151        if blocks.is_empty() {
152            return None; // No code blocks — let heuristics decide
153        }
154
155        // Try to parse each block with tree-sitter
156        let languages = [
157            car_ast::Language::Rust,
158            car_ast::Language::Python,
159            car_ast::Language::TypeScript,
160            car_ast::Language::JavaScript,
161            car_ast::Language::Go,
162        ];
163
164        for block in &blocks {
165            let trimmed = block.trim();
166            if trimmed.is_empty() {
167                continue;
168            }
169
170            for lang in &languages {
171                if let Some(parsed) = car_ast::parse(trimmed, *lang) {
172                    // If it parsed into any symbols, it's definitely code
173                    if !parsed.symbols.is_empty() {
174                        return Some(true);
175                    }
176                }
177            }
178        }
179
180        // Had code blocks but none parsed into symbols — could be
181        // pseudocode, output, or unsupported language
182        Some(false)
183    }
184
185    /// Map complexity to required capabilities.
186    pub fn required_capabilities(&self) -> Vec<ModelCapability> {
187        match self {
188            TaskComplexity::Simple => vec![ModelCapability::Generate],
189            TaskComplexity::Medium => vec![ModelCapability::Generate],
190            TaskComplexity::Code => vec![ModelCapability::Code],
191            TaskComplexity::Complex => vec![ModelCapability::Reasoning],
192        }
193    }
194
195    /// Map complexity to the InferenceTask type.
196    pub fn inference_task(&self) -> InferenceTask {
197        match self {
198            TaskComplexity::Simple | TaskComplexity::Medium => InferenceTask::Generate,
199            TaskComplexity::Code => InferenceTask::Code,
200            TaskComplexity::Complex => InferenceTask::Reasoning,
201        }
202    }
203}
204
205/// Configuration for routing behavior.
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct RoutingConfig {
208    /// Minimum observations before trusting a model's profile over schema defaults.
209    pub min_observations: u64,
210    /// Scoring weights (must sum to 1.0).
211    pub quality_weight: f64,
212    pub latency_weight: f64,
213    pub cost_weight: f64,
214    /// Hard constraint: maximum latency budget in ms.
215    pub max_latency_ms: Option<u64>,
216    /// Hard constraint: maximum cost per call in USD.
217    pub max_cost_usd: Option<f64>,
218    /// Prefer local models over remote (all else being equal).
219    pub prefer_local: bool,
220    /// Thompson Sampling prior strength. Higher = more weight on the Phase 2 score
221    /// as a prior, lower = more influenced by observed outcomes.
222    /// Equivalent to the number of "virtual" observations from the prior.
223    pub prior_strength: f64,
224    /// Prefer trusted remote models for quality-critical tasks until local models
225    /// have enough task-specific evidence to be promoted.
226    pub quality_first_cold_start: bool,
227    /// Minimum task-specific observations required before a local model can
228    /// compete with trusted remote models during cold start.
229    pub bootstrap_min_task_observations: u64,
230    /// Minimum task-specific EMA quality required before a local model can
231    /// displace trusted remote models during cold start.
232    pub bootstrap_quality_floor: f64,
233}
234
235impl Default for RoutingConfig {
236    fn default() -> Self {
237        Self {
238            min_observations: 2,
239            quality_weight: 0.45,
240            latency_weight: 0.4,
241            cost_weight: 0.15,
242            max_latency_ms: None,
243            max_cost_usd: None,
244            prefer_local: true,
245            prior_strength: 2.0,
246            quality_first_cold_start: true,
247            bootstrap_min_task_observations: 8,
248            bootstrap_quality_floor: 0.8,
249        }
250    }
251}
252
253/// How a model was selected.
254#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
255#[serde(rename_all = "snake_case")]
256pub enum RoutingStrategy {
257    /// Using declared schema capabilities (no observed data).
258    SchemaBased,
259    /// Using observed performance profiles (exploitation).
260    ProfileBased,
261    /// Deliberately trying an under-tested model (exploration).
262    Exploration,
263    /// User explicitly specified the model.
264    Explicit,
265}
266
267/// The result of adaptive routing.
268#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct AdaptiveRoutingDecision {
270    /// Selected model id.
271    pub model_id: String,
272    /// Selected model name (display).
273    pub model_name: String,
274    /// Task type.
275    pub task: InferenceTask,
276    /// Assessed complexity.
277    pub complexity: TaskComplexity,
278    /// Human-readable reason.
279    pub reason: String,
280    /// How the model was selected.
281    pub strategy: RoutingStrategy,
282    /// Predicted quality (0.0-1.0).
283    pub predicted_quality: f64,
284    /// Fallback chain (ordered list of alternative model ids).
285    pub fallbacks: Vec<String>,
286    /// Context window of the selected model (tokens). 0 = unknown.
287    pub context_length: usize,
288    /// Whether the prompt needs compaction to fit the selected model's context window.
289    pub needs_compaction: bool,
290}
291
292/// Adaptive router with three-phase model selection.
293pub struct AdaptiveRouter {
294    hw: HardwareInfo,
295    config: RoutingConfig,
296    /// Circuit breaker registry — blocks models after consecutive failures (#25).
297    pub circuit_breakers: Arc<Mutex<CircuitBreakerRegistry>>,
298}
299
300/// All inputs to a routing decision packed into one struct so adding
301/// a new combinator (per-tenant routing, streaming awareness, …)
302/// only adds a field rather than another `route_*` sibling method.
303/// See [`AdaptiveRouter::route_with`].
304///
305/// Use [`RouteRequest::new`] for the common defaults, then mutate
306/// just the fields the caller cares about:
307///
308/// ```ignore
309/// let decision = router.route_with(RouteRequest {
310///     has_tools: true,
311///     intent: Some(&hint),
312///     ..RouteRequest::new(prompt, &registry, &tracker)
313/// });
314/// ```
315pub struct RouteRequest<'a> {
316    pub prompt: &'a str,
317    pub registry: &'a UnifiedRegistry,
318    pub tracker: &'a OutcomeTracker,
319    /// Estimated prompt-side token count for context-window-aware
320    /// scoring. `0` skips the compaction-headroom check.
321    pub estimated_total_tokens: usize,
322    pub has_tools: bool,
323    pub has_vision: bool,
324    pub workload: RoutingWorkload,
325    /// Caller-supplied intent hint. `prefer_local: true` overrides
326    /// `workload` to [`RoutingWorkload::LocalPreferred`].
327    pub intent: Option<&'a crate::intent::IntentHint>,
328}
329
330impl<'a> RouteRequest<'a> {
331    /// Build a request with the same defaults the bare
332    /// [`AdaptiveRouter::route`] uses: interactive workload, no tools,
333    /// no vision, no context-aware sizing, no intent.
334    pub fn new(
335        prompt: &'a str,
336        registry: &'a UnifiedRegistry,
337        tracker: &'a OutcomeTracker,
338    ) -> Self {
339        Self {
340            prompt,
341            registry,
342            tracker,
343            estimated_total_tokens: 0,
344            has_tools: false,
345            has_vision: false,
346            workload: RoutingWorkload::Interactive,
347            intent: None,
348        }
349    }
350}
351
352impl AdaptiveRouter {
353    pub fn new(hw: HardwareInfo, config: RoutingConfig) -> Self {
354        let circuit_breakers = Arc::new(Mutex::new(
355            CircuitBreakerRegistry::new(3, 300), // 3 failures, 5 min cooldown
356        ));
357        Self {
358            hw,
359            config,
360            circuit_breakers,
361        }
362    }
363
364    pub fn with_default_config(hw: HardwareInfo) -> Self {
365        Self::new(hw, RoutingConfig::default())
366    }
367
368    pub fn config(&self) -> &RoutingConfig {
369        &self.config
370    }
371
372    pub fn set_config(&mut self, config: RoutingConfig) {
373        self.config = config;
374    }
375
376    /// Canonical entry point. The seven `route_*` sibling methods
377    /// each build a `RouteRequest` with their fixed defaults and
378    /// call this — adding a new combinator (per-tenant routing,
379    /// streaming awareness, …) only adds a field here, not another
380    /// public method. Closes #108.
381    pub fn route_with(&self, req: RouteRequest<'_>) -> AdaptiveRoutingDecision {
382        // Intent-hint precedence over the caller-supplied workload:
383        // `prefer_fast` wins outright (voice fast track is the most
384        // latency-sensitive path we have); `prefer_local` is the
385        // long-standing override; absent either, the caller's
386        // workload stands.
387        let workload = match req.intent {
388            Some(h) if h.prefer_fast => RoutingWorkload::Fastest,
389            Some(h) if h.prefer_local => RoutingWorkload::LocalPreferred,
390            _ => req.workload,
391        };
392        self.route_inner_with_intent(
393            req.prompt,
394            req.registry,
395            req.tracker,
396            req.has_tools,
397            req.has_vision,
398            req.estimated_total_tokens,
399            workload,
400            req.intent,
401        )
402    }
403
404    /// Route a generation request to the best model.
405    /// If `has_tools` is true, requires ToolUse capability (#13).
406    pub fn route(
407        &self,
408        prompt: &str,
409        registry: &UnifiedRegistry,
410        tracker: &OutcomeTracker,
411    ) -> AdaptiveRoutingDecision {
412        self.route_with(RouteRequest::new(prompt, registry, tracker))
413    }
414
415    /// Route an "editor" request — cheap mechanical work (context compaction,
416    /// edit materialization, title generation, replanning). Uses
417    /// [`RoutingWorkload::Background`] so cost/quality weights favour cheaper
418    /// models, and lets hosting layers express the Aider-style architect/editor
419    /// split without plumbing a raw model-name override through every layer.
420    pub fn route_editor(
421        &self,
422        prompt: &str,
423        registry: &UnifiedRegistry,
424        tracker: &OutcomeTracker,
425    ) -> AdaptiveRoutingDecision {
426        self.route_with(RouteRequest {
427            workload: RoutingWorkload::Background,
428            ..RouteRequest::new(prompt, registry, tracker)
429        })
430    }
431
432    /// Route with tool_use requirement — filters to models that support structured tool calls.
433    pub fn route_with_tools(
434        &self,
435        prompt: &str,
436        registry: &UnifiedRegistry,
437        tracker: &OutcomeTracker,
438    ) -> AdaptiveRoutingDecision {
439        self.route_with(RouteRequest {
440            has_tools: true,
441            ..RouteRequest::new(prompt, registry, tracker)
442        })
443    }
444
445    /// Route with image input requirement — filters to models that support vision.
446    pub fn route_with_vision(
447        &self,
448        prompt: &str,
449        registry: &UnifiedRegistry,
450        tracker: &OutcomeTracker,
451        has_tools: bool,
452    ) -> AdaptiveRoutingDecision {
453        self.route_with(RouteRequest {
454            has_tools,
455            has_vision: true,
456            ..RouteRequest::new(prompt, registry, tracker)
457        })
458    }
459
460    /// Route with caller-supplied intent — see [`crate::IntentHint`].
461    /// The hint can override the auto-detected `InferenceTask`, add hard
462    /// `require` capability filters on top of the prompt-derived ones,
463    /// and bias the score profile toward local models when
464    /// `prefer_local` is set.
465    pub fn route_with_intent<'a>(
466        &self,
467        prompt: &'a str,
468        registry: &'a UnifiedRegistry,
469        tracker: &'a OutcomeTracker,
470        intent: &'a crate::intent::IntentHint,
471    ) -> AdaptiveRoutingDecision {
472        self.route_with(RouteRequest {
473            intent: Some(intent),
474            ..RouteRequest::new(prompt, registry, tracker)
475        })
476    }
477
478    /// Route with context awareness — estimates prompt tokens and prefers models
479    /// whose context window can fit the full prompt without compaction.
480    pub fn route_context_aware(
481        &self,
482        prompt: &str,
483        estimated_total_tokens: usize,
484        registry: &UnifiedRegistry,
485        tracker: &OutcomeTracker,
486        has_tools: bool,
487        has_vision: bool,
488        workload: RoutingWorkload,
489    ) -> AdaptiveRoutingDecision {
490        self.route_with(RouteRequest {
491            estimated_total_tokens,
492            has_tools,
493            has_vision,
494            workload,
495            ..RouteRequest::new(prompt, registry, tracker)
496        })
497    }
498
499    /// Context-aware routing with caller-supplied intent. Same context
500    /// math as [`Self::route_context_aware`]; the intent layers on
501    /// top — task override, additional `require` filters,
502    /// `prefer_local` workload override.
503    pub fn route_context_aware_with_intent<'a>(
504        &self,
505        prompt: &'a str,
506        estimated_total_tokens: usize,
507        registry: &'a UnifiedRegistry,
508        tracker: &'a OutcomeTracker,
509        has_tools: bool,
510        has_vision: bool,
511        workload: RoutingWorkload,
512        intent: &'a crate::intent::IntentHint,
513    ) -> AdaptiveRoutingDecision {
514        self.route_with(RouteRequest {
515            estimated_total_tokens,
516            has_tools,
517            has_vision,
518            workload,
519            intent: Some(intent),
520            ..RouteRequest::new(prompt, registry, tracker)
521        })
522    }
523
524    fn route_inner_with_intent(
525        &self,
526        prompt: &str,
527        registry: &UnifiedRegistry,
528        tracker: &OutcomeTracker,
529        has_tools: bool,
530        has_vision: bool,
531        estimated_total_tokens: usize,
532        workload: RoutingWorkload,
533        intent: Option<&crate::intent::IntentHint>,
534    ) -> AdaptiveRoutingDecision {
535        let complexity = TaskComplexity::assess(prompt);
536        // Caller intent overrides the prompt-derived task when supplied.
537        let task = intent
538            .and_then(|h| h.task)
539            .map(task_hint_to_inference_task)
540            .unwrap_or_else(|| complexity.inference_task());
541        let mut required_caps = complexity.required_capabilities();
542        if let Some(hint) = intent {
543            for cap in &hint.require {
544                if !required_caps.contains(cap) {
545                    required_caps.push(*cap);
546                }
547            }
548        }
549        if has_vision {
550            required_caps.push(ModelCapability::Vision);
551        }
552        if has_tools {
553            required_caps.push(ModelCapability::ToolUse);
554            // Detect multi-step prompts that need multiple tool calls in one response.
555            // Patterns: numbered lists ("1) ... 2) ..."), multiple instructions, explicit multi-edit.
556            if Self::needs_multi_tool_call(prompt) {
557                required_caps.push(ModelCapability::MultiToolCall);
558            }
559        }
560
561        // Phase 1: Filter candidates
562        let mut candidates = self.filter_candidates(&required_caps, registry, tracker);
563
564        // Fallback: if requiring MultiToolCall eliminates all candidates, drop the
565        // requirement and let the best ToolUse model handle it with multiple round-trips.
566        if candidates.is_empty() && required_caps.contains(&ModelCapability::MultiToolCall) {
567            required_caps.retain(|c| *c != ModelCapability::MultiToolCall);
568            candidates = self.filter_candidates(&required_caps, registry, tracker);
569        }
570
571        if candidates.is_empty() {
572            // Nothing available — return the schema-based default
573            return self.cold_start_decision(complexity, task, registry, has_vision);
574        }
575
576        candidates = self.apply_quality_first_bootstrap_policy(
577            candidates,
578            task,
579            tracker,
580            has_vision,
581            has_tools,
582            workload,
583        );
584
585        // Context-aware filtering: if we know the prompt size, prefer models that fit.
586        // Phase 1b: separate candidates into "fits" and "needs compaction" groups.
587        let (fits, needs_compaction_candidates) = if estimated_total_tokens > 0 {
588            let mut fits = Vec::new();
589            let mut tight = Vec::new();
590            for m in &candidates {
591                if m.context_length == 0 || m.context_length >= estimated_total_tokens {
592                    fits.push(m.clone());
593                } else {
594                    tight.push(m.clone());
595                }
596            }
597            (fits, tight)
598        } else {
599            (candidates.clone(), Vec::new())
600        };
601
602        // Prefer models that fit; fall back to compaction-required models if none fit
603        let (scoring_candidates, compaction_needed) = if !fits.is_empty() {
604            (fits, false)
605        } else if !needs_compaction_candidates.is_empty() {
606            tracing::info!(
607                prompt_tokens = estimated_total_tokens,
608                candidates = needs_compaction_candidates.len(),
609                "no model fits full prompt — compaction will be needed"
610            );
611            (needs_compaction_candidates.clone(), true)
612        } else {
613            (candidates.clone(), false)
614        };
615
616        // Phase 2: Score candidates (with context headroom bonus)
617        let scored = self.score_candidates_context_aware(
618            &scoring_candidates,
619            task,
620            tracker,
621            estimated_total_tokens,
622            workload,
623        );
624
625        // Phase 3: Thompson Sampling selection
626        let (selected_id, strategy) = self.select_with_thompson_sampling(&scored, tracker);
627
628        // Build fallback chain: prefer models that fit, then compaction candidates
629        let mut fallbacks: Vec<String> = scored
630            .iter()
631            .filter(|(id, _)| *id != selected_id)
632            .map(|(id, _)| id.clone())
633            .collect();
634        // Add compaction candidates to the end of the fallback chain
635        if !compaction_needed {
636            for m in &needs_compaction_candidates {
637                if m.id != selected_id && !fallbacks.contains(&m.id) {
638                    fallbacks.push(m.id.clone());
639                }
640            }
641        }
642
643        let predicted_quality = scored
644            .iter()
645            .find(|(id, _)| *id == selected_id)
646            .map(|(_, score)| *score)
647            .unwrap_or(0.5);
648
649        let selected_schema = registry
650            .get(&selected_id)
651            .or_else(|| registry.find_by_name(&selected_id));
652        let model_name = selected_schema
653            .map(|m| m.name.clone())
654            .unwrap_or_else(|| selected_id.clone());
655        let context_length = selected_schema.map(|m| m.context_length).unwrap_or(0);
656
657        let needs_compact = compaction_needed
658            || (estimated_total_tokens > 0
659                && context_length > 0
660                && estimated_total_tokens > context_length);
661
662        let compaction_note = if needs_compact {
663            format!(
664                " [compaction needed: {}→{}tok]",
665                estimated_total_tokens, context_length
666            )
667        } else {
668            String::new()
669        };
670
671        let reason = format!(
672            "{:?} task → {} via {:?} (quality: {:.2}, {} candidates){}",
673            complexity,
674            model_name,
675            strategy,
676            predicted_quality,
677            scoring_candidates.len(),
678            compaction_note,
679        );
680
681        AdaptiveRoutingDecision {
682            model_id: selected_id,
683            model_name,
684            task,
685            complexity,
686            reason,
687            strategy,
688            predicted_quality,
689            fallbacks,
690            context_length,
691            needs_compaction: needs_compact,
692        }
693    }
694
695    /// Route to the best embedding model.
696    pub fn route_embedding(&self, registry: &UnifiedRegistry) -> String {
697        let embed_models = registry.query_by_capability(ModelCapability::Embed);
698        embed_models
699            .first()
700            .map(|m| m.name.clone())
701            .unwrap_or_else(|| "Qwen3-Embedding-0.6B".to_string())
702    }
703
704    /// Route to the smallest available model (for classification).
705    pub fn route_small(&self, registry: &UnifiedRegistry) -> String {
706        let gen_models = registry.query_by_capability(ModelCapability::Generate);
707        // Pick smallest by size
708        gen_models
709            .iter()
710            .filter(|m| m.is_local())
711            .min_by_key(|m| m.size_mb())
712            .map(|m| m.name.clone())
713            .unwrap_or_else(|| "Qwen3-0.6B".to_string())
714    }
715
716    // --- Internal phases ---
717
718    // --- Scoring constants ---
719
720    /// Latency ceiling: requests taking longer than this score 0.0.
721    const LATENCY_CEILING_MS: f64 = 10000.0;
722    /// TPS ceiling: models faster than this score 1.0.
723    const _TPS_CEILING: f64 = 150.0;
724    /// MoE throughput penalty for Candle: naive expert routing runs at ~10% of declared TPS.
725    const MOE_TPS_MULTIPLIER: f64 = 0.10;
726    /// MoE throughput multiplier for MLX: fused Metal kernels run at ~50% of declared TPS.
727    const MLX_MOE_TPS_MULTIPLIER: f64 = 0.50;
728    /// Cost ceiling: models costing more than this per 1K output tokens score 0.0.
729    const COST_CEILING_PER_1K: f64 = 0.1;
730    /// Local preference bonus added to the weighted score (before normalization).
731    const LOCAL_BONUS: f64 = 0.15;
732    /// True on platforms where local inference has a GPU/NPU backend
733    /// (Apple Silicon with MLX). On Intel Macs and on hosts built with
734    /// `--cfg=car_skip_mlx`, local inference falls through to candle/
735    /// CPU which is far slower than cloud for non-trivial models. The
736    /// router suppresses LOCAL_BONUS when this is false so cloud
737    /// models rank fairly instead of losing to a CPU-bound 4B that
738    /// will take 30s to first-token.
739    #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
740    const HAS_GPU_BACKEND: bool = true;
741    #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
742    const HAS_GPU_BACKEND: bool = false;
743    /// Extra bonus for MLX models on Apple Silicon (stacks with LOCAL_BONUS).
744    /// MLX gets fused Metal kernels and better memory layout vs Candle on Mac.
745    #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
746    const MLX_BONUS: f64 = 0.10;
747    /// Extra bonus for system-owned models — no model file on disk, no
748    /// download to provision, framework-managed memory. Currently only
749    /// `apple/foundation:default` qualifies. Tag-driven so future
750    /// system-LLM integrations (e.g. Android AICore) inherit the
751    /// scoring without a code change. Stacks with LOCAL_BONUS but
752    /// **excludes** MLX_BONUS (system models aren't MLX); the net
753    /// effect is FoundationModels ranks roughly even with a
754    /// well-warmed MLX 4B for short fast-turn tasks instead of
755    /// strictly losing because its catalog `tokens_per_second` /
756    /// `size_mb` are null.
757    const SYSTEM_LLM_BONUS: f64 = 0.12;
758
759    // --- Internal phases ---
760
761    /// Phase 1: Filter by hard constraints.
762    fn filter_candidates(
763        &self,
764        required_caps: &[ModelCapability],
765        registry: &UnifiedRegistry,
766        tracker: &OutcomeTracker,
767    ) -> Vec<ModelSchema> {
768        registry
769            .list()
770            .into_iter()
771            .filter(|m| {
772                // Must have all required capabilities
773                if !required_caps.iter().all(|c| m.has_capability(*c)) {
774                    return false;
775                }
776                // Must be available (downloaded for local, API key set for remote)
777                if !m.available {
778                    return false;
779                }
780                // Local models must fit in memory (strict: >= excludes models at the limit)
781                if m.is_local() && m.size_mb() >= self.hw.max_model_mb {
782                    return false;
783                }
784                // Hard latency constraint
785                if let Some(max) = self.config.max_latency_ms {
786                    if let Some(p50) = m.performance.latency_p50_ms {
787                        if p50 > max {
788                            return false;
789                        }
790                    }
791                }
792                // Hard cost constraint
793                if let Some(max) = self.config.max_cost_usd {
794                    if m.cost_per_1k_output() > max {
795                        return false;
796                    }
797                }
798                // Hard exclusion: prefer_local=false excludes all local models (#12)
799                if !self.config.prefer_local && m.is_local() {
800                    return false;
801                }
802                // Hard exclusion: rate-limited models excluded for this session (#13)
803                if tracker.is_excluded(&m.id) {
804                    return false;
805                }
806                // Circuit breaker: skip models with consecutive failures (#25)
807                if let Ok(mut cb) = self.circuit_breakers.lock() {
808                    if !cb.allow_request(&m.id) {
809                        tracing::debug!(model = %m.id, "skipped by circuit breaker");
810                        return false;
811                    }
812                }
813                true
814            })
815            .cloned()
816            .collect()
817    }
818
819    fn apply_quality_first_bootstrap_policy(
820        &self,
821        candidates: Vec<ModelSchema>,
822        task: InferenceTask,
823        tracker: &OutcomeTracker,
824        has_vision: bool,
825        has_tools: bool,
826        workload: RoutingWorkload,
827    ) -> Vec<ModelSchema> {
828        if !self.config.quality_first_cold_start
829            || !workload.is_latency_sensitive()
830            || !self.is_quality_critical_bootstrap_task(task, has_vision, has_tools)
831        {
832            return candidates;
833        }
834
835        let trusted_remote: Vec<ModelSchema> = candidates
836            .iter()
837            .filter(|model| self.is_trusted_quality_remote(model))
838            .cloned()
839            .collect();
840
841        if trusted_remote.is_empty() {
842            return candidates;
843        }
844
845        let proven_local: Vec<ModelSchema> = candidates
846            .iter()
847            .filter(|model| model.is_local() && self.is_local_model_proven_for_task(model, task, tracker))
848            .cloned()
849            .collect();
850
851        if !proven_local.is_empty() {
852            return proven_local;
853        }
854
855        trusted_remote
856    }
857
858    /// Phase 2: Score candidates with context awareness.
859    /// Applies a headroom bonus to models with more context window for the prompt.
860    /// When estimated_total_tokens is 0, no context bonus/penalty is applied.
861    fn score_candidates_context_aware(
862        &self,
863        candidates: &[ModelSchema],
864        task: InferenceTask,
865        tracker: &OutcomeTracker,
866        estimated_total_tokens: usize,
867        workload: RoutingWorkload,
868    ) -> Vec<(String, f64)> {
869        let mut scored: Vec<(String, f64)> = candidates
870            .iter()
871            .map(|m| {
872                let base_score = self.score_model(m, task, tracker, workload);
873                // Context headroom bonus: prefer models with more room to spare.
874                // Max bonus: 0.10 (at 4x headroom or more). No bonus if unknown.
875                let headroom_bonus = if estimated_total_tokens > 0 && m.context_length > 0 {
876                    let ratio = m.context_length as f64 / estimated_total_tokens as f64;
877                    if ratio >= 1.0 {
878                        (ratio.min(4.0) - 1.0) / 3.0 * 0.10 // 0.0 at exact fit, 0.10 at 4x
879                    } else {
880                        -0.15 // Penalty for models that require compaction
881                    }
882                } else {
883                    0.0
884                };
885                (m.id.clone(), base_score + headroom_bonus)
886            })
887            .collect();
888
889        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
890        scored
891    }
892
893    /// Score a single model. All sub-scores are in [0.0, 1.0].
894    /// Final score = weighted sum + local_bonus, so range is [0.0, ~1.15].
895    fn score_model(
896        &self,
897        model: &ModelSchema,
898        task: InferenceTask,
899        tracker: &OutcomeTracker,
900        workload: RoutingWorkload,
901    ) -> f64 {
902        let profile = tracker.profile(&model.id);
903        let schema_quality = self.schema_quality_estimate(model);
904        let schema_latency = self.schema_latency_estimate(model);
905        let (quality_weight, latency_weight, cost_weight) = workload.weights();
906
907        // Quality: blend schema estimate with observed data based on observation count.
908        // Both cold start and warm start use consistent blending.
909        let quality = match profile {
910            Some(p) if p.total_calls >= self.config.min_observations => p
911                .task_stats(task)
912                .map(|ts| ts.ema_quality)
913                .unwrap_or(p.ema_quality),
914            Some(p) if p.total_calls == 0 => p
915                .task_stats(task)
916                .map(|ts| ts.ema_quality)
917                .unwrap_or(p.ema_quality),
918            Some(p) if p.total_calls > 0 => {
919                let w = p.total_calls as f64 / self.config.min_observations as f64;
920                schema_quality * (1.0 - w) + p.ema_quality * w
921            }
922            _ => schema_quality,
923        };
924
925        // Latency: same blending as quality — don't trust a single observation more
926        // than schema estimates. This prevents routing oscillation on first few calls.
927        let latency = match profile {
928            Some(p) if p.total_calls >= self.config.min_observations => {
929                let avg = p
930                    .task_stats(task)
931                    .filter(|ts| ts.calls > 0 || ts.avg_latency_ms > 0.0)
932                    .map(|ts| ts.avg_latency_ms)
933                    .unwrap_or_else(|| p.avg_latency_ms());
934                self.latency_ms_to_score(avg)
935            }
936            Some(p) if p.total_calls == 0 => p
937                .task_stats(task)
938                .filter(|ts| ts.avg_latency_ms > 0.0)
939                .map(|ts| self.latency_ms_to_score(ts.avg_latency_ms))
940                .unwrap_or(schema_latency),
941            Some(p) if p.total_calls > 0 => {
942                let observed = self.latency_ms_to_score(
943                    p.task_stats(task)
944                        .filter(|ts| ts.calls > 0 || ts.avg_latency_ms > 0.0)
945                        .map(|ts| ts.avg_latency_ms)
946                        .unwrap_or_else(|| p.avg_latency_ms()),
947                );
948                let w = p.total_calls as f64 / self.config.min_observations as f64;
949                schema_latency * (1.0 - w) + observed * w
950            }
951            _ => schema_latency,
952        };
953
954        // Cost score (lower is better → invert)
955        let cost = if model.is_local() {
956            1.0
957        } else {
958            (1.0 - (model.cost_per_1k_output() / Self::COST_CEILING_PER_1K)).clamp(0.0, 1.0)
959        };
960
961        // Suppress LOCAL_BONUS on hosts without a GPU/NPU backend.
962        // Without it, an Intel Mac or a `car_skip_mlx` build would
963        // pick a CPU-bound local 4B over a comparable cloud model
964        // every time, then surprise the user with 30s+ first-token
965        // latency. Cloud loses on cost/privacy in normal scoring;
966        // dropping the local bonus lets it win on the latency that
967        // actually matters when the GPU isn't there. Tracing emits
968        // when this fires so the silent degradation becomes visible.
969        let local_bonus = if self.config.prefer_local && model.is_local() && Self::HAS_GPU_BACKEND {
970            Self::LOCAL_BONUS
971        } else {
972            if self.config.prefer_local && model.is_local() && !Self::HAS_GPU_BACKEND {
973                tracing::debug!(
974                    model = %model.id,
975                    "LOCAL_BONUS suppressed: no GPU backend on this host (Intel Mac or car_skip_mlx); cloud models will rank higher"
976                );
977            }
978            0.0
979        };
980        let workload_local_bonus = if model.is_local() {
981            workload.local_bonus()
982        } else {
983            0.0
984        };
985
986        // On Apple Silicon, prefer MLX models over Candle equivalents
987        #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
988        let mlx_bonus = if model.is_mlx() { Self::MLX_BONUS } else { 0.0 };
989        #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
990        let mlx_bonus = 0.0;
991
992        // vLLM-MLX bonus: continuous batching gives better multi-agent throughput
993        let vllm_mlx_bonus = if model.is_vllm_mlx() {
994            Self::LOCAL_BONUS + 0.05
995        } else {
996            0.0
997        };
998
999        // System-LLM bonus: catalog tag-driven so it generalizes beyond
1000        // FoundationModels. Models tagged `low_latency` AND `private`
1001        // are zero-cost system-owned LLMs (apple/foundation:default
1002        // today; AICore-on-Android etc. in the future). They don't
1003        // appear in `is_mlx()` but deserve to compete with MLX 4B on
1004        // routing — the catalog's tags carry that intent and the
1005        // router now honors it.
1006        let system_llm_bonus = if model.tags.iter().any(|t| t == "low_latency")
1007            && model.tags.iter().any(|t| t == "private")
1008        {
1009            Self::SYSTEM_LLM_BONUS
1010        } else {
1011            0.0
1012        };
1013
1014        quality_weight * quality
1015            + latency_weight * latency
1016            + cost_weight * cost
1017            + local_bonus
1018            + workload_local_bonus
1019            + mlx_bonus
1020            + vllm_mlx_bonus
1021            + system_llm_bonus
1022    }
1023
1024    /// Convert latency in ms to a [0, 1] score. Used by both schema and observed paths
1025    /// so the scales are consistent (fixes Linus review issue #2).
1026    fn latency_ms_to_score(&self, ms: f64) -> f64 {
1027        (1.0 - (ms / Self::LATENCY_CEILING_MS)).clamp(0.0, 1.0)
1028    }
1029
1030    /// Convert TPS to estimated latency in ms (for a typical 200-token response).
1031    fn tps_to_latency_ms(tps: f64) -> f64 {
1032        if tps <= 0.0 {
1033            return Self::LATENCY_CEILING_MS;
1034        }
1035        // Assume ~200 tokens per response as baseline
1036        (200.0 / tps) * 1000.0
1037    }
1038
1039    /// Detect whether a prompt likely needs multiple tool calls in a single response.
1040    /// Looks for numbered lists, multiple explicit instructions, multi-edit patterns.
1041    fn needs_multi_tool_call(prompt: &str) -> bool {
1042        let lower = prompt.to_lowercase();
1043
1044        // Numbered list patterns: "1) ... 2) ..." or "1. ... 2. ..."
1045        let has_numbered_list = {
1046            let mut count = 0u32;
1047            for i in 1..=5u32 {
1048                if lower.contains(&format!("{}) ", i)) || lower.contains(&format!("{}. ", i)) {
1049                    count += 1;
1050                }
1051            }
1052            count >= 2
1053        };
1054
1055        // Explicit multi-action keywords
1056        let multi_keywords = [
1057            "multiple edits",
1058            "several changes",
1059            "three changes",
1060            "two changes",
1061            "all of the following",
1062            "each of these",
1063            "do both",
1064            "do all",
1065            "and also",
1066            "additionally",
1067            "as well as",
1068            "then also",
1069        ];
1070        let has_multi_keywords = multi_keywords.iter().any(|kw| lower.contains(kw));
1071
1072        // Bullet point lists with action verbs
1073        let bullet_actions = lower.matches("- add ").count()
1074            + lower.matches("- update ").count()
1075            + lower.matches("- change ").count()
1076            + lower.matches("- remove ").count()
1077            + lower.matches("- fix ").count()
1078            + lower.matches("- edit ").count()
1079            + lower.matches("- implement ").count()
1080            + lower.matches("- create ").count();
1081        let has_bullet_list = bullet_actions >= 2;
1082
1083        has_numbered_list || has_multi_keywords || has_bullet_list
1084    }
1085
1086    /// Schema-based quality estimate (cold start).
1087    ///
1088    /// Diminishing returns on model size — the jump from 4B to 8B matters,
1089    /// but 8B to 30B is marginal. Remote models get a conservative estimate.
1090    fn schema_quality_estimate(&self, model: &ModelSchema) -> f64 {
1091        match model.size_mb() {
1092            0 => 0.5,             // remote: unknown, conservative
1093            s if s < 1000 => 0.4, // 0.6B
1094            s if s < 2000 => 0.5, // 1.7B
1095            s if s < 3000 => 0.6, // 4B
1096            s if s < 6000 => 0.7, // 8B
1097            _ => 0.75,            // 30B+: diminishing returns
1098        }
1099    }
1100
1101    /// Schema-based latency estimate (cold start).
1102    ///
1103    /// Converts declared TPS/p50 to the same ms-based score used by observed data,
1104    /// so there's no discontinuity when the first observation arrives.
1105    fn schema_latency_estimate(&self, model: &ModelSchema) -> f64 {
1106        let is_moe = model.tags.contains(&"moe".to_string());
1107
1108        if model.is_local() {
1109            if let Some(tps) = model.performance.tokens_per_second {
1110                let effective_tps = if is_moe {
1111                    let multiplier = if model.is_mlx() {
1112                        Self::MLX_MOE_TPS_MULTIPLIER
1113                    } else {
1114                        Self::MOE_TPS_MULTIPLIER
1115                    };
1116                    tps * multiplier
1117                } else {
1118                    tps
1119                };
1120                let estimated_ms = Self::tps_to_latency_ms(effective_tps);
1121                return self.latency_ms_to_score(estimated_ms);
1122            }
1123            return 0.5; // local, no declared TPS
1124        }
1125
1126        // Remote: use declared p50 latency
1127        if let Some(p50) = model.performance.latency_p50_ms {
1128            return self.latency_ms_to_score(p50 as f64);
1129        }
1130        0.3 // remote, no declared latency
1131    }
1132
1133    fn is_quality_critical_bootstrap_task(
1134        &self,
1135        task: InferenceTask,
1136        has_vision: bool,
1137        has_tools: bool,
1138    ) -> bool {
1139        has_vision
1140            || has_tools
1141            || matches!(
1142                task,
1143                InferenceTask::Generate | InferenceTask::Code | InferenceTask::Reasoning
1144            )
1145    }
1146
1147    fn is_trusted_quality_remote(&self, model: &ModelSchema) -> bool {
1148        model.is_remote()
1149            && matches!(
1150                model.provider.as_str(),
1151                "openai" | "anthropic" | "google"
1152            )
1153            && !model.has_capability(ModelCapability::SpeechToText)
1154            && !model.has_capability(ModelCapability::TextToSpeech)
1155    }
1156
1157    fn is_local_model_proven_for_task(
1158        &self,
1159        model: &ModelSchema,
1160        task: InferenceTask,
1161        tracker: &OutcomeTracker,
1162    ) -> bool {
1163        let Some(profile) = tracker.profile(&model.id) else {
1164            return false;
1165        };
1166        if let Some(task_stats) = profile.task_stats(task) {
1167            if task_stats.calls >= self.config.bootstrap_min_task_observations
1168                && task_stats.ema_quality >= self.config.bootstrap_quality_floor
1169            {
1170                return true;
1171            }
1172        }
1173
1174        profile.total_calls >= self.config.bootstrap_min_task_observations
1175            && profile.ema_quality >= self.config.bootstrap_quality_floor
1176    }
1177
1178    /// Phase 3: Thompson Sampling selection.
1179    ///
1180    /// Each model gets a Beta(alpha, beta) distribution where:
1181    /// - alpha = prior_successes + observed_successes
1182    /// - beta = prior_failures + observed_failures
1183    ///
1184    /// The Phase 2 score serves as the prior mean, scaled by `prior_strength`.
1185    /// Models with few observations have wide distributions (natural exploration).
1186    /// Models with many observations have tight distributions (exploitation).
1187    fn select_with_thompson_sampling(
1188        &self,
1189        scored: &[(String, f64)],
1190        tracker: &OutcomeTracker,
1191    ) -> (String, RoutingStrategy) {
1192        if scored.is_empty() {
1193            return (String::new(), RoutingStrategy::SchemaBased);
1194        }
1195
1196        let mut rng = rand::rng();
1197        let mut best_sample = f64::NEG_INFINITY;
1198        let mut best_id = scored[0].0.clone();
1199        let mut best_strategy = RoutingStrategy::SchemaBased;
1200
1201        for (id, phase2_score) in scored {
1202            let profile = tracker.profile(id);
1203            let prior = self.config.prior_strength;
1204
1205            // Convert Phase 2 score (0.0-1.15) to a prior mean in [0, 1]
1206            let prior_mean = phase2_score.clamp(0.0, 1.0);
1207
1208            // Prior pseudo-counts from the Phase 2 score
1209            let prior_alpha = prior * prior_mean;
1210            let prior_beta = prior * (1.0 - prior_mean);
1211
1212            // Observed counts
1213            let (obs_alpha, obs_beta) = match profile {
1214                Some(p) => (p.success_count as f64, p.fail_count as f64),
1215                None => (0.0, 0.0),
1216            };
1217
1218            // Posterior Beta parameters
1219            let alpha = (prior_alpha + obs_alpha).max(0.01);
1220            let beta = (prior_beta + obs_beta).max(0.01);
1221
1222            // Sample from Beta(alpha, beta) using the Jöhnk algorithm
1223            let sample = sample_beta(&mut rng, alpha, beta);
1224
1225            if sample > best_sample {
1226                best_sample = sample;
1227                best_id = id.clone();
1228                best_strategy = match profile {
1229                    Some(p) if p.total_calls >= self.config.min_observations => {
1230                        RoutingStrategy::ProfileBased
1231                    }
1232                    Some(p) if p.total_calls > 0 => {
1233                        // Under-tested but has some data — exploration
1234                        RoutingStrategy::Exploration
1235                    }
1236                    _ => RoutingStrategy::SchemaBased,
1237                };
1238            }
1239        }
1240
1241        (best_id, best_strategy)
1242    }
1243
1244    /// Fallback decision when no candidates pass filtering.
1245    fn cold_start_decision(
1246        &self,
1247        complexity: TaskComplexity,
1248        task: InferenceTask,
1249        registry: &UnifiedRegistry,
1250        has_vision: bool,
1251    ) -> AdaptiveRoutingDecision {
1252        if has_vision {
1253            if let Some(model) = registry
1254                .query_by_capability(ModelCapability::Vision)
1255                .into_iter()
1256                .find(|model| model.available && self.is_trusted_quality_remote(model))
1257                .or_else(|| registry.query_by_capability(ModelCapability::Vision).first().copied())
1258            {
1259                return AdaptiveRoutingDecision {
1260                    model_id: model.id.clone(),
1261                    model_name: model.name.clone(),
1262                    task,
1263                    complexity,
1264                    reason: format!(
1265                        "{:?} task → {} (cold start, vision fallback)",
1266                        complexity, model.name
1267                    ),
1268                    strategy: RoutingStrategy::SchemaBased,
1269                    predicted_quality: 0.5,
1270                    fallbacks: vec![],
1271                    context_length: model.context_length,
1272                    needs_compaction: false,
1273                };
1274            }
1275        }
1276
1277        if self.config.quality_first_cold_start {
1278            let required_caps = complexity.required_capabilities();
1279            if let Some(model) = registry
1280                .list()
1281                .into_iter()
1282                .filter(|model| {
1283                    model.available
1284                        && required_caps.iter().all(|cap| model.has_capability(*cap))
1285                        && self.is_trusted_quality_remote(model)
1286                })
1287                .max_by(|a, b| {
1288                    self.schema_quality_estimate(a)
1289                        .partial_cmp(&self.schema_quality_estimate(b))
1290                        .unwrap_or(std::cmp::Ordering::Equal)
1291                })
1292            {
1293                return AdaptiveRoutingDecision {
1294                    model_id: model.id.clone(),
1295                    model_name: model.name.clone(),
1296                    task,
1297                    complexity,
1298                    reason: format!(
1299                        "{:?} task → {} (quality-first cold start)",
1300                        complexity, model.name
1301                    ),
1302                    strategy: RoutingStrategy::SchemaBased,
1303                    predicted_quality: self.schema_quality_estimate(model),
1304                    fallbacks: vec![],
1305                    context_length: model.context_length,
1306                    needs_compaction: false,
1307                };
1308            }
1309        }
1310
1311        // Fall back to the old complexity-based defaults
1312        let model_name = match complexity {
1313            TaskComplexity::Simple => "Qwen3-0.6B",
1314            TaskComplexity::Medium => "Qwen3-1.7B",
1315            TaskComplexity::Code => "Qwen3-4B",
1316            TaskComplexity::Complex => &self.hw.recommended_model,
1317        };
1318
1319        let model_id = registry
1320            .find_by_name(model_name)
1321            .map(|m| m.id.clone())
1322            .unwrap_or_else(|| model_name.to_string());
1323
1324        let context_length = registry
1325            .find_by_name(model_name)
1326            .map(|m| m.context_length)
1327            .unwrap_or(0);
1328
1329        AdaptiveRoutingDecision {
1330            model_id,
1331            model_name: model_name.to_string(),
1332            task,
1333            complexity,
1334            reason: format!(
1335                "{:?} task → {} (cold start, no candidates)",
1336                complexity, model_name
1337            ),
1338            strategy: RoutingStrategy::SchemaBased,
1339            predicted_quality: 0.5,
1340            fallbacks: vec![],
1341            context_length,
1342            needs_compaction: false,
1343        }
1344    }
1345}
1346
1347/// Map a caller-supplied [`crate::TaskHint`] to the engine's
1348/// [`InferenceTask`] enum. The intent surface uses the higher-level
1349/// hint vocabulary; the router operates on InferenceTask. Every
1350/// TaskHint variant maps to a distinct InferenceTask — variants that
1351/// would have silently collapsed to `Generate` were cut from the MVP.
1352fn task_hint_to_inference_task(hint: crate::intent::TaskHint) -> InferenceTask {
1353    use crate::intent::TaskHint;
1354    match hint {
1355        TaskHint::Chat => InferenceTask::Generate,
1356        TaskHint::Classify => InferenceTask::Classify,
1357        TaskHint::Reasoning => InferenceTask::Reasoning,
1358        TaskHint::Code => InferenceTask::Code,
1359    }
1360}
1361
1362/// Sample from a Beta(alpha, beta) distribution.
1363///
1364/// Uses the Gamma distribution method: if X ~ Gamma(alpha, 1) and Y ~ Gamma(beta, 1),
1365/// then X / (X + Y) ~ Beta(alpha, beta).
1366///
1367/// For Gamma sampling, uses Marsaglia and Tsang's method for alpha >= 1,
1368/// and Ahrens-Dieter for alpha < 1.
1369fn sample_beta(rng: &mut impl Rng, alpha: f64, beta: f64) -> f64 {
1370    let x = sample_gamma(rng, alpha);
1371    let y = sample_gamma(rng, beta);
1372    if x + y == 0.0 {
1373        0.5 // degenerate case
1374    } else {
1375        x / (x + y)
1376    }
1377}
1378
1379/// Sample from Gamma(shape, 1) using Marsaglia-Tsang for shape >= 1,
1380/// with Ahrens-Dieter boost for shape < 1.
1381fn sample_gamma(rng: &mut impl Rng, shape: f64) -> f64 {
1382    if shape < 1.0 {
1383        // Ahrens-Dieter: Gamma(a) = Gamma(a+1) * U^(1/a)
1384        let u: f64 = rng.random();
1385        return sample_gamma(rng, shape + 1.0) * u.powf(1.0 / shape);
1386    }
1387
1388    // Marsaglia-Tsang method for shape >= 1
1389    let d = shape - 1.0 / 3.0;
1390    let c = 1.0 / (9.0 * d).sqrt();
1391
1392    loop {
1393        let x: f64 = loop {
1394            let n = sample_standard_normal(rng);
1395            if 1.0 + c * n > 0.0 {
1396                break n;
1397            }
1398        };
1399
1400        let v = (1.0 + c * x).powi(3);
1401        let u: f64 = rng.random();
1402
1403        if u < 1.0 - 0.0331 * x.powi(4) {
1404            return d * v;
1405        }
1406        if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
1407            return d * v;
1408        }
1409    }
1410}
1411
1412/// Sample from standard normal N(0,1) using Box-Muller transform.
1413fn sample_standard_normal(rng: &mut impl Rng) -> f64 {
1414    let u1: f64 = rng.random();
1415    let u2: f64 = rng.random();
1416    (-2.0 * u1.max(1e-300).ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
1417}
1418
1419#[cfg(test)]
1420mod tests {
1421    use super::*;
1422    use crate::outcome::InferredOutcome;
1423
1424    fn test_hw() -> HardwareInfo {
1425        HardwareInfo {
1426            os: "macos".into(),
1427            arch: "aarch64".into(),
1428            cpu_cores: 10,
1429            total_ram_mb: 32768,
1430            gpu_backend: crate::hardware::GpuBackend::Metal,
1431            gpu_memory_mb: Some(28672),
1432            gpu_devices: Vec::new(),
1433            recommended_model: "Qwen3-8B".into(),
1434            recommended_context: 8192,
1435            max_model_mb: 18000, // headroom above 30B-A3B's 17000MB
1436        }
1437    }
1438
1439    fn test_registry() -> UnifiedRegistry {
1440        let tmp = std::path::PathBuf::from("/tmp/car-test-adaptive-router");
1441        unsafe {
1442            std::env::set_var("OPENAI_API_KEY", "test-openai-key");
1443        }
1444        // Create fake model dirs so the registry marks them as available
1445        for name in &[
1446            "Qwen3-0.6B",
1447            "Qwen3-1.7B",
1448            "Qwen3-4B",
1449            "Qwen3-8B",
1450            "Qwen3-Embedding-0.6B",
1451        ] {
1452            let dir = tmp.join(name);
1453            let _ = std::fs::create_dir_all(&dir);
1454            let _ = std::fs::write(dir.join("model.gguf"), b"fake");
1455            let _ = std::fs::write(dir.join("tokenizer.json"), b"{}");
1456        }
1457        let mut reg = UnifiedRegistry::new(tmp);
1458        reg.register(ModelSchema {
1459            id: "openai/gpt-5.4-mini:latest".into(),
1460            name: "gpt-5.4-mini".into(),
1461            provider: "openai".into(),
1462            family: "gpt-5.4".into(),
1463            version: "latest".into(),
1464            capabilities: vec![
1465                ModelCapability::Generate,
1466                ModelCapability::Code,
1467                ModelCapability::Reasoning,
1468                ModelCapability::ToolUse,
1469                ModelCapability::MultiToolCall,
1470                ModelCapability::Vision,
1471            ],
1472            context_length: 128_000,
1473            param_count: "api".into(),
1474            quantization: None,
1475            performance: Default::default(),
1476            cost: Default::default(),
1477            source: crate::schema::ModelSource::RemoteApi {
1478                endpoint: "https://api.openai.com/v1".into(),
1479                api_key_env: "OPENAI_API_KEY".into(),
1480                api_key_envs: vec![],
1481                api_version: None,
1482                protocol: crate::schema::ApiProtocol::OpenAiCompat,
1483            },
1484            tags: vec!["trusted-remote".into()],
1485            supported_params: vec![],
1486            public_benchmarks: vec![],
1487            available: true,
1488        });
1489        reg
1490    }
1491
1492    #[test]
1493    fn routes_simple_to_trusted_remote_during_cold_start() {
1494        let router = AdaptiveRouter::new(
1495            test_hw(),
1496            RoutingConfig {
1497                prior_strength: 100.0, // strong prior = exploit Phase 2 scores (deterministic-ish)
1498                ..Default::default()
1499            },
1500        );
1501        let reg = test_registry();
1502        let tracker = OutcomeTracker::new();
1503
1504        let decision = router.route("What is 2+2?", &reg, &tracker);
1505        assert_eq!(decision.complexity, TaskComplexity::Simple);
1506        assert_eq!(decision.strategy, RoutingStrategy::SchemaBased);
1507        // On cold start, quality-critical tasks should stay on trusted remote models.
1508        let schema = reg
1509            .find_by_name(&decision.model_name)
1510            .expect("selected model should exist in registry");
1511        assert!(
1512            !schema.is_local(),
1513            "simple task should route to trusted remote model during cold start"
1514        );
1515        assert!(matches!(
1516            schema.provider.as_str(),
1517            "openai" | "anthropic" | "google"
1518        ));
1519    }
1520
1521    #[test]
1522    fn routes_code_to_code_capable_remote_during_cold_start() {
1523        let router = AdaptiveRouter::new(
1524            test_hw(),
1525            RoutingConfig {
1526                prior_strength: 100.0, // strong prior = exploit Phase 2 scores (deterministic-ish)
1527                ..Default::default()
1528            },
1529        );
1530        let reg = test_registry();
1531        let tracker = OutcomeTracker::new();
1532
1533        let decision = router.route(
1534            "Fix this function:\n```rust\nfn main() {}\n```",
1535            &reg,
1536            &tracker,
1537        );
1538        assert_eq!(decision.complexity, TaskComplexity::Code);
1539        assert_eq!(decision.task, InferenceTask::Code);
1540        // Must select a code-capable local model (not 0.6B which lacks Code)
1541        let schema = reg
1542            .find_by_name(&decision.model_name)
1543            .expect("model should exist");
1544        assert!(
1545            schema.has_capability(ModelCapability::Code),
1546            "selected model must support Code"
1547        );
1548        assert!(!schema.is_local(), "should route to trusted remote model");
1549    }
1550
1551    #[test]
1552    fn routes_images_to_vision_capable_model() {
1553        let router = AdaptiveRouter::new(
1554            test_hw(),
1555            RoutingConfig {
1556                prior_strength: 100.0,
1557                ..Default::default()
1558            },
1559        );
1560        let mut reg = test_registry();
1561        let tracker = OutcomeTracker::new();
1562
1563        reg.register(ModelSchema {
1564            id: "mlx-vlm/qwen3-vl-2b:bf16".into(),
1565            name: "Qwen3-VL-2B-mlx-vlm".into(),
1566            provider: "qwen".into(),
1567            family: "qwen3-vl".into(),
1568            version: "bf16".into(),
1569            capabilities: vec![
1570                ModelCapability::Generate,
1571                ModelCapability::Vision,
1572                ModelCapability::Grounding,
1573            ],
1574            context_length: 262_144,
1575            param_count: "2B".into(),
1576            quantization: None,
1577            performance: Default::default(),
1578            cost: Default::default(),
1579            source: crate::schema::ModelSource::Mlx {
1580                hf_repo: "Qwen/Qwen3-VL-2B-Instruct".into(),
1581                hf_weight_file: None,
1582            },
1583            tags: vec!["vision".into(), "mlx-vlm-cli".into()],
1584            supported_params: vec![],
1585            public_benchmarks: vec![],
1586            available: true,
1587        });
1588
1589        let decision = router.route_with_vision("What is in this image?", &reg, &tracker, false);
1590        let schema = reg
1591            .find_by_name(&decision.model_name)
1592            .expect("model should exist");
1593        assert!(
1594            schema.has_capability(ModelCapability::Vision),
1595            "selected model must support Vision"
1596        );
1597    }
1598
1599    #[test]
1600    fn profile_based_routing_favors_proven_model() {
1601        let router = AdaptiveRouter::new(
1602            test_hw(),
1603            RoutingConfig {
1604                prior_strength: 0.5, // weak prior, let observed data dominate
1605                min_observations: 3,
1606                ..Default::default()
1607            },
1608        );
1609        let reg = test_registry();
1610        let mut tracker = OutcomeTracker::new();
1611
1612        // Build a strong profile for Qwen3-8B on code tasks (fast + high quality)
1613        let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
1614        for _ in 0..20 {
1615            let trace = tracker.record_start(qwen_8b_id, InferenceTask::Code, "test");
1616            tracker.record_complete(&trace, 500, 100, 50);
1617            tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1618        }
1619
1620        // Thompson Sampling is stochastic — run multiple times, 8B should win majority
1621        let mut wins = 0;
1622        for _ in 0..20 {
1623            let decision = router.route("Fix this bug in the parser", &reg, &tracker);
1624            assert_eq!(decision.complexity, TaskComplexity::Code);
1625            if decision.model_id == qwen_8b_id {
1626                wins += 1;
1627            }
1628        }
1629        assert!(
1630            wins >= 12,
1631            "proven model won only {wins}/20 times (expected >= 12)"
1632        );
1633    }
1634
1635    #[test]
1636    fn proven_local_model_can_displace_bootstrap_remote() {
1637        let router = AdaptiveRouter::new(
1638            test_hw(),
1639            RoutingConfig {
1640                prior_strength: 100.0,
1641                bootstrap_min_task_observations: 6,
1642                bootstrap_quality_floor: 0.8,
1643                ..Default::default()
1644            },
1645        );
1646        let reg = test_registry();
1647        let mut tracker = OutcomeTracker::new();
1648
1649        let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
1650        for _ in 0..12 {
1651            let trace = tracker.record_start(qwen_8b_id, InferenceTask::Generate, "test");
1652            tracker.record_complete(&trace, 300, 50, 20);
1653            tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1654        }
1655
1656        let mut local_wins = 0;
1657        for _ in 0..20 {
1658            let decision = router.route("Summarize this design decision.", &reg, &tracker);
1659            let schema = reg.get(&decision.model_id).expect("selected model should exist");
1660            if schema.is_local() {
1661                local_wins += 1;
1662            }
1663        }
1664
1665        assert!(
1666            local_wins >= 12,
1667            "proven local model won only {local_wins}/20 times (expected >= 12)"
1668        );
1669    }
1670
1671    #[test]
1672    fn benchmark_prior_informs_background_routing() {
1673        let router = AdaptiveRouter::new(
1674            test_hw(),
1675            RoutingConfig {
1676                prior_strength: 100.0,
1677                ..Default::default()
1678            },
1679        );
1680        let reg = test_registry();
1681        let mut tracker = OutcomeTracker::new();
1682        let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1683        profile.ema_quality = 0.95;
1684        tracker.import_profiles(vec![profile]);
1685
1686        let decision = router.route_context_aware(
1687            "Write a Python fibonacci function.",
1688            128,
1689            &reg,
1690            &tracker,
1691            false,
1692            false,
1693            RoutingWorkload::Background,
1694        );
1695
1696        let schema = reg.get(&decision.model_id).expect("selected model should exist");
1697        assert!(schema.is_local(), "background routing should allow strong local benchmark priors to win");
1698    }
1699
1700    #[test]
1701    fn task_specific_benchmark_prior_informs_cold_start_routing() {
1702        let router = AdaptiveRouter::new(
1703            test_hw(),
1704            RoutingConfig {
1705                prior_strength: 100.0,
1706                ..Default::default()
1707            },
1708        );
1709        let reg = test_registry();
1710        let mut tracker = OutcomeTracker::new();
1711        let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1712        profile.task_stats.insert(
1713            crate::outcome::InferenceTask::Code.to_string(),
1714            crate::outcome::TaskStats {
1715                ema_quality: 0.95,
1716                ..Default::default()
1717            },
1718        );
1719        tracker.import_profiles(vec![profile]);
1720
1721        let decision = router.route_context_aware(
1722            "Write a Python fibonacci function.",
1723            128,
1724            &reg,
1725            &tracker,
1726            false,
1727            false,
1728            RoutingWorkload::Background,
1729        );
1730
1731        let schema = reg.get(&decision.model_id).expect("selected model should exist");
1732        assert!(
1733            schema.is_local(),
1734            "background routing should use task-specific cold-start priors for local code models"
1735        );
1736    }
1737
1738    #[test]
1739    fn task_specific_latency_prior_affects_cold_start_score() {
1740        let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1741        let reg = test_registry();
1742        let model = reg
1743            .get("qwen/qwen3-8b:q4_k_m")
1744            .expect("local test model should exist");
1745
1746        let mut fast_tracker = OutcomeTracker::new();
1747        let mut fast_profile = crate::outcome::ModelProfile::new(model.id.clone());
1748        fast_profile.task_stats.insert(
1749            crate::outcome::InferenceTask::Generate.to_string(),
1750            crate::outcome::TaskStats {
1751                ema_quality: 0.95,
1752                avg_latency_ms: 1200.0,
1753                ..Default::default()
1754            },
1755        );
1756        fast_tracker.import_profiles(vec![fast_profile]);
1757
1758        let mut slow_tracker = OutcomeTracker::new();
1759        let mut slow_profile = crate::outcome::ModelProfile::new(model.id.clone());
1760        slow_profile.task_stats.insert(
1761            crate::outcome::InferenceTask::Generate.to_string(),
1762            crate::outcome::TaskStats {
1763                ema_quality: 0.95,
1764                avg_latency_ms: 120_000.0,
1765                ..Default::default()
1766            },
1767        );
1768        slow_tracker.import_profiles(vec![slow_profile]);
1769
1770        let fast_score = router.score_model(
1771            model,
1772            InferenceTask::Generate,
1773            &fast_tracker,
1774            RoutingWorkload::Interactive,
1775        );
1776        let slow_score = router.score_model(
1777            model,
1778            InferenceTask::Generate,
1779            &slow_tracker,
1780            RoutingWorkload::Interactive,
1781        );
1782
1783        assert!(
1784            fast_score > slow_score,
1785            "faster task latency prior should improve cold-start score ({fast_score} <= {slow_score})"
1786        );
1787    }
1788
1789    #[test]
1790    fn interactive_workload_keeps_remote_bootstrap_bias() {
1791        let router = AdaptiveRouter::new(
1792            test_hw(),
1793            RoutingConfig {
1794                prior_strength: 100.0,
1795                ..Default::default()
1796            },
1797        );
1798        let reg = test_registry();
1799        let mut tracker = OutcomeTracker::new();
1800        let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1801        profile.ema_quality = 0.95;
1802        tracker.import_profiles(vec![profile]);
1803
1804        let decision = router.route_context_aware(
1805            "Write a Python fibonacci function.",
1806            128,
1807            &reg,
1808            &tracker,
1809            false,
1810            false,
1811            RoutingWorkload::Interactive,
1812        );
1813
1814        let schema = reg.get(&decision.model_id).expect("selected model should exist");
1815        assert!(
1816            !schema.is_local(),
1817            "interactive routing should still prefer trusted remote models during cold start"
1818        );
1819    }
1820
1821    #[test]
1822    fn fallback_chain_has_alternatives() {
1823        let router = AdaptiveRouter::new(
1824            test_hw(),
1825            RoutingConfig {
1826                prior_strength: 100.0, // strong prior = exploit Phase 2 scores (deterministic-ish)
1827                ..Default::default()
1828            },
1829        );
1830        let reg = test_registry();
1831        let tracker = OutcomeTracker::new();
1832
1833        let decision = router.route("Analyze the architecture trade-offs", &reg, &tracker);
1834        assert!(!decision.fallbacks.is_empty());
1835        // Primary should not appear in fallbacks
1836        assert!(!decision.fallbacks.contains(&decision.model_id));
1837    }
1838
1839    #[test]
1840    fn latency_scoring_is_consistent() {
1841        // Verify that schema and observed latency produce comparable scores
1842        let router = AdaptiveRouter::with_default_config(test_hw());
1843
1844        // A model at 25 TPS → ~200/25*1000 = 8000ms → score = 1 - 8000/10000 = 0.2
1845        let schema_score = router.latency_ms_to_score(AdaptiveRouter::tps_to_latency_ms(25.0));
1846        // Same model observed at 8000ms → same formula
1847        let observed_score = router.latency_ms_to_score(8000.0);
1848        assert!(
1849            (schema_score - observed_score).abs() < 0.01,
1850            "schema ({schema_score}) and observed ({observed_score}) should match"
1851        );
1852    }
1853
1854    #[test]
1855    fn complexity_assessment() {
1856        assert_eq!(
1857            TaskComplexity::assess("What is the capital of France?"),
1858            TaskComplexity::Simple
1859        );
1860        assert_eq!(
1861            TaskComplexity::assess("Fix this broken test"),
1862            TaskComplexity::Code
1863        );
1864        assert_eq!(
1865            TaskComplexity::assess("Analyze the trade-offs between A and B"),
1866            TaskComplexity::Complex
1867        );
1868    }
1869
1870    #[test]
1871    fn beta_sampling_produces_valid_values() {
1872        let mut rng = rand::rng();
1873        // Sample 100 times from Beta(2, 5) — should be in [0, 1]
1874        for _ in 0..100 {
1875            let s = sample_beta(&mut rng, 2.0, 5.0);
1876            assert!(s >= 0.0 && s <= 1.0, "sample {s} out of [0,1] range");
1877        }
1878        // Beta(1, 1) = Uniform(0, 1) — mean should be ~0.5
1879        let samples: Vec<f64> = (0..1000).map(|_| sample_beta(&mut rng, 1.0, 1.0)).collect();
1880        let mean = samples.iter().sum::<f64>() / samples.len() as f64;
1881        assert!(
1882            (mean - 0.5).abs() < 0.05,
1883            "Beta(1,1) mean {mean} should be ~0.5"
1884        );
1885    }
1886
1887    #[test]
1888    fn thompson_sampling_converges_to_best() {
1889        // A model with strong observed success should win most of the time
1890        let router = AdaptiveRouter::new(
1891            test_hw(),
1892            RoutingConfig {
1893                prior_strength: 1.0, // weak prior, let observations dominate
1894                ..Default::default()
1895            },
1896        );
1897        let reg = test_registry();
1898        let mut tracker = OutcomeTracker::new();
1899
1900        // Give Qwen3-4B 20 successes (strong signal)
1901        let qwen_4b_id = "qwen/qwen3-4b:q4_k_m";
1902        for _ in 0..20 {
1903            let trace = tracker.record_start(qwen_4b_id, InferenceTask::Code, "test");
1904            tracker.record_complete(&trace, 500, 100, 50);
1905            tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1906        }
1907
1908        // Route 20 code tasks — strong model should win the majority
1909        let mut wins = 0;
1910        for _ in 0..20 {
1911            let decision = router.route("Fix this parser bug", &reg, &tracker);
1912            if decision.model_id == qwen_4b_id {
1913                wins += 1;
1914            }
1915        }
1916        assert!(
1917            wins >= 14,
1918            "strong model won only {wins}/20 times (expected >= 14)"
1919        );
1920    }
1921
1922    // ----- Intent surface (parslee-ai/car-releases#18) -----
1923
1924    #[test]
1925    fn intent_require_filters_out_models_lacking_capability() {
1926        // Asking for vision when no candidate has it should fall to
1927        // the cold-start decision rather than scoring incompatible
1928        // candidates. The fixture registry has no vision-capable
1929        // local models.
1930        let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1931        let reg = test_registry();
1932        let tracker = OutcomeTracker::new();
1933
1934        let intent = crate::intent::IntentHint {
1935            require: vec![ModelCapability::Vision],
1936            ..Default::default()
1937        };
1938        let decision = router.route_with_intent("hello", &reg, &tracker, &intent);
1939
1940        // When require filters out every candidate, the router falls
1941        // to the schema-based cold-start path. Asserting the strategy
1942        // (not just non-empty model_id) catches a regression where
1943        // future code might silently include filtered candidates.
1944        assert_eq!(
1945            decision.strategy,
1946            RoutingStrategy::SchemaBased,
1947            "require=[vision] with no vision-capable candidates must drop to schema cold-start"
1948        );
1949    }
1950
1951    #[test]
1952    fn intent_default_does_not_override_task_or_caps() {
1953        // Thompson sampling makes per-call model selection
1954        // non-deterministic, so we can't compare model_ids directly.
1955        // What we can assert deterministically: a default IntentHint
1956        // must not change the task selection or the capability
1957        // requirements — those are functions of the prompt only when
1958        // no hint is supplied.
1959        let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1960        let reg = test_registry();
1961        let tracker = OutcomeTracker::new();
1962
1963        let baseline = router.route("write a haiku", &reg, &tracker);
1964        let with_default = router.route_with_intent(
1965            "write a haiku",
1966            &reg,
1967            &tracker,
1968            &crate::intent::IntentHint::default(),
1969        );
1970
1971        assert_eq!(
1972            baseline.task, with_default.task,
1973            "default IntentHint must not change the prompt-derived task"
1974        );
1975        assert_eq!(
1976            baseline.complexity, with_default.complexity,
1977            "default IntentHint must not change the prompt-derived complexity"
1978        );
1979    }
1980
1981    #[test]
1982    fn intent_task_hint_overrides_prompt_complexity() {
1983        // A short prompt that complexity assessment would route as
1984        // Generate should land on Reasoning when the intent says so.
1985        let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1986        let reg = test_registry();
1987        let tracker = OutcomeTracker::new();
1988
1989        let hint = crate::intent::IntentHint {
1990            task: Some(crate::intent::TaskHint::Reasoning),
1991            ..Default::default()
1992        };
1993        let decision = router.route_with_intent("hi", &reg, &tracker, &hint);
1994
1995        assert_eq!(
1996            decision.task,
1997            InferenceTask::Reasoning,
1998            "TaskHint::Reasoning should override the prompt-derived task"
1999        );
2000    }
2001}