Skip to main content

car_inference/
adaptive_router.rs

1//! Adaptive model routing — three-phase routing with learned performance profiles.
2//!
3//! Phase 1: **Filter** — hard constraints (capability, availability, memory, cost).
4//! Phase 2: **Score** — blend quality, latency, and cost using observed profiles
5//!          or schema defaults on cold start.
6//! Phase 3: **Select** — Thompson Sampling (Beta distribution per model) for
7//!          natural exploration-exploitation balance. Models with fewer observations
8//!          have wider distributions, giving them chances to prove themselves.
9//!
10//! Replaces the hardcoded `ModelRouter` from `router.rs`.
11
12use rand::Rng;
13use serde::{Deserialize, Serialize};
14
15use std::sync::{Arc, Mutex};
16
17use crate::hardware::HardwareInfo;
18use crate::outcome::{InferenceTask, OutcomeTracker};
19use crate::registry::UnifiedRegistry;
20use crate::routing_ext::CircuitBreakerRegistry;
21use crate::schema::{ModelCapability, ModelSchema};
22use crate::tasks::RoutingWorkload;
23
24/// Prompt complexity assessment (migrated from router.rs).
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case")]
27pub enum TaskComplexity {
28    Simple,
29    Medium,
30    Code,
31    Complex,
32}
33
34impl TaskComplexity {
35    /// Assess complexity of a prompt string.
36    ///
37    /// Uses tree-sitter AST parsing (when the `ast` feature is enabled) for
38    /// accurate code detection: if a code block parses successfully as any
39    /// supported language, it's definitively code. Falls back to keyword
40    /// heuristics for prompts that mention code without containing code blocks.
41    pub fn assess(prompt: &str) -> Self {
42        let lower = prompt.to_lowercase();
43        let word_count = prompt.split_whitespace().count();
44        let estimated_tokens = (word_count as f64 * 1.3) as usize;
45
46        let has_code = Self::detect_code(prompt);
47
48        let repair_markers = [
49            "fix", "repair", "debug", "refactor", "broken", "failing", "error", "bug",
50        ];
51        let has_repair = repair_markers.iter().any(|m| lower.contains(m));
52
53        let reasoning_markers = [
54            "analyze",
55            "compare",
56            "explain why",
57            "step by step",
58            "think through",
59            "evaluate",
60            "trade-off",
61            "tradeoff",
62            "pros and cons",
63            "architecture",
64            "design",
65            "strategy",
66            "optimize",
67            "comprehensive",
68        ];
69        let has_reasoning = reasoning_markers.iter().any(|m| lower.contains(m));
70
71        let simple_patterns = [
72            "what is",
73            "who is",
74            "when did",
75            "where is",
76            "how many",
77            "yes or no",
78            "true or false",
79            "name the",
80            "list the",
81            "define ",
82        ];
83        let is_simple = simple_patterns.iter().any(|p| lower.contains(p));
84
85        if has_code || has_repair {
86            TaskComplexity::Code
87        } else if has_reasoning || estimated_tokens > 500 {
88            TaskComplexity::Complex
89        } else if is_simple || estimated_tokens < 30 {
90            TaskComplexity::Simple
91        } else {
92            TaskComplexity::Medium
93        }
94    }
95
96    /// Detect whether a prompt contains code.
97    ///
98    /// With `ast` feature: extracts code blocks (``` delimited), attempts to
99    /// parse each with tree-sitter. If any parses into symbols, it's real code.
100    /// Without `ast` feature: falls back to keyword heuristics.
101    fn detect_code(prompt: &str) -> bool {
102        // First try AST-based detection on code blocks
103        #[cfg(feature = "ast")]
104        {
105            if let Some(is_code) = Self::detect_code_ast(prompt) {
106                return is_code;
107            }
108        }
109
110        // Fallback: keyword heuristics
111        let code_markers = [
112            "```",
113            "fn ",
114            "def ",
115            "class ",
116            "import ",
117            "require(",
118            "async fn",
119            "pub fn",
120            "function ",
121            "const ",
122            "let ",
123            "var ",
124            "#include",
125            "package ",
126            "impl ",
127        ];
128        code_markers.iter().any(|m| prompt.contains(m))
129    }
130
131    /// AST-based code detection: parse code blocks with tree-sitter.
132    /// Returns Some(true) if code found, Some(false) if blocks exist but
133    /// don't parse, None if no code blocks found (fall through to heuristics).
134    #[cfg(feature = "ast")]
135    fn detect_code_ast(prompt: &str) -> Option<bool> {
136        // Extract code blocks between ``` markers
137        let mut blocks = Vec::new();
138        let mut rest = prompt;
139        while let Some(start) = rest.find("```") {
140            let after_fence = &rest[start + 3..];
141            // Skip optional language tag on the opening fence
142            let code_start = after_fence.find('\n').map(|i| i + 1).unwrap_or(0);
143            if let Some(end) = after_fence[code_start..].find("```") {
144                blocks.push(&after_fence[code_start..code_start + end]);
145                rest = &after_fence[code_start + end + 3..];
146            } else {
147                break;
148            }
149        }
150
151        if blocks.is_empty() {
152            return None; // No code blocks — let heuristics decide
153        }
154
155        // Try to parse each block with tree-sitter
156        let languages = [
157            car_ast::Language::Rust,
158            car_ast::Language::Python,
159            car_ast::Language::TypeScript,
160            car_ast::Language::JavaScript,
161            car_ast::Language::Go,
162        ];
163
164        for block in &blocks {
165            let trimmed = block.trim();
166            if trimmed.is_empty() {
167                continue;
168            }
169
170            for lang in &languages {
171                if let Some(parsed) = car_ast::parse(trimmed, *lang) {
172                    // If it parsed into any symbols, it's definitely code
173                    if !parsed.symbols.is_empty() {
174                        return Some(true);
175                    }
176                }
177            }
178        }
179
180        // Had code blocks but none parsed into symbols — could be
181        // pseudocode, output, or unsupported language
182        Some(false)
183    }
184
185    /// Map complexity to required capabilities.
186    pub fn required_capabilities(&self) -> Vec<ModelCapability> {
187        match self {
188            TaskComplexity::Simple => vec![ModelCapability::Generate],
189            TaskComplexity::Medium => vec![ModelCapability::Generate],
190            TaskComplexity::Code => vec![ModelCapability::Code],
191            TaskComplexity::Complex => vec![ModelCapability::Reasoning],
192        }
193    }
194
195    /// Map complexity to the InferenceTask type.
196    pub fn inference_task(&self) -> InferenceTask {
197        match self {
198            TaskComplexity::Simple | TaskComplexity::Medium => InferenceTask::Generate,
199            TaskComplexity::Code => InferenceTask::Code,
200            TaskComplexity::Complex => InferenceTask::Reasoning,
201        }
202    }
203}
204
205/// Configuration for routing behavior.
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct RoutingConfig {
208    /// Minimum observations before trusting a model's profile over schema defaults.
209    pub min_observations: u64,
210    /// Scoring weights (must sum to 1.0).
211    pub quality_weight: f64,
212    pub latency_weight: f64,
213    pub cost_weight: f64,
214    /// Hard constraint: maximum latency budget in ms.
215    pub max_latency_ms: Option<u64>,
216    /// Hard constraint: maximum cost per call in USD.
217    pub max_cost_usd: Option<f64>,
218    /// Prefer local models over remote (all else being equal).
219    pub prefer_local: bool,
220    /// Thompson Sampling prior strength. Higher = more weight on the Phase 2 score
221    /// as a prior, lower = more influenced by observed outcomes.
222    /// Equivalent to the number of "virtual" observations from the prior.
223    pub prior_strength: f64,
224    /// Prefer trusted remote models for quality-critical tasks until local models
225    /// have enough task-specific evidence to be promoted.
226    pub quality_first_cold_start: bool,
227    /// Minimum task-specific observations required before a local model can
228    /// compete with trusted remote models during cold start.
229    pub bootstrap_min_task_observations: u64,
230    /// Minimum task-specific EMA quality required before a local model can
231    /// displace trusted remote models during cold start.
232    pub bootstrap_quality_floor: f64,
233}
234
235impl Default for RoutingConfig {
236    fn default() -> Self {
237        Self {
238            min_observations: 2,
239            quality_weight: 0.45,
240            latency_weight: 0.4,
241            cost_weight: 0.15,
242            max_latency_ms: None,
243            max_cost_usd: None,
244            prefer_local: true,
245            prior_strength: 2.0,
246            quality_first_cold_start: true,
247            bootstrap_min_task_observations: 8,
248            bootstrap_quality_floor: 0.8,
249        }
250    }
251}
252
253/// How a model was selected.
254#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
255#[serde(rename_all = "snake_case")]
256pub enum RoutingStrategy {
257    /// Using declared schema capabilities (no observed data).
258    SchemaBased,
259    /// Using observed performance profiles (exploitation).
260    ProfileBased,
261    /// Deliberately trying an under-tested model (exploration).
262    Exploration,
263    /// User explicitly specified the model.
264    Explicit,
265}
266
267/// The result of adaptive routing.
268#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct AdaptiveRoutingDecision {
270    /// Selected model id.
271    pub model_id: String,
272    /// Selected model name (display).
273    pub model_name: String,
274    /// Task type.
275    pub task: InferenceTask,
276    /// Assessed complexity.
277    pub complexity: TaskComplexity,
278    /// Human-readable reason.
279    pub reason: String,
280    /// How the model was selected.
281    pub strategy: RoutingStrategy,
282    /// Predicted quality (0.0-1.0).
283    pub predicted_quality: f64,
284    /// Fallback chain (ordered list of alternative model ids).
285    pub fallbacks: Vec<String>,
286    /// Context window of the selected model (tokens). 0 = unknown.
287    pub context_length: usize,
288    /// Whether the prompt needs compaction to fit the selected model's context window.
289    pub needs_compaction: bool,
290}
291
292/// Adaptive router with three-phase model selection.
293pub struct AdaptiveRouter {
294    hw: HardwareInfo,
295    config: RoutingConfig,
296    /// Circuit breaker registry — blocks models after consecutive failures (#25).
297    pub circuit_breakers: Arc<Mutex<CircuitBreakerRegistry>>,
298}
299
300/// All inputs to a routing decision packed into one struct so adding
301/// a new combinator (per-tenant routing, streaming awareness, …)
302/// only adds a field rather than another `route_*` sibling method.
303/// See [`AdaptiveRouter::route_with`].
304///
305/// Use [`RouteRequest::new`] for the common defaults, then mutate
306/// just the fields the caller cares about:
307///
308/// ```ignore
309/// let decision = router.route_with(RouteRequest {
310///     has_tools: true,
311///     intent: Some(&hint),
312///     ..RouteRequest::new(prompt, &registry, &tracker)
313/// });
314/// ```
315pub struct RouteRequest<'a> {
316    pub prompt: &'a str,
317    pub registry: &'a UnifiedRegistry,
318    pub tracker: &'a OutcomeTracker,
319    /// Estimated prompt-side token count for context-window-aware
320    /// scoring. `0` skips the compaction-headroom check.
321    pub estimated_total_tokens: usize,
322    pub has_tools: bool,
323    pub has_vision: bool,
324    pub workload: RoutingWorkload,
325    /// Caller-supplied intent hint. `prefer_local: true` overrides
326    /// `workload` to [`RoutingWorkload::LocalPreferred`].
327    pub intent: Option<&'a crate::intent::IntentHint>,
328}
329
330impl<'a> RouteRequest<'a> {
331    /// Build a request with the same defaults the bare
332    /// [`AdaptiveRouter::route`] uses: interactive workload, no tools,
333    /// no vision, no context-aware sizing, no intent.
334    pub fn new(
335        prompt: &'a str,
336        registry: &'a UnifiedRegistry,
337        tracker: &'a OutcomeTracker,
338    ) -> Self {
339        Self {
340            prompt,
341            registry,
342            tracker,
343            estimated_total_tokens: 0,
344            has_tools: false,
345            has_vision: false,
346            workload: RoutingWorkload::Interactive,
347            intent: None,
348        }
349    }
350}
351
352impl AdaptiveRouter {
353    pub fn new(hw: HardwareInfo, config: RoutingConfig) -> Self {
354        let circuit_breakers = Arc::new(Mutex::new(
355            CircuitBreakerRegistry::new(3, 300), // 3 failures, 5 min cooldown
356        ));
357        Self {
358            hw,
359            config,
360            circuit_breakers,
361        }
362    }
363
364    pub fn with_default_config(hw: HardwareInfo) -> Self {
365        Self::new(hw, RoutingConfig::default())
366    }
367
368    pub fn config(&self) -> &RoutingConfig {
369        &self.config
370    }
371
372    pub fn set_config(&mut self, config: RoutingConfig) {
373        self.config = config;
374    }
375
376    /// Canonical entry point. The seven `route_*` sibling methods
377    /// each build a `RouteRequest` with their fixed defaults and
378    /// call this — adding a new combinator (per-tenant routing,
379    /// streaming awareness, …) only adds a field here, not another
380    /// public method. Closes #108.
381    pub fn route_with(&self, req: RouteRequest<'_>) -> AdaptiveRoutingDecision {
382        // `prefer_local` on the intent hint takes precedence over
383        // the caller-supplied workload — matches the long-standing
384        // `route_with_intent` / `route_context_aware_with_intent`
385        // override semantics.
386        let workload = match req.intent {
387            Some(h) if h.prefer_local => RoutingWorkload::LocalPreferred,
388            _ => req.workload,
389        };
390        self.route_inner_with_intent(
391            req.prompt,
392            req.registry,
393            req.tracker,
394            req.has_tools,
395            req.has_vision,
396            req.estimated_total_tokens,
397            workload,
398            req.intent,
399        )
400    }
401
402    /// Route a generation request to the best model.
403    /// If `has_tools` is true, requires ToolUse capability (#13).
404    pub fn route(
405        &self,
406        prompt: &str,
407        registry: &UnifiedRegistry,
408        tracker: &OutcomeTracker,
409    ) -> AdaptiveRoutingDecision {
410        self.route_with(RouteRequest::new(prompt, registry, tracker))
411    }
412
413    /// Route an "editor" request — cheap mechanical work (context compaction,
414    /// edit materialization, title generation, replanning). Uses
415    /// [`RoutingWorkload::Background`] so cost/quality weights favour cheaper
416    /// models, and lets hosting layers express the Aider-style architect/editor
417    /// split without plumbing a raw model-name override through every layer.
418    pub fn route_editor(
419        &self,
420        prompt: &str,
421        registry: &UnifiedRegistry,
422        tracker: &OutcomeTracker,
423    ) -> AdaptiveRoutingDecision {
424        self.route_with(RouteRequest {
425            workload: RoutingWorkload::Background,
426            ..RouteRequest::new(prompt, registry, tracker)
427        })
428    }
429
430    /// Route with tool_use requirement — filters to models that support structured tool calls.
431    pub fn route_with_tools(
432        &self,
433        prompt: &str,
434        registry: &UnifiedRegistry,
435        tracker: &OutcomeTracker,
436    ) -> AdaptiveRoutingDecision {
437        self.route_with(RouteRequest {
438            has_tools: true,
439            ..RouteRequest::new(prompt, registry, tracker)
440        })
441    }
442
443    /// Route with image input requirement — filters to models that support vision.
444    pub fn route_with_vision(
445        &self,
446        prompt: &str,
447        registry: &UnifiedRegistry,
448        tracker: &OutcomeTracker,
449        has_tools: bool,
450    ) -> AdaptiveRoutingDecision {
451        self.route_with(RouteRequest {
452            has_tools,
453            has_vision: true,
454            ..RouteRequest::new(prompt, registry, tracker)
455        })
456    }
457
458    /// Route with caller-supplied intent — see [`crate::IntentHint`].
459    /// The hint can override the auto-detected `InferenceTask`, add hard
460    /// `require` capability filters on top of the prompt-derived ones,
461    /// and bias the score profile toward local models when
462    /// `prefer_local` is set.
463    pub fn route_with_intent<'a>(
464        &self,
465        prompt: &'a str,
466        registry: &'a UnifiedRegistry,
467        tracker: &'a OutcomeTracker,
468        intent: &'a crate::intent::IntentHint,
469    ) -> AdaptiveRoutingDecision {
470        self.route_with(RouteRequest {
471            intent: Some(intent),
472            ..RouteRequest::new(prompt, registry, tracker)
473        })
474    }
475
476    /// Route with context awareness — estimates prompt tokens and prefers models
477    /// whose context window can fit the full prompt without compaction.
478    pub fn route_context_aware(
479        &self,
480        prompt: &str,
481        estimated_total_tokens: usize,
482        registry: &UnifiedRegistry,
483        tracker: &OutcomeTracker,
484        has_tools: bool,
485        has_vision: bool,
486        workload: RoutingWorkload,
487    ) -> AdaptiveRoutingDecision {
488        self.route_with(RouteRequest {
489            estimated_total_tokens,
490            has_tools,
491            has_vision,
492            workload,
493            ..RouteRequest::new(prompt, registry, tracker)
494        })
495    }
496
497    /// Context-aware routing with caller-supplied intent. Same context
498    /// math as [`Self::route_context_aware`]; the intent layers on
499    /// top — task override, additional `require` filters,
500    /// `prefer_local` workload override.
501    pub fn route_context_aware_with_intent<'a>(
502        &self,
503        prompt: &'a str,
504        estimated_total_tokens: usize,
505        registry: &'a UnifiedRegistry,
506        tracker: &'a OutcomeTracker,
507        has_tools: bool,
508        has_vision: bool,
509        workload: RoutingWorkload,
510        intent: &'a crate::intent::IntentHint,
511    ) -> AdaptiveRoutingDecision {
512        self.route_with(RouteRequest {
513            estimated_total_tokens,
514            has_tools,
515            has_vision,
516            workload,
517            intent: Some(intent),
518            ..RouteRequest::new(prompt, registry, tracker)
519        })
520    }
521
522    fn route_inner_with_intent(
523        &self,
524        prompt: &str,
525        registry: &UnifiedRegistry,
526        tracker: &OutcomeTracker,
527        has_tools: bool,
528        has_vision: bool,
529        estimated_total_tokens: usize,
530        workload: RoutingWorkload,
531        intent: Option<&crate::intent::IntentHint>,
532    ) -> AdaptiveRoutingDecision {
533        let complexity = TaskComplexity::assess(prompt);
534        // Caller intent overrides the prompt-derived task when supplied.
535        let task = intent
536            .and_then(|h| h.task)
537            .map(task_hint_to_inference_task)
538            .unwrap_or_else(|| complexity.inference_task());
539        let mut required_caps = complexity.required_capabilities();
540        if let Some(hint) = intent {
541            for cap in &hint.require {
542                if !required_caps.contains(cap) {
543                    required_caps.push(*cap);
544                }
545            }
546        }
547        if has_vision {
548            required_caps.push(ModelCapability::Vision);
549        }
550        if has_tools {
551            required_caps.push(ModelCapability::ToolUse);
552            // Detect multi-step prompts that need multiple tool calls in one response.
553            // Patterns: numbered lists ("1) ... 2) ..."), multiple instructions, explicit multi-edit.
554            if Self::needs_multi_tool_call(prompt) {
555                required_caps.push(ModelCapability::MultiToolCall);
556            }
557        }
558
559        // Phase 1: Filter candidates
560        let mut candidates = self.filter_candidates(&required_caps, registry, tracker);
561
562        // Fallback: if requiring MultiToolCall eliminates all candidates, drop the
563        // requirement and let the best ToolUse model handle it with multiple round-trips.
564        if candidates.is_empty() && required_caps.contains(&ModelCapability::MultiToolCall) {
565            required_caps.retain(|c| *c != ModelCapability::MultiToolCall);
566            candidates = self.filter_candidates(&required_caps, registry, tracker);
567        }
568
569        if candidates.is_empty() {
570            // Nothing available — return the schema-based default
571            return self.cold_start_decision(complexity, task, registry, has_vision);
572        }
573
574        candidates = self.apply_quality_first_bootstrap_policy(
575            candidates,
576            task,
577            tracker,
578            has_vision,
579            has_tools,
580            workload,
581        );
582
583        // Context-aware filtering: if we know the prompt size, prefer models that fit.
584        // Phase 1b: separate candidates into "fits" and "needs compaction" groups.
585        let (fits, needs_compaction_candidates) = if estimated_total_tokens > 0 {
586            let mut fits = Vec::new();
587            let mut tight = Vec::new();
588            for m in &candidates {
589                if m.context_length == 0 || m.context_length >= estimated_total_tokens {
590                    fits.push(m.clone());
591                } else {
592                    tight.push(m.clone());
593                }
594            }
595            (fits, tight)
596        } else {
597            (candidates.clone(), Vec::new())
598        };
599
600        // Prefer models that fit; fall back to compaction-required models if none fit
601        let (scoring_candidates, compaction_needed) = if !fits.is_empty() {
602            (fits, false)
603        } else if !needs_compaction_candidates.is_empty() {
604            tracing::info!(
605                prompt_tokens = estimated_total_tokens,
606                candidates = needs_compaction_candidates.len(),
607                "no model fits full prompt — compaction will be needed"
608            );
609            (needs_compaction_candidates.clone(), true)
610        } else {
611            (candidates.clone(), false)
612        };
613
614        // Phase 2: Score candidates (with context headroom bonus)
615        let scored = self.score_candidates_context_aware(
616            &scoring_candidates,
617            task,
618            tracker,
619            estimated_total_tokens,
620            workload,
621        );
622
623        // Phase 3: Thompson Sampling selection
624        let (selected_id, strategy) = self.select_with_thompson_sampling(&scored, tracker);
625
626        // Build fallback chain: prefer models that fit, then compaction candidates
627        let mut fallbacks: Vec<String> = scored
628            .iter()
629            .filter(|(id, _)| *id != selected_id)
630            .map(|(id, _)| id.clone())
631            .collect();
632        // Add compaction candidates to the end of the fallback chain
633        if !compaction_needed {
634            for m in &needs_compaction_candidates {
635                if m.id != selected_id && !fallbacks.contains(&m.id) {
636                    fallbacks.push(m.id.clone());
637                }
638            }
639        }
640
641        let predicted_quality = scored
642            .iter()
643            .find(|(id, _)| *id == selected_id)
644            .map(|(_, score)| *score)
645            .unwrap_or(0.5);
646
647        let selected_schema = registry
648            .get(&selected_id)
649            .or_else(|| registry.find_by_name(&selected_id));
650        let model_name = selected_schema
651            .map(|m| m.name.clone())
652            .unwrap_or_else(|| selected_id.clone());
653        let context_length = selected_schema.map(|m| m.context_length).unwrap_or(0);
654
655        let needs_compact = compaction_needed
656            || (estimated_total_tokens > 0
657                && context_length > 0
658                && estimated_total_tokens > context_length);
659
660        let compaction_note = if needs_compact {
661            format!(
662                " [compaction needed: {}→{}tok]",
663                estimated_total_tokens, context_length
664            )
665        } else {
666            String::new()
667        };
668
669        let reason = format!(
670            "{:?} task → {} via {:?} (quality: {:.2}, {} candidates){}",
671            complexity,
672            model_name,
673            strategy,
674            predicted_quality,
675            scoring_candidates.len(),
676            compaction_note,
677        );
678
679        AdaptiveRoutingDecision {
680            model_id: selected_id,
681            model_name,
682            task,
683            complexity,
684            reason,
685            strategy,
686            predicted_quality,
687            fallbacks,
688            context_length,
689            needs_compaction: needs_compact,
690        }
691    }
692
693    /// Route to the best embedding model.
694    pub fn route_embedding(&self, registry: &UnifiedRegistry) -> String {
695        let embed_models = registry.query_by_capability(ModelCapability::Embed);
696        embed_models
697            .first()
698            .map(|m| m.name.clone())
699            .unwrap_or_else(|| "Qwen3-Embedding-0.6B".to_string())
700    }
701
702    /// Route to the smallest available model (for classification).
703    pub fn route_small(&self, registry: &UnifiedRegistry) -> String {
704        let gen_models = registry.query_by_capability(ModelCapability::Generate);
705        // Pick smallest by size
706        gen_models
707            .iter()
708            .filter(|m| m.is_local())
709            .min_by_key(|m| m.size_mb())
710            .map(|m| m.name.clone())
711            .unwrap_or_else(|| "Qwen3-0.6B".to_string())
712    }
713
714    // --- Internal phases ---
715
716    // --- Scoring constants ---
717
718    /// Latency ceiling: requests taking longer than this score 0.0.
719    const LATENCY_CEILING_MS: f64 = 10000.0;
720    /// TPS ceiling: models faster than this score 1.0.
721    const _TPS_CEILING: f64 = 150.0;
722    /// MoE throughput penalty for Candle: naive expert routing runs at ~10% of declared TPS.
723    const MOE_TPS_MULTIPLIER: f64 = 0.10;
724    /// MoE throughput multiplier for MLX: fused Metal kernels run at ~50% of declared TPS.
725    const MLX_MOE_TPS_MULTIPLIER: f64 = 0.50;
726    /// Cost ceiling: models costing more than this per 1K output tokens score 0.0.
727    const COST_CEILING_PER_1K: f64 = 0.1;
728    /// Local preference bonus added to the weighted score (before normalization).
729    const LOCAL_BONUS: f64 = 0.15;
730    /// Extra bonus for MLX models on Apple Silicon (stacks with LOCAL_BONUS).
731    /// MLX gets fused Metal kernels and better memory layout vs Candle on Mac.
732    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
733    const MLX_BONUS: f64 = 0.10;
734    /// Extra bonus for system-owned models — no model file on disk, no
735    /// download to provision, framework-managed memory. Currently only
736    /// `apple/foundation:default` qualifies. Tag-driven so future
737    /// system-LLM integrations (e.g. Android AICore) inherit the
738    /// scoring without a code change. Stacks with LOCAL_BONUS but
739    /// **excludes** MLX_BONUS (system models aren't MLX); the net
740    /// effect is FoundationModels ranks roughly even with a
741    /// well-warmed MLX 4B for short fast-turn tasks instead of
742    /// strictly losing because its catalog `tokens_per_second` /
743    /// `size_mb` are null.
744    const SYSTEM_LLM_BONUS: f64 = 0.12;
745
746    // --- Internal phases ---
747
748    /// Phase 1: Filter by hard constraints.
749    fn filter_candidates(
750        &self,
751        required_caps: &[ModelCapability],
752        registry: &UnifiedRegistry,
753        tracker: &OutcomeTracker,
754    ) -> Vec<ModelSchema> {
755        registry
756            .list()
757            .into_iter()
758            .filter(|m| {
759                // Must have all required capabilities
760                if !required_caps.iter().all(|c| m.has_capability(*c)) {
761                    return false;
762                }
763                // Must be available (downloaded for local, API key set for remote)
764                if !m.available {
765                    return false;
766                }
767                // Local models must fit in memory (strict: >= excludes models at the limit)
768                if m.is_local() && m.size_mb() >= self.hw.max_model_mb {
769                    return false;
770                }
771                // Hard latency constraint
772                if let Some(max) = self.config.max_latency_ms {
773                    if let Some(p50) = m.performance.latency_p50_ms {
774                        if p50 > max {
775                            return false;
776                        }
777                    }
778                }
779                // Hard cost constraint
780                if let Some(max) = self.config.max_cost_usd {
781                    if m.cost_per_1k_output() > max {
782                        return false;
783                    }
784                }
785                // Hard exclusion: prefer_local=false excludes all local models (#12)
786                if !self.config.prefer_local && m.is_local() {
787                    return false;
788                }
789                // Hard exclusion: rate-limited models excluded for this session (#13)
790                if tracker.is_excluded(&m.id) {
791                    return false;
792                }
793                // Circuit breaker: skip models with consecutive failures (#25)
794                if let Ok(mut cb) = self.circuit_breakers.lock() {
795                    if !cb.allow_request(&m.id) {
796                        tracing::debug!(model = %m.id, "skipped by circuit breaker");
797                        return false;
798                    }
799                }
800                true
801            })
802            .cloned()
803            .collect()
804    }
805
806    fn apply_quality_first_bootstrap_policy(
807        &self,
808        candidates: Vec<ModelSchema>,
809        task: InferenceTask,
810        tracker: &OutcomeTracker,
811        has_vision: bool,
812        has_tools: bool,
813        workload: RoutingWorkload,
814    ) -> Vec<ModelSchema> {
815        if !self.config.quality_first_cold_start
816            || !workload.is_latency_sensitive()
817            || !self.is_quality_critical_bootstrap_task(task, has_vision, has_tools)
818        {
819            return candidates;
820        }
821
822        let trusted_remote: Vec<ModelSchema> = candidates
823            .iter()
824            .filter(|model| self.is_trusted_quality_remote(model))
825            .cloned()
826            .collect();
827
828        if trusted_remote.is_empty() {
829            return candidates;
830        }
831
832        let proven_local: Vec<ModelSchema> = candidates
833            .iter()
834            .filter(|model| model.is_local() && self.is_local_model_proven_for_task(model, task, tracker))
835            .cloned()
836            .collect();
837
838        if !proven_local.is_empty() {
839            return proven_local;
840        }
841
842        trusted_remote
843    }
844
845    /// Phase 2: Score candidates with context awareness.
846    /// Applies a headroom bonus to models with more context window for the prompt.
847    /// When estimated_total_tokens is 0, no context bonus/penalty is applied.
848    fn score_candidates_context_aware(
849        &self,
850        candidates: &[ModelSchema],
851        task: InferenceTask,
852        tracker: &OutcomeTracker,
853        estimated_total_tokens: usize,
854        workload: RoutingWorkload,
855    ) -> Vec<(String, f64)> {
856        let mut scored: Vec<(String, f64)> = candidates
857            .iter()
858            .map(|m| {
859                let base_score = self.score_model(m, task, tracker, workload);
860                // Context headroom bonus: prefer models with more room to spare.
861                // Max bonus: 0.10 (at 4x headroom or more). No bonus if unknown.
862                let headroom_bonus = if estimated_total_tokens > 0 && m.context_length > 0 {
863                    let ratio = m.context_length as f64 / estimated_total_tokens as f64;
864                    if ratio >= 1.0 {
865                        (ratio.min(4.0) - 1.0) / 3.0 * 0.10 // 0.0 at exact fit, 0.10 at 4x
866                    } else {
867                        -0.15 // Penalty for models that require compaction
868                    }
869                } else {
870                    0.0
871                };
872                (m.id.clone(), base_score + headroom_bonus)
873            })
874            .collect();
875
876        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
877        scored
878    }
879
880    /// Score a single model. All sub-scores are in [0.0, 1.0].
881    /// Final score = weighted sum + local_bonus, so range is [0.0, ~1.15].
882    fn score_model(
883        &self,
884        model: &ModelSchema,
885        task: InferenceTask,
886        tracker: &OutcomeTracker,
887        workload: RoutingWorkload,
888    ) -> f64 {
889        let profile = tracker.profile(&model.id);
890        let schema_quality = self.schema_quality_estimate(model);
891        let schema_latency = self.schema_latency_estimate(model);
892        let (quality_weight, latency_weight, cost_weight) = workload.weights();
893
894        // Quality: blend schema estimate with observed data based on observation count.
895        // Both cold start and warm start use consistent blending.
896        let quality = match profile {
897            Some(p) if p.total_calls >= self.config.min_observations => p
898                .task_stats(task)
899                .map(|ts| ts.ema_quality)
900                .unwrap_or(p.ema_quality),
901            Some(p) if p.total_calls == 0 => p
902                .task_stats(task)
903                .map(|ts| ts.ema_quality)
904                .unwrap_or(p.ema_quality),
905            Some(p) if p.total_calls > 0 => {
906                let w = p.total_calls as f64 / self.config.min_observations as f64;
907                schema_quality * (1.0 - w) + p.ema_quality * w
908            }
909            _ => schema_quality,
910        };
911
912        // Latency: same blending as quality — don't trust a single observation more
913        // than schema estimates. This prevents routing oscillation on first few calls.
914        let latency = match profile {
915            Some(p) if p.total_calls >= self.config.min_observations => {
916                let avg = p
917                    .task_stats(task)
918                    .filter(|ts| ts.calls > 0 || ts.avg_latency_ms > 0.0)
919                    .map(|ts| ts.avg_latency_ms)
920                    .unwrap_or_else(|| p.avg_latency_ms());
921                self.latency_ms_to_score(avg)
922            }
923            Some(p) if p.total_calls == 0 => p
924                .task_stats(task)
925                .filter(|ts| ts.avg_latency_ms > 0.0)
926                .map(|ts| self.latency_ms_to_score(ts.avg_latency_ms))
927                .unwrap_or(schema_latency),
928            Some(p) if p.total_calls > 0 => {
929                let observed = self.latency_ms_to_score(
930                    p.task_stats(task)
931                        .filter(|ts| ts.calls > 0 || ts.avg_latency_ms > 0.0)
932                        .map(|ts| ts.avg_latency_ms)
933                        .unwrap_or_else(|| p.avg_latency_ms()),
934                );
935                let w = p.total_calls as f64 / self.config.min_observations as f64;
936                schema_latency * (1.0 - w) + observed * w
937            }
938            _ => schema_latency,
939        };
940
941        // Cost score (lower is better → invert)
942        let cost = if model.is_local() {
943            1.0
944        } else {
945            (1.0 - (model.cost_per_1k_output() / Self::COST_CEILING_PER_1K)).clamp(0.0, 1.0)
946        };
947
948        let local_bonus = if self.config.prefer_local && model.is_local() {
949            Self::LOCAL_BONUS
950        } else {
951            0.0
952        };
953        let workload_local_bonus = if model.is_local() {
954            workload.local_bonus()
955        } else {
956            0.0
957        };
958
959        // On Apple Silicon, prefer MLX models over Candle equivalents
960        #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
961        let mlx_bonus = if model.is_mlx() { Self::MLX_BONUS } else { 0.0 };
962        #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
963        let mlx_bonus = 0.0;
964
965        // vLLM-MLX bonus: continuous batching gives better multi-agent throughput
966        let vllm_mlx_bonus = if model.is_vllm_mlx() {
967            Self::LOCAL_BONUS + 0.05
968        } else {
969            0.0
970        };
971
972        // System-LLM bonus: catalog tag-driven so it generalizes beyond
973        // FoundationModels. Models tagged `low_latency` AND `private`
974        // are zero-cost system-owned LLMs (apple/foundation:default
975        // today; AICore-on-Android etc. in the future). They don't
976        // appear in `is_mlx()` but deserve to compete with MLX 4B on
977        // routing — the catalog's tags carry that intent and the
978        // router now honors it.
979        let system_llm_bonus = if model.tags.iter().any(|t| t == "low_latency")
980            && model.tags.iter().any(|t| t == "private")
981        {
982            Self::SYSTEM_LLM_BONUS
983        } else {
984            0.0
985        };
986
987        quality_weight * quality
988            + latency_weight * latency
989            + cost_weight * cost
990            + local_bonus
991            + workload_local_bonus
992            + mlx_bonus
993            + vllm_mlx_bonus
994            + system_llm_bonus
995    }
996
997    /// Convert latency in ms to a [0, 1] score. Used by both schema and observed paths
998    /// so the scales are consistent (fixes Linus review issue #2).
999    fn latency_ms_to_score(&self, ms: f64) -> f64 {
1000        (1.0 - (ms / Self::LATENCY_CEILING_MS)).clamp(0.0, 1.0)
1001    }
1002
1003    /// Convert TPS to estimated latency in ms (for a typical 200-token response).
1004    fn tps_to_latency_ms(tps: f64) -> f64 {
1005        if tps <= 0.0 {
1006            return Self::LATENCY_CEILING_MS;
1007        }
1008        // Assume ~200 tokens per response as baseline
1009        (200.0 / tps) * 1000.0
1010    }
1011
1012    /// Detect whether a prompt likely needs multiple tool calls in a single response.
1013    /// Looks for numbered lists, multiple explicit instructions, multi-edit patterns.
1014    fn needs_multi_tool_call(prompt: &str) -> bool {
1015        let lower = prompt.to_lowercase();
1016
1017        // Numbered list patterns: "1) ... 2) ..." or "1. ... 2. ..."
1018        let has_numbered_list = {
1019            let mut count = 0u32;
1020            for i in 1..=5u32 {
1021                if lower.contains(&format!("{}) ", i)) || lower.contains(&format!("{}. ", i)) {
1022                    count += 1;
1023                }
1024            }
1025            count >= 2
1026        };
1027
1028        // Explicit multi-action keywords
1029        let multi_keywords = [
1030            "multiple edits",
1031            "several changes",
1032            "three changes",
1033            "two changes",
1034            "all of the following",
1035            "each of these",
1036            "do both",
1037            "do all",
1038            "and also",
1039            "additionally",
1040            "as well as",
1041            "then also",
1042        ];
1043        let has_multi_keywords = multi_keywords.iter().any(|kw| lower.contains(kw));
1044
1045        // Bullet point lists with action verbs
1046        let bullet_actions = lower.matches("- add ").count()
1047            + lower.matches("- update ").count()
1048            + lower.matches("- change ").count()
1049            + lower.matches("- remove ").count()
1050            + lower.matches("- fix ").count()
1051            + lower.matches("- edit ").count()
1052            + lower.matches("- implement ").count()
1053            + lower.matches("- create ").count();
1054        let has_bullet_list = bullet_actions >= 2;
1055
1056        has_numbered_list || has_multi_keywords || has_bullet_list
1057    }
1058
1059    /// Schema-based quality estimate (cold start).
1060    ///
1061    /// Diminishing returns on model size — the jump from 4B to 8B matters,
1062    /// but 8B to 30B is marginal. Remote models get a conservative estimate.
1063    fn schema_quality_estimate(&self, model: &ModelSchema) -> f64 {
1064        match model.size_mb() {
1065            0 => 0.5,             // remote: unknown, conservative
1066            s if s < 1000 => 0.4, // 0.6B
1067            s if s < 2000 => 0.5, // 1.7B
1068            s if s < 3000 => 0.6, // 4B
1069            s if s < 6000 => 0.7, // 8B
1070            _ => 0.75,            // 30B+: diminishing returns
1071        }
1072    }
1073
1074    /// Schema-based latency estimate (cold start).
1075    ///
1076    /// Converts declared TPS/p50 to the same ms-based score used by observed data,
1077    /// so there's no discontinuity when the first observation arrives.
1078    fn schema_latency_estimate(&self, model: &ModelSchema) -> f64 {
1079        let is_moe = model.tags.contains(&"moe".to_string());
1080
1081        if model.is_local() {
1082            if let Some(tps) = model.performance.tokens_per_second {
1083                let effective_tps = if is_moe {
1084                    let multiplier = if model.is_mlx() {
1085                        Self::MLX_MOE_TPS_MULTIPLIER
1086                    } else {
1087                        Self::MOE_TPS_MULTIPLIER
1088                    };
1089                    tps * multiplier
1090                } else {
1091                    tps
1092                };
1093                let estimated_ms = Self::tps_to_latency_ms(effective_tps);
1094                return self.latency_ms_to_score(estimated_ms);
1095            }
1096            return 0.5; // local, no declared TPS
1097        }
1098
1099        // Remote: use declared p50 latency
1100        if let Some(p50) = model.performance.latency_p50_ms {
1101            return self.latency_ms_to_score(p50 as f64);
1102        }
1103        0.3 // remote, no declared latency
1104    }
1105
1106    fn is_quality_critical_bootstrap_task(
1107        &self,
1108        task: InferenceTask,
1109        has_vision: bool,
1110        has_tools: bool,
1111    ) -> bool {
1112        has_vision
1113            || has_tools
1114            || matches!(
1115                task,
1116                InferenceTask::Generate | InferenceTask::Code | InferenceTask::Reasoning
1117            )
1118    }
1119
1120    fn is_trusted_quality_remote(&self, model: &ModelSchema) -> bool {
1121        model.is_remote()
1122            && matches!(
1123                model.provider.as_str(),
1124                "openai" | "anthropic" | "google"
1125            )
1126            && !model.has_capability(ModelCapability::SpeechToText)
1127            && !model.has_capability(ModelCapability::TextToSpeech)
1128    }
1129
1130    fn is_local_model_proven_for_task(
1131        &self,
1132        model: &ModelSchema,
1133        task: InferenceTask,
1134        tracker: &OutcomeTracker,
1135    ) -> bool {
1136        let Some(profile) = tracker.profile(&model.id) else {
1137            return false;
1138        };
1139        if let Some(task_stats) = profile.task_stats(task) {
1140            if task_stats.calls >= self.config.bootstrap_min_task_observations
1141                && task_stats.ema_quality >= self.config.bootstrap_quality_floor
1142            {
1143                return true;
1144            }
1145        }
1146
1147        profile.total_calls >= self.config.bootstrap_min_task_observations
1148            && profile.ema_quality >= self.config.bootstrap_quality_floor
1149    }
1150
1151    /// Phase 3: Thompson Sampling selection.
1152    ///
1153    /// Each model gets a Beta(alpha, beta) distribution where:
1154    /// - alpha = prior_successes + observed_successes
1155    /// - beta = prior_failures + observed_failures
1156    ///
1157    /// The Phase 2 score serves as the prior mean, scaled by `prior_strength`.
1158    /// Models with few observations have wide distributions (natural exploration).
1159    /// Models with many observations have tight distributions (exploitation).
1160    fn select_with_thompson_sampling(
1161        &self,
1162        scored: &[(String, f64)],
1163        tracker: &OutcomeTracker,
1164    ) -> (String, RoutingStrategy) {
1165        if scored.is_empty() {
1166            return (String::new(), RoutingStrategy::SchemaBased);
1167        }
1168
1169        let mut rng = rand::rng();
1170        let mut best_sample = f64::NEG_INFINITY;
1171        let mut best_id = scored[0].0.clone();
1172        let mut best_strategy = RoutingStrategy::SchemaBased;
1173
1174        for (id, phase2_score) in scored {
1175            let profile = tracker.profile(id);
1176            let prior = self.config.prior_strength;
1177
1178            // Convert Phase 2 score (0.0-1.15) to a prior mean in [0, 1]
1179            let prior_mean = phase2_score.clamp(0.0, 1.0);
1180
1181            // Prior pseudo-counts from the Phase 2 score
1182            let prior_alpha = prior * prior_mean;
1183            let prior_beta = prior * (1.0 - prior_mean);
1184
1185            // Observed counts
1186            let (obs_alpha, obs_beta) = match profile {
1187                Some(p) => (p.success_count as f64, p.fail_count as f64),
1188                None => (0.0, 0.0),
1189            };
1190
1191            // Posterior Beta parameters
1192            let alpha = (prior_alpha + obs_alpha).max(0.01);
1193            let beta = (prior_beta + obs_beta).max(0.01);
1194
1195            // Sample from Beta(alpha, beta) using the Jöhnk algorithm
1196            let sample = sample_beta(&mut rng, alpha, beta);
1197
1198            if sample > best_sample {
1199                best_sample = sample;
1200                best_id = id.clone();
1201                best_strategy = match profile {
1202                    Some(p) if p.total_calls >= self.config.min_observations => {
1203                        RoutingStrategy::ProfileBased
1204                    }
1205                    Some(p) if p.total_calls > 0 => {
1206                        // Under-tested but has some data — exploration
1207                        RoutingStrategy::Exploration
1208                    }
1209                    _ => RoutingStrategy::SchemaBased,
1210                };
1211            }
1212        }
1213
1214        (best_id, best_strategy)
1215    }
1216
1217    /// Fallback decision when no candidates pass filtering.
1218    fn cold_start_decision(
1219        &self,
1220        complexity: TaskComplexity,
1221        task: InferenceTask,
1222        registry: &UnifiedRegistry,
1223        has_vision: bool,
1224    ) -> AdaptiveRoutingDecision {
1225        if has_vision {
1226            if let Some(model) = registry
1227                .query_by_capability(ModelCapability::Vision)
1228                .into_iter()
1229                .find(|model| model.available && self.is_trusted_quality_remote(model))
1230                .or_else(|| registry.query_by_capability(ModelCapability::Vision).first().copied())
1231            {
1232                return AdaptiveRoutingDecision {
1233                    model_id: model.id.clone(),
1234                    model_name: model.name.clone(),
1235                    task,
1236                    complexity,
1237                    reason: format!(
1238                        "{:?} task → {} (cold start, vision fallback)",
1239                        complexity, model.name
1240                    ),
1241                    strategy: RoutingStrategy::SchemaBased,
1242                    predicted_quality: 0.5,
1243                    fallbacks: vec![],
1244                    context_length: model.context_length,
1245                    needs_compaction: false,
1246                };
1247            }
1248        }
1249
1250        if self.config.quality_first_cold_start {
1251            let required_caps = complexity.required_capabilities();
1252            if let Some(model) = registry
1253                .list()
1254                .into_iter()
1255                .filter(|model| {
1256                    model.available
1257                        && required_caps.iter().all(|cap| model.has_capability(*cap))
1258                        && self.is_trusted_quality_remote(model)
1259                })
1260                .max_by(|a, b| {
1261                    self.schema_quality_estimate(a)
1262                        .partial_cmp(&self.schema_quality_estimate(b))
1263                        .unwrap_or(std::cmp::Ordering::Equal)
1264                })
1265            {
1266                return AdaptiveRoutingDecision {
1267                    model_id: model.id.clone(),
1268                    model_name: model.name.clone(),
1269                    task,
1270                    complexity,
1271                    reason: format!(
1272                        "{:?} task → {} (quality-first cold start)",
1273                        complexity, model.name
1274                    ),
1275                    strategy: RoutingStrategy::SchemaBased,
1276                    predicted_quality: self.schema_quality_estimate(model),
1277                    fallbacks: vec![],
1278                    context_length: model.context_length,
1279                    needs_compaction: false,
1280                };
1281            }
1282        }
1283
1284        // Fall back to the old complexity-based defaults
1285        let model_name = match complexity {
1286            TaskComplexity::Simple => "Qwen3-0.6B",
1287            TaskComplexity::Medium => "Qwen3-1.7B",
1288            TaskComplexity::Code => "Qwen3-4B",
1289            TaskComplexity::Complex => &self.hw.recommended_model,
1290        };
1291
1292        let model_id = registry
1293            .find_by_name(model_name)
1294            .map(|m| m.id.clone())
1295            .unwrap_or_else(|| model_name.to_string());
1296
1297        let context_length = registry
1298            .find_by_name(model_name)
1299            .map(|m| m.context_length)
1300            .unwrap_or(0);
1301
1302        AdaptiveRoutingDecision {
1303            model_id,
1304            model_name: model_name.to_string(),
1305            task,
1306            complexity,
1307            reason: format!(
1308                "{:?} task → {} (cold start, no candidates)",
1309                complexity, model_name
1310            ),
1311            strategy: RoutingStrategy::SchemaBased,
1312            predicted_quality: 0.5,
1313            fallbacks: vec![],
1314            context_length,
1315            needs_compaction: false,
1316        }
1317    }
1318}
1319
1320/// Map a caller-supplied [`crate::TaskHint`] to the engine's
1321/// [`InferenceTask`] enum. The intent surface uses the higher-level
1322/// hint vocabulary; the router operates on InferenceTask. Every
1323/// TaskHint variant maps to a distinct InferenceTask — variants that
1324/// would have silently collapsed to `Generate` were cut from the MVP.
1325fn task_hint_to_inference_task(hint: crate::intent::TaskHint) -> InferenceTask {
1326    use crate::intent::TaskHint;
1327    match hint {
1328        TaskHint::Chat => InferenceTask::Generate,
1329        TaskHint::Classify => InferenceTask::Classify,
1330        TaskHint::Reasoning => InferenceTask::Reasoning,
1331        TaskHint::Code => InferenceTask::Code,
1332    }
1333}
1334
1335/// Sample from a Beta(alpha, beta) distribution.
1336///
1337/// Uses the Gamma distribution method: if X ~ Gamma(alpha, 1) and Y ~ Gamma(beta, 1),
1338/// then X / (X + Y) ~ Beta(alpha, beta).
1339///
1340/// For Gamma sampling, uses Marsaglia and Tsang's method for alpha >= 1,
1341/// and Ahrens-Dieter for alpha < 1.
1342fn sample_beta(rng: &mut impl Rng, alpha: f64, beta: f64) -> f64 {
1343    let x = sample_gamma(rng, alpha);
1344    let y = sample_gamma(rng, beta);
1345    if x + y == 0.0 {
1346        0.5 // degenerate case
1347    } else {
1348        x / (x + y)
1349    }
1350}
1351
1352/// Sample from Gamma(shape, 1) using Marsaglia-Tsang for shape >= 1,
1353/// with Ahrens-Dieter boost for shape < 1.
1354fn sample_gamma(rng: &mut impl Rng, shape: f64) -> f64 {
1355    if shape < 1.0 {
1356        // Ahrens-Dieter: Gamma(a) = Gamma(a+1) * U^(1/a)
1357        let u: f64 = rng.random();
1358        return sample_gamma(rng, shape + 1.0) * u.powf(1.0 / shape);
1359    }
1360
1361    // Marsaglia-Tsang method for shape >= 1
1362    let d = shape - 1.0 / 3.0;
1363    let c = 1.0 / (9.0 * d).sqrt();
1364
1365    loop {
1366        let x: f64 = loop {
1367            let n = sample_standard_normal(rng);
1368            if 1.0 + c * n > 0.0 {
1369                break n;
1370            }
1371        };
1372
1373        let v = (1.0 + c * x).powi(3);
1374        let u: f64 = rng.random();
1375
1376        if u < 1.0 - 0.0331 * x.powi(4) {
1377            return d * v;
1378        }
1379        if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
1380            return d * v;
1381        }
1382    }
1383}
1384
1385/// Sample from standard normal N(0,1) using Box-Muller transform.
1386fn sample_standard_normal(rng: &mut impl Rng) -> f64 {
1387    let u1: f64 = rng.random();
1388    let u2: f64 = rng.random();
1389    (-2.0 * u1.max(1e-300).ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
1390}
1391
1392#[cfg(test)]
1393mod tests {
1394    use super::*;
1395    use crate::outcome::InferredOutcome;
1396
1397    fn test_hw() -> HardwareInfo {
1398        HardwareInfo {
1399            os: "macos".into(),
1400            arch: "aarch64".into(),
1401            cpu_cores: 10,
1402            total_ram_mb: 32768,
1403            gpu_backend: crate::hardware::GpuBackend::Metal,
1404            gpu_memory_mb: Some(28672),
1405            gpu_devices: Vec::new(),
1406            recommended_model: "Qwen3-8B".into(),
1407            recommended_context: 8192,
1408            max_model_mb: 18000, // headroom above 30B-A3B's 17000MB
1409        }
1410    }
1411
1412    fn test_registry() -> UnifiedRegistry {
1413        let tmp = std::path::PathBuf::from("/tmp/car-test-adaptive-router");
1414        unsafe {
1415            std::env::set_var("OPENAI_API_KEY", "test-openai-key");
1416        }
1417        // Create fake model dirs so the registry marks them as available
1418        for name in &[
1419            "Qwen3-0.6B",
1420            "Qwen3-1.7B",
1421            "Qwen3-4B",
1422            "Qwen3-8B",
1423            "Qwen3-Embedding-0.6B",
1424        ] {
1425            let dir = tmp.join(name);
1426            let _ = std::fs::create_dir_all(&dir);
1427            let _ = std::fs::write(dir.join("model.gguf"), b"fake");
1428            let _ = std::fs::write(dir.join("tokenizer.json"), b"{}");
1429        }
1430        let mut reg = UnifiedRegistry::new(tmp);
1431        reg.register(ModelSchema {
1432            id: "openai/gpt-5.4-mini:latest".into(),
1433            name: "gpt-5.4-mini".into(),
1434            provider: "openai".into(),
1435            family: "gpt-5.4".into(),
1436            version: "latest".into(),
1437            capabilities: vec![
1438                ModelCapability::Generate,
1439                ModelCapability::Code,
1440                ModelCapability::Reasoning,
1441                ModelCapability::ToolUse,
1442                ModelCapability::MultiToolCall,
1443                ModelCapability::Vision,
1444            ],
1445            context_length: 128_000,
1446            param_count: "api".into(),
1447            quantization: None,
1448            performance: Default::default(),
1449            cost: Default::default(),
1450            source: crate::schema::ModelSource::RemoteApi {
1451                endpoint: "https://api.openai.com/v1".into(),
1452                api_key_env: "OPENAI_API_KEY".into(),
1453                api_key_envs: vec![],
1454                api_version: None,
1455                protocol: crate::schema::ApiProtocol::OpenAiCompat,
1456            },
1457            tags: vec!["trusted-remote".into()],
1458            supported_params: vec![],
1459            public_benchmarks: vec![],
1460            available: true,
1461        });
1462        reg
1463    }
1464
1465    #[test]
1466    fn routes_simple_to_trusted_remote_during_cold_start() {
1467        let router = AdaptiveRouter::new(
1468            test_hw(),
1469            RoutingConfig {
1470                prior_strength: 100.0, // strong prior = exploit Phase 2 scores (deterministic-ish)
1471                ..Default::default()
1472            },
1473        );
1474        let reg = test_registry();
1475        let tracker = OutcomeTracker::new();
1476
1477        let decision = router.route("What is 2+2?", &reg, &tracker);
1478        assert_eq!(decision.complexity, TaskComplexity::Simple);
1479        assert_eq!(decision.strategy, RoutingStrategy::SchemaBased);
1480        // On cold start, quality-critical tasks should stay on trusted remote models.
1481        let schema = reg
1482            .find_by_name(&decision.model_name)
1483            .expect("selected model should exist in registry");
1484        assert!(
1485            !schema.is_local(),
1486            "simple task should route to trusted remote model during cold start"
1487        );
1488        assert!(matches!(
1489            schema.provider.as_str(),
1490            "openai" | "anthropic" | "google"
1491        ));
1492    }
1493
1494    #[test]
1495    fn routes_code_to_code_capable_remote_during_cold_start() {
1496        let router = AdaptiveRouter::new(
1497            test_hw(),
1498            RoutingConfig {
1499                prior_strength: 100.0, // strong prior = exploit Phase 2 scores (deterministic-ish)
1500                ..Default::default()
1501            },
1502        );
1503        let reg = test_registry();
1504        let tracker = OutcomeTracker::new();
1505
1506        let decision = router.route(
1507            "Fix this function:\n```rust\nfn main() {}\n```",
1508            &reg,
1509            &tracker,
1510        );
1511        assert_eq!(decision.complexity, TaskComplexity::Code);
1512        assert_eq!(decision.task, InferenceTask::Code);
1513        // Must select a code-capable local model (not 0.6B which lacks Code)
1514        let schema = reg
1515            .find_by_name(&decision.model_name)
1516            .expect("model should exist");
1517        assert!(
1518            schema.has_capability(ModelCapability::Code),
1519            "selected model must support Code"
1520        );
1521        assert!(!schema.is_local(), "should route to trusted remote model");
1522    }
1523
1524    #[test]
1525    fn routes_images_to_vision_capable_model() {
1526        let router = AdaptiveRouter::new(
1527            test_hw(),
1528            RoutingConfig {
1529                prior_strength: 100.0,
1530                ..Default::default()
1531            },
1532        );
1533        let mut reg = test_registry();
1534        let tracker = OutcomeTracker::new();
1535
1536        reg.register(ModelSchema {
1537            id: "mlx-vlm/qwen3-vl-2b:bf16".into(),
1538            name: "Qwen3-VL-2B-mlx-vlm".into(),
1539            provider: "qwen".into(),
1540            family: "qwen3-vl".into(),
1541            version: "bf16".into(),
1542            capabilities: vec![
1543                ModelCapability::Generate,
1544                ModelCapability::Vision,
1545                ModelCapability::Grounding,
1546            ],
1547            context_length: 262_144,
1548            param_count: "2B".into(),
1549            quantization: None,
1550            performance: Default::default(),
1551            cost: Default::default(),
1552            source: crate::schema::ModelSource::Mlx {
1553                hf_repo: "Qwen/Qwen3-VL-2B-Instruct".into(),
1554                hf_weight_file: None,
1555            },
1556            tags: vec!["vision".into(), "mlx-vlm-cli".into()],
1557            supported_params: vec![],
1558            public_benchmarks: vec![],
1559            available: true,
1560        });
1561
1562        let decision = router.route_with_vision("What is in this image?", &reg, &tracker, false);
1563        let schema = reg
1564            .find_by_name(&decision.model_name)
1565            .expect("model should exist");
1566        assert!(
1567            schema.has_capability(ModelCapability::Vision),
1568            "selected model must support Vision"
1569        );
1570    }
1571
1572    #[test]
1573    fn profile_based_routing_favors_proven_model() {
1574        let router = AdaptiveRouter::new(
1575            test_hw(),
1576            RoutingConfig {
1577                prior_strength: 0.5, // weak prior, let observed data dominate
1578                min_observations: 3,
1579                ..Default::default()
1580            },
1581        );
1582        let reg = test_registry();
1583        let mut tracker = OutcomeTracker::new();
1584
1585        // Build a strong profile for Qwen3-8B on code tasks (fast + high quality)
1586        let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
1587        for _ in 0..20 {
1588            let trace = tracker.record_start(qwen_8b_id, InferenceTask::Code, "test");
1589            tracker.record_complete(&trace, 500, 100, 50);
1590            tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1591        }
1592
1593        // Thompson Sampling is stochastic — run multiple times, 8B should win majority
1594        let mut wins = 0;
1595        for _ in 0..20 {
1596            let decision = router.route("Fix this bug in the parser", &reg, &tracker);
1597            assert_eq!(decision.complexity, TaskComplexity::Code);
1598            if decision.model_id == qwen_8b_id {
1599                wins += 1;
1600            }
1601        }
1602        assert!(
1603            wins >= 12,
1604            "proven model won only {wins}/20 times (expected >= 12)"
1605        );
1606    }
1607
1608    #[test]
1609    fn proven_local_model_can_displace_bootstrap_remote() {
1610        let router = AdaptiveRouter::new(
1611            test_hw(),
1612            RoutingConfig {
1613                prior_strength: 100.0,
1614                bootstrap_min_task_observations: 6,
1615                bootstrap_quality_floor: 0.8,
1616                ..Default::default()
1617            },
1618        );
1619        let reg = test_registry();
1620        let mut tracker = OutcomeTracker::new();
1621
1622        let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
1623        for _ in 0..12 {
1624            let trace = tracker.record_start(qwen_8b_id, InferenceTask::Generate, "test");
1625            tracker.record_complete(&trace, 300, 50, 20);
1626            tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1627        }
1628
1629        let mut local_wins = 0;
1630        for _ in 0..20 {
1631            let decision = router.route("Summarize this design decision.", &reg, &tracker);
1632            let schema = reg.get(&decision.model_id).expect("selected model should exist");
1633            if schema.is_local() {
1634                local_wins += 1;
1635            }
1636        }
1637
1638        assert!(
1639            local_wins >= 12,
1640            "proven local model won only {local_wins}/20 times (expected >= 12)"
1641        );
1642    }
1643
1644    #[test]
1645    fn benchmark_prior_informs_background_routing() {
1646        let router = AdaptiveRouter::new(
1647            test_hw(),
1648            RoutingConfig {
1649                prior_strength: 100.0,
1650                ..Default::default()
1651            },
1652        );
1653        let reg = test_registry();
1654        let mut tracker = OutcomeTracker::new();
1655        let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1656        profile.ema_quality = 0.95;
1657        tracker.import_profiles(vec![profile]);
1658
1659        let decision = router.route_context_aware(
1660            "Write a Python fibonacci function.",
1661            128,
1662            &reg,
1663            &tracker,
1664            false,
1665            false,
1666            RoutingWorkload::Background,
1667        );
1668
1669        let schema = reg.get(&decision.model_id).expect("selected model should exist");
1670        assert!(schema.is_local(), "background routing should allow strong local benchmark priors to win");
1671    }
1672
1673    #[test]
1674    fn task_specific_benchmark_prior_informs_cold_start_routing() {
1675        let router = AdaptiveRouter::new(
1676            test_hw(),
1677            RoutingConfig {
1678                prior_strength: 100.0,
1679                ..Default::default()
1680            },
1681        );
1682        let reg = test_registry();
1683        let mut tracker = OutcomeTracker::new();
1684        let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1685        profile.task_stats.insert(
1686            crate::outcome::InferenceTask::Code.to_string(),
1687            crate::outcome::TaskStats {
1688                ema_quality: 0.95,
1689                ..Default::default()
1690            },
1691        );
1692        tracker.import_profiles(vec![profile]);
1693
1694        let decision = router.route_context_aware(
1695            "Write a Python fibonacci function.",
1696            128,
1697            &reg,
1698            &tracker,
1699            false,
1700            false,
1701            RoutingWorkload::Background,
1702        );
1703
1704        let schema = reg.get(&decision.model_id).expect("selected model should exist");
1705        assert!(
1706            schema.is_local(),
1707            "background routing should use task-specific cold-start priors for local code models"
1708        );
1709    }
1710
1711    #[test]
1712    fn task_specific_latency_prior_affects_cold_start_score() {
1713        let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1714        let reg = test_registry();
1715        let model = reg
1716            .get("qwen/qwen3-8b:q4_k_m")
1717            .expect("local test model should exist");
1718
1719        let mut fast_tracker = OutcomeTracker::new();
1720        let mut fast_profile = crate::outcome::ModelProfile::new(model.id.clone());
1721        fast_profile.task_stats.insert(
1722            crate::outcome::InferenceTask::Generate.to_string(),
1723            crate::outcome::TaskStats {
1724                ema_quality: 0.95,
1725                avg_latency_ms: 1200.0,
1726                ..Default::default()
1727            },
1728        );
1729        fast_tracker.import_profiles(vec![fast_profile]);
1730
1731        let mut slow_tracker = OutcomeTracker::new();
1732        let mut slow_profile = crate::outcome::ModelProfile::new(model.id.clone());
1733        slow_profile.task_stats.insert(
1734            crate::outcome::InferenceTask::Generate.to_string(),
1735            crate::outcome::TaskStats {
1736                ema_quality: 0.95,
1737                avg_latency_ms: 120_000.0,
1738                ..Default::default()
1739            },
1740        );
1741        slow_tracker.import_profiles(vec![slow_profile]);
1742
1743        let fast_score = router.score_model(
1744            model,
1745            InferenceTask::Generate,
1746            &fast_tracker,
1747            RoutingWorkload::Interactive,
1748        );
1749        let slow_score = router.score_model(
1750            model,
1751            InferenceTask::Generate,
1752            &slow_tracker,
1753            RoutingWorkload::Interactive,
1754        );
1755
1756        assert!(
1757            fast_score > slow_score,
1758            "faster task latency prior should improve cold-start score ({fast_score} <= {slow_score})"
1759        );
1760    }
1761
1762    #[test]
1763    fn interactive_workload_keeps_remote_bootstrap_bias() {
1764        let router = AdaptiveRouter::new(
1765            test_hw(),
1766            RoutingConfig {
1767                prior_strength: 100.0,
1768                ..Default::default()
1769            },
1770        );
1771        let reg = test_registry();
1772        let mut tracker = OutcomeTracker::new();
1773        let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1774        profile.ema_quality = 0.95;
1775        tracker.import_profiles(vec![profile]);
1776
1777        let decision = router.route_context_aware(
1778            "Write a Python fibonacci function.",
1779            128,
1780            &reg,
1781            &tracker,
1782            false,
1783            false,
1784            RoutingWorkload::Interactive,
1785        );
1786
1787        let schema = reg.get(&decision.model_id).expect("selected model should exist");
1788        assert!(
1789            !schema.is_local(),
1790            "interactive routing should still prefer trusted remote models during cold start"
1791        );
1792    }
1793
1794    #[test]
1795    fn fallback_chain_has_alternatives() {
1796        let router = AdaptiveRouter::new(
1797            test_hw(),
1798            RoutingConfig {
1799                prior_strength: 100.0, // strong prior = exploit Phase 2 scores (deterministic-ish)
1800                ..Default::default()
1801            },
1802        );
1803        let reg = test_registry();
1804        let tracker = OutcomeTracker::new();
1805
1806        let decision = router.route("Analyze the architecture trade-offs", &reg, &tracker);
1807        assert!(!decision.fallbacks.is_empty());
1808        // Primary should not appear in fallbacks
1809        assert!(!decision.fallbacks.contains(&decision.model_id));
1810    }
1811
1812    #[test]
1813    fn latency_scoring_is_consistent() {
1814        // Verify that schema and observed latency produce comparable scores
1815        let router = AdaptiveRouter::with_default_config(test_hw());
1816
1817        // A model at 25 TPS → ~200/25*1000 = 8000ms → score = 1 - 8000/10000 = 0.2
1818        let schema_score = router.latency_ms_to_score(AdaptiveRouter::tps_to_latency_ms(25.0));
1819        // Same model observed at 8000ms → same formula
1820        let observed_score = router.latency_ms_to_score(8000.0);
1821        assert!(
1822            (schema_score - observed_score).abs() < 0.01,
1823            "schema ({schema_score}) and observed ({observed_score}) should match"
1824        );
1825    }
1826
1827    #[test]
1828    fn complexity_assessment() {
1829        assert_eq!(
1830            TaskComplexity::assess("What is the capital of France?"),
1831            TaskComplexity::Simple
1832        );
1833        assert_eq!(
1834            TaskComplexity::assess("Fix this broken test"),
1835            TaskComplexity::Code
1836        );
1837        assert_eq!(
1838            TaskComplexity::assess("Analyze the trade-offs between A and B"),
1839            TaskComplexity::Complex
1840        );
1841    }
1842
1843    #[test]
1844    fn beta_sampling_produces_valid_values() {
1845        let mut rng = rand::rng();
1846        // Sample 100 times from Beta(2, 5) — should be in [0, 1]
1847        for _ in 0..100 {
1848            let s = sample_beta(&mut rng, 2.0, 5.0);
1849            assert!(s >= 0.0 && s <= 1.0, "sample {s} out of [0,1] range");
1850        }
1851        // Beta(1, 1) = Uniform(0, 1) — mean should be ~0.5
1852        let samples: Vec<f64> = (0..1000).map(|_| sample_beta(&mut rng, 1.0, 1.0)).collect();
1853        let mean = samples.iter().sum::<f64>() / samples.len() as f64;
1854        assert!(
1855            (mean - 0.5).abs() < 0.05,
1856            "Beta(1,1) mean {mean} should be ~0.5"
1857        );
1858    }
1859
1860    #[test]
1861    fn thompson_sampling_converges_to_best() {
1862        // A model with strong observed success should win most of the time
1863        let router = AdaptiveRouter::new(
1864            test_hw(),
1865            RoutingConfig {
1866                prior_strength: 1.0, // weak prior, let observations dominate
1867                ..Default::default()
1868            },
1869        );
1870        let reg = test_registry();
1871        let mut tracker = OutcomeTracker::new();
1872
1873        // Give Qwen3-4B 20 successes (strong signal)
1874        let qwen_4b_id = "qwen/qwen3-4b:q4_k_m";
1875        for _ in 0..20 {
1876            let trace = tracker.record_start(qwen_4b_id, InferenceTask::Code, "test");
1877            tracker.record_complete(&trace, 500, 100, 50);
1878            tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1879        }
1880
1881        // Route 20 code tasks — strong model should win the majority
1882        let mut wins = 0;
1883        for _ in 0..20 {
1884            let decision = router.route("Fix this parser bug", &reg, &tracker);
1885            if decision.model_id == qwen_4b_id {
1886                wins += 1;
1887            }
1888        }
1889        assert!(
1890            wins >= 14,
1891            "strong model won only {wins}/20 times (expected >= 14)"
1892        );
1893    }
1894
1895    // ----- Intent surface (parslee-ai/car-releases#18) -----
1896
1897    #[test]
1898    fn intent_require_filters_out_models_lacking_capability() {
1899        // Asking for vision when no candidate has it should fall to
1900        // the cold-start decision rather than scoring incompatible
1901        // candidates. The fixture registry has no vision-capable
1902        // local models.
1903        let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1904        let reg = test_registry();
1905        let tracker = OutcomeTracker::new();
1906
1907        let intent = crate::intent::IntentHint {
1908            require: vec![ModelCapability::Vision],
1909            ..Default::default()
1910        };
1911        let decision = router.route_with_intent("hello", &reg, &tracker, &intent);
1912
1913        // When require filters out every candidate, the router falls
1914        // to the schema-based cold-start path. Asserting the strategy
1915        // (not just non-empty model_id) catches a regression where
1916        // future code might silently include filtered candidates.
1917        assert_eq!(
1918            decision.strategy,
1919            RoutingStrategy::SchemaBased,
1920            "require=[vision] with no vision-capable candidates must drop to schema cold-start"
1921        );
1922    }
1923
1924    #[test]
1925    fn intent_default_does_not_override_task_or_caps() {
1926        // Thompson sampling makes per-call model selection
1927        // non-deterministic, so we can't compare model_ids directly.
1928        // What we can assert deterministically: a default IntentHint
1929        // must not change the task selection or the capability
1930        // requirements — those are functions of the prompt only when
1931        // no hint is supplied.
1932        let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1933        let reg = test_registry();
1934        let tracker = OutcomeTracker::new();
1935
1936        let baseline = router.route("write a haiku", &reg, &tracker);
1937        let with_default = router.route_with_intent(
1938            "write a haiku",
1939            &reg,
1940            &tracker,
1941            &crate::intent::IntentHint::default(),
1942        );
1943
1944        assert_eq!(
1945            baseline.task, with_default.task,
1946            "default IntentHint must not change the prompt-derived task"
1947        );
1948        assert_eq!(
1949            baseline.complexity, with_default.complexity,
1950            "default IntentHint must not change the prompt-derived complexity"
1951        );
1952    }
1953
1954    #[test]
1955    fn intent_task_hint_overrides_prompt_complexity() {
1956        // A short prompt that complexity assessment would route as
1957        // Generate should land on Reasoning when the intent says so.
1958        let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1959        let reg = test_registry();
1960        let tracker = OutcomeTracker::new();
1961
1962        let hint = crate::intent::IntentHint {
1963            task: Some(crate::intent::TaskHint::Reasoning),
1964            ..Default::default()
1965        };
1966        let decision = router.route_with_intent("hi", &reg, &tracker, &hint);
1967
1968        assert_eq!(
1969            decision.task,
1970            InferenceTask::Reasoning,
1971            "TaskHint::Reasoning should override the prompt-derived task"
1972        );
1973    }
1974}