Skip to main content

car_inference/
adaptive_router.rs

1//! Adaptive model routing — three-phase routing with learned performance profiles.
2//!
3//! Phase 1: **Filter** — hard constraints (capability, availability, memory, cost).
4//! Phase 2: **Score** — blend quality, latency, and cost using observed profiles
5//!          or schema defaults on cold start.
6//! Phase 3: **Select** — Thompson Sampling (Beta distribution per model) for
7//!          natural exploration-exploitation balance. Models with fewer observations
8//!          have wider distributions, giving them chances to prove themselves.
9//!
10//! Replaces the hardcoded `ModelRouter` from `router.rs`.
11
12use rand::Rng;
13use serde::{Deserialize, Serialize};
14
15use std::sync::{Arc, Mutex};
16
17use crate::hardware::HardwareInfo;
18use crate::outcome::{InferenceTask, OutcomeTracker};
19use crate::registry::UnifiedRegistry;
20use crate::routing_ext::CircuitBreakerRegistry;
21use crate::schema::{ModelCapability, ModelSchema};
22use crate::tasks::RoutingWorkload;
23
24/// Prompt complexity assessment (migrated from router.rs).
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case")]
27pub enum TaskComplexity {
28    Simple,
29    Medium,
30    Code,
31    Complex,
32}
33
34impl TaskComplexity {
35    /// Assess complexity of a prompt string.
36    ///
37    /// Uses tree-sitter AST parsing (when the `ast` feature is enabled) for
38    /// accurate code detection: if a code block parses successfully as any
39    /// supported language, it's definitively code. Falls back to keyword
40    /// heuristics for prompts that mention code without containing code blocks.
41    pub fn assess(prompt: &str) -> Self {
42        let lower = prompt.to_lowercase();
43        let word_count = prompt.split_whitespace().count();
44        let estimated_tokens = (word_count as f64 * 1.3) as usize;
45
46        let has_code = Self::detect_code(prompt);
47
48        let repair_markers = [
49            "fix", "repair", "debug", "refactor", "broken", "failing", "error", "bug",
50        ];
51        let has_repair = repair_markers.iter().any(|m| lower.contains(m));
52
53        let reasoning_markers = [
54            "analyze",
55            "compare",
56            "explain why",
57            "step by step",
58            "think through",
59            "evaluate",
60            "trade-off",
61            "tradeoff",
62            "pros and cons",
63            "architecture",
64            "design",
65            "strategy",
66            "optimize",
67            "comprehensive",
68        ];
69        let has_reasoning = reasoning_markers.iter().any(|m| lower.contains(m));
70
71        let simple_patterns = [
72            "what is",
73            "who is",
74            "when did",
75            "where is",
76            "how many",
77            "yes or no",
78            "true or false",
79            "name the",
80            "list the",
81            "define ",
82        ];
83        let is_simple = simple_patterns.iter().any(|p| lower.contains(p));
84
85        if has_code || has_repair {
86            TaskComplexity::Code
87        } else if has_reasoning || estimated_tokens > 500 {
88            TaskComplexity::Complex
89        } else if is_simple || estimated_tokens < 30 {
90            TaskComplexity::Simple
91        } else {
92            TaskComplexity::Medium
93        }
94    }
95
96    /// Detect whether a prompt contains code.
97    ///
98    /// With `ast` feature: extracts code blocks (``` delimited), attempts to
99    /// parse each with tree-sitter. If any parses into symbols, it's real code.
100    /// Without `ast` feature: falls back to keyword heuristics.
101    fn detect_code(prompt: &str) -> bool {
102        // First try AST-based detection on code blocks
103        #[cfg(feature = "ast")]
104        {
105            if let Some(is_code) = Self::detect_code_ast(prompt) {
106                return is_code;
107            }
108        }
109
110        // Fallback: keyword heuristics
111        let code_markers = [
112            "```",
113            "fn ",
114            "def ",
115            "class ",
116            "import ",
117            "require(",
118            "async fn",
119            "pub fn",
120            "function ",
121            "const ",
122            "let ",
123            "var ",
124            "#include",
125            "package ",
126            "impl ",
127        ];
128        code_markers.iter().any(|m| prompt.contains(m))
129    }
130
131    /// AST-based code detection: parse code blocks with tree-sitter.
132    /// Returns Some(true) if code found, Some(false) if blocks exist but
133    /// don't parse, None if no code blocks found (fall through to heuristics).
134    #[cfg(feature = "ast")]
135    fn detect_code_ast(prompt: &str) -> Option<bool> {
136        // Extract code blocks between ``` markers
137        let mut blocks = Vec::new();
138        let mut rest = prompt;
139        while let Some(start) = rest.find("```") {
140            let after_fence = &rest[start + 3..];
141            // Skip optional language tag on the opening fence
142            let code_start = after_fence.find('\n').map(|i| i + 1).unwrap_or(0);
143            if let Some(end) = after_fence[code_start..].find("```") {
144                blocks.push(&after_fence[code_start..code_start + end]);
145                rest = &after_fence[code_start + end + 3..];
146            } else {
147                break;
148            }
149        }
150
151        if blocks.is_empty() {
152            return None; // No code blocks — let heuristics decide
153        }
154
155        // Try to parse each block with tree-sitter
156        let languages = [
157            car_ast::Language::Rust,
158            car_ast::Language::Python,
159            car_ast::Language::TypeScript,
160            car_ast::Language::JavaScript,
161            car_ast::Language::Go,
162        ];
163
164        for block in &blocks {
165            let trimmed = block.trim();
166            if trimmed.is_empty() {
167                continue;
168            }
169
170            for lang in &languages {
171                if let Some(parsed) = car_ast::parse(trimmed, *lang) {
172                    // If it parsed into any symbols, it's definitely code
173                    if !parsed.symbols.is_empty() {
174                        return Some(true);
175                    }
176                }
177            }
178        }
179
180        // Had code blocks but none parsed into symbols — could be
181        // pseudocode, output, or unsupported language
182        Some(false)
183    }
184
185    /// Map complexity to required capabilities.
186    pub fn required_capabilities(&self) -> Vec<ModelCapability> {
187        match self {
188            TaskComplexity::Simple => vec![ModelCapability::Generate],
189            TaskComplexity::Medium => vec![ModelCapability::Generate],
190            TaskComplexity::Code => vec![ModelCapability::Code],
191            TaskComplexity::Complex => vec![ModelCapability::Reasoning],
192        }
193    }
194
195    /// Map complexity to the InferenceTask type.
196    pub fn inference_task(&self) -> InferenceTask {
197        match self {
198            TaskComplexity::Simple | TaskComplexity::Medium => InferenceTask::Generate,
199            TaskComplexity::Code => InferenceTask::Code,
200            TaskComplexity::Complex => InferenceTask::Reasoning,
201        }
202    }
203}
204
205/// Configuration for routing behavior.
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct RoutingConfig {
208    /// Minimum observations before trusting a model's profile over schema defaults.
209    pub min_observations: u64,
210    /// Scoring weights (must sum to 1.0).
211    pub quality_weight: f64,
212    pub latency_weight: f64,
213    pub cost_weight: f64,
214    /// Hard constraint: maximum latency budget in ms.
215    pub max_latency_ms: Option<u64>,
216    /// Hard constraint: maximum cost per call in USD.
217    pub max_cost_usd: Option<f64>,
218    /// Prefer local models over remote (all else being equal).
219    pub prefer_local: bool,
220    /// Thompson Sampling prior strength. Higher = more weight on the Phase 2 score
221    /// as a prior, lower = more influenced by observed outcomes.
222    /// Equivalent to the number of "virtual" observations from the prior.
223    pub prior_strength: f64,
224    /// Prefer trusted remote models for quality-critical tasks until local models
225    /// have enough task-specific evidence to be promoted.
226    pub quality_first_cold_start: bool,
227    /// Minimum task-specific observations required before a local model can
228    /// compete with trusted remote models during cold start.
229    pub bootstrap_min_task_observations: u64,
230    /// Minimum task-specific EMA quality required before a local model can
231    /// displace trusted remote models during cold start.
232    pub bootstrap_quality_floor: f64,
233}
234
235impl Default for RoutingConfig {
236    fn default() -> Self {
237        Self {
238            min_observations: 2,
239            quality_weight: 0.45,
240            latency_weight: 0.4,
241            cost_weight: 0.15,
242            max_latency_ms: None,
243            max_cost_usd: None,
244            prefer_local: true,
245            prior_strength: 2.0,
246            quality_first_cold_start: true,
247            bootstrap_min_task_observations: 8,
248            bootstrap_quality_floor: 0.8,
249        }
250    }
251}
252
253/// How a model was selected.
254#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
255#[serde(rename_all = "snake_case")]
256pub enum RoutingStrategy {
257    /// Using declared schema capabilities (no observed data).
258    SchemaBased,
259    /// Using observed performance profiles (exploitation).
260    ProfileBased,
261    /// Deliberately trying an under-tested model (exploration).
262    Exploration,
263    /// User explicitly specified the model.
264    Explicit,
265}
266
267/// The result of adaptive routing.
268#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct AdaptiveRoutingDecision {
270    /// Selected model id.
271    pub model_id: String,
272    /// Selected model name (display).
273    pub model_name: String,
274    /// Task type.
275    pub task: InferenceTask,
276    /// Assessed complexity.
277    pub complexity: TaskComplexity,
278    /// Human-readable reason.
279    pub reason: String,
280    /// How the model was selected.
281    pub strategy: RoutingStrategy,
282    /// Predicted quality (0.0-1.0).
283    pub predicted_quality: f64,
284    /// Fallback chain (ordered list of alternative model ids).
285    pub fallbacks: Vec<String>,
286    /// Context window of the selected model (tokens). 0 = unknown.
287    pub context_length: usize,
288    /// Whether the prompt needs compaction to fit the selected model's context window.
289    pub needs_compaction: bool,
290}
291
292/// Adaptive router with three-phase model selection.
293pub struct AdaptiveRouter {
294    hw: HardwareInfo,
295    config: RoutingConfig,
296    /// Circuit breaker registry — blocks models after consecutive failures (#25).
297    pub circuit_breakers: Arc<Mutex<CircuitBreakerRegistry>>,
298}
299
300/// All inputs to a routing decision packed into one struct so adding
301/// a new combinator (per-tenant routing, streaming awareness, …)
302/// only adds a field rather than another `route_*` sibling method.
303/// See [`AdaptiveRouter::route_with`].
304///
305/// Use [`RouteRequest::new`] for the common defaults, then mutate
306/// just the fields the caller cares about:
307///
308/// ```ignore
309/// let decision = router.route_with(RouteRequest {
310///     has_tools: true,
311///     intent: Some(&hint),
312///     ..RouteRequest::new(prompt, &registry, &tracker)
313/// });
314/// ```
315pub struct RouteRequest<'a> {
316    pub prompt: &'a str,
317    pub registry: &'a UnifiedRegistry,
318    pub tracker: &'a OutcomeTracker,
319    /// Estimated prompt-side token count for context-window-aware
320    /// scoring. `0` skips the compaction-headroom check.
321    pub estimated_total_tokens: usize,
322    pub has_tools: bool,
323    pub has_vision: bool,
324    pub workload: RoutingWorkload,
325    /// Caller-supplied intent hint. `prefer_local: true` overrides
326    /// `workload` to [`RoutingWorkload::LocalPreferred`].
327    pub intent: Option<&'a crate::intent::IntentHint>,
328}
329
330impl<'a> RouteRequest<'a> {
331    /// Build a request with the same defaults the bare
332    /// [`AdaptiveRouter::route`] uses: interactive workload, no tools,
333    /// no vision, no context-aware sizing, no intent.
334    pub fn new(
335        prompt: &'a str,
336        registry: &'a UnifiedRegistry,
337        tracker: &'a OutcomeTracker,
338    ) -> Self {
339        Self {
340            prompt,
341            registry,
342            tracker,
343            estimated_total_tokens: 0,
344            has_tools: false,
345            has_vision: false,
346            workload: RoutingWorkload::Interactive,
347            intent: None,
348        }
349    }
350}
351
352impl AdaptiveRouter {
353    pub fn new(hw: HardwareInfo, config: RoutingConfig) -> Self {
354        let circuit_breakers = Arc::new(Mutex::new(
355            CircuitBreakerRegistry::new(3, 300), // 3 failures, 5 min cooldown
356        ));
357        Self {
358            hw,
359            config,
360            circuit_breakers,
361        }
362    }
363
364    pub fn with_default_config(hw: HardwareInfo) -> Self {
365        Self::new(hw, RoutingConfig::default())
366    }
367
368    pub fn config(&self) -> &RoutingConfig {
369        &self.config
370    }
371
372    pub fn set_config(&mut self, config: RoutingConfig) {
373        self.config = config;
374    }
375
376    /// Canonical entry point. The seven `route_*` sibling methods
377    /// each build a `RouteRequest` with their fixed defaults and
378    /// call this — adding a new combinator (per-tenant routing,
379    /// streaming awareness, …) only adds a field here, not another
380    /// public method. Closes #108.
381    pub fn route_with(&self, req: RouteRequest<'_>) -> AdaptiveRoutingDecision {
382        // Intent-hint precedence over the caller-supplied workload:
383        // `prefer_fast` wins outright (voice fast track is the most
384        // latency-sensitive path we have); `prefer_local` is the
385        // long-standing override; absent either, the caller's
386        // workload stands.
387        let workload = match req.intent {
388            Some(h) if h.prefer_fast => RoutingWorkload::Fastest,
389            Some(h) if h.prefer_local => RoutingWorkload::LocalPreferred,
390            _ => req.workload,
391        };
392        self.route_inner_with_intent(
393            req.prompt,
394            req.registry,
395            req.tracker,
396            req.has_tools,
397            req.has_vision,
398            req.estimated_total_tokens,
399            workload,
400            req.intent,
401        )
402    }
403
404    /// Route a generation request to the best model.
405    /// If `has_tools` is true, requires ToolUse capability (#13).
406    pub fn route(
407        &self,
408        prompt: &str,
409        registry: &UnifiedRegistry,
410        tracker: &OutcomeTracker,
411    ) -> AdaptiveRoutingDecision {
412        self.route_with(RouteRequest::new(prompt, registry, tracker))
413    }
414
415    /// Route an "editor" request — cheap mechanical work (context compaction,
416    /// edit materialization, title generation, replanning). Uses
417    /// [`RoutingWorkload::Background`] so cost/quality weights favour cheaper
418    /// models, and lets hosting layers express the Aider-style architect/editor
419    /// split without plumbing a raw model-name override through every layer.
420    pub fn route_editor(
421        &self,
422        prompt: &str,
423        registry: &UnifiedRegistry,
424        tracker: &OutcomeTracker,
425    ) -> AdaptiveRoutingDecision {
426        self.route_with(RouteRequest {
427            workload: RoutingWorkload::Background,
428            ..RouteRequest::new(prompt, registry, tracker)
429        })
430    }
431
432    /// Route with tool_use requirement — filters to models that support structured tool calls.
433    pub fn route_with_tools(
434        &self,
435        prompt: &str,
436        registry: &UnifiedRegistry,
437        tracker: &OutcomeTracker,
438    ) -> AdaptiveRoutingDecision {
439        self.route_with(RouteRequest {
440            has_tools: true,
441            ..RouteRequest::new(prompt, registry, tracker)
442        })
443    }
444
445    /// Route with image input requirement — filters to models that support vision.
446    pub fn route_with_vision(
447        &self,
448        prompt: &str,
449        registry: &UnifiedRegistry,
450        tracker: &OutcomeTracker,
451        has_tools: bool,
452    ) -> AdaptiveRoutingDecision {
453        self.route_with(RouteRequest {
454            has_tools,
455            has_vision: true,
456            ..RouteRequest::new(prompt, registry, tracker)
457        })
458    }
459
460    /// Route with caller-supplied intent — see [`crate::IntentHint`].
461    /// The hint can override the auto-detected `InferenceTask`, add hard
462    /// `require` capability filters on top of the prompt-derived ones,
463    /// and bias the score profile toward local models when
464    /// `prefer_local` is set.
465    pub fn route_with_intent<'a>(
466        &self,
467        prompt: &'a str,
468        registry: &'a UnifiedRegistry,
469        tracker: &'a OutcomeTracker,
470        intent: &'a crate::intent::IntentHint,
471    ) -> AdaptiveRoutingDecision {
472        self.route_with(RouteRequest {
473            intent: Some(intent),
474            ..RouteRequest::new(prompt, registry, tracker)
475        })
476    }
477
478    /// Route with context awareness — estimates prompt tokens and prefers models
479    /// whose context window can fit the full prompt without compaction.
480    pub fn route_context_aware(
481        &self,
482        prompt: &str,
483        estimated_total_tokens: usize,
484        registry: &UnifiedRegistry,
485        tracker: &OutcomeTracker,
486        has_tools: bool,
487        has_vision: bool,
488        workload: RoutingWorkload,
489    ) -> AdaptiveRoutingDecision {
490        self.route_with(RouteRequest {
491            estimated_total_tokens,
492            has_tools,
493            has_vision,
494            workload,
495            ..RouteRequest::new(prompt, registry, tracker)
496        })
497    }
498
499    /// Context-aware routing with caller-supplied intent. Same context
500    /// math as [`Self::route_context_aware`]; the intent layers on
501    /// top — task override, additional `require` filters,
502    /// `prefer_local` workload override.
503    pub fn route_context_aware_with_intent<'a>(
504        &self,
505        prompt: &'a str,
506        estimated_total_tokens: usize,
507        registry: &'a UnifiedRegistry,
508        tracker: &'a OutcomeTracker,
509        has_tools: bool,
510        has_vision: bool,
511        workload: RoutingWorkload,
512        intent: &'a crate::intent::IntentHint,
513    ) -> AdaptiveRoutingDecision {
514        self.route_with(RouteRequest {
515            estimated_total_tokens,
516            has_tools,
517            has_vision,
518            workload,
519            intent: Some(intent),
520            ..RouteRequest::new(prompt, registry, tracker)
521        })
522    }
523
524    fn route_inner_with_intent(
525        &self,
526        prompt: &str,
527        registry: &UnifiedRegistry,
528        tracker: &OutcomeTracker,
529        has_tools: bool,
530        has_vision: bool,
531        estimated_total_tokens: usize,
532        workload: RoutingWorkload,
533        intent: Option<&crate::intent::IntentHint>,
534    ) -> AdaptiveRoutingDecision {
535        let complexity = TaskComplexity::assess(prompt);
536        // Caller intent overrides the prompt-derived task when supplied.
537        let task = intent
538            .and_then(|h| h.task)
539            .map(task_hint_to_inference_task)
540            .unwrap_or_else(|| complexity.inference_task());
541        let mut required_caps = complexity.required_capabilities();
542        if let Some(hint) = intent {
543            for cap in &hint.require {
544                if !required_caps.contains(cap) {
545                    required_caps.push(*cap);
546                }
547            }
548        }
549        if has_vision {
550            required_caps.push(ModelCapability::Vision);
551        }
552        if has_tools {
553            required_caps.push(ModelCapability::ToolUse);
554            // Detect multi-step prompts that need multiple tool calls in one response.
555            // Patterns: numbered lists ("1) ... 2) ..."), multiple instructions, explicit multi-edit.
556            if Self::needs_multi_tool_call(prompt) {
557                required_caps.push(ModelCapability::MultiToolCall);
558            }
559        }
560
561        // Phase 1: Filter candidates
562        let mut candidates = self.filter_candidates(&required_caps, registry, tracker, has_vision);
563
564        // Fallback: if requiring MultiToolCall eliminates all candidates, drop the
565        // requirement and let the best ToolUse model handle it with multiple round-trips.
566        if candidates.is_empty() && required_caps.contains(&ModelCapability::MultiToolCall) {
567            required_caps.retain(|c| *c != ModelCapability::MultiToolCall);
568            candidates = self.filter_candidates(&required_caps, registry, tracker, has_vision);
569        }
570
571        if candidates.is_empty() {
572            // Nothing available — return the schema-based default
573            return self.cold_start_decision(complexity, task, registry, has_vision);
574        }
575
576        candidates = self.apply_quality_first_bootstrap_policy(
577            candidates, task, tracker, has_vision, has_tools, workload,
578        );
579
580        // Context-aware filtering: if we know the prompt size, prefer models that fit.
581        // Phase 1b: separate candidates into "fits" and "needs compaction" groups.
582        let (fits, needs_compaction_candidates) = if estimated_total_tokens > 0 {
583            let mut fits = Vec::new();
584            let mut tight = Vec::new();
585            for m in &candidates {
586                if m.context_length == 0 || m.context_length >= estimated_total_tokens {
587                    fits.push(m.clone());
588                } else {
589                    tight.push(m.clone());
590                }
591            }
592            (fits, tight)
593        } else {
594            (candidates.clone(), Vec::new())
595        };
596
597        // Prefer models that fit; fall back to compaction-required models if none fit
598        let (scoring_candidates, compaction_needed) = if !fits.is_empty() {
599            (fits, false)
600        } else if !needs_compaction_candidates.is_empty() {
601            tracing::info!(
602                prompt_tokens = estimated_total_tokens,
603                candidates = needs_compaction_candidates.len(),
604                "no model fits full prompt — compaction will be needed"
605            );
606            (needs_compaction_candidates.clone(), true)
607        } else {
608            (candidates.clone(), false)
609        };
610
611        // Phase 2: Score candidates (with context headroom bonus)
612        let scored = self.score_candidates_context_aware(
613            &scoring_candidates,
614            task,
615            tracker,
616            estimated_total_tokens,
617            workload,
618        );
619
620        // Phase 3: Thompson Sampling selection
621        let (selected_id, strategy) = self.select_with_thompson_sampling(&scored, tracker);
622
623        // Build fallback chain: prefer models that fit, then compaction candidates
624        let mut fallbacks: Vec<String> = scored
625            .iter()
626            .filter(|(id, _)| *id != selected_id)
627            .map(|(id, _)| id.clone())
628            .collect();
629        // Add compaction candidates to the end of the fallback chain
630        if !compaction_needed {
631            for m in &needs_compaction_candidates {
632                if m.id != selected_id && !fallbacks.contains(&m.id) {
633                    fallbacks.push(m.id.clone());
634                }
635            }
636        }
637
638        let predicted_quality = scored
639            .iter()
640            .find(|(id, _)| *id == selected_id)
641            .map(|(_, score)| *score)
642            .unwrap_or(0.5);
643
644        let selected_schema = registry
645            .get(&selected_id)
646            .or_else(|| registry.find_by_name(&selected_id));
647        let model_name = selected_schema
648            .map(|m| m.name.clone())
649            .unwrap_or_else(|| selected_id.clone());
650        let context_length = selected_schema.map(|m| m.context_length).unwrap_or(0);
651
652        let needs_compact = compaction_needed
653            || (estimated_total_tokens > 0
654                && context_length > 0
655                && estimated_total_tokens > context_length);
656
657        let compaction_note = if needs_compact {
658            format!(
659                " [compaction needed: {}→{}tok]",
660                estimated_total_tokens, context_length
661            )
662        } else {
663            String::new()
664        };
665
666        let reason = format!(
667            "{:?} task → {} via {:?} (quality: {:.2}, {} candidates){}",
668            complexity,
669            model_name,
670            strategy,
671            predicted_quality,
672            scoring_candidates.len(),
673            compaction_note,
674        );
675
676        AdaptiveRoutingDecision {
677            model_id: selected_id,
678            model_name,
679            task,
680            complexity,
681            reason,
682            strategy,
683            predicted_quality,
684            fallbacks,
685            context_length,
686            needs_compaction: needs_compact,
687        }
688    }
689
690    /// Route to the best embedding model.
691    pub fn route_embedding(&self, registry: &UnifiedRegistry) -> String {
692        let embed_models = registry.query_by_capability(ModelCapability::Embed);
693        embed_models
694            .first()
695            .map(|m| m.name.clone())
696            .unwrap_or_else(|| "Qwen3-Embedding-0.6B".to_string())
697    }
698
699    /// Route to the smallest available model (for classification).
700    pub fn route_small(&self, registry: &UnifiedRegistry) -> String {
701        let gen_models = registry.query_by_capability(ModelCapability::Generate);
702        // Pick smallest by size
703        gen_models
704            .iter()
705            .filter(|m| m.is_local())
706            .min_by_key(|m| m.size_mb())
707            .map(|m| m.name.clone())
708            .unwrap_or_else(|| "Qwen3-0.6B".to_string())
709    }
710
711    // --- Internal phases ---
712
713    // --- Scoring constants ---
714
715    /// Latency ceiling: requests taking longer than this score 0.0.
716    const LATENCY_CEILING_MS: f64 = 10000.0;
717    /// TPS ceiling: models faster than this score 1.0.
718    const _TPS_CEILING: f64 = 150.0;
719    /// MoE throughput penalty for Candle: naive expert routing runs at ~10% of declared TPS.
720    const MOE_TPS_MULTIPLIER: f64 = 0.10;
721    /// MoE throughput multiplier for MLX: fused Metal kernels run at ~50% of declared TPS.
722    const MLX_MOE_TPS_MULTIPLIER: f64 = 0.50;
723    /// Cost ceiling: models costing more than this per 1K output tokens score 0.0.
724    const COST_CEILING_PER_1K: f64 = 0.1;
725    /// Local preference bonus added to the weighted score (before normalization).
726    const LOCAL_BONUS: f64 = 0.15;
727    /// True on platforms where local inference has a GPU/NPU backend
728    /// (Apple Silicon with MLX). On Intel Macs and on hosts built with
729    /// `--cfg=car_skip_mlx`, local inference falls through to candle/
730    /// CPU which is far slower than cloud for non-trivial models. The
731    /// router suppresses LOCAL_BONUS when this is false so cloud
732    /// models rank fairly instead of losing to a CPU-bound 4B that
733    /// will take 30s to first-token.
734    #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
735    const HAS_GPU_BACKEND: bool = true;
736    #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
737    const HAS_GPU_BACKEND: bool = false;
738    /// Extra bonus for MLX models on Apple Silicon (stacks with LOCAL_BONUS).
739    /// MLX gets fused Metal kernels and better memory layout vs Candle on Mac.
740    #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
741    const MLX_BONUS: f64 = 0.10;
742    /// Extra bonus for system-owned models — no model file on disk, no
743    /// download to provision, framework-managed memory. Currently only
744    /// `apple/foundation:default` qualifies. Tag-driven so future
745    /// system-LLM integrations (e.g. Android AICore) inherit the
746    /// scoring without a code change. Stacks with LOCAL_BONUS but
747    /// **excludes** MLX_BONUS (system models aren't MLX); the net
748    /// effect is FoundationModels ranks roughly even with a
749    /// well-warmed MLX 4B for short fast-turn tasks instead of
750    /// strictly losing because its catalog `tokens_per_second` /
751    /// `size_mb` are null.
752    const SYSTEM_LLM_BONUS: f64 = 0.12;
753
754    // --- Internal phases ---
755
756    /// Phase 1: Filter by hard constraints.
757    fn filter_candidates(
758        &self,
759        required_caps: &[ModelCapability],
760        registry: &UnifiedRegistry,
761        tracker: &OutcomeTracker,
762        has_vision: bool,
763    ) -> Vec<ModelSchema> {
764        registry
765            .list()
766            .into_iter()
767            .filter(|m| {
768                // Must have all required capabilities
769                if !required_caps.iter().all(|c| m.has_capability(*c)) {
770                    return false;
771                }
772                // Hard exclusion (car-releases#50): `mlx-vlm-cli`
773                // models route exclusively through the mlx-vlm CLI
774                // for image inference and the backend hard-rejects
775                // text-only input ("text-only-on-mlx-vlm-id not
776                // implemented"). They advertise `generate`, so a
777                // text-only request (e.g. task=classify) would
778                // otherwise pass the capability filter and get
779                // dispatched, only to fail at the backend. A
780                // text-only request must never select one.
781                if !has_vision && m.tags.iter().any(|t| t == "mlx-vlm-cli") {
782                    return false;
783                }
784                // Must be available (downloaded for local, API key set for remote)
785                if !m.available {
786                    return false;
787                }
788                // Local models must fit in memory (strict: >= excludes models at the limit)
789                if m.is_local() && m.size_mb() >= self.hw.max_model_mb {
790                    return false;
791                }
792                // Hard latency constraint
793                if let Some(max) = self.config.max_latency_ms {
794                    if let Some(p50) = m.performance.latency_p50_ms {
795                        if p50 > max {
796                            return false;
797                        }
798                    }
799                }
800                // Hard cost constraint
801                if let Some(max) = self.config.max_cost_usd {
802                    if m.cost_per_1k_output() > max {
803                        return false;
804                    }
805                }
806                // Hard exclusion: prefer_local=false excludes all local models (#12)
807                if !self.config.prefer_local && m.is_local() {
808                    return false;
809                }
810                // Hard exclusion: rate-limited models excluded for this session (#13)
811                if tracker.is_excluded(&m.id) {
812                    return false;
813                }
814                // Circuit breaker: skip models with consecutive failures (#25)
815                if let Ok(mut cb) = self.circuit_breakers.lock() {
816                    if !cb.allow_request(&m.id) {
817                        tracing::debug!(model = %m.id, "skipped by circuit breaker");
818                        return false;
819                    }
820                }
821                true
822            })
823            .cloned()
824            .collect()
825    }
826
827    fn apply_quality_first_bootstrap_policy(
828        &self,
829        candidates: Vec<ModelSchema>,
830        task: InferenceTask,
831        tracker: &OutcomeTracker,
832        has_vision: bool,
833        has_tools: bool,
834        workload: RoutingWorkload,
835    ) -> Vec<ModelSchema> {
836        if !self.config.quality_first_cold_start
837            || !workload.is_latency_sensitive()
838            || !self.is_quality_critical_bootstrap_task(task, has_vision, has_tools)
839        {
840            return candidates;
841        }
842
843        let trusted_remote: Vec<ModelSchema> = candidates
844            .iter()
845            .filter(|model| self.is_trusted_quality_remote(model))
846            .cloned()
847            .collect();
848
849        if trusted_remote.is_empty() {
850            return candidates;
851        }
852
853        let proven_local: Vec<ModelSchema> = candidates
854            .iter()
855            .filter(|model| {
856                model.is_local() && self.is_local_model_proven_for_task(model, task, tracker)
857            })
858            .cloned()
859            .collect();
860
861        if !proven_local.is_empty() {
862            return proven_local;
863        }
864
865        trusted_remote
866    }
867
868    /// Phase 2: Score candidates with context awareness.
869    /// Applies a headroom bonus to models with more context window for the prompt.
870    /// When estimated_total_tokens is 0, no context bonus/penalty is applied.
871    fn score_candidates_context_aware(
872        &self,
873        candidates: &[ModelSchema],
874        task: InferenceTask,
875        tracker: &OutcomeTracker,
876        estimated_total_tokens: usize,
877        workload: RoutingWorkload,
878    ) -> Vec<(String, f64)> {
879        let mut scored: Vec<(String, f64)> = candidates
880            .iter()
881            .map(|m| {
882                let base_score = self.score_model(m, task, tracker, workload);
883                // Context headroom bonus: prefer models with more room to spare.
884                // Max bonus: 0.10 (at 4x headroom or more). No bonus if unknown.
885                let headroom_bonus = if estimated_total_tokens > 0 && m.context_length > 0 {
886                    let ratio = m.context_length as f64 / estimated_total_tokens as f64;
887                    if ratio >= 1.0 {
888                        (ratio.min(4.0) - 1.0) / 3.0 * 0.10 // 0.0 at exact fit, 0.10 at 4x
889                    } else {
890                        -0.15 // Penalty for models that require compaction
891                    }
892                } else {
893                    0.0
894                };
895                (m.id.clone(), base_score + headroom_bonus)
896            })
897            .collect();
898
899        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
900        scored
901    }
902
903    /// Score a single model. All sub-scores are in [0.0, 1.0].
904    /// Final score = weighted sum + local_bonus, so range is [0.0, ~1.15].
905    /// Scoring weights `(quality, latency, cost)`, adjusted for the task.
906    ///
907    /// [`RoutingWorkload::weights`] is workload-shaped only — it has no
908    /// notion that the request is code generation. The default
909    /// interactive profile is `(0.45 quality, 0.40 latency, 0.15 cost)`,
910    /// where quality barely edges out latency, so a fast cheap general
911    /// model can out-score a code-tuned one on a `task=code` request
912    /// (car-releases#52). Code correctness is worth waiting a beat for:
913    /// under an interactive-class workload `task=code` is reweighted to
914    /// prioritise quality > speed > cost explicitly. `Fastest` (the
915    /// voice fast-track) is an explicit hard-latency demand and is left
916    /// untouched; `Batch`/`Background` already encode a deliberate
917    /// non-interactive caller choice and stay as-is.
918    fn task_aware_weights(
919        task: InferenceTask,
920        workload: RoutingWorkload,
921    ) -> (f64, f64, f64) {
922        match (task, workload) {
923            (
924                InferenceTask::Code,
925                RoutingWorkload::Interactive | RoutingWorkload::LocalPreferred,
926            ) => (0.62, 0.28, 0.10),
927            _ => workload.weights(),
928        }
929    }
930
931    fn score_model(
932        &self,
933        model: &ModelSchema,
934        task: InferenceTask,
935        tracker: &OutcomeTracker,
936        workload: RoutingWorkload,
937    ) -> f64 {
938        let profile = tracker.profile(&model.id);
939        let schema_quality = self.schema_quality_estimate(model);
940        let schema_latency = self.schema_latency_estimate(model);
941        let (quality_weight, latency_weight, cost_weight) =
942            Self::task_aware_weights(task, workload);
943
944        // Quality: blend schema estimate with observed data based on observation count.
945        // Both cold start and warm start use consistent blending.
946        let quality = match profile {
947            Some(p) if p.total_calls >= self.config.min_observations => p
948                .task_stats(task)
949                .map(|ts| ts.ema_quality)
950                .unwrap_or(p.ema_quality),
951            Some(p) if p.total_calls == 0 => p
952                .task_stats(task)
953                .map(|ts| ts.ema_quality)
954                .unwrap_or(p.ema_quality),
955            Some(p) if p.total_calls > 0 => {
956                let w = p.total_calls as f64 / self.config.min_observations as f64;
957                schema_quality * (1.0 - w) + p.ema_quality * w
958            }
959            _ => schema_quality,
960        };
961
962        // Latency: same blending as quality — don't trust a single observation more
963        // than schema estimates. This prevents routing oscillation on first few calls.
964        let latency = match profile {
965            Some(p) if p.total_calls >= self.config.min_observations => {
966                let avg = p
967                    .task_stats(task)
968                    .filter(|ts| ts.calls > 0 || ts.avg_latency_ms > 0.0)
969                    .map(|ts| ts.avg_latency_ms)
970                    .unwrap_or_else(|| p.avg_latency_ms());
971                self.latency_ms_to_score(avg)
972            }
973            Some(p) if p.total_calls == 0 => p
974                .task_stats(task)
975                .filter(|ts| ts.avg_latency_ms > 0.0)
976                .map(|ts| self.latency_ms_to_score(ts.avg_latency_ms))
977                .unwrap_or(schema_latency),
978            Some(p) if p.total_calls > 0 => {
979                let observed = self.latency_ms_to_score(
980                    p.task_stats(task)
981                        .filter(|ts| ts.calls > 0 || ts.avg_latency_ms > 0.0)
982                        .map(|ts| ts.avg_latency_ms)
983                        .unwrap_or_else(|| p.avg_latency_ms()),
984                );
985                let w = p.total_calls as f64 / self.config.min_observations as f64;
986                schema_latency * (1.0 - w) + observed * w
987            }
988            _ => schema_latency,
989        };
990
991        // Cost score (lower is better → invert)
992        let cost = if model.is_local() {
993            1.0
994        } else {
995            (1.0 - (model.cost_per_1k_output() / Self::COST_CEILING_PER_1K)).clamp(0.0, 1.0)
996        };
997
998        // Suppress LOCAL_BONUS on hosts without a GPU/NPU backend.
999        // Without it, an Intel Mac or a `car_skip_mlx` build would
1000        // pick a CPU-bound local 4B over a comparable cloud model
1001        // every time, then surprise the user with 30s+ first-token
1002        // latency. Cloud loses on cost/privacy in normal scoring;
1003        // dropping the local bonus lets it win on the latency that
1004        // actually matters when the GPU isn't there. Tracing emits
1005        // when this fires so the silent degradation becomes visible.
1006        let local_bonus = if self.config.prefer_local && model.is_local() && Self::HAS_GPU_BACKEND {
1007            Self::LOCAL_BONUS
1008        } else {
1009            if self.config.prefer_local && model.is_local() && !Self::HAS_GPU_BACKEND {
1010                tracing::debug!(
1011                    model = %model.id,
1012                    "LOCAL_BONUS suppressed: no GPU backend on this host (Intel Mac or car_skip_mlx); cloud models will rank higher"
1013                );
1014            }
1015            0.0
1016        };
1017        let workload_local_bonus = if model.is_local() {
1018            workload.local_bonus()
1019        } else {
1020            0.0
1021        };
1022
1023        // On Apple Silicon, prefer MLX models over Candle equivalents
1024        #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1025        let mlx_bonus = if model.is_mlx() { Self::MLX_BONUS } else { 0.0 };
1026        #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1027        let mlx_bonus = 0.0;
1028
1029        // vLLM-MLX bonus: continuous batching gives better multi-agent throughput
1030        let vllm_mlx_bonus = if model.is_vllm_mlx() {
1031            Self::LOCAL_BONUS + 0.05
1032        } else {
1033            0.0
1034        };
1035
1036        // System-LLM bonus: catalog tag-driven so it generalizes beyond
1037        // FoundationModels. Models tagged `low_latency` AND `private`
1038        // are zero-cost system-owned LLMs (apple/foundation:default
1039        // today; AICore-on-Android etc. in the future). They don't
1040        // appear in `is_mlx()` but deserve to compete with MLX 4B on
1041        // routing — the catalog's tags carry that intent and the
1042        // router now honors it.
1043        let system_llm_bonus = if model.tags.iter().any(|t| t == "low_latency")
1044            && model.tags.iter().any(|t| t == "private")
1045        {
1046            Self::SYSTEM_LLM_BONUS
1047        } else {
1048            0.0
1049        };
1050
1051        quality_weight * quality
1052            + latency_weight * latency
1053            + cost_weight * cost
1054            + local_bonus
1055            + workload_local_bonus
1056            + mlx_bonus
1057            + vllm_mlx_bonus
1058            + system_llm_bonus
1059    }
1060
1061    /// Convert latency in ms to a [0, 1] score. Used by both schema and observed paths
1062    /// so the scales are consistent (fixes Linus review issue #2).
1063    fn latency_ms_to_score(&self, ms: f64) -> f64 {
1064        (1.0 - (ms / Self::LATENCY_CEILING_MS)).clamp(0.0, 1.0)
1065    }
1066
1067    /// Convert TPS to estimated latency in ms (for a typical 200-token response).
1068    fn tps_to_latency_ms(tps: f64) -> f64 {
1069        if tps <= 0.0 {
1070            return Self::LATENCY_CEILING_MS;
1071        }
1072        // Assume ~200 tokens per response as baseline
1073        (200.0 / tps) * 1000.0
1074    }
1075
1076    /// Detect whether a prompt likely needs multiple tool calls in a single response.
1077    /// Looks for numbered lists, multiple explicit instructions, multi-edit patterns.
1078    fn needs_multi_tool_call(prompt: &str) -> bool {
1079        let lower = prompt.to_lowercase();
1080
1081        // Numbered list patterns: "1) ... 2) ..." or "1. ... 2. ..."
1082        let has_numbered_list = {
1083            let mut count = 0u32;
1084            for i in 1..=5u32 {
1085                if lower.contains(&format!("{}) ", i)) || lower.contains(&format!("{}. ", i)) {
1086                    count += 1;
1087                }
1088            }
1089            count >= 2
1090        };
1091
1092        // Explicit multi-action keywords
1093        let multi_keywords = [
1094            "multiple edits",
1095            "several changes",
1096            "three changes",
1097            "two changes",
1098            "all of the following",
1099            "each of these",
1100            "do both",
1101            "do all",
1102            "and also",
1103            "additionally",
1104            "as well as",
1105            "then also",
1106        ];
1107        let has_multi_keywords = multi_keywords.iter().any(|kw| lower.contains(kw));
1108
1109        // Bullet point lists with action verbs
1110        let bullet_actions = lower.matches("- add ").count()
1111            + lower.matches("- update ").count()
1112            + lower.matches("- change ").count()
1113            + lower.matches("- remove ").count()
1114            + lower.matches("- fix ").count()
1115            + lower.matches("- edit ").count()
1116            + lower.matches("- implement ").count()
1117            + lower.matches("- create ").count();
1118        let has_bullet_list = bullet_actions >= 2;
1119
1120        has_numbered_list || has_multi_keywords || has_bullet_list
1121    }
1122
1123    /// Schema-based quality estimate (cold start).
1124    ///
1125    /// Diminishing returns on model size — the jump from 4B to 8B matters,
1126    /// but 8B to 30B is marginal. Remote models get a conservative estimate.
1127    fn schema_quality_estimate(&self, model: &ModelSchema) -> f64 {
1128        match model.size_mb() {
1129            0 => 0.5,             // remote: unknown, conservative
1130            s if s < 1000 => 0.4, // 0.6B
1131            s if s < 2000 => 0.5, // 1.7B
1132            s if s < 3000 => 0.6, // 4B
1133            s if s < 6000 => 0.7, // 8B
1134            _ => 0.75,            // 30B+: diminishing returns
1135        }
1136    }
1137
1138    /// Schema-based latency estimate (cold start).
1139    ///
1140    /// Converts declared TPS/p50 to the same ms-based score used by observed data,
1141    /// so there's no discontinuity when the first observation arrives.
1142    fn schema_latency_estimate(&self, model: &ModelSchema) -> f64 {
1143        let is_moe = model.tags.contains(&"moe".to_string());
1144
1145        if model.is_local() {
1146            if let Some(tps) = model.performance.tokens_per_second {
1147                let effective_tps = if is_moe {
1148                    let multiplier = if model.is_mlx() {
1149                        Self::MLX_MOE_TPS_MULTIPLIER
1150                    } else {
1151                        Self::MOE_TPS_MULTIPLIER
1152                    };
1153                    tps * multiplier
1154                } else {
1155                    tps
1156                };
1157                let estimated_ms = Self::tps_to_latency_ms(effective_tps);
1158                return self.latency_ms_to_score(estimated_ms);
1159            }
1160            return 0.5; // local, no declared TPS
1161        }
1162
1163        // Remote: use declared p50 latency
1164        if let Some(p50) = model.performance.latency_p50_ms {
1165            return self.latency_ms_to_score(p50 as f64);
1166        }
1167        0.3 // remote, no declared latency
1168    }
1169
1170    fn is_quality_critical_bootstrap_task(
1171        &self,
1172        task: InferenceTask,
1173        has_vision: bool,
1174        has_tools: bool,
1175    ) -> bool {
1176        has_vision
1177            || has_tools
1178            || matches!(
1179                task,
1180                InferenceTask::Generate | InferenceTask::Code | InferenceTask::Reasoning
1181            )
1182    }
1183
1184    fn is_trusted_quality_remote(&self, model: &ModelSchema) -> bool {
1185        model.is_remote()
1186            && matches!(model.provider.as_str(), "openai" | "anthropic" | "google")
1187            && !model.has_capability(ModelCapability::SpeechToText)
1188            && !model.has_capability(ModelCapability::TextToSpeech)
1189    }
1190
1191    fn is_local_model_proven_for_task(
1192        &self,
1193        model: &ModelSchema,
1194        task: InferenceTask,
1195        tracker: &OutcomeTracker,
1196    ) -> bool {
1197        let Some(profile) = tracker.profile(&model.id) else {
1198            return false;
1199        };
1200        if let Some(task_stats) = profile.task_stats(task) {
1201            if task_stats.calls >= self.config.bootstrap_min_task_observations
1202                && task_stats.ema_quality >= self.config.bootstrap_quality_floor
1203            {
1204                return true;
1205            }
1206        }
1207
1208        profile.total_calls >= self.config.bootstrap_min_task_observations
1209            && profile.ema_quality >= self.config.bootstrap_quality_floor
1210    }
1211
1212    /// Phase 3: Thompson Sampling selection.
1213    ///
1214    /// Each model gets a Beta(alpha, beta) distribution where:
1215    /// - alpha = prior_successes + observed_successes
1216    /// - beta = prior_failures + observed_failures
1217    ///
1218    /// The Phase 2 score serves as the prior mean, scaled by `prior_strength`.
1219    /// Models with few observations have wide distributions (natural exploration).
1220    /// Models with many observations have tight distributions (exploitation).
1221    fn select_with_thompson_sampling(
1222        &self,
1223        scored: &[(String, f64)],
1224        tracker: &OutcomeTracker,
1225    ) -> (String, RoutingStrategy) {
1226        if scored.is_empty() {
1227            return (String::new(), RoutingStrategy::SchemaBased);
1228        }
1229
1230        let mut rng = rand::rng();
1231        let mut best_sample = f64::NEG_INFINITY;
1232        let mut best_id = scored[0].0.clone();
1233        let mut best_strategy = RoutingStrategy::SchemaBased;
1234
1235        for (id, phase2_score) in scored {
1236            let profile = tracker.profile(id);
1237            let prior = self.config.prior_strength;
1238
1239            // Convert Phase 2 score (0.0-1.15) to a prior mean in [0, 1]
1240            let prior_mean = phase2_score.clamp(0.0, 1.0);
1241
1242            // Prior pseudo-counts from the Phase 2 score
1243            let prior_alpha = prior * prior_mean;
1244            let prior_beta = prior * (1.0 - prior_mean);
1245
1246            // Observed counts
1247            let (obs_alpha, obs_beta) = match profile {
1248                Some(p) => (p.success_count as f64, p.fail_count as f64),
1249                None => (0.0, 0.0),
1250            };
1251
1252            // Posterior Beta parameters
1253            let alpha = (prior_alpha + obs_alpha).max(0.01);
1254            let beta = (prior_beta + obs_beta).max(0.01);
1255
1256            // Sample from Beta(alpha, beta) using the Jöhnk algorithm
1257            let sample = sample_beta(&mut rng, alpha, beta);
1258
1259            if sample > best_sample {
1260                best_sample = sample;
1261                best_id = id.clone();
1262                best_strategy = match profile {
1263                    Some(p) if p.total_calls >= self.config.min_observations => {
1264                        RoutingStrategy::ProfileBased
1265                    }
1266                    Some(p) if p.total_calls > 0 => {
1267                        // Under-tested but has some data — exploration
1268                        RoutingStrategy::Exploration
1269                    }
1270                    _ => RoutingStrategy::SchemaBased,
1271                };
1272            }
1273        }
1274
1275        (best_id, best_strategy)
1276    }
1277
1278    /// Fallback decision when no candidates pass filtering.
1279    fn cold_start_decision(
1280        &self,
1281        complexity: TaskComplexity,
1282        task: InferenceTask,
1283        registry: &UnifiedRegistry,
1284        has_vision: bool,
1285    ) -> AdaptiveRoutingDecision {
1286        if has_vision {
1287            if let Some(model) = registry
1288                .query_by_capability(ModelCapability::Vision)
1289                .into_iter()
1290                .find(|model| model.available && self.is_trusted_quality_remote(model))
1291                .or_else(|| {
1292                    registry
1293                        .query_by_capability(ModelCapability::Vision)
1294                        .first()
1295                        .copied()
1296                })
1297            {
1298                return AdaptiveRoutingDecision {
1299                    model_id: model.id.clone(),
1300                    model_name: model.name.clone(),
1301                    task,
1302                    complexity,
1303                    reason: format!(
1304                        "{:?} task → {} (cold start, vision fallback)",
1305                        complexity, model.name
1306                    ),
1307                    strategy: RoutingStrategy::SchemaBased,
1308                    predicted_quality: 0.5,
1309                    fallbacks: vec![],
1310                    context_length: model.context_length,
1311                    needs_compaction: false,
1312                };
1313            }
1314        }
1315
1316        if self.config.quality_first_cold_start {
1317            let required_caps = complexity.required_capabilities();
1318            if let Some(model) = registry
1319                .list()
1320                .into_iter()
1321                .filter(|model| {
1322                    model.available
1323                        && required_caps.iter().all(|cap| model.has_capability(*cap))
1324                        && self.is_trusted_quality_remote(model)
1325                })
1326                .max_by(|a, b| {
1327                    self.schema_quality_estimate(a)
1328                        .partial_cmp(&self.schema_quality_estimate(b))
1329                        .unwrap_or(std::cmp::Ordering::Equal)
1330                })
1331            {
1332                return AdaptiveRoutingDecision {
1333                    model_id: model.id.clone(),
1334                    model_name: model.name.clone(),
1335                    task,
1336                    complexity,
1337                    reason: format!(
1338                        "{:?} task → {} (quality-first cold start)",
1339                        complexity, model.name
1340                    ),
1341                    strategy: RoutingStrategy::SchemaBased,
1342                    predicted_quality: self.schema_quality_estimate(model),
1343                    fallbacks: vec![],
1344                    context_length: model.context_length,
1345                    needs_compaction: false,
1346                };
1347            }
1348        }
1349
1350        // Fall back to the old complexity-based defaults
1351        let model_name = match complexity {
1352            TaskComplexity::Simple => "Qwen3-0.6B",
1353            TaskComplexity::Medium => "Qwen3-1.7B",
1354            TaskComplexity::Code => "Qwen3-4B",
1355            TaskComplexity::Complex => &self.hw.recommended_model,
1356        };
1357
1358        let model_id = registry
1359            .find_by_name(model_name)
1360            .map(|m| m.id.clone())
1361            .unwrap_or_else(|| model_name.to_string());
1362
1363        let context_length = registry
1364            .find_by_name(model_name)
1365            .map(|m| m.context_length)
1366            .unwrap_or(0);
1367
1368        AdaptiveRoutingDecision {
1369            model_id,
1370            model_name: model_name.to_string(),
1371            task,
1372            complexity,
1373            reason: format!(
1374                "{:?} task → {} (cold start, no candidates)",
1375                complexity, model_name
1376            ),
1377            strategy: RoutingStrategy::SchemaBased,
1378            predicted_quality: 0.5,
1379            fallbacks: vec![],
1380            context_length,
1381            needs_compaction: false,
1382        }
1383    }
1384}
1385
1386/// Map a caller-supplied [`crate::TaskHint`] to the engine's
1387/// [`InferenceTask`] enum. The intent surface uses the higher-level
1388/// hint vocabulary; the router operates on InferenceTask. Every
1389/// TaskHint variant maps to a distinct InferenceTask — variants that
1390/// would have silently collapsed to `Generate` were cut from the MVP.
1391fn task_hint_to_inference_task(hint: crate::intent::TaskHint) -> InferenceTask {
1392    use crate::intent::TaskHint;
1393    match hint {
1394        TaskHint::Chat => InferenceTask::Generate,
1395        TaskHint::Classify => InferenceTask::Classify,
1396        TaskHint::Reasoning => InferenceTask::Reasoning,
1397        TaskHint::Code => InferenceTask::Code,
1398    }
1399}
1400
1401/// Sample from a Beta(alpha, beta) distribution.
1402///
1403/// Uses the Gamma distribution method: if X ~ Gamma(alpha, 1) and Y ~ Gamma(beta, 1),
1404/// then X / (X + Y) ~ Beta(alpha, beta).
1405///
1406/// For Gamma sampling, uses Marsaglia and Tsang's method for alpha >= 1,
1407/// and Ahrens-Dieter for alpha < 1.
1408fn sample_beta(rng: &mut impl Rng, alpha: f64, beta: f64) -> f64 {
1409    let x = sample_gamma(rng, alpha);
1410    let y = sample_gamma(rng, beta);
1411    if x + y == 0.0 {
1412        0.5 // degenerate case
1413    } else {
1414        x / (x + y)
1415    }
1416}
1417
1418/// Sample from Gamma(shape, 1) using Marsaglia-Tsang for shape >= 1,
1419/// with Ahrens-Dieter boost for shape < 1.
1420fn sample_gamma(rng: &mut impl Rng, shape: f64) -> f64 {
1421    if shape < 1.0 {
1422        // Ahrens-Dieter: Gamma(a) = Gamma(a+1) * U^(1/a)
1423        let u: f64 = rng.random();
1424        return sample_gamma(rng, shape + 1.0) * u.powf(1.0 / shape);
1425    }
1426
1427    // Marsaglia-Tsang method for shape >= 1
1428    let d = shape - 1.0 / 3.0;
1429    let c = 1.0 / (9.0 * d).sqrt();
1430
1431    loop {
1432        let x: f64 = loop {
1433            let n = sample_standard_normal(rng);
1434            if 1.0 + c * n > 0.0 {
1435                break n;
1436            }
1437        };
1438
1439        let v = (1.0 + c * x).powi(3);
1440        let u: f64 = rng.random();
1441
1442        if u < 1.0 - 0.0331 * x.powi(4) {
1443            return d * v;
1444        }
1445        if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
1446            return d * v;
1447        }
1448    }
1449}
1450
1451/// Sample from standard normal N(0,1) using Box-Muller transform.
1452fn sample_standard_normal(rng: &mut impl Rng) -> f64 {
1453    let u1: f64 = rng.random();
1454    let u2: f64 = rng.random();
1455    (-2.0 * u1.max(1e-300).ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
1456}
1457
1458#[cfg(test)]
1459mod tests {
1460    use super::*;
1461    use crate::outcome::InferredOutcome;
1462
1463    fn test_hw() -> HardwareInfo {
1464        HardwareInfo {
1465            os: "macos".into(),
1466            arch: "aarch64".into(),
1467            cpu_cores: 10,
1468            total_ram_mb: 32768,
1469            gpu_backend: crate::hardware::GpuBackend::Metal,
1470            gpu_memory_mb: Some(28672),
1471            gpu_devices: Vec::new(),
1472            recommended_model: "Qwen3-8B".into(),
1473            recommended_context: 8192,
1474            max_model_mb: 18000, // headroom above 30B-A3B's 17000MB
1475        }
1476    }
1477
1478    fn test_registry() -> UnifiedRegistry {
1479        let tmp = std::path::PathBuf::from("/tmp/car-test-adaptive-router");
1480        unsafe {
1481            std::env::set_var("OPENAI_API_KEY", "test-openai-key");
1482        }
1483        // Create fake model dirs so the registry marks them as available
1484        for name in &[
1485            "Qwen3-0.6B",
1486            "Qwen3-1.7B",
1487            "Qwen3-4B",
1488            "Qwen3-8B",
1489            "Qwen3-Embedding-0.6B",
1490        ] {
1491            let dir = tmp.join(name);
1492            let _ = std::fs::create_dir_all(&dir);
1493            let _ = std::fs::write(dir.join("model.gguf"), b"fake");
1494            let _ = std::fs::write(dir.join("tokenizer.json"), b"{}");
1495        }
1496        let mut reg = UnifiedRegistry::new(tmp);
1497        reg.register(ModelSchema {
1498            id: "openai/gpt-5.4-mini:latest".into(),
1499            name: "gpt-5.4-mini".into(),
1500            provider: "openai".into(),
1501            family: "gpt-5.4".into(),
1502            version: "latest".into(),
1503            capabilities: vec![
1504                ModelCapability::Generate,
1505                ModelCapability::Code,
1506                ModelCapability::Reasoning,
1507                ModelCapability::ToolUse,
1508                ModelCapability::MultiToolCall,
1509                ModelCapability::Vision,
1510            ],
1511            context_length: 128_000,
1512            param_count: "api".into(),
1513            quantization: None,
1514            performance: Default::default(),
1515            cost: Default::default(),
1516            source: crate::schema::ModelSource::RemoteApi {
1517                endpoint: "https://api.openai.com/v1".into(),
1518                api_key_env: "OPENAI_API_KEY".into(),
1519                api_key_envs: vec![],
1520                api_version: None,
1521                protocol: crate::schema::ApiProtocol::OpenAiCompat,
1522            },
1523            tags: vec!["trusted-remote".into()],
1524            supported_params: vec![],
1525            public_benchmarks: vec![],
1526            trust_tier: crate::schema::TrustTier::Curated,
1527            deprecated: false,
1528            available: true,
1529        });
1530        reg
1531    }
1532
1533    #[test]
1534    fn routes_simple_to_trusted_remote_during_cold_start() {
1535        let router = AdaptiveRouter::new(
1536            test_hw(),
1537            RoutingConfig {
1538                prior_strength: 100.0, // strong prior = exploit Phase 2 scores (deterministic-ish)
1539                ..Default::default()
1540            },
1541        );
1542        let reg = test_registry();
1543        let tracker = OutcomeTracker::new();
1544
1545        let decision = router.route("What is 2+2?", &reg, &tracker);
1546        assert_eq!(decision.complexity, TaskComplexity::Simple);
1547        assert_eq!(decision.strategy, RoutingStrategy::SchemaBased);
1548        // On cold start, quality-critical tasks should stay on trusted remote models.
1549        let schema = reg
1550            .find_by_name(&decision.model_name)
1551            .expect("selected model should exist in registry");
1552        assert!(
1553            !schema.is_local(),
1554            "simple task should route to trusted remote model during cold start"
1555        );
1556        assert!(matches!(
1557            schema.provider.as_str(),
1558            "openai" | "anthropic" | "google"
1559        ));
1560    }
1561
1562    #[test]
1563    fn routes_code_to_code_capable_remote_during_cold_start() {
1564        let router = AdaptiveRouter::new(
1565            test_hw(),
1566            RoutingConfig {
1567                prior_strength: 100.0, // strong prior = exploit Phase 2 scores (deterministic-ish)
1568                ..Default::default()
1569            },
1570        );
1571        let reg = test_registry();
1572        let tracker = OutcomeTracker::new();
1573
1574        let decision = router.route(
1575            "Fix this function:\n```rust\nfn main() {}\n```",
1576            &reg,
1577            &tracker,
1578        );
1579        assert_eq!(decision.complexity, TaskComplexity::Code);
1580        assert_eq!(decision.task, InferenceTask::Code);
1581        // Must select a code-capable local model (not 0.6B which lacks Code)
1582        let schema = reg
1583            .find_by_name(&decision.model_name)
1584            .expect("model should exist");
1585        assert!(
1586            schema.has_capability(ModelCapability::Code),
1587            "selected model must support Code"
1588        );
1589        assert!(!schema.is_local(), "should route to trusted remote model");
1590    }
1591
1592    #[test]
1593    fn routes_images_to_vision_capable_model() {
1594        let router = AdaptiveRouter::new(
1595            test_hw(),
1596            RoutingConfig {
1597                prior_strength: 100.0,
1598                ..Default::default()
1599            },
1600        );
1601        let mut reg = test_registry();
1602        let tracker = OutcomeTracker::new();
1603
1604        reg.register(ModelSchema {
1605            id: "mlx-vlm/qwen3-vl-2b:bf16".into(),
1606            name: "Qwen3-VL-2B-mlx-vlm".into(),
1607            provider: "qwen".into(),
1608            family: "qwen3-vl".into(),
1609            version: "bf16".into(),
1610            capabilities: vec![
1611                ModelCapability::Generate,
1612                ModelCapability::Vision,
1613                ModelCapability::Grounding,
1614            ],
1615            context_length: 262_144,
1616            param_count: "2B".into(),
1617            quantization: None,
1618            performance: Default::default(),
1619            cost: Default::default(),
1620            source: crate::schema::ModelSource::Mlx {
1621                hf_repo: "Qwen/Qwen3-VL-2B-Instruct".into(),
1622                hf_weight_file: None,
1623            },
1624            tags: vec!["vision".into(), "mlx-vlm-cli".into()],
1625            supported_params: vec![],
1626            public_benchmarks: vec![],
1627            trust_tier: crate::schema::TrustTier::Curated,
1628            deprecated: false,
1629            available: true,
1630        });
1631
1632        let decision = router.route_with_vision("What is in this image?", &reg, &tracker, false);
1633        let schema = reg
1634            .find_by_name(&decision.model_name)
1635            .expect("model should exist");
1636        assert!(
1637            schema.has_capability(ModelCapability::Vision),
1638            "selected model must support Vision"
1639        );
1640    }
1641
1642    #[test]
1643    fn profile_based_routing_favors_proven_model() {
1644        let router = AdaptiveRouter::new(
1645            test_hw(),
1646            RoutingConfig {
1647                prior_strength: 0.5, // weak prior, let observed data dominate
1648                min_observations: 3,
1649                ..Default::default()
1650            },
1651        );
1652        let reg = test_registry();
1653        let mut tracker = OutcomeTracker::new();
1654
1655        // Build a strong profile for Qwen3-8B on code tasks (fast + high quality)
1656        let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
1657        for _ in 0..20 {
1658            let trace = tracker.record_start(qwen_8b_id, InferenceTask::Code, "test");
1659            tracker.record_complete(&trace, 500, 100, 50);
1660            tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1661        }
1662
1663        // Thompson Sampling is stochastic — run multiple times, 8B should win majority
1664        let mut wins = 0;
1665        for _ in 0..20 {
1666            let decision = router.route("Fix this bug in the parser", &reg, &tracker);
1667            assert_eq!(decision.complexity, TaskComplexity::Code);
1668            if decision.model_id == qwen_8b_id {
1669                wins += 1;
1670            }
1671        }
1672        assert!(
1673            wins >= 12,
1674            "proven model won only {wins}/20 times (expected >= 12)"
1675        );
1676    }
1677
1678    #[test]
1679    fn proven_local_model_can_displace_bootstrap_remote() {
1680        let router = AdaptiveRouter::new(
1681            test_hw(),
1682            RoutingConfig {
1683                prior_strength: 100.0,
1684                bootstrap_min_task_observations: 6,
1685                bootstrap_quality_floor: 0.8,
1686                ..Default::default()
1687            },
1688        );
1689        let reg = test_registry();
1690        let mut tracker = OutcomeTracker::new();
1691
1692        let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
1693        for _ in 0..12 {
1694            let trace = tracker.record_start(qwen_8b_id, InferenceTask::Generate, "test");
1695            tracker.record_complete(&trace, 300, 50, 20);
1696            tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1697        }
1698
1699        let mut local_wins = 0;
1700        for _ in 0..20 {
1701            let decision = router.route("Summarize this design decision.", &reg, &tracker);
1702            let schema = reg
1703                .get(&decision.model_id)
1704                .expect("selected model should exist");
1705            if schema.is_local() {
1706                local_wins += 1;
1707            }
1708        }
1709
1710        assert!(
1711            local_wins >= 12,
1712            "proven local model won only {local_wins}/20 times (expected >= 12)"
1713        );
1714    }
1715
1716    #[test]
1717    fn benchmark_prior_informs_background_routing() {
1718        let router = AdaptiveRouter::new(
1719            test_hw(),
1720            RoutingConfig {
1721                prior_strength: 100.0,
1722                ..Default::default()
1723            },
1724        );
1725        let reg = test_registry();
1726        let mut tracker = OutcomeTracker::new();
1727        let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1728        profile.ema_quality = 0.95;
1729        tracker.import_profiles(vec![profile]);
1730
1731        let decision = router.route_context_aware(
1732            "Write a Python fibonacci function.",
1733            128,
1734            &reg,
1735            &tracker,
1736            false,
1737            false,
1738            RoutingWorkload::Background,
1739        );
1740
1741        let schema = reg
1742            .get(&decision.model_id)
1743            .expect("selected model should exist");
1744        assert!(
1745            schema.is_local(),
1746            "background routing should allow strong local benchmark priors to win"
1747        );
1748    }
1749
1750    #[test]
1751    fn task_specific_benchmark_prior_informs_cold_start_routing() {
1752        let router = AdaptiveRouter::new(
1753            test_hw(),
1754            RoutingConfig {
1755                prior_strength: 100.0,
1756                ..Default::default()
1757            },
1758        );
1759        let reg = test_registry();
1760        let mut tracker = OutcomeTracker::new();
1761        let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1762        profile.task_stats.insert(
1763            crate::outcome::InferenceTask::Code.to_string(),
1764            crate::outcome::TaskStats {
1765                ema_quality: 0.95,
1766                ..Default::default()
1767            },
1768        );
1769        tracker.import_profiles(vec![profile]);
1770
1771        let decision = router.route_context_aware(
1772            "Write a Python fibonacci function.",
1773            128,
1774            &reg,
1775            &tracker,
1776            false,
1777            false,
1778            RoutingWorkload::Background,
1779        );
1780
1781        let schema = reg
1782            .get(&decision.model_id)
1783            .expect("selected model should exist");
1784        assert!(
1785            schema.is_local(),
1786            "background routing should use task-specific cold-start priors for local code models"
1787        );
1788    }
1789
1790    #[test]
1791    fn task_specific_latency_prior_affects_cold_start_score() {
1792        let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1793        let reg = test_registry();
1794        let model = reg
1795            .get("qwen/qwen3-8b:q4_k_m")
1796            .expect("local test model should exist");
1797
1798        let mut fast_tracker = OutcomeTracker::new();
1799        let mut fast_profile = crate::outcome::ModelProfile::new(model.id.clone());
1800        fast_profile.task_stats.insert(
1801            crate::outcome::InferenceTask::Generate.to_string(),
1802            crate::outcome::TaskStats {
1803                ema_quality: 0.95,
1804                avg_latency_ms: 1200.0,
1805                ..Default::default()
1806            },
1807        );
1808        fast_tracker.import_profiles(vec![fast_profile]);
1809
1810        let mut slow_tracker = OutcomeTracker::new();
1811        let mut slow_profile = crate::outcome::ModelProfile::new(model.id.clone());
1812        slow_profile.task_stats.insert(
1813            crate::outcome::InferenceTask::Generate.to_string(),
1814            crate::outcome::TaskStats {
1815                ema_quality: 0.95,
1816                avg_latency_ms: 120_000.0,
1817                ..Default::default()
1818            },
1819        );
1820        slow_tracker.import_profiles(vec![slow_profile]);
1821
1822        let fast_score = router.score_model(
1823            model,
1824            InferenceTask::Generate,
1825            &fast_tracker,
1826            RoutingWorkload::Interactive,
1827        );
1828        let slow_score = router.score_model(
1829            model,
1830            InferenceTask::Generate,
1831            &slow_tracker,
1832            RoutingWorkload::Interactive,
1833        );
1834
1835        assert!(
1836            fast_score > slow_score,
1837            "faster task latency prior should improve cold-start score ({fast_score} <= {slow_score})"
1838        );
1839    }
1840
1841    #[test]
1842    fn interactive_workload_keeps_remote_bootstrap_bias() {
1843        let router = AdaptiveRouter::new(
1844            test_hw(),
1845            RoutingConfig {
1846                prior_strength: 100.0,
1847                ..Default::default()
1848            },
1849        );
1850        let reg = test_registry();
1851        let mut tracker = OutcomeTracker::new();
1852        let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1853        profile.ema_quality = 0.95;
1854        tracker.import_profiles(vec![profile]);
1855
1856        let decision = router.route_context_aware(
1857            "Write a Python fibonacci function.",
1858            128,
1859            &reg,
1860            &tracker,
1861            false,
1862            false,
1863            RoutingWorkload::Interactive,
1864        );
1865
1866        let schema = reg
1867            .get(&decision.model_id)
1868            .expect("selected model should exist");
1869        assert!(
1870            !schema.is_local(),
1871            "interactive routing should still prefer trusted remote models during cold start"
1872        );
1873    }
1874
1875    #[test]
1876    fn code_interactive_weights_prioritise_quality_over_speed_over_cost() {
1877        // car-releases#52: task=code under the default interactive
1878        // workload must rank quality > speed > cost, with quality
1879        // clearly dominant — not the near-tie (0.45 q, 0.40 lat) the
1880        // generic interactive profile produces, which let a fast cheap
1881        // general model out-score a code-tuned one.
1882        let (q, lat, cost) = AdaptiveRouter::task_aware_weights(
1883            InferenceTask::Code,
1884            RoutingWorkload::Interactive,
1885        );
1886        assert!(q > lat && lat > cost, "expected quality > speed > cost, got ({q}, {lat}, {cost})");
1887        assert!((q + lat + cost - 1.0).abs() < 1e-9, "weights must sum to 1.0");
1888        let (default_q, default_lat, default_cost) = RoutingWorkload::Interactive.weights();
1889        assert!(
1890            q > default_q && lat < default_lat && cost < default_cost,
1891            "code weighting must lift quality and cut latency/cost vs the generic interactive profile"
1892        );
1893
1894        // LocalPreferred is interactive-class — same code reweighting.
1895        assert_eq!(
1896            AdaptiveRouter::task_aware_weights(
1897                InferenceTask::Code,
1898                RoutingWorkload::LocalPreferred,
1899            ),
1900            (q, lat, cost),
1901        );
1902
1903        // Fastest (voice fast-track) is an explicit hard-latency demand
1904        // — never overridden, even for code.
1905        assert_eq!(
1906            AdaptiveRouter::task_aware_weights(InferenceTask::Code, RoutingWorkload::Fastest),
1907            RoutingWorkload::Fastest.weights(),
1908        );
1909        // Batch/Background encode a deliberate non-interactive choice.
1910        assert_eq!(
1911            AdaptiveRouter::task_aware_weights(InferenceTask::Code, RoutingWorkload::Background),
1912            RoutingWorkload::Background.weights(),
1913        );
1914        // Non-code tasks are untouched under every workload.
1915        assert_eq!(
1916            AdaptiveRouter::task_aware_weights(
1917                InferenceTask::Generate,
1918                RoutingWorkload::Interactive,
1919            ),
1920            RoutingWorkload::Interactive.weights(),
1921        );
1922    }
1923
1924    #[test]
1925    fn fallback_chain_has_alternatives() {
1926        let router = AdaptiveRouter::new(
1927            test_hw(),
1928            RoutingConfig {
1929                prior_strength: 100.0, // strong prior = exploit Phase 2 scores (deterministic-ish)
1930                ..Default::default()
1931            },
1932        );
1933        let reg = test_registry();
1934        let tracker = OutcomeTracker::new();
1935
1936        let decision = router.route("Analyze the architecture trade-offs", &reg, &tracker);
1937        assert!(!decision.fallbacks.is_empty());
1938        // Primary should not appear in fallbacks
1939        assert!(!decision.fallbacks.contains(&decision.model_id));
1940    }
1941
1942    #[test]
1943    fn latency_scoring_is_consistent() {
1944        // Verify that schema and observed latency produce comparable scores
1945        let router = AdaptiveRouter::with_default_config(test_hw());
1946
1947        // A model at 25 TPS → ~200/25*1000 = 8000ms → score = 1 - 8000/10000 = 0.2
1948        let schema_score = router.latency_ms_to_score(AdaptiveRouter::tps_to_latency_ms(25.0));
1949        // Same model observed at 8000ms → same formula
1950        let observed_score = router.latency_ms_to_score(8000.0);
1951        assert!(
1952            (schema_score - observed_score).abs() < 0.01,
1953            "schema ({schema_score}) and observed ({observed_score}) should match"
1954        );
1955    }
1956
1957    #[test]
1958    fn complexity_assessment() {
1959        assert_eq!(
1960            TaskComplexity::assess("What is the capital of France?"),
1961            TaskComplexity::Simple
1962        );
1963        assert_eq!(
1964            TaskComplexity::assess("Fix this broken test"),
1965            TaskComplexity::Code
1966        );
1967        assert_eq!(
1968            TaskComplexity::assess("Analyze the trade-offs between A and B"),
1969            TaskComplexity::Complex
1970        );
1971    }
1972
1973    #[test]
1974    fn beta_sampling_produces_valid_values() {
1975        let mut rng = rand::rng();
1976        // Sample 100 times from Beta(2, 5) — should be in [0, 1]
1977        for _ in 0..100 {
1978            let s = sample_beta(&mut rng, 2.0, 5.0);
1979            assert!(s >= 0.0 && s <= 1.0, "sample {s} out of [0,1] range");
1980        }
1981        // Beta(1, 1) = Uniform(0, 1) — mean should be ~0.5
1982        let samples: Vec<f64> = (0..1000).map(|_| sample_beta(&mut rng, 1.0, 1.0)).collect();
1983        let mean = samples.iter().sum::<f64>() / samples.len() as f64;
1984        assert!(
1985            (mean - 0.5).abs() < 0.05,
1986            "Beta(1,1) mean {mean} should be ~0.5"
1987        );
1988    }
1989
1990    #[test]
1991    fn thompson_sampling_converges_to_best() {
1992        // A model with strong observed success should win most of the time
1993        let router = AdaptiveRouter::new(
1994            test_hw(),
1995            RoutingConfig {
1996                prior_strength: 1.0, // weak prior, let observations dominate
1997                ..Default::default()
1998            },
1999        );
2000        let reg = test_registry();
2001        let mut tracker = OutcomeTracker::new();
2002
2003        // Give Qwen3-4B 20 successes (strong signal)
2004        let qwen_4b_id = "qwen/qwen3-4b:q4_k_m";
2005        for _ in 0..20 {
2006            let trace = tracker.record_start(qwen_4b_id, InferenceTask::Code, "test");
2007            tracker.record_complete(&trace, 500, 100, 50);
2008            tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
2009        }
2010
2011        // Route 20 code tasks — strong model should win the majority
2012        let mut wins = 0;
2013        for _ in 0..20 {
2014            let decision = router.route("Fix this parser bug", &reg, &tracker);
2015            if decision.model_id == qwen_4b_id {
2016                wins += 1;
2017            }
2018        }
2019        assert!(
2020            wins >= 14,
2021            "strong model won only {wins}/20 times (expected >= 14)"
2022        );
2023    }
2024
2025    // ----- Intent surface (parslee-ai/car-releases#18) -----
2026
2027    #[test]
2028    fn intent_require_filters_out_models_lacking_capability() {
2029        // Asking for vision when no candidate has it should fall to
2030        // the cold-start decision rather than scoring incompatible
2031        // candidates. The fixture registry has no vision-capable
2032        // local models.
2033        let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
2034        let reg = test_registry();
2035        let tracker = OutcomeTracker::new();
2036
2037        let intent = crate::intent::IntentHint {
2038            require: vec![ModelCapability::Vision],
2039            ..Default::default()
2040        };
2041        let decision = router.route_with_intent("hello", &reg, &tracker, &intent);
2042
2043        // When require filters out every candidate, the router falls
2044        // to the schema-based cold-start path. Asserting the strategy
2045        // (not just non-empty model_id) catches a regression where
2046        // future code might silently include filtered candidates.
2047        assert_eq!(
2048            decision.strategy,
2049            RoutingStrategy::SchemaBased,
2050            "require=[vision] with no vision-capable candidates must drop to schema cold-start"
2051        );
2052    }
2053
2054    #[test]
2055    fn intent_default_does_not_override_task_or_caps() {
2056        // Thompson sampling makes per-call model selection
2057        // non-deterministic, so we can't compare model_ids directly.
2058        // What we can assert deterministically: a default IntentHint
2059        // must not change the task selection or the capability
2060        // requirements — those are functions of the prompt only when
2061        // no hint is supplied.
2062        let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
2063        let reg = test_registry();
2064        let tracker = OutcomeTracker::new();
2065
2066        let baseline = router.route("write a haiku", &reg, &tracker);
2067        let with_default = router.route_with_intent(
2068            "write a haiku",
2069            &reg,
2070            &tracker,
2071            &crate::intent::IntentHint::default(),
2072        );
2073
2074        assert_eq!(
2075            baseline.task, with_default.task,
2076            "default IntentHint must not change the prompt-derived task"
2077        );
2078        assert_eq!(
2079            baseline.complexity, with_default.complexity,
2080            "default IntentHint must not change the prompt-derived complexity"
2081        );
2082    }
2083
2084    #[test]
2085    fn intent_task_hint_overrides_prompt_complexity() {
2086        // A short prompt that complexity assessment would route as
2087        // Generate should land on Reasoning when the intent says so.
2088        let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
2089        let reg = test_registry();
2090        let tracker = OutcomeTracker::new();
2091
2092        let hint = crate::intent::IntentHint {
2093            task: Some(crate::intent::TaskHint::Reasoning),
2094            ..Default::default()
2095        };
2096        let decision = router.route_with_intent("hi", &reg, &tracker, &hint);
2097
2098        assert_eq!(
2099            decision.task,
2100            InferenceTask::Reasoning,
2101            "TaskHint::Reasoning should override the prompt-derived task"
2102        );
2103    }
2104}