Skip to main content

car_inference/
adaptive_router.rs

1//! Adaptive model routing — three-phase routing with learned performance profiles.
2//!
3//! Phase 1: **Filter** — hard constraints (capability, availability, memory, cost).
4//! Phase 2: **Score** — blend quality, latency, and cost using observed profiles
5//!          or schema defaults on cold start.
6//! Phase 3: **Select** — Thompson Sampling (Beta distribution per model) for
7//!          natural exploration-exploitation balance. Models with fewer observations
8//!          have wider distributions, giving them chances to prove themselves.
9//!
10//! Replaces the hardcoded `ModelRouter` from `router.rs`.
11
12use rand::Rng;
13use serde::{Deserialize, Serialize};
14
15use std::sync::{Arc, Mutex};
16
17use crate::hardware::HardwareInfo;
18use crate::outcome::{InferenceTask, OutcomeTracker};
19use crate::registry::UnifiedRegistry;
20use crate::routing_ext::CircuitBreakerRegistry;
21use crate::schema::{ModelCapability, ModelSchema};
22use crate::tasks::RoutingWorkload;
23
24/// Prompt complexity assessment (migrated from router.rs).
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case")]
27pub enum TaskComplexity {
28    Simple,
29    Medium,
30    Code,
31    Complex,
32}
33
34impl TaskComplexity {
35    /// Assess complexity of a prompt string.
36    ///
37    /// Uses tree-sitter AST parsing (when the `ast` feature is enabled) for
38    /// accurate code detection: if a code block parses successfully as any
39    /// supported language, it's definitively code. Falls back to keyword
40    /// heuristics for prompts that mention code without containing code blocks.
41    pub fn assess(prompt: &str) -> Self {
42        let lower = prompt.to_lowercase();
43        let word_count = prompt.split_whitespace().count();
44        let estimated_tokens = (word_count as f64 * 1.3) as usize;
45
46        let has_code = Self::detect_code(prompt);
47
48        let repair_markers = [
49            "fix", "repair", "debug", "refactor", "broken", "failing", "error", "bug",
50        ];
51        let has_repair = repair_markers.iter().any(|m| lower.contains(m));
52
53        let reasoning_markers = [
54            "analyze",
55            "compare",
56            "explain why",
57            "step by step",
58            "think through",
59            "evaluate",
60            "trade-off",
61            "tradeoff",
62            "pros and cons",
63            "architecture",
64            "design",
65            "strategy",
66            "optimize",
67            "comprehensive",
68        ];
69        let has_reasoning = reasoning_markers.iter().any(|m| lower.contains(m));
70
71        let simple_patterns = [
72            "what is",
73            "who is",
74            "when did",
75            "where is",
76            "how many",
77            "yes or no",
78            "true or false",
79            "name the",
80            "list the",
81            "define ",
82        ];
83        let is_simple = simple_patterns.iter().any(|p| lower.contains(p));
84
85        if has_code || has_repair {
86            TaskComplexity::Code
87        } else if has_reasoning || estimated_tokens > 500 {
88            TaskComplexity::Complex
89        } else if is_simple || estimated_tokens < 30 {
90            TaskComplexity::Simple
91        } else {
92            TaskComplexity::Medium
93        }
94    }
95
96    /// Detect whether a prompt contains code.
97    ///
98    /// With `ast` feature: extracts code blocks (``` delimited), attempts to
99    /// parse each with tree-sitter. If any parses into symbols, it's real code.
100    /// Without `ast` feature: falls back to keyword heuristics.
101    fn detect_code(prompt: &str) -> bool {
102        // First try AST-based detection on code blocks
103        #[cfg(feature = "ast")]
104        {
105            if let Some(is_code) = Self::detect_code_ast(prompt) {
106                return is_code;
107            }
108        }
109
110        // Fallback: keyword heuristics
111        let code_markers = [
112            "```",
113            "fn ",
114            "def ",
115            "class ",
116            "import ",
117            "require(",
118            "async fn",
119            "pub fn",
120            "function ",
121            "const ",
122            "let ",
123            "var ",
124            "#include",
125            "package ",
126            "impl ",
127        ];
128        code_markers.iter().any(|m| prompt.contains(m))
129    }
130
131    /// AST-based code detection: parse code blocks with tree-sitter.
132    /// Returns Some(true) if code found, Some(false) if blocks exist but
133    /// don't parse, None if no code blocks found (fall through to heuristics).
134    #[cfg(feature = "ast")]
135    fn detect_code_ast(prompt: &str) -> Option<bool> {
136        // Extract code blocks between ``` markers
137        let mut blocks = Vec::new();
138        let mut rest = prompt;
139        while let Some(start) = rest.find("```") {
140            let after_fence = &rest[start + 3..];
141            // Skip optional language tag on the opening fence
142            let code_start = after_fence.find('\n').map(|i| i + 1).unwrap_or(0);
143            if let Some(end) = after_fence[code_start..].find("```") {
144                blocks.push(&after_fence[code_start..code_start + end]);
145                rest = &after_fence[code_start + end + 3..];
146            } else {
147                break;
148            }
149        }
150
151        if blocks.is_empty() {
152            return None; // No code blocks — let heuristics decide
153        }
154
155        // Try to parse each block with tree-sitter
156        let languages = [
157            car_ast::Language::Rust,
158            car_ast::Language::Python,
159            car_ast::Language::TypeScript,
160            car_ast::Language::JavaScript,
161            car_ast::Language::Go,
162        ];
163
164        for block in &blocks {
165            let trimmed = block.trim();
166            if trimmed.is_empty() {
167                continue;
168            }
169
170            for lang in &languages {
171                if let Some(parsed) = car_ast::parse(trimmed, *lang) {
172                    // If it parsed into any symbols, it's definitely code
173                    if !parsed.symbols.is_empty() {
174                        return Some(true);
175                    }
176                }
177            }
178        }
179
180        // Had code blocks but none parsed into symbols — could be
181        // pseudocode, output, or unsupported language
182        Some(false)
183    }
184
185    /// Map complexity to required capabilities.
186    pub fn required_capabilities(&self) -> Vec<ModelCapability> {
187        match self {
188            TaskComplexity::Simple => vec![ModelCapability::Generate],
189            TaskComplexity::Medium => vec![ModelCapability::Generate],
190            TaskComplexity::Code => vec![ModelCapability::Code],
191            TaskComplexity::Complex => vec![ModelCapability::Reasoning],
192        }
193    }
194
195    /// Map complexity to the InferenceTask type.
196    pub fn inference_task(&self) -> InferenceTask {
197        match self {
198            TaskComplexity::Simple | TaskComplexity::Medium => InferenceTask::Generate,
199            TaskComplexity::Code => InferenceTask::Code,
200            TaskComplexity::Complex => InferenceTask::Reasoning,
201        }
202    }
203}
204
205/// Configuration for routing behavior.
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct RoutingConfig {
208    /// Minimum observations before trusting a model's profile over schema defaults.
209    pub min_observations: u64,
210    /// Scoring weights (must sum to 1.0).
211    pub quality_weight: f64,
212    pub latency_weight: f64,
213    pub cost_weight: f64,
214    /// Hard constraint: maximum latency budget in ms.
215    pub max_latency_ms: Option<u64>,
216    /// Hard constraint: maximum cost per call in USD.
217    pub max_cost_usd: Option<f64>,
218    /// Prefer local models over remote (all else being equal).
219    pub prefer_local: bool,
220    /// Thompson Sampling prior strength. Higher = more weight on the Phase 2 score
221    /// as a prior, lower = more influenced by observed outcomes.
222    /// Equivalent to the number of "virtual" observations from the prior.
223    pub prior_strength: f64,
224    /// Prefer trusted remote models for quality-critical tasks until local models
225    /// have enough task-specific evidence to be promoted.
226    pub quality_first_cold_start: bool,
227    /// Minimum task-specific observations required before a local model can
228    /// compete with trusted remote models during cold start.
229    pub bootstrap_min_task_observations: u64,
230    /// Minimum task-specific EMA quality required before a local model can
231    /// displace trusted remote models during cold start.
232    pub bootstrap_quality_floor: f64,
233}
234
235impl Default for RoutingConfig {
236    fn default() -> Self {
237        Self {
238            min_observations: 2,
239            quality_weight: 0.45,
240            latency_weight: 0.4,
241            cost_weight: 0.15,
242            max_latency_ms: None,
243            max_cost_usd: None,
244            prefer_local: true,
245            prior_strength: 2.0,
246            quality_first_cold_start: true,
247            bootstrap_min_task_observations: 8,
248            bootstrap_quality_floor: 0.8,
249        }
250    }
251}
252
253/// How a model was selected.
254#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
255#[serde(rename_all = "snake_case")]
256pub enum RoutingStrategy {
257    /// Using declared schema capabilities (no observed data).
258    SchemaBased,
259    /// Using observed performance profiles (exploitation).
260    ProfileBased,
261    /// Deliberately trying an under-tested model (exploration).
262    Exploration,
263    /// User explicitly specified the model.
264    Explicit,
265}
266
267/// The result of adaptive routing.
268#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct AdaptiveRoutingDecision {
270    /// Selected model id.
271    pub model_id: String,
272    /// Selected model name (display).
273    pub model_name: String,
274    /// Task type.
275    pub task: InferenceTask,
276    /// Assessed complexity.
277    pub complexity: TaskComplexity,
278    /// Human-readable reason.
279    pub reason: String,
280    /// How the model was selected.
281    pub strategy: RoutingStrategy,
282    /// Predicted quality (0.0-1.0).
283    pub predicted_quality: f64,
284    /// Fallback chain (ordered list of alternative model ids).
285    pub fallbacks: Vec<String>,
286    /// Context window of the selected model (tokens). 0 = unknown.
287    pub context_length: usize,
288    /// Whether the prompt needs compaction to fit the selected model's context window.
289    pub needs_compaction: bool,
290}
291
292/// Adaptive router with three-phase model selection.
293pub struct AdaptiveRouter {
294    hw: HardwareInfo,
295    config: RoutingConfig,
296    /// Circuit breaker registry — blocks models after consecutive failures (#25).
297    pub circuit_breakers: Arc<Mutex<CircuitBreakerRegistry>>,
298}
299
300/// All inputs to a routing decision packed into one struct so adding
301/// a new combinator (per-tenant routing, streaming awareness, …)
302/// only adds a field rather than another `route_*` sibling method.
303/// See [`AdaptiveRouter::route_with`].
304///
305/// Use [`RouteRequest::new`] for the common defaults, then mutate
306/// just the fields the caller cares about:
307///
308/// ```ignore
309/// let decision = router.route_with(RouteRequest {
310///     has_tools: true,
311///     intent: Some(&hint),
312///     ..RouteRequest::new(prompt, &registry, &tracker)
313/// });
314/// ```
315pub struct RouteRequest<'a> {
316    pub prompt: &'a str,
317    pub registry: &'a UnifiedRegistry,
318    pub tracker: &'a OutcomeTracker,
319    /// Estimated prompt-side token count for context-window-aware
320    /// scoring. `0` skips the compaction-headroom check.
321    pub estimated_total_tokens: usize,
322    pub has_tools: bool,
323    pub has_vision: bool,
324    pub workload: RoutingWorkload,
325    /// Caller-supplied intent hint. `prefer_local: true` overrides
326    /// `workload` to [`RoutingWorkload::LocalPreferred`].
327    pub intent: Option<&'a crate::intent::IntentHint>,
328}
329
330impl<'a> RouteRequest<'a> {
331    /// Build a request with the same defaults the bare
332    /// [`AdaptiveRouter::route`] uses: interactive workload, no tools,
333    /// no vision, no context-aware sizing, no intent.
334    pub fn new(
335        prompt: &'a str,
336        registry: &'a UnifiedRegistry,
337        tracker: &'a OutcomeTracker,
338    ) -> Self {
339        Self {
340            prompt,
341            registry,
342            tracker,
343            estimated_total_tokens: 0,
344            has_tools: false,
345            has_vision: false,
346            workload: RoutingWorkload::Interactive,
347            intent: None,
348        }
349    }
350}
351
352impl AdaptiveRouter {
353    pub fn new(hw: HardwareInfo, config: RoutingConfig) -> Self {
354        let circuit_breakers = Arc::new(Mutex::new(
355            CircuitBreakerRegistry::new(3, 300), // 3 failures, 5 min cooldown
356        ));
357        Self {
358            hw,
359            config,
360            circuit_breakers,
361        }
362    }
363
364    pub fn with_default_config(hw: HardwareInfo) -> Self {
365        Self::new(hw, RoutingConfig::default())
366    }
367
368    pub fn config(&self) -> &RoutingConfig {
369        &self.config
370    }
371
372    pub fn set_config(&mut self, config: RoutingConfig) {
373        self.config = config;
374    }
375
376    /// Canonical entry point. The seven `route_*` sibling methods
377    /// each build a `RouteRequest` with their fixed defaults and
378    /// call this — adding a new combinator (per-tenant routing,
379    /// streaming awareness, …) only adds a field here, not another
380    /// public method. Closes #108.
381    pub fn route_with(&self, req: RouteRequest<'_>) -> AdaptiveRoutingDecision {
382        // Intent-hint precedence over the caller-supplied workload:
383        // `prefer_fast` wins outright (voice fast track is the most
384        // latency-sensitive path we have); `prefer_local` is the
385        // long-standing override; absent either, the caller's
386        // workload stands.
387        let workload = match req.intent {
388            Some(h) if h.prefer_fast => RoutingWorkload::Fastest,
389            Some(h) if h.prefer_local => RoutingWorkload::LocalPreferred,
390            _ => req.workload,
391        };
392        self.route_inner_with_intent(
393            req.prompt,
394            req.registry,
395            req.tracker,
396            req.has_tools,
397            req.has_vision,
398            req.estimated_total_tokens,
399            workload,
400            req.intent,
401        )
402    }
403
404    /// Route a generation request to the best model.
405    /// If `has_tools` is true, requires ToolUse capability (#13).
406    pub fn route(
407        &self,
408        prompt: &str,
409        registry: &UnifiedRegistry,
410        tracker: &OutcomeTracker,
411    ) -> AdaptiveRoutingDecision {
412        self.route_with(RouteRequest::new(prompt, registry, tracker))
413    }
414
415    /// Route an "editor" request — cheap mechanical work (context compaction,
416    /// edit materialization, title generation, replanning). Uses
417    /// [`RoutingWorkload::Background`] so cost/quality weights favour cheaper
418    /// models, and lets hosting layers express the Aider-style architect/editor
419    /// split without plumbing a raw model-name override through every layer.
420    pub fn route_editor(
421        &self,
422        prompt: &str,
423        registry: &UnifiedRegistry,
424        tracker: &OutcomeTracker,
425    ) -> AdaptiveRoutingDecision {
426        self.route_with(RouteRequest {
427            workload: RoutingWorkload::Background,
428            ..RouteRequest::new(prompt, registry, tracker)
429        })
430    }
431
432    /// Route with tool_use requirement — filters to models that support structured tool calls.
433    pub fn route_with_tools(
434        &self,
435        prompt: &str,
436        registry: &UnifiedRegistry,
437        tracker: &OutcomeTracker,
438    ) -> AdaptiveRoutingDecision {
439        self.route_with(RouteRequest {
440            has_tools: true,
441            ..RouteRequest::new(prompt, registry, tracker)
442        })
443    }
444
445    /// Route with image input requirement — filters to models that support vision.
446    pub fn route_with_vision(
447        &self,
448        prompt: &str,
449        registry: &UnifiedRegistry,
450        tracker: &OutcomeTracker,
451        has_tools: bool,
452    ) -> AdaptiveRoutingDecision {
453        self.route_with(RouteRequest {
454            has_tools,
455            has_vision: true,
456            ..RouteRequest::new(prompt, registry, tracker)
457        })
458    }
459
460    /// Route with caller-supplied intent — see [`crate::IntentHint`].
461    /// The hint can override the auto-detected `InferenceTask`, add hard
462    /// `require` capability filters on top of the prompt-derived ones,
463    /// and bias the score profile toward local models when
464    /// `prefer_local` is set.
465    pub fn route_with_intent<'a>(
466        &self,
467        prompt: &'a str,
468        registry: &'a UnifiedRegistry,
469        tracker: &'a OutcomeTracker,
470        intent: &'a crate::intent::IntentHint,
471    ) -> AdaptiveRoutingDecision {
472        self.route_with(RouteRequest {
473            intent: Some(intent),
474            ..RouteRequest::new(prompt, registry, tracker)
475        })
476    }
477
478    /// Route with context awareness — estimates prompt tokens and prefers models
479    /// whose context window can fit the full prompt without compaction.
480    pub fn route_context_aware(
481        &self,
482        prompt: &str,
483        estimated_total_tokens: usize,
484        registry: &UnifiedRegistry,
485        tracker: &OutcomeTracker,
486        has_tools: bool,
487        has_vision: bool,
488        workload: RoutingWorkload,
489    ) -> AdaptiveRoutingDecision {
490        self.route_with(RouteRequest {
491            estimated_total_tokens,
492            has_tools,
493            has_vision,
494            workload,
495            ..RouteRequest::new(prompt, registry, tracker)
496        })
497    }
498
499    /// Context-aware routing with caller-supplied intent. Same context
500    /// math as [`Self::route_context_aware`]; the intent layers on
501    /// top — task override, additional `require` filters,
502    /// `prefer_local` workload override.
503    pub fn route_context_aware_with_intent<'a>(
504        &self,
505        prompt: &'a str,
506        estimated_total_tokens: usize,
507        registry: &'a UnifiedRegistry,
508        tracker: &'a OutcomeTracker,
509        has_tools: bool,
510        has_vision: bool,
511        workload: RoutingWorkload,
512        intent: &'a crate::intent::IntentHint,
513    ) -> AdaptiveRoutingDecision {
514        self.route_with(RouteRequest {
515            estimated_total_tokens,
516            has_tools,
517            has_vision,
518            workload,
519            intent: Some(intent),
520            ..RouteRequest::new(prompt, registry, tracker)
521        })
522    }
523
524    fn route_inner_with_intent(
525        &self,
526        prompt: &str,
527        registry: &UnifiedRegistry,
528        tracker: &OutcomeTracker,
529        has_tools: bool,
530        has_vision: bool,
531        estimated_total_tokens: usize,
532        workload: RoutingWorkload,
533        intent: Option<&crate::intent::IntentHint>,
534    ) -> AdaptiveRoutingDecision {
535        let complexity = TaskComplexity::assess(prompt);
536        // Caller intent overrides the prompt-derived task when supplied.
537        let task = intent
538            .and_then(|h| h.task)
539            .map(task_hint_to_inference_task)
540            .unwrap_or_else(|| complexity.inference_task());
541        let mut required_caps = complexity.required_capabilities();
542        if let Some(hint) = intent {
543            for cap in &hint.require {
544                if !required_caps.contains(cap) {
545                    required_caps.push(*cap);
546                }
547            }
548        }
549        if has_vision {
550            required_caps.push(ModelCapability::Vision);
551        }
552        if has_tools {
553            required_caps.push(ModelCapability::ToolUse);
554            // Detect multi-step prompts that need multiple tool calls in one response.
555            // Patterns: numbered lists ("1) ... 2) ..."), multiple instructions, explicit multi-edit.
556            if Self::needs_multi_tool_call(prompt) {
557                required_caps.push(ModelCapability::MultiToolCall);
558            }
559        }
560
561        // Phase 1: Filter candidates
562        let mut candidates = self.filter_candidates(&required_caps, registry, tracker);
563
564        // Fallback: if requiring MultiToolCall eliminates all candidates, drop the
565        // requirement and let the best ToolUse model handle it with multiple round-trips.
566        if candidates.is_empty() && required_caps.contains(&ModelCapability::MultiToolCall) {
567            required_caps.retain(|c| *c != ModelCapability::MultiToolCall);
568            candidates = self.filter_candidates(&required_caps, registry, tracker);
569        }
570
571        if candidates.is_empty() {
572            // Nothing available — return the schema-based default
573            return self.cold_start_decision(complexity, task, registry, has_vision);
574        }
575
576        candidates = self.apply_quality_first_bootstrap_policy(
577            candidates,
578            task,
579            tracker,
580            has_vision,
581            has_tools,
582            workload,
583        );
584
585        // Context-aware filtering: if we know the prompt size, prefer models that fit.
586        // Phase 1b: separate candidates into "fits" and "needs compaction" groups.
587        let (fits, needs_compaction_candidates) = if estimated_total_tokens > 0 {
588            let mut fits = Vec::new();
589            let mut tight = Vec::new();
590            for m in &candidates {
591                if m.context_length == 0 || m.context_length >= estimated_total_tokens {
592                    fits.push(m.clone());
593                } else {
594                    tight.push(m.clone());
595                }
596            }
597            (fits, tight)
598        } else {
599            (candidates.clone(), Vec::new())
600        };
601
602        // Prefer models that fit; fall back to compaction-required models if none fit
603        let (scoring_candidates, compaction_needed) = if !fits.is_empty() {
604            (fits, false)
605        } else if !needs_compaction_candidates.is_empty() {
606            tracing::info!(
607                prompt_tokens = estimated_total_tokens,
608                candidates = needs_compaction_candidates.len(),
609                "no model fits full prompt — compaction will be needed"
610            );
611            (needs_compaction_candidates.clone(), true)
612        } else {
613            (candidates.clone(), false)
614        };
615
616        // Phase 2: Score candidates (with context headroom bonus)
617        let scored = self.score_candidates_context_aware(
618            &scoring_candidates,
619            task,
620            tracker,
621            estimated_total_tokens,
622            workload,
623        );
624
625        // Phase 3: Thompson Sampling selection
626        let (selected_id, strategy) = self.select_with_thompson_sampling(&scored, tracker);
627
628        // Build fallback chain: prefer models that fit, then compaction candidates
629        let mut fallbacks: Vec<String> = scored
630            .iter()
631            .filter(|(id, _)| *id != selected_id)
632            .map(|(id, _)| id.clone())
633            .collect();
634        // Add compaction candidates to the end of the fallback chain
635        if !compaction_needed {
636            for m in &needs_compaction_candidates {
637                if m.id != selected_id && !fallbacks.contains(&m.id) {
638                    fallbacks.push(m.id.clone());
639                }
640            }
641        }
642
643        let predicted_quality = scored
644            .iter()
645            .find(|(id, _)| *id == selected_id)
646            .map(|(_, score)| *score)
647            .unwrap_or(0.5);
648
649        let selected_schema = registry
650            .get(&selected_id)
651            .or_else(|| registry.find_by_name(&selected_id));
652        let model_name = selected_schema
653            .map(|m| m.name.clone())
654            .unwrap_or_else(|| selected_id.clone());
655        let context_length = selected_schema.map(|m| m.context_length).unwrap_or(0);
656
657        let needs_compact = compaction_needed
658            || (estimated_total_tokens > 0
659                && context_length > 0
660                && estimated_total_tokens > context_length);
661
662        let compaction_note = if needs_compact {
663            format!(
664                " [compaction needed: {}→{}tok]",
665                estimated_total_tokens, context_length
666            )
667        } else {
668            String::new()
669        };
670
671        let reason = format!(
672            "{:?} task → {} via {:?} (quality: {:.2}, {} candidates){}",
673            complexity,
674            model_name,
675            strategy,
676            predicted_quality,
677            scoring_candidates.len(),
678            compaction_note,
679        );
680
681        AdaptiveRoutingDecision {
682            model_id: selected_id,
683            model_name,
684            task,
685            complexity,
686            reason,
687            strategy,
688            predicted_quality,
689            fallbacks,
690            context_length,
691            needs_compaction: needs_compact,
692        }
693    }
694
695    /// Route to the best embedding model.
696    pub fn route_embedding(&self, registry: &UnifiedRegistry) -> String {
697        let embed_models = registry.query_by_capability(ModelCapability::Embed);
698        embed_models
699            .first()
700            .map(|m| m.name.clone())
701            .unwrap_or_else(|| "Qwen3-Embedding-0.6B".to_string())
702    }
703
704    /// Route to the smallest available model (for classification).
705    pub fn route_small(&self, registry: &UnifiedRegistry) -> String {
706        let gen_models = registry.query_by_capability(ModelCapability::Generate);
707        // Pick smallest by size
708        gen_models
709            .iter()
710            .filter(|m| m.is_local())
711            .min_by_key(|m| m.size_mb())
712            .map(|m| m.name.clone())
713            .unwrap_or_else(|| "Qwen3-0.6B".to_string())
714    }
715
716    // --- Internal phases ---
717
718    // --- Scoring constants ---
719
720    /// Latency ceiling: requests taking longer than this score 0.0.
721    const LATENCY_CEILING_MS: f64 = 10000.0;
722    /// TPS ceiling: models faster than this score 1.0.
723    const _TPS_CEILING: f64 = 150.0;
724    /// MoE throughput penalty for Candle: naive expert routing runs at ~10% of declared TPS.
725    const MOE_TPS_MULTIPLIER: f64 = 0.10;
726    /// MoE throughput multiplier for MLX: fused Metal kernels run at ~50% of declared TPS.
727    const MLX_MOE_TPS_MULTIPLIER: f64 = 0.50;
728    /// Cost ceiling: models costing more than this per 1K output tokens score 0.0.
729    const COST_CEILING_PER_1K: f64 = 0.1;
730    /// Local preference bonus added to the weighted score (before normalization).
731    const LOCAL_BONUS: f64 = 0.15;
732    /// Extra bonus for MLX models on Apple Silicon (stacks with LOCAL_BONUS).
733    /// MLX gets fused Metal kernels and better memory layout vs Candle on Mac.
734    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
735    const MLX_BONUS: f64 = 0.10;
736    /// Extra bonus for system-owned models — no model file on disk, no
737    /// download to provision, framework-managed memory. Currently only
738    /// `apple/foundation:default` qualifies. Tag-driven so future
739    /// system-LLM integrations (e.g. Android AICore) inherit the
740    /// scoring without a code change. Stacks with LOCAL_BONUS but
741    /// **excludes** MLX_BONUS (system models aren't MLX); the net
742    /// effect is FoundationModels ranks roughly even with a
743    /// well-warmed MLX 4B for short fast-turn tasks instead of
744    /// strictly losing because its catalog `tokens_per_second` /
745    /// `size_mb` are null.
746    const SYSTEM_LLM_BONUS: f64 = 0.12;
747
748    // --- Internal phases ---
749
750    /// Phase 1: Filter by hard constraints.
751    fn filter_candidates(
752        &self,
753        required_caps: &[ModelCapability],
754        registry: &UnifiedRegistry,
755        tracker: &OutcomeTracker,
756    ) -> Vec<ModelSchema> {
757        registry
758            .list()
759            .into_iter()
760            .filter(|m| {
761                // Must have all required capabilities
762                if !required_caps.iter().all(|c| m.has_capability(*c)) {
763                    return false;
764                }
765                // Must be available (downloaded for local, API key set for remote)
766                if !m.available {
767                    return false;
768                }
769                // Local models must fit in memory (strict: >= excludes models at the limit)
770                if m.is_local() && m.size_mb() >= self.hw.max_model_mb {
771                    return false;
772                }
773                // Hard latency constraint
774                if let Some(max) = self.config.max_latency_ms {
775                    if let Some(p50) = m.performance.latency_p50_ms {
776                        if p50 > max {
777                            return false;
778                        }
779                    }
780                }
781                // Hard cost constraint
782                if let Some(max) = self.config.max_cost_usd {
783                    if m.cost_per_1k_output() > max {
784                        return false;
785                    }
786                }
787                // Hard exclusion: prefer_local=false excludes all local models (#12)
788                if !self.config.prefer_local && m.is_local() {
789                    return false;
790                }
791                // Hard exclusion: rate-limited models excluded for this session (#13)
792                if tracker.is_excluded(&m.id) {
793                    return false;
794                }
795                // Circuit breaker: skip models with consecutive failures (#25)
796                if let Ok(mut cb) = self.circuit_breakers.lock() {
797                    if !cb.allow_request(&m.id) {
798                        tracing::debug!(model = %m.id, "skipped by circuit breaker");
799                        return false;
800                    }
801                }
802                true
803            })
804            .cloned()
805            .collect()
806    }
807
808    fn apply_quality_first_bootstrap_policy(
809        &self,
810        candidates: Vec<ModelSchema>,
811        task: InferenceTask,
812        tracker: &OutcomeTracker,
813        has_vision: bool,
814        has_tools: bool,
815        workload: RoutingWorkload,
816    ) -> Vec<ModelSchema> {
817        if !self.config.quality_first_cold_start
818            || !workload.is_latency_sensitive()
819            || !self.is_quality_critical_bootstrap_task(task, has_vision, has_tools)
820        {
821            return candidates;
822        }
823
824        let trusted_remote: Vec<ModelSchema> = candidates
825            .iter()
826            .filter(|model| self.is_trusted_quality_remote(model))
827            .cloned()
828            .collect();
829
830        if trusted_remote.is_empty() {
831            return candidates;
832        }
833
834        let proven_local: Vec<ModelSchema> = candidates
835            .iter()
836            .filter(|model| model.is_local() && self.is_local_model_proven_for_task(model, task, tracker))
837            .cloned()
838            .collect();
839
840        if !proven_local.is_empty() {
841            return proven_local;
842        }
843
844        trusted_remote
845    }
846
847    /// Phase 2: Score candidates with context awareness.
848    /// Applies a headroom bonus to models with more context window for the prompt.
849    /// When estimated_total_tokens is 0, no context bonus/penalty is applied.
850    fn score_candidates_context_aware(
851        &self,
852        candidates: &[ModelSchema],
853        task: InferenceTask,
854        tracker: &OutcomeTracker,
855        estimated_total_tokens: usize,
856        workload: RoutingWorkload,
857    ) -> Vec<(String, f64)> {
858        let mut scored: Vec<(String, f64)> = candidates
859            .iter()
860            .map(|m| {
861                let base_score = self.score_model(m, task, tracker, workload);
862                // Context headroom bonus: prefer models with more room to spare.
863                // Max bonus: 0.10 (at 4x headroom or more). No bonus if unknown.
864                let headroom_bonus = if estimated_total_tokens > 0 && m.context_length > 0 {
865                    let ratio = m.context_length as f64 / estimated_total_tokens as f64;
866                    if ratio >= 1.0 {
867                        (ratio.min(4.0) - 1.0) / 3.0 * 0.10 // 0.0 at exact fit, 0.10 at 4x
868                    } else {
869                        -0.15 // Penalty for models that require compaction
870                    }
871                } else {
872                    0.0
873                };
874                (m.id.clone(), base_score + headroom_bonus)
875            })
876            .collect();
877
878        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
879        scored
880    }
881
882    /// Score a single model. All sub-scores are in [0.0, 1.0].
883    /// Final score = weighted sum + local_bonus, so range is [0.0, ~1.15].
884    fn score_model(
885        &self,
886        model: &ModelSchema,
887        task: InferenceTask,
888        tracker: &OutcomeTracker,
889        workload: RoutingWorkload,
890    ) -> f64 {
891        let profile = tracker.profile(&model.id);
892        let schema_quality = self.schema_quality_estimate(model);
893        let schema_latency = self.schema_latency_estimate(model);
894        let (quality_weight, latency_weight, cost_weight) = workload.weights();
895
896        // Quality: blend schema estimate with observed data based on observation count.
897        // Both cold start and warm start use consistent blending.
898        let quality = match profile {
899            Some(p) if p.total_calls >= self.config.min_observations => p
900                .task_stats(task)
901                .map(|ts| ts.ema_quality)
902                .unwrap_or(p.ema_quality),
903            Some(p) if p.total_calls == 0 => p
904                .task_stats(task)
905                .map(|ts| ts.ema_quality)
906                .unwrap_or(p.ema_quality),
907            Some(p) if p.total_calls > 0 => {
908                let w = p.total_calls as f64 / self.config.min_observations as f64;
909                schema_quality * (1.0 - w) + p.ema_quality * w
910            }
911            _ => schema_quality,
912        };
913
914        // Latency: same blending as quality — don't trust a single observation more
915        // than schema estimates. This prevents routing oscillation on first few calls.
916        let latency = match profile {
917            Some(p) if p.total_calls >= self.config.min_observations => {
918                let avg = p
919                    .task_stats(task)
920                    .filter(|ts| ts.calls > 0 || ts.avg_latency_ms > 0.0)
921                    .map(|ts| ts.avg_latency_ms)
922                    .unwrap_or_else(|| p.avg_latency_ms());
923                self.latency_ms_to_score(avg)
924            }
925            Some(p) if p.total_calls == 0 => p
926                .task_stats(task)
927                .filter(|ts| ts.avg_latency_ms > 0.0)
928                .map(|ts| self.latency_ms_to_score(ts.avg_latency_ms))
929                .unwrap_or(schema_latency),
930            Some(p) if p.total_calls > 0 => {
931                let observed = self.latency_ms_to_score(
932                    p.task_stats(task)
933                        .filter(|ts| ts.calls > 0 || ts.avg_latency_ms > 0.0)
934                        .map(|ts| ts.avg_latency_ms)
935                        .unwrap_or_else(|| p.avg_latency_ms()),
936                );
937                let w = p.total_calls as f64 / self.config.min_observations as f64;
938                schema_latency * (1.0 - w) + observed * w
939            }
940            _ => schema_latency,
941        };
942
943        // Cost score (lower is better → invert)
944        let cost = if model.is_local() {
945            1.0
946        } else {
947            (1.0 - (model.cost_per_1k_output() / Self::COST_CEILING_PER_1K)).clamp(0.0, 1.0)
948        };
949
950        let local_bonus = if self.config.prefer_local && model.is_local() {
951            Self::LOCAL_BONUS
952        } else {
953            0.0
954        };
955        let workload_local_bonus = if model.is_local() {
956            workload.local_bonus()
957        } else {
958            0.0
959        };
960
961        // On Apple Silicon, prefer MLX models over Candle equivalents
962        #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
963        let mlx_bonus = if model.is_mlx() { Self::MLX_BONUS } else { 0.0 };
964        #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
965        let mlx_bonus = 0.0;
966
967        // vLLM-MLX bonus: continuous batching gives better multi-agent throughput
968        let vllm_mlx_bonus = if model.is_vllm_mlx() {
969            Self::LOCAL_BONUS + 0.05
970        } else {
971            0.0
972        };
973
974        // System-LLM bonus: catalog tag-driven so it generalizes beyond
975        // FoundationModels. Models tagged `low_latency` AND `private`
976        // are zero-cost system-owned LLMs (apple/foundation:default
977        // today; AICore-on-Android etc. in the future). They don't
978        // appear in `is_mlx()` but deserve to compete with MLX 4B on
979        // routing — the catalog's tags carry that intent and the
980        // router now honors it.
981        let system_llm_bonus = if model.tags.iter().any(|t| t == "low_latency")
982            && model.tags.iter().any(|t| t == "private")
983        {
984            Self::SYSTEM_LLM_BONUS
985        } else {
986            0.0
987        };
988
989        quality_weight * quality
990            + latency_weight * latency
991            + cost_weight * cost
992            + local_bonus
993            + workload_local_bonus
994            + mlx_bonus
995            + vllm_mlx_bonus
996            + system_llm_bonus
997    }
998
999    /// Convert latency in ms to a [0, 1] score. Used by both schema and observed paths
1000    /// so the scales are consistent (fixes Linus review issue #2).
1001    fn latency_ms_to_score(&self, ms: f64) -> f64 {
1002        (1.0 - (ms / Self::LATENCY_CEILING_MS)).clamp(0.0, 1.0)
1003    }
1004
1005    /// Convert TPS to estimated latency in ms (for a typical 200-token response).
1006    fn tps_to_latency_ms(tps: f64) -> f64 {
1007        if tps <= 0.0 {
1008            return Self::LATENCY_CEILING_MS;
1009        }
1010        // Assume ~200 tokens per response as baseline
1011        (200.0 / tps) * 1000.0
1012    }
1013
1014    /// Detect whether a prompt likely needs multiple tool calls in a single response.
1015    /// Looks for numbered lists, multiple explicit instructions, multi-edit patterns.
1016    fn needs_multi_tool_call(prompt: &str) -> bool {
1017        let lower = prompt.to_lowercase();
1018
1019        // Numbered list patterns: "1) ... 2) ..." or "1. ... 2. ..."
1020        let has_numbered_list = {
1021            let mut count = 0u32;
1022            for i in 1..=5u32 {
1023                if lower.contains(&format!("{}) ", i)) || lower.contains(&format!("{}. ", i)) {
1024                    count += 1;
1025                }
1026            }
1027            count >= 2
1028        };
1029
1030        // Explicit multi-action keywords
1031        let multi_keywords = [
1032            "multiple edits",
1033            "several changes",
1034            "three changes",
1035            "two changes",
1036            "all of the following",
1037            "each of these",
1038            "do both",
1039            "do all",
1040            "and also",
1041            "additionally",
1042            "as well as",
1043            "then also",
1044        ];
1045        let has_multi_keywords = multi_keywords.iter().any(|kw| lower.contains(kw));
1046
1047        // Bullet point lists with action verbs
1048        let bullet_actions = lower.matches("- add ").count()
1049            + lower.matches("- update ").count()
1050            + lower.matches("- change ").count()
1051            + lower.matches("- remove ").count()
1052            + lower.matches("- fix ").count()
1053            + lower.matches("- edit ").count()
1054            + lower.matches("- implement ").count()
1055            + lower.matches("- create ").count();
1056        let has_bullet_list = bullet_actions >= 2;
1057
1058        has_numbered_list || has_multi_keywords || has_bullet_list
1059    }
1060
1061    /// Schema-based quality estimate (cold start).
1062    ///
1063    /// Diminishing returns on model size — the jump from 4B to 8B matters,
1064    /// but 8B to 30B is marginal. Remote models get a conservative estimate.
1065    fn schema_quality_estimate(&self, model: &ModelSchema) -> f64 {
1066        match model.size_mb() {
1067            0 => 0.5,             // remote: unknown, conservative
1068            s if s < 1000 => 0.4, // 0.6B
1069            s if s < 2000 => 0.5, // 1.7B
1070            s if s < 3000 => 0.6, // 4B
1071            s if s < 6000 => 0.7, // 8B
1072            _ => 0.75,            // 30B+: diminishing returns
1073        }
1074    }
1075
1076    /// Schema-based latency estimate (cold start).
1077    ///
1078    /// Converts declared TPS/p50 to the same ms-based score used by observed data,
1079    /// so there's no discontinuity when the first observation arrives.
1080    fn schema_latency_estimate(&self, model: &ModelSchema) -> f64 {
1081        let is_moe = model.tags.contains(&"moe".to_string());
1082
1083        if model.is_local() {
1084            if let Some(tps) = model.performance.tokens_per_second {
1085                let effective_tps = if is_moe {
1086                    let multiplier = if model.is_mlx() {
1087                        Self::MLX_MOE_TPS_MULTIPLIER
1088                    } else {
1089                        Self::MOE_TPS_MULTIPLIER
1090                    };
1091                    tps * multiplier
1092                } else {
1093                    tps
1094                };
1095                let estimated_ms = Self::tps_to_latency_ms(effective_tps);
1096                return self.latency_ms_to_score(estimated_ms);
1097            }
1098            return 0.5; // local, no declared TPS
1099        }
1100
1101        // Remote: use declared p50 latency
1102        if let Some(p50) = model.performance.latency_p50_ms {
1103            return self.latency_ms_to_score(p50 as f64);
1104        }
1105        0.3 // remote, no declared latency
1106    }
1107
1108    fn is_quality_critical_bootstrap_task(
1109        &self,
1110        task: InferenceTask,
1111        has_vision: bool,
1112        has_tools: bool,
1113    ) -> bool {
1114        has_vision
1115            || has_tools
1116            || matches!(
1117                task,
1118                InferenceTask::Generate | InferenceTask::Code | InferenceTask::Reasoning
1119            )
1120    }
1121
1122    fn is_trusted_quality_remote(&self, model: &ModelSchema) -> bool {
1123        model.is_remote()
1124            && matches!(
1125                model.provider.as_str(),
1126                "openai" | "anthropic" | "google"
1127            )
1128            && !model.has_capability(ModelCapability::SpeechToText)
1129            && !model.has_capability(ModelCapability::TextToSpeech)
1130    }
1131
1132    fn is_local_model_proven_for_task(
1133        &self,
1134        model: &ModelSchema,
1135        task: InferenceTask,
1136        tracker: &OutcomeTracker,
1137    ) -> bool {
1138        let Some(profile) = tracker.profile(&model.id) else {
1139            return false;
1140        };
1141        if let Some(task_stats) = profile.task_stats(task) {
1142            if task_stats.calls >= self.config.bootstrap_min_task_observations
1143                && task_stats.ema_quality >= self.config.bootstrap_quality_floor
1144            {
1145                return true;
1146            }
1147        }
1148
1149        profile.total_calls >= self.config.bootstrap_min_task_observations
1150            && profile.ema_quality >= self.config.bootstrap_quality_floor
1151    }
1152
1153    /// Phase 3: Thompson Sampling selection.
1154    ///
1155    /// Each model gets a Beta(alpha, beta) distribution where:
1156    /// - alpha = prior_successes + observed_successes
1157    /// - beta = prior_failures + observed_failures
1158    ///
1159    /// The Phase 2 score serves as the prior mean, scaled by `prior_strength`.
1160    /// Models with few observations have wide distributions (natural exploration).
1161    /// Models with many observations have tight distributions (exploitation).
1162    fn select_with_thompson_sampling(
1163        &self,
1164        scored: &[(String, f64)],
1165        tracker: &OutcomeTracker,
1166    ) -> (String, RoutingStrategy) {
1167        if scored.is_empty() {
1168            return (String::new(), RoutingStrategy::SchemaBased);
1169        }
1170
1171        let mut rng = rand::rng();
1172        let mut best_sample = f64::NEG_INFINITY;
1173        let mut best_id = scored[0].0.clone();
1174        let mut best_strategy = RoutingStrategy::SchemaBased;
1175
1176        for (id, phase2_score) in scored {
1177            let profile = tracker.profile(id);
1178            let prior = self.config.prior_strength;
1179
1180            // Convert Phase 2 score (0.0-1.15) to a prior mean in [0, 1]
1181            let prior_mean = phase2_score.clamp(0.0, 1.0);
1182
1183            // Prior pseudo-counts from the Phase 2 score
1184            let prior_alpha = prior * prior_mean;
1185            let prior_beta = prior * (1.0 - prior_mean);
1186
1187            // Observed counts
1188            let (obs_alpha, obs_beta) = match profile {
1189                Some(p) => (p.success_count as f64, p.fail_count as f64),
1190                None => (0.0, 0.0),
1191            };
1192
1193            // Posterior Beta parameters
1194            let alpha = (prior_alpha + obs_alpha).max(0.01);
1195            let beta = (prior_beta + obs_beta).max(0.01);
1196
1197            // Sample from Beta(alpha, beta) using the Jöhnk algorithm
1198            let sample = sample_beta(&mut rng, alpha, beta);
1199
1200            if sample > best_sample {
1201                best_sample = sample;
1202                best_id = id.clone();
1203                best_strategy = match profile {
1204                    Some(p) if p.total_calls >= self.config.min_observations => {
1205                        RoutingStrategy::ProfileBased
1206                    }
1207                    Some(p) if p.total_calls > 0 => {
1208                        // Under-tested but has some data — exploration
1209                        RoutingStrategy::Exploration
1210                    }
1211                    _ => RoutingStrategy::SchemaBased,
1212                };
1213            }
1214        }
1215
1216        (best_id, best_strategy)
1217    }
1218
1219    /// Fallback decision when no candidates pass filtering.
1220    fn cold_start_decision(
1221        &self,
1222        complexity: TaskComplexity,
1223        task: InferenceTask,
1224        registry: &UnifiedRegistry,
1225        has_vision: bool,
1226    ) -> AdaptiveRoutingDecision {
1227        if has_vision {
1228            if let Some(model) = registry
1229                .query_by_capability(ModelCapability::Vision)
1230                .into_iter()
1231                .find(|model| model.available && self.is_trusted_quality_remote(model))
1232                .or_else(|| registry.query_by_capability(ModelCapability::Vision).first().copied())
1233            {
1234                return AdaptiveRoutingDecision {
1235                    model_id: model.id.clone(),
1236                    model_name: model.name.clone(),
1237                    task,
1238                    complexity,
1239                    reason: format!(
1240                        "{:?} task → {} (cold start, vision fallback)",
1241                        complexity, model.name
1242                    ),
1243                    strategy: RoutingStrategy::SchemaBased,
1244                    predicted_quality: 0.5,
1245                    fallbacks: vec![],
1246                    context_length: model.context_length,
1247                    needs_compaction: false,
1248                };
1249            }
1250        }
1251
1252        if self.config.quality_first_cold_start {
1253            let required_caps = complexity.required_capabilities();
1254            if let Some(model) = registry
1255                .list()
1256                .into_iter()
1257                .filter(|model| {
1258                    model.available
1259                        && required_caps.iter().all(|cap| model.has_capability(*cap))
1260                        && self.is_trusted_quality_remote(model)
1261                })
1262                .max_by(|a, b| {
1263                    self.schema_quality_estimate(a)
1264                        .partial_cmp(&self.schema_quality_estimate(b))
1265                        .unwrap_or(std::cmp::Ordering::Equal)
1266                })
1267            {
1268                return AdaptiveRoutingDecision {
1269                    model_id: model.id.clone(),
1270                    model_name: model.name.clone(),
1271                    task,
1272                    complexity,
1273                    reason: format!(
1274                        "{:?} task → {} (quality-first cold start)",
1275                        complexity, model.name
1276                    ),
1277                    strategy: RoutingStrategy::SchemaBased,
1278                    predicted_quality: self.schema_quality_estimate(model),
1279                    fallbacks: vec![],
1280                    context_length: model.context_length,
1281                    needs_compaction: false,
1282                };
1283            }
1284        }
1285
1286        // Fall back to the old complexity-based defaults
1287        let model_name = match complexity {
1288            TaskComplexity::Simple => "Qwen3-0.6B",
1289            TaskComplexity::Medium => "Qwen3-1.7B",
1290            TaskComplexity::Code => "Qwen3-4B",
1291            TaskComplexity::Complex => &self.hw.recommended_model,
1292        };
1293
1294        let model_id = registry
1295            .find_by_name(model_name)
1296            .map(|m| m.id.clone())
1297            .unwrap_or_else(|| model_name.to_string());
1298
1299        let context_length = registry
1300            .find_by_name(model_name)
1301            .map(|m| m.context_length)
1302            .unwrap_or(0);
1303
1304        AdaptiveRoutingDecision {
1305            model_id,
1306            model_name: model_name.to_string(),
1307            task,
1308            complexity,
1309            reason: format!(
1310                "{:?} task → {} (cold start, no candidates)",
1311                complexity, model_name
1312            ),
1313            strategy: RoutingStrategy::SchemaBased,
1314            predicted_quality: 0.5,
1315            fallbacks: vec![],
1316            context_length,
1317            needs_compaction: false,
1318        }
1319    }
1320}
1321
1322/// Map a caller-supplied [`crate::TaskHint`] to the engine's
1323/// [`InferenceTask`] enum. The intent surface uses the higher-level
1324/// hint vocabulary; the router operates on InferenceTask. Every
1325/// TaskHint variant maps to a distinct InferenceTask — variants that
1326/// would have silently collapsed to `Generate` were cut from the MVP.
1327fn task_hint_to_inference_task(hint: crate::intent::TaskHint) -> InferenceTask {
1328    use crate::intent::TaskHint;
1329    match hint {
1330        TaskHint::Chat => InferenceTask::Generate,
1331        TaskHint::Classify => InferenceTask::Classify,
1332        TaskHint::Reasoning => InferenceTask::Reasoning,
1333        TaskHint::Code => InferenceTask::Code,
1334    }
1335}
1336
1337/// Sample from a Beta(alpha, beta) distribution.
1338///
1339/// Uses the Gamma distribution method: if X ~ Gamma(alpha, 1) and Y ~ Gamma(beta, 1),
1340/// then X / (X + Y) ~ Beta(alpha, beta).
1341///
1342/// For Gamma sampling, uses Marsaglia and Tsang's method for alpha >= 1,
1343/// and Ahrens-Dieter for alpha < 1.
1344fn sample_beta(rng: &mut impl Rng, alpha: f64, beta: f64) -> f64 {
1345    let x = sample_gamma(rng, alpha);
1346    let y = sample_gamma(rng, beta);
1347    if x + y == 0.0 {
1348        0.5 // degenerate case
1349    } else {
1350        x / (x + y)
1351    }
1352}
1353
1354/// Sample from Gamma(shape, 1) using Marsaglia-Tsang for shape >= 1,
1355/// with Ahrens-Dieter boost for shape < 1.
1356fn sample_gamma(rng: &mut impl Rng, shape: f64) -> f64 {
1357    if shape < 1.0 {
1358        // Ahrens-Dieter: Gamma(a) = Gamma(a+1) * U^(1/a)
1359        let u: f64 = rng.random();
1360        return sample_gamma(rng, shape + 1.0) * u.powf(1.0 / shape);
1361    }
1362
1363    // Marsaglia-Tsang method for shape >= 1
1364    let d = shape - 1.0 / 3.0;
1365    let c = 1.0 / (9.0 * d).sqrt();
1366
1367    loop {
1368        let x: f64 = loop {
1369            let n = sample_standard_normal(rng);
1370            if 1.0 + c * n > 0.0 {
1371                break n;
1372            }
1373        };
1374
1375        let v = (1.0 + c * x).powi(3);
1376        let u: f64 = rng.random();
1377
1378        if u < 1.0 - 0.0331 * x.powi(4) {
1379            return d * v;
1380        }
1381        if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
1382            return d * v;
1383        }
1384    }
1385}
1386
1387/// Sample from standard normal N(0,1) using Box-Muller transform.
1388fn sample_standard_normal(rng: &mut impl Rng) -> f64 {
1389    let u1: f64 = rng.random();
1390    let u2: f64 = rng.random();
1391    (-2.0 * u1.max(1e-300).ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
1392}
1393
1394#[cfg(test)]
1395mod tests {
1396    use super::*;
1397    use crate::outcome::InferredOutcome;
1398
1399    fn test_hw() -> HardwareInfo {
1400        HardwareInfo {
1401            os: "macos".into(),
1402            arch: "aarch64".into(),
1403            cpu_cores: 10,
1404            total_ram_mb: 32768,
1405            gpu_backend: crate::hardware::GpuBackend::Metal,
1406            gpu_memory_mb: Some(28672),
1407            gpu_devices: Vec::new(),
1408            recommended_model: "Qwen3-8B".into(),
1409            recommended_context: 8192,
1410            max_model_mb: 18000, // headroom above 30B-A3B's 17000MB
1411        }
1412    }
1413
1414    fn test_registry() -> UnifiedRegistry {
1415        let tmp = std::path::PathBuf::from("/tmp/car-test-adaptive-router");
1416        unsafe {
1417            std::env::set_var("OPENAI_API_KEY", "test-openai-key");
1418        }
1419        // Create fake model dirs so the registry marks them as available
1420        for name in &[
1421            "Qwen3-0.6B",
1422            "Qwen3-1.7B",
1423            "Qwen3-4B",
1424            "Qwen3-8B",
1425            "Qwen3-Embedding-0.6B",
1426        ] {
1427            let dir = tmp.join(name);
1428            let _ = std::fs::create_dir_all(&dir);
1429            let _ = std::fs::write(dir.join("model.gguf"), b"fake");
1430            let _ = std::fs::write(dir.join("tokenizer.json"), b"{}");
1431        }
1432        let mut reg = UnifiedRegistry::new(tmp);
1433        reg.register(ModelSchema {
1434            id: "openai/gpt-5.4-mini:latest".into(),
1435            name: "gpt-5.4-mini".into(),
1436            provider: "openai".into(),
1437            family: "gpt-5.4".into(),
1438            version: "latest".into(),
1439            capabilities: vec![
1440                ModelCapability::Generate,
1441                ModelCapability::Code,
1442                ModelCapability::Reasoning,
1443                ModelCapability::ToolUse,
1444                ModelCapability::MultiToolCall,
1445                ModelCapability::Vision,
1446            ],
1447            context_length: 128_000,
1448            param_count: "api".into(),
1449            quantization: None,
1450            performance: Default::default(),
1451            cost: Default::default(),
1452            source: crate::schema::ModelSource::RemoteApi {
1453                endpoint: "https://api.openai.com/v1".into(),
1454                api_key_env: "OPENAI_API_KEY".into(),
1455                api_key_envs: vec![],
1456                api_version: None,
1457                protocol: crate::schema::ApiProtocol::OpenAiCompat,
1458            },
1459            tags: vec!["trusted-remote".into()],
1460            supported_params: vec![],
1461            public_benchmarks: vec![],
1462            available: true,
1463        });
1464        reg
1465    }
1466
1467    #[test]
1468    fn routes_simple_to_trusted_remote_during_cold_start() {
1469        let router = AdaptiveRouter::new(
1470            test_hw(),
1471            RoutingConfig {
1472                prior_strength: 100.0, // strong prior = exploit Phase 2 scores (deterministic-ish)
1473                ..Default::default()
1474            },
1475        );
1476        let reg = test_registry();
1477        let tracker = OutcomeTracker::new();
1478
1479        let decision = router.route("What is 2+2?", &reg, &tracker);
1480        assert_eq!(decision.complexity, TaskComplexity::Simple);
1481        assert_eq!(decision.strategy, RoutingStrategy::SchemaBased);
1482        // On cold start, quality-critical tasks should stay on trusted remote models.
1483        let schema = reg
1484            .find_by_name(&decision.model_name)
1485            .expect("selected model should exist in registry");
1486        assert!(
1487            !schema.is_local(),
1488            "simple task should route to trusted remote model during cold start"
1489        );
1490        assert!(matches!(
1491            schema.provider.as_str(),
1492            "openai" | "anthropic" | "google"
1493        ));
1494    }
1495
1496    #[test]
1497    fn routes_code_to_code_capable_remote_during_cold_start() {
1498        let router = AdaptiveRouter::new(
1499            test_hw(),
1500            RoutingConfig {
1501                prior_strength: 100.0, // strong prior = exploit Phase 2 scores (deterministic-ish)
1502                ..Default::default()
1503            },
1504        );
1505        let reg = test_registry();
1506        let tracker = OutcomeTracker::new();
1507
1508        let decision = router.route(
1509            "Fix this function:\n```rust\nfn main() {}\n```",
1510            &reg,
1511            &tracker,
1512        );
1513        assert_eq!(decision.complexity, TaskComplexity::Code);
1514        assert_eq!(decision.task, InferenceTask::Code);
1515        // Must select a code-capable local model (not 0.6B which lacks Code)
1516        let schema = reg
1517            .find_by_name(&decision.model_name)
1518            .expect("model should exist");
1519        assert!(
1520            schema.has_capability(ModelCapability::Code),
1521            "selected model must support Code"
1522        );
1523        assert!(!schema.is_local(), "should route to trusted remote model");
1524    }
1525
1526    #[test]
1527    fn routes_images_to_vision_capable_model() {
1528        let router = AdaptiveRouter::new(
1529            test_hw(),
1530            RoutingConfig {
1531                prior_strength: 100.0,
1532                ..Default::default()
1533            },
1534        );
1535        let mut reg = test_registry();
1536        let tracker = OutcomeTracker::new();
1537
1538        reg.register(ModelSchema {
1539            id: "mlx-vlm/qwen3-vl-2b:bf16".into(),
1540            name: "Qwen3-VL-2B-mlx-vlm".into(),
1541            provider: "qwen".into(),
1542            family: "qwen3-vl".into(),
1543            version: "bf16".into(),
1544            capabilities: vec![
1545                ModelCapability::Generate,
1546                ModelCapability::Vision,
1547                ModelCapability::Grounding,
1548            ],
1549            context_length: 262_144,
1550            param_count: "2B".into(),
1551            quantization: None,
1552            performance: Default::default(),
1553            cost: Default::default(),
1554            source: crate::schema::ModelSource::Mlx {
1555                hf_repo: "Qwen/Qwen3-VL-2B-Instruct".into(),
1556                hf_weight_file: None,
1557            },
1558            tags: vec!["vision".into(), "mlx-vlm-cli".into()],
1559            supported_params: vec![],
1560            public_benchmarks: vec![],
1561            available: true,
1562        });
1563
1564        let decision = router.route_with_vision("What is in this image?", &reg, &tracker, false);
1565        let schema = reg
1566            .find_by_name(&decision.model_name)
1567            .expect("model should exist");
1568        assert!(
1569            schema.has_capability(ModelCapability::Vision),
1570            "selected model must support Vision"
1571        );
1572    }
1573
1574    #[test]
1575    fn profile_based_routing_favors_proven_model() {
1576        let router = AdaptiveRouter::new(
1577            test_hw(),
1578            RoutingConfig {
1579                prior_strength: 0.5, // weak prior, let observed data dominate
1580                min_observations: 3,
1581                ..Default::default()
1582            },
1583        );
1584        let reg = test_registry();
1585        let mut tracker = OutcomeTracker::new();
1586
1587        // Build a strong profile for Qwen3-8B on code tasks (fast + high quality)
1588        let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
1589        for _ in 0..20 {
1590            let trace = tracker.record_start(qwen_8b_id, InferenceTask::Code, "test");
1591            tracker.record_complete(&trace, 500, 100, 50);
1592            tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1593        }
1594
1595        // Thompson Sampling is stochastic — run multiple times, 8B should win majority
1596        let mut wins = 0;
1597        for _ in 0..20 {
1598            let decision = router.route("Fix this bug in the parser", &reg, &tracker);
1599            assert_eq!(decision.complexity, TaskComplexity::Code);
1600            if decision.model_id == qwen_8b_id {
1601                wins += 1;
1602            }
1603        }
1604        assert!(
1605            wins >= 12,
1606            "proven model won only {wins}/20 times (expected >= 12)"
1607        );
1608    }
1609
1610    #[test]
1611    fn proven_local_model_can_displace_bootstrap_remote() {
1612        let router = AdaptiveRouter::new(
1613            test_hw(),
1614            RoutingConfig {
1615                prior_strength: 100.0,
1616                bootstrap_min_task_observations: 6,
1617                bootstrap_quality_floor: 0.8,
1618                ..Default::default()
1619            },
1620        );
1621        let reg = test_registry();
1622        let mut tracker = OutcomeTracker::new();
1623
1624        let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
1625        for _ in 0..12 {
1626            let trace = tracker.record_start(qwen_8b_id, InferenceTask::Generate, "test");
1627            tracker.record_complete(&trace, 300, 50, 20);
1628            tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1629        }
1630
1631        let mut local_wins = 0;
1632        for _ in 0..20 {
1633            let decision = router.route("Summarize this design decision.", &reg, &tracker);
1634            let schema = reg.get(&decision.model_id).expect("selected model should exist");
1635            if schema.is_local() {
1636                local_wins += 1;
1637            }
1638        }
1639
1640        assert!(
1641            local_wins >= 12,
1642            "proven local model won only {local_wins}/20 times (expected >= 12)"
1643        );
1644    }
1645
1646    #[test]
1647    fn benchmark_prior_informs_background_routing() {
1648        let router = AdaptiveRouter::new(
1649            test_hw(),
1650            RoutingConfig {
1651                prior_strength: 100.0,
1652                ..Default::default()
1653            },
1654        );
1655        let reg = test_registry();
1656        let mut tracker = OutcomeTracker::new();
1657        let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1658        profile.ema_quality = 0.95;
1659        tracker.import_profiles(vec![profile]);
1660
1661        let decision = router.route_context_aware(
1662            "Write a Python fibonacci function.",
1663            128,
1664            &reg,
1665            &tracker,
1666            false,
1667            false,
1668            RoutingWorkload::Background,
1669        );
1670
1671        let schema = reg.get(&decision.model_id).expect("selected model should exist");
1672        assert!(schema.is_local(), "background routing should allow strong local benchmark priors to win");
1673    }
1674
1675    #[test]
1676    fn task_specific_benchmark_prior_informs_cold_start_routing() {
1677        let router = AdaptiveRouter::new(
1678            test_hw(),
1679            RoutingConfig {
1680                prior_strength: 100.0,
1681                ..Default::default()
1682            },
1683        );
1684        let reg = test_registry();
1685        let mut tracker = OutcomeTracker::new();
1686        let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1687        profile.task_stats.insert(
1688            crate::outcome::InferenceTask::Code.to_string(),
1689            crate::outcome::TaskStats {
1690                ema_quality: 0.95,
1691                ..Default::default()
1692            },
1693        );
1694        tracker.import_profiles(vec![profile]);
1695
1696        let decision = router.route_context_aware(
1697            "Write a Python fibonacci function.",
1698            128,
1699            &reg,
1700            &tracker,
1701            false,
1702            false,
1703            RoutingWorkload::Background,
1704        );
1705
1706        let schema = reg.get(&decision.model_id).expect("selected model should exist");
1707        assert!(
1708            schema.is_local(),
1709            "background routing should use task-specific cold-start priors for local code models"
1710        );
1711    }
1712
1713    #[test]
1714    fn task_specific_latency_prior_affects_cold_start_score() {
1715        let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1716        let reg = test_registry();
1717        let model = reg
1718            .get("qwen/qwen3-8b:q4_k_m")
1719            .expect("local test model should exist");
1720
1721        let mut fast_tracker = OutcomeTracker::new();
1722        let mut fast_profile = crate::outcome::ModelProfile::new(model.id.clone());
1723        fast_profile.task_stats.insert(
1724            crate::outcome::InferenceTask::Generate.to_string(),
1725            crate::outcome::TaskStats {
1726                ema_quality: 0.95,
1727                avg_latency_ms: 1200.0,
1728                ..Default::default()
1729            },
1730        );
1731        fast_tracker.import_profiles(vec![fast_profile]);
1732
1733        let mut slow_tracker = OutcomeTracker::new();
1734        let mut slow_profile = crate::outcome::ModelProfile::new(model.id.clone());
1735        slow_profile.task_stats.insert(
1736            crate::outcome::InferenceTask::Generate.to_string(),
1737            crate::outcome::TaskStats {
1738                ema_quality: 0.95,
1739                avg_latency_ms: 120_000.0,
1740                ..Default::default()
1741            },
1742        );
1743        slow_tracker.import_profiles(vec![slow_profile]);
1744
1745        let fast_score = router.score_model(
1746            model,
1747            InferenceTask::Generate,
1748            &fast_tracker,
1749            RoutingWorkload::Interactive,
1750        );
1751        let slow_score = router.score_model(
1752            model,
1753            InferenceTask::Generate,
1754            &slow_tracker,
1755            RoutingWorkload::Interactive,
1756        );
1757
1758        assert!(
1759            fast_score > slow_score,
1760            "faster task latency prior should improve cold-start score ({fast_score} <= {slow_score})"
1761        );
1762    }
1763
1764    #[test]
1765    fn interactive_workload_keeps_remote_bootstrap_bias() {
1766        let router = AdaptiveRouter::new(
1767            test_hw(),
1768            RoutingConfig {
1769                prior_strength: 100.0,
1770                ..Default::default()
1771            },
1772        );
1773        let reg = test_registry();
1774        let mut tracker = OutcomeTracker::new();
1775        let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1776        profile.ema_quality = 0.95;
1777        tracker.import_profiles(vec![profile]);
1778
1779        let decision = router.route_context_aware(
1780            "Write a Python fibonacci function.",
1781            128,
1782            &reg,
1783            &tracker,
1784            false,
1785            false,
1786            RoutingWorkload::Interactive,
1787        );
1788
1789        let schema = reg.get(&decision.model_id).expect("selected model should exist");
1790        assert!(
1791            !schema.is_local(),
1792            "interactive routing should still prefer trusted remote models during cold start"
1793        );
1794    }
1795
1796    #[test]
1797    fn fallback_chain_has_alternatives() {
1798        let router = AdaptiveRouter::new(
1799            test_hw(),
1800            RoutingConfig {
1801                prior_strength: 100.0, // strong prior = exploit Phase 2 scores (deterministic-ish)
1802                ..Default::default()
1803            },
1804        );
1805        let reg = test_registry();
1806        let tracker = OutcomeTracker::new();
1807
1808        let decision = router.route("Analyze the architecture trade-offs", &reg, &tracker);
1809        assert!(!decision.fallbacks.is_empty());
1810        // Primary should not appear in fallbacks
1811        assert!(!decision.fallbacks.contains(&decision.model_id));
1812    }
1813
1814    #[test]
1815    fn latency_scoring_is_consistent() {
1816        // Verify that schema and observed latency produce comparable scores
1817        let router = AdaptiveRouter::with_default_config(test_hw());
1818
1819        // A model at 25 TPS → ~200/25*1000 = 8000ms → score = 1 - 8000/10000 = 0.2
1820        let schema_score = router.latency_ms_to_score(AdaptiveRouter::tps_to_latency_ms(25.0));
1821        // Same model observed at 8000ms → same formula
1822        let observed_score = router.latency_ms_to_score(8000.0);
1823        assert!(
1824            (schema_score - observed_score).abs() < 0.01,
1825            "schema ({schema_score}) and observed ({observed_score}) should match"
1826        );
1827    }
1828
1829    #[test]
1830    fn complexity_assessment() {
1831        assert_eq!(
1832            TaskComplexity::assess("What is the capital of France?"),
1833            TaskComplexity::Simple
1834        );
1835        assert_eq!(
1836            TaskComplexity::assess("Fix this broken test"),
1837            TaskComplexity::Code
1838        );
1839        assert_eq!(
1840            TaskComplexity::assess("Analyze the trade-offs between A and B"),
1841            TaskComplexity::Complex
1842        );
1843    }
1844
1845    #[test]
1846    fn beta_sampling_produces_valid_values() {
1847        let mut rng = rand::rng();
1848        // Sample 100 times from Beta(2, 5) — should be in [0, 1]
1849        for _ in 0..100 {
1850            let s = sample_beta(&mut rng, 2.0, 5.0);
1851            assert!(s >= 0.0 && s <= 1.0, "sample {s} out of [0,1] range");
1852        }
1853        // Beta(1, 1) = Uniform(0, 1) — mean should be ~0.5
1854        let samples: Vec<f64> = (0..1000).map(|_| sample_beta(&mut rng, 1.0, 1.0)).collect();
1855        let mean = samples.iter().sum::<f64>() / samples.len() as f64;
1856        assert!(
1857            (mean - 0.5).abs() < 0.05,
1858            "Beta(1,1) mean {mean} should be ~0.5"
1859        );
1860    }
1861
1862    #[test]
1863    fn thompson_sampling_converges_to_best() {
1864        // A model with strong observed success should win most of the time
1865        let router = AdaptiveRouter::new(
1866            test_hw(),
1867            RoutingConfig {
1868                prior_strength: 1.0, // weak prior, let observations dominate
1869                ..Default::default()
1870            },
1871        );
1872        let reg = test_registry();
1873        let mut tracker = OutcomeTracker::new();
1874
1875        // Give Qwen3-4B 20 successes (strong signal)
1876        let qwen_4b_id = "qwen/qwen3-4b:q4_k_m";
1877        for _ in 0..20 {
1878            let trace = tracker.record_start(qwen_4b_id, InferenceTask::Code, "test");
1879            tracker.record_complete(&trace, 500, 100, 50);
1880            tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1881        }
1882
1883        // Route 20 code tasks — strong model should win the majority
1884        let mut wins = 0;
1885        for _ in 0..20 {
1886            let decision = router.route("Fix this parser bug", &reg, &tracker);
1887            if decision.model_id == qwen_4b_id {
1888                wins += 1;
1889            }
1890        }
1891        assert!(
1892            wins >= 14,
1893            "strong model won only {wins}/20 times (expected >= 14)"
1894        );
1895    }
1896
1897    // ----- Intent surface (parslee-ai/car-releases#18) -----
1898
1899    #[test]
1900    fn intent_require_filters_out_models_lacking_capability() {
1901        // Asking for vision when no candidate has it should fall to
1902        // the cold-start decision rather than scoring incompatible
1903        // candidates. The fixture registry has no vision-capable
1904        // local models.
1905        let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1906        let reg = test_registry();
1907        let tracker = OutcomeTracker::new();
1908
1909        let intent = crate::intent::IntentHint {
1910            require: vec![ModelCapability::Vision],
1911            ..Default::default()
1912        };
1913        let decision = router.route_with_intent("hello", &reg, &tracker, &intent);
1914
1915        // When require filters out every candidate, the router falls
1916        // to the schema-based cold-start path. Asserting the strategy
1917        // (not just non-empty model_id) catches a regression where
1918        // future code might silently include filtered candidates.
1919        assert_eq!(
1920            decision.strategy,
1921            RoutingStrategy::SchemaBased,
1922            "require=[vision] with no vision-capable candidates must drop to schema cold-start"
1923        );
1924    }
1925
1926    #[test]
1927    fn intent_default_does_not_override_task_or_caps() {
1928        // Thompson sampling makes per-call model selection
1929        // non-deterministic, so we can't compare model_ids directly.
1930        // What we can assert deterministically: a default IntentHint
1931        // must not change the task selection or the capability
1932        // requirements — those are functions of the prompt only when
1933        // no hint is supplied.
1934        let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1935        let reg = test_registry();
1936        let tracker = OutcomeTracker::new();
1937
1938        let baseline = router.route("write a haiku", &reg, &tracker);
1939        let with_default = router.route_with_intent(
1940            "write a haiku",
1941            &reg,
1942            &tracker,
1943            &crate::intent::IntentHint::default(),
1944        );
1945
1946        assert_eq!(
1947            baseline.task, with_default.task,
1948            "default IntentHint must not change the prompt-derived task"
1949        );
1950        assert_eq!(
1951            baseline.complexity, with_default.complexity,
1952            "default IntentHint must not change the prompt-derived complexity"
1953        );
1954    }
1955
1956    #[test]
1957    fn intent_task_hint_overrides_prompt_complexity() {
1958        // A short prompt that complexity assessment would route as
1959        // Generate should land on Reasoning when the intent says so.
1960        let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1961        let reg = test_registry();
1962        let tracker = OutcomeTracker::new();
1963
1964        let hint = crate::intent::IntentHint {
1965            task: Some(crate::intent::TaskHint::Reasoning),
1966            ..Default::default()
1967        };
1968        let decision = router.route_with_intent("hi", &reg, &tracker, &hint);
1969
1970        assert_eq!(
1971            decision.task,
1972            InferenceTask::Reasoning,
1973            "TaskHint::Reasoning should override the prompt-derived task"
1974        );
1975    }
1976}