Skip to main content

car_inference/
adaptive_router.rs

1//! Adaptive model routing — three-phase routing with learned performance profiles.
2//!
3//! Phase 1: **Filter** — hard constraints (capability, availability, memory, cost).
4//! Phase 2: **Score** — blend quality, latency, and cost using observed profiles
5//!          or schema defaults on cold start.
6//! Phase 3: **Explore** — epsilon-greedy bandit to gather data on under-tested models.
7//!
8//! Replaces the hardcoded `ModelRouter` from `router.rs`.
9
10use rand::Rng;
11use serde::{Deserialize, Serialize};
12
13use crate::hardware::HardwareInfo;
14use crate::outcome::{InferenceTask, OutcomeTracker};
15use crate::registry::UnifiedRegistry;
16use crate::schema::{ModelCapability, ModelSchema};
17
18/// Prompt complexity assessment (migrated from router.rs).
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(rename_all = "snake_case")]
21pub enum TaskComplexity {
22    Simple,
23    Medium,
24    Code,
25    Complex,
26}
27
28impl TaskComplexity {
29    /// Assess complexity of a prompt string.
30    ///
31    /// Uses tree-sitter AST parsing (when the `ast` feature is enabled) for
32    /// accurate code detection: if a code block parses successfully as any
33    /// supported language, it's definitively code. Falls back to keyword
34    /// heuristics for prompts that mention code without containing code blocks.
35    pub fn assess(prompt: &str) -> Self {
36        let lower = prompt.to_lowercase();
37        let word_count = prompt.split_whitespace().count();
38        let estimated_tokens = (word_count as f64 * 1.3) as usize;
39
40        let has_code = Self::detect_code(prompt);
41
42        let repair_markers = [
43            "fix", "repair", "debug", "refactor", "broken",
44            "failing", "error", "bug",
45        ];
46        let has_repair = repair_markers.iter().any(|m| lower.contains(m));
47
48        let reasoning_markers = [
49            "analyze", "compare", "explain why", "step by step",
50            "think through", "evaluate", "trade-off", "tradeoff",
51            "pros and cons", "architecture", "design", "strategy",
52            "optimize", "comprehensive",
53        ];
54        let has_reasoning = reasoning_markers.iter().any(|m| lower.contains(m));
55
56        let simple_patterns = [
57            "what is", "who is", "when did", "where is",
58            "how many", "yes or no", "true or false", "name the",
59            "list the", "define ",
60        ];
61        let is_simple = simple_patterns.iter().any(|p| lower.contains(p));
62
63        if has_code || has_repair {
64            TaskComplexity::Code
65        } else if has_reasoning || estimated_tokens > 500 {
66            TaskComplexity::Complex
67        } else if is_simple || estimated_tokens < 30 {
68            TaskComplexity::Simple
69        } else {
70            TaskComplexity::Medium
71        }
72    }
73
74    /// Detect whether a prompt contains code.
75    ///
76    /// With `ast` feature: extracts code blocks (``` delimited), attempts to
77    /// parse each with tree-sitter. If any parses into symbols, it's real code.
78    /// Without `ast` feature: falls back to keyword heuristics.
79    fn detect_code(prompt: &str) -> bool {
80        // First try AST-based detection on code blocks
81        #[cfg(feature = "ast")]
82        {
83            if let Some(is_code) = Self::detect_code_ast(prompt) {
84                return is_code;
85            }
86        }
87
88        // Fallback: keyword heuristics
89        let code_markers = [
90            "```", "fn ", "def ", "class ", "import ", "require(",
91            "async fn", "pub fn", "function ", "const ", "let ", "var ",
92            "#include", "package ", "impl ",
93        ];
94        code_markers.iter().any(|m| prompt.contains(m))
95    }
96
97    /// AST-based code detection: parse code blocks with tree-sitter.
98    /// Returns Some(true) if code found, Some(false) if blocks exist but
99    /// don't parse, None if no code blocks found (fall through to heuristics).
100    #[cfg(feature = "ast")]
101    fn detect_code_ast(prompt: &str) -> Option<bool> {
102        // Extract code blocks between ``` markers
103        let mut blocks = Vec::new();
104        let mut rest = prompt;
105        while let Some(start) = rest.find("```") {
106            let after_fence = &rest[start + 3..];
107            // Skip optional language tag on the opening fence
108            let code_start = after_fence.find('\n').map(|i| i + 1).unwrap_or(0);
109            if let Some(end) = after_fence[code_start..].find("```") {
110                blocks.push(&after_fence[code_start..code_start + end]);
111                rest = &after_fence[code_start + end + 3..];
112            } else {
113                break;
114            }
115        }
116
117        if blocks.is_empty() {
118            return None; // No code blocks — let heuristics decide
119        }
120
121        // Try to parse each block with tree-sitter
122        let languages = [
123            car_ast::Language::Rust,
124            car_ast::Language::Python,
125            car_ast::Language::TypeScript,
126            car_ast::Language::JavaScript,
127            car_ast::Language::Go,
128        ];
129
130        for block in &blocks {
131            let trimmed = block.trim();
132            if trimmed.is_empty() { continue; }
133
134            for lang in &languages {
135                if let Some(parsed) = car_ast::parse(trimmed, *lang) {
136                    // If it parsed into any symbols, it's definitely code
137                    if !parsed.symbols.is_empty() {
138                        return Some(true);
139                    }
140                }
141            }
142        }
143
144        // Had code blocks but none parsed into symbols — could be
145        // pseudocode, output, or unsupported language
146        Some(false)
147    }
148
149    /// Map complexity to required capabilities.
150    pub fn required_capabilities(&self) -> Vec<ModelCapability> {
151        match self {
152            TaskComplexity::Simple => vec![ModelCapability::Generate],
153            TaskComplexity::Medium => vec![ModelCapability::Generate],
154            TaskComplexity::Code => vec![ModelCapability::Code],
155            TaskComplexity::Complex => vec![ModelCapability::Reasoning],
156        }
157    }
158
159    /// Map complexity to the InferenceTask type.
160    pub fn inference_task(&self) -> InferenceTask {
161        match self {
162            TaskComplexity::Simple | TaskComplexity::Medium => InferenceTask::Generate,
163            TaskComplexity::Code => InferenceTask::Code,
164            TaskComplexity::Complex => InferenceTask::Reasoning,
165        }
166    }
167}
168
169/// Configuration for routing behavior.
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct RoutingConfig {
172    /// Exploration rate for epsilon-greedy bandit (0.0 = always exploit, 1.0 = always explore).
173    pub exploration_rate: f64,
174    /// Minimum observations before trusting a model's profile over schema defaults.
175    pub min_observations: u64,
176    /// Scoring weights (must sum to 1.0).
177    pub quality_weight: f64,
178    pub latency_weight: f64,
179    pub cost_weight: f64,
180    /// Hard constraint: maximum latency budget in ms.
181    pub max_latency_ms: Option<u64>,
182    /// Hard constraint: maximum cost per call in USD.
183    pub max_cost_usd: Option<f64>,
184    /// Prefer local models over remote (all else being equal).
185    pub prefer_local: bool,
186}
187
188impl Default for RoutingConfig {
189    fn default() -> Self {
190        Self {
191            exploration_rate: 0.1,
192            min_observations: 5,
193            quality_weight: 0.45,
194            latency_weight: 0.4,
195            cost_weight: 0.15,
196            max_latency_ms: None,
197            max_cost_usd: None,
198            prefer_local: true,
199        }
200    }
201}
202
203/// How a model was selected.
204#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
205#[serde(rename_all = "snake_case")]
206pub enum RoutingStrategy {
207    /// Using declared schema capabilities (no observed data).
208    SchemaBased,
209    /// Using observed performance profiles (exploitation).
210    ProfileBased,
211    /// Deliberately trying an under-tested model (exploration).
212    Exploration,
213    /// User explicitly specified the model.
214    Explicit,
215}
216
217/// The result of adaptive routing.
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct AdaptiveRoutingDecision {
220    /// Selected model id.
221    pub model_id: String,
222    /// Selected model name (display).
223    pub model_name: String,
224    /// Task type.
225    pub task: InferenceTask,
226    /// Assessed complexity.
227    pub complexity: TaskComplexity,
228    /// Human-readable reason.
229    pub reason: String,
230    /// How the model was selected.
231    pub strategy: RoutingStrategy,
232    /// Predicted quality (0.0-1.0).
233    pub predicted_quality: f64,
234    /// Fallback chain (ordered list of alternative model ids).
235    pub fallbacks: Vec<String>,
236}
237
238/// Adaptive router with three-phase model selection.
239pub struct AdaptiveRouter {
240    hw: HardwareInfo,
241    config: RoutingConfig,
242}
243
244impl AdaptiveRouter {
245    pub fn new(hw: HardwareInfo, config: RoutingConfig) -> Self {
246        Self { hw, config }
247    }
248
249    pub fn with_default_config(hw: HardwareInfo) -> Self {
250        Self::new(hw, RoutingConfig::default())
251    }
252
253    /// Route a generation request to the best model.
254    pub fn route(
255        &self,
256        prompt: &str,
257        registry: &UnifiedRegistry,
258        tracker: &OutcomeTracker,
259    ) -> AdaptiveRoutingDecision {
260        let complexity = TaskComplexity::assess(prompt);
261        let task = complexity.inference_task();
262        let required_caps = complexity.required_capabilities();
263
264        // Phase 1: Filter candidates
265        let candidates = self.filter_candidates(&required_caps, registry);
266
267        if candidates.is_empty() {
268            // Nothing available — return the schema-based default
269            return self.cold_start_decision(complexity, task, registry);
270        }
271
272        // Phase 2: Score candidates
273        let scored = self.score_candidates(&candidates, task, tracker);
274
275        // Phase 3: Exploration vs exploitation
276        let (selected_id, strategy) = self.select_with_exploration(&scored, tracker);
277
278        // Build fallback chain
279        let fallbacks: Vec<String> = scored.iter()
280            .filter(|(id, _)| *id != selected_id)
281            .map(|(id, _)| id.clone())
282            .collect();
283
284        let predicted_quality = scored.iter()
285            .find(|(id, _)| *id == selected_id)
286            .map(|(_, score)| *score)
287            .unwrap_or(0.5);
288
289        let model_name = registry.get(&selected_id)
290            .or_else(|| registry.find_by_name(&selected_id))
291            .map(|m| m.name.clone())
292            .unwrap_or_else(|| selected_id.clone());
293
294        let reason = format!(
295            "{:?} task → {} via {:?} (quality: {:.2}, {} candidates)",
296            complexity, model_name, strategy, predicted_quality, candidates.len()
297        );
298
299        AdaptiveRoutingDecision {
300            model_id: selected_id,
301            model_name,
302            task,
303            complexity,
304            reason,
305            strategy,
306            predicted_quality,
307            fallbacks,
308        }
309    }
310
311    /// Route to the best embedding model.
312    pub fn route_embedding(&self, registry: &UnifiedRegistry) -> String {
313        let embed_models = registry.query_by_capability(ModelCapability::Embed);
314        embed_models.first()
315            .map(|m| m.name.clone())
316            .unwrap_or_else(|| "Qwen3-Embedding-0.6B".to_string())
317    }
318
319    /// Route to the smallest available model (for classification).
320    pub fn route_small(&self, registry: &UnifiedRegistry) -> String {
321        let gen_models = registry.query_by_capability(ModelCapability::Generate);
322        // Pick smallest by size
323        gen_models.iter()
324            .filter(|m| m.is_local())
325            .min_by_key(|m| m.size_mb())
326            .map(|m| m.name.clone())
327            .unwrap_or_else(|| "Qwen3-0.6B".to_string())
328    }
329
330    // --- Internal phases ---
331
332    // --- Scoring constants ---
333
334    /// Latency ceiling: requests taking longer than this score 0.0.
335    const LATENCY_CEILING_MS: f64 = 10000.0;
336    /// TPS ceiling: models faster than this score 1.0.
337    const TPS_CEILING: f64 = 150.0;
338    /// MoE throughput penalty: naive expert routing on Metal/CPU runs at ~10% of declared TPS.
339    const MOE_TPS_MULTIPLIER: f64 = 0.10;
340    /// Cost ceiling: models costing more than this per 1K output tokens score 0.0.
341    const COST_CEILING_PER_1K: f64 = 0.1;
342    /// Local preference bonus added to the weighted score (before normalization).
343    const LOCAL_BONUS: f64 = 0.15;
344
345    // --- Internal phases ---
346
347    /// Phase 1: Filter by hard constraints.
348    fn filter_candidates(
349        &self,
350        required_caps: &[ModelCapability],
351        registry: &UnifiedRegistry,
352    ) -> Vec<ModelSchema> {
353        registry.list().into_iter()
354            .filter(|m| {
355                // Must have all required capabilities
356                if !required_caps.iter().all(|c| m.has_capability(*c)) {
357                    return false;
358                }
359                // Must be available (downloaded for local, API key set for remote)
360                if !m.available {
361                    return false;
362                }
363                // Local models must fit in memory (strict: >= excludes models at the limit)
364                if m.is_local() && m.size_mb() >= self.hw.max_model_mb {
365                    return false;
366                }
367                // Hard latency constraint
368                if let Some(max) = self.config.max_latency_ms {
369                    if let Some(p50) = m.performance.latency_p50_ms {
370                        if p50 > max {
371                            return false;
372                        }
373                    }
374                }
375                // Hard cost constraint
376                if let Some(max) = self.config.max_cost_usd {
377                    if m.cost_per_1k_output() > max {
378                        return false;
379                    }
380                }
381                true
382            })
383            .cloned()
384            .collect()
385    }
386
387    /// Phase 2: Score candidates using profiles or schema defaults.
388    fn score_candidates(
389        &self,
390        candidates: &[ModelSchema],
391        task: InferenceTask,
392        tracker: &OutcomeTracker,
393    ) -> Vec<(String, f64)> {
394        let mut scored: Vec<(String, f64)> = candidates.iter()
395            .map(|m| {
396                let score = self.score_model(m, task, tracker);
397                (m.id.clone(), score)
398            })
399            .collect();
400
401        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
402        scored
403    }
404
405    /// Score a single model. All sub-scores are in [0.0, 1.0].
406    /// Final score = weighted sum + local_bonus, so range is [0.0, ~1.15].
407    fn score_model(
408        &self,
409        model: &ModelSchema,
410        task: InferenceTask,
411        tracker: &OutcomeTracker,
412    ) -> f64 {
413        let profile = tracker.profile(&model.id);
414        let schema_quality = self.schema_quality_estimate(model);
415        let schema_latency = self.schema_latency_estimate(model);
416
417        // Quality: blend schema estimate with observed data based on observation count.
418        // Both cold start and warm start use consistent blending.
419        let quality = match profile {
420            Some(p) if p.total_calls >= self.config.min_observations => {
421                p.task_stats(task).map(|ts| ts.ema_quality).unwrap_or(p.ema_quality)
422            }
423            Some(p) if p.total_calls > 0 => {
424                let w = p.total_calls as f64 / self.config.min_observations as f64;
425                schema_quality * (1.0 - w) + p.ema_quality * w
426            }
427            _ => schema_quality,
428        };
429
430        // Latency: same blending as quality — don't trust a single observation more
431        // than schema estimates. This prevents routing oscillation on first few calls.
432        let latency = match profile {
433            Some(p) if p.total_calls >= self.config.min_observations => {
434                let avg = p.avg_latency_ms();
435                self.latency_ms_to_score(avg)
436            }
437            Some(p) if p.total_calls > 0 => {
438                let observed = self.latency_ms_to_score(p.avg_latency_ms());
439                let w = p.total_calls as f64 / self.config.min_observations as f64;
440                schema_latency * (1.0 - w) + observed * w
441            }
442            _ => schema_latency,
443        };
444
445        // Cost score (lower is better → invert)
446        let cost = if model.is_local() {
447            1.0
448        } else {
449            (1.0 - (model.cost_per_1k_output() / Self::COST_CEILING_PER_1K)).clamp(0.0, 1.0)
450        };
451
452        let local_bonus = if self.config.prefer_local && model.is_local() {
453            Self::LOCAL_BONUS
454        } else {
455            0.0
456        };
457
458        self.config.quality_weight * quality
459            + self.config.latency_weight * latency
460            + self.config.cost_weight * cost
461            + local_bonus
462    }
463
464    /// Convert latency in ms to a [0, 1] score. Used by both schema and observed paths
465    /// so the scales are consistent (fixes Linus review issue #2).
466    fn latency_ms_to_score(&self, ms: f64) -> f64 {
467        (1.0 - (ms / Self::LATENCY_CEILING_MS)).clamp(0.0, 1.0)
468    }
469
470    /// Convert TPS to estimated latency in ms (for a typical 200-token response).
471    fn tps_to_latency_ms(tps: f64) -> f64 {
472        if tps <= 0.0 { return Self::LATENCY_CEILING_MS; }
473        // Assume ~200 tokens per response as baseline
474        (200.0 / tps) * 1000.0
475    }
476
477    /// Schema-based quality estimate (cold start).
478    ///
479    /// Diminishing returns on model size — the jump from 4B to 8B matters,
480    /// but 8B to 30B is marginal. Remote models get a conservative estimate.
481    fn schema_quality_estimate(&self, model: &ModelSchema) -> f64 {
482        match model.size_mb() {
483            0 => 0.5,             // remote: unknown, conservative
484            s if s < 1000 => 0.4, // 0.6B
485            s if s < 2000 => 0.5, // 1.7B
486            s if s < 3000 => 0.6, // 4B
487            s if s < 6000 => 0.7, // 8B
488            _ => 0.75,            // 30B+: diminishing returns
489        }
490    }
491
492    /// Schema-based latency estimate (cold start).
493    ///
494    /// Converts declared TPS/p50 to the same ms-based score used by observed data,
495    /// so there's no discontinuity when the first observation arrives.
496    fn schema_latency_estimate(&self, model: &ModelSchema) -> f64 {
497        let is_moe = model.tags.contains(&"moe".to_string());
498
499        if model.is_local() {
500            if let Some(tps) = model.performance.tokens_per_second {
501                let effective_tps = if is_moe { tps * Self::MOE_TPS_MULTIPLIER } else { tps };
502                let estimated_ms = Self::tps_to_latency_ms(effective_tps);
503                return self.latency_ms_to_score(estimated_ms);
504            }
505            return 0.5; // local, no declared TPS
506        }
507
508        // Remote: use declared p50 latency
509        if let Some(p50) = model.performance.latency_p50_ms {
510            return self.latency_ms_to_score(p50 as f64);
511        }
512        0.3 // remote, no declared latency
513    }
514
515    /// Phase 3: Epsilon-greedy selection.
516    fn select_with_exploration(
517        &self,
518        scored: &[(String, f64)],
519        tracker: &OutcomeTracker,
520    ) -> (String, RoutingStrategy) {
521        if scored.is_empty() {
522            return (String::new(), RoutingStrategy::SchemaBased);
523        }
524
525        let mut rng = rand::rng();
526
527        // With probability epsilon, explore an under-tested model
528        if rng.random::<f64>() < self.config.exploration_rate {
529            // Find candidates with fewer than min_observations
530            let under_tested: Vec<&str> = scored.iter()
531                .filter(|(id, _)| {
532                    tracker.profile(id)
533                        .map(|p| p.total_calls < self.config.min_observations)
534                        .unwrap_or(true) // never tested = definitely explore
535                })
536                .map(|(id, _)| id.as_str())
537                .collect();
538
539            if !under_tested.is_empty() {
540                let idx = rng.random_range(0..under_tested.len());
541                return (under_tested[idx].to_string(), RoutingStrategy::Exploration);
542            }
543        }
544
545        // Exploit: pick the highest-scored model
546        let best = &scored[0];
547        let strategy = if tracker.profile(&best.0)
548            .map(|p| p.total_calls >= self.config.min_observations)
549            .unwrap_or(false)
550        {
551            RoutingStrategy::ProfileBased
552        } else {
553            RoutingStrategy::SchemaBased
554        };
555
556        (best.0.clone(), strategy)
557    }
558
559    /// Fallback decision when no candidates pass filtering.
560    fn cold_start_decision(
561        &self,
562        complexity: TaskComplexity,
563        task: InferenceTask,
564        registry: &UnifiedRegistry,
565    ) -> AdaptiveRoutingDecision {
566        // Fall back to the old complexity-based defaults
567        let model_name = match complexity {
568            TaskComplexity::Simple => "Qwen3-0.6B",
569            TaskComplexity::Medium => "Qwen3-1.7B",
570            TaskComplexity::Code => "Qwen3-4B",
571            TaskComplexity::Complex => &self.hw.recommended_model,
572        };
573
574        let model_id = registry.find_by_name(model_name)
575            .map(|m| m.id.clone())
576            .unwrap_or_else(|| model_name.to_string());
577
578        AdaptiveRoutingDecision {
579            model_id,
580            model_name: model_name.to_string(),
581            task,
582            complexity,
583            reason: format!("{:?} task → {} (cold start, no candidates)", complexity, model_name),
584            strategy: RoutingStrategy::SchemaBased,
585            predicted_quality: 0.5,
586            fallbacks: vec![],
587        }
588    }
589}
590
591#[cfg(test)]
592mod tests {
593    use super::*;
594    use crate::outcome::InferredOutcome;
595
596    fn test_hw() -> HardwareInfo {
597        HardwareInfo {
598            os: "macos".into(),
599            arch: "aarch64".into(),
600            cpu_cores: 10,
601            total_ram_mb: 32768,
602            gpu_backend: crate::hardware::GpuBackend::Metal,
603            gpu_memory_mb: Some(28672),
604            recommended_model: "Qwen3-8B".into(),
605            recommended_context: 8192,
606            max_model_mb: 18000, // headroom above 30B-A3B's 17000MB
607        }
608    }
609
610    fn test_registry() -> UnifiedRegistry {
611        let tmp = std::path::PathBuf::from("/tmp/car-test-adaptive-router");
612        // Create fake model dirs so the registry marks them as available
613        for name in &["Qwen3-0.6B", "Qwen3-1.7B", "Qwen3-4B", "Qwen3-8B", "Qwen3-Embedding-0.6B"] {
614            let dir = tmp.join(name);
615            let _ = std::fs::create_dir_all(&dir);
616            let _ = std::fs::write(dir.join("model.gguf"), b"fake");
617            let _ = std::fs::write(dir.join("tokenizer.json"), b"{}");
618        }
619        UnifiedRegistry::new(tmp)
620    }
621
622    #[test]
623    fn routes_simple_to_balanced_local() {
624        let router = AdaptiveRouter::new(test_hw(), RoutingConfig {
625            exploration_rate: 0.0,
626            ..Default::default()
627        });
628        let reg = test_registry();
629        let tracker = OutcomeTracker::new();
630
631        let decision = router.route("What is 2+2?", &reg, &tracker);
632        assert_eq!(decision.complexity, TaskComplexity::Simple);
633        assert_eq!(decision.strategy, RoutingStrategy::SchemaBased);
634        // On cold start, the router balances quality vs latency.
635        // The selected model should be local (not remote).
636        let schema = reg.find_by_name(&decision.model_name);
637        assert!(schema.is_some(), "selected model should exist in registry");
638        assert!(schema.unwrap().is_local(), "simple task should route to local model");
639    }
640
641    #[test]
642    fn routes_code_to_code_capable_local() {
643        let router = AdaptiveRouter::new(test_hw(), RoutingConfig {
644            exploration_rate: 0.0,
645            ..Default::default()
646        });
647        let reg = test_registry();
648        let tracker = OutcomeTracker::new();
649
650        let decision = router.route("Fix this function:\n```rust\nfn main() {}\n```", &reg, &tracker);
651        assert_eq!(decision.complexity, TaskComplexity::Code);
652        assert_eq!(decision.task, InferenceTask::Code);
653        // 0.6B has no Code capability, so smallest code-capable model wins
654        assert_eq!(decision.model_name, "Qwen3-1.7B");
655    }
656
657    #[test]
658    fn profile_based_routing_selects_proven_model() {
659        let router = AdaptiveRouter::new(test_hw(), RoutingConfig {
660            exploration_rate: 0.0,
661            min_observations: 3,
662            ..Default::default()
663        });
664        let reg = test_registry();
665        let mut tracker = OutcomeTracker::new();
666
667        // Build a strong profile for Qwen3-8B on code tasks (fast + high quality)
668        let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
669        for _ in 0..5 {
670            let trace = tracker.record_start(qwen_8b_id, InferenceTask::Code, "test");
671            tracker.record_complete(&trace, 500, 100, 50); // 500ms avg
672            tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
673        }
674
675        let decision = router.route("Fix this bug in the parser", &reg, &tracker);
676        assert_eq!(decision.complexity, TaskComplexity::Code);
677        // With strong observed data (quality ~0.95, latency 500ms), 8B should beat
678        // the 1.7B's cold-start schema score
679        assert_eq!(decision.model_name, "Qwen3-8B");
680        assert_eq!(decision.strategy, RoutingStrategy::ProfileBased);
681    }
682
683    #[test]
684    fn fallback_chain_has_alternatives() {
685        let router = AdaptiveRouter::new(test_hw(), RoutingConfig {
686            exploration_rate: 0.0,
687            ..Default::default()
688        });
689        let reg = test_registry();
690        let tracker = OutcomeTracker::new();
691
692        let decision = router.route("Analyze the architecture trade-offs", &reg, &tracker);
693        assert!(!decision.fallbacks.is_empty());
694        // Primary should not appear in fallbacks
695        assert!(!decision.fallbacks.contains(&decision.model_id));
696    }
697
698    #[test]
699    fn latency_scoring_is_consistent() {
700        // Verify that schema and observed latency produce comparable scores
701        let router = AdaptiveRouter::with_default_config(test_hw());
702
703        // A model at 25 TPS → ~200/25*1000 = 8000ms → score = 1 - 8000/10000 = 0.2
704        let schema_score = router.latency_ms_to_score(AdaptiveRouter::tps_to_latency_ms(25.0));
705        // Same model observed at 8000ms → same formula
706        let observed_score = router.latency_ms_to_score(8000.0);
707        assert!((schema_score - observed_score).abs() < 0.01,
708            "schema ({schema_score}) and observed ({observed_score}) should match");
709    }
710
711    #[test]
712    fn complexity_assessment() {
713        assert_eq!(TaskComplexity::assess("What is the capital of France?"), TaskComplexity::Simple);
714        assert_eq!(TaskComplexity::assess("Fix this broken test"), TaskComplexity::Code);
715        assert_eq!(TaskComplexity::assess("Analyze the trade-offs between A and B"), TaskComplexity::Complex);
716    }
717}