1use rand::Rng;
13use serde::{Deserialize, Serialize};
14
15use std::sync::{Arc, Mutex};
16
17use crate::hardware::HardwareInfo;
18use crate::outcome::{InferenceTask, OutcomeTracker};
19use crate::registry::UnifiedRegistry;
20use crate::routing_ext::CircuitBreakerRegistry;
21use crate::schema::{ModelCapability, ModelSchema};
22use crate::tasks::RoutingWorkload;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case")]
27pub enum TaskComplexity {
28 Simple,
29 Medium,
30 Code,
31 Complex,
32}
33
34impl TaskComplexity {
35 pub fn assess(prompt: &str) -> Self {
42 let lower = prompt.to_lowercase();
43 let word_count = prompt.split_whitespace().count();
44 let estimated_tokens = (word_count as f64 * 1.3) as usize;
45
46 let has_code = Self::detect_code(prompt);
47
48 let repair_markers = [
49 "fix", "repair", "debug", "refactor", "broken", "failing", "error", "bug",
50 ];
51 let has_repair = repair_markers.iter().any(|m| lower.contains(m));
52
53 let reasoning_markers = [
54 "analyze",
55 "compare",
56 "explain why",
57 "step by step",
58 "think through",
59 "evaluate",
60 "trade-off",
61 "tradeoff",
62 "pros and cons",
63 "architecture",
64 "design",
65 "strategy",
66 "optimize",
67 "comprehensive",
68 ];
69 let has_reasoning = reasoning_markers.iter().any(|m| lower.contains(m));
70
71 let simple_patterns = [
72 "what is",
73 "who is",
74 "when did",
75 "where is",
76 "how many",
77 "yes or no",
78 "true or false",
79 "name the",
80 "list the",
81 "define ",
82 ];
83 let is_simple = simple_patterns.iter().any(|p| lower.contains(p));
84
85 if has_code || has_repair {
86 TaskComplexity::Code
87 } else if has_reasoning || estimated_tokens > 500 {
88 TaskComplexity::Complex
89 } else if is_simple || estimated_tokens < 30 {
90 TaskComplexity::Simple
91 } else {
92 TaskComplexity::Medium
93 }
94 }
95
96 fn detect_code(prompt: &str) -> bool {
102 #[cfg(feature = "ast")]
104 {
105 if let Some(is_code) = Self::detect_code_ast(prompt) {
106 return is_code;
107 }
108 }
109
110 let code_markers = [
112 "```",
113 "fn ",
114 "def ",
115 "class ",
116 "import ",
117 "require(",
118 "async fn",
119 "pub fn",
120 "function ",
121 "const ",
122 "let ",
123 "var ",
124 "#include",
125 "package ",
126 "impl ",
127 ];
128 code_markers.iter().any(|m| prompt.contains(m))
129 }
130
131 #[cfg(feature = "ast")]
135 fn detect_code_ast(prompt: &str) -> Option<bool> {
136 let mut blocks = Vec::new();
138 let mut rest = prompt;
139 while let Some(start) = rest.find("```") {
140 let after_fence = &rest[start + 3..];
141 let code_start = after_fence.find('\n').map(|i| i + 1).unwrap_or(0);
143 if let Some(end) = after_fence[code_start..].find("```") {
144 blocks.push(&after_fence[code_start..code_start + end]);
145 rest = &after_fence[code_start + end + 3..];
146 } else {
147 break;
148 }
149 }
150
151 if blocks.is_empty() {
152 return None; }
154
155 let languages = [
157 car_ast::Language::Rust,
158 car_ast::Language::Python,
159 car_ast::Language::TypeScript,
160 car_ast::Language::JavaScript,
161 car_ast::Language::Go,
162 ];
163
164 for block in &blocks {
165 let trimmed = block.trim();
166 if trimmed.is_empty() {
167 continue;
168 }
169
170 for lang in &languages {
171 if let Some(parsed) = car_ast::parse(trimmed, *lang) {
172 if !parsed.symbols.is_empty() {
174 return Some(true);
175 }
176 }
177 }
178 }
179
180 Some(false)
183 }
184
185 pub fn required_capabilities(&self) -> Vec<ModelCapability> {
187 match self {
188 TaskComplexity::Simple => vec![ModelCapability::Generate],
189 TaskComplexity::Medium => vec![ModelCapability::Generate],
190 TaskComplexity::Code => vec![ModelCapability::Code],
191 TaskComplexity::Complex => vec![ModelCapability::Reasoning],
192 }
193 }
194
195 pub fn inference_task(&self) -> InferenceTask {
197 match self {
198 TaskComplexity::Simple | TaskComplexity::Medium => InferenceTask::Generate,
199 TaskComplexity::Code => InferenceTask::Code,
200 TaskComplexity::Complex => InferenceTask::Reasoning,
201 }
202 }
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct RoutingConfig {
208 pub min_observations: u64,
210 pub quality_weight: f64,
212 pub latency_weight: f64,
213 pub cost_weight: f64,
214 pub max_latency_ms: Option<u64>,
216 pub max_cost_usd: Option<f64>,
218 pub prefer_local: bool,
220 pub prior_strength: f64,
224 pub quality_first_cold_start: bool,
227 pub bootstrap_min_task_observations: u64,
230 pub bootstrap_quality_floor: f64,
233}
234
235impl Default for RoutingConfig {
236 fn default() -> Self {
237 Self {
238 min_observations: 2,
239 quality_weight: 0.45,
240 latency_weight: 0.4,
241 cost_weight: 0.15,
242 max_latency_ms: None,
243 max_cost_usd: None,
244 prefer_local: true,
245 prior_strength: 2.0,
246 quality_first_cold_start: true,
247 bootstrap_min_task_observations: 8,
248 bootstrap_quality_floor: 0.8,
249 }
250 }
251}
252
253#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
255#[serde(rename_all = "snake_case")]
256pub enum RoutingStrategy {
257 SchemaBased,
259 ProfileBased,
261 Exploration,
263 Explicit,
265}
266
267#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct AdaptiveRoutingDecision {
270 pub model_id: String,
272 pub model_name: String,
274 pub task: InferenceTask,
276 pub complexity: TaskComplexity,
278 pub reason: String,
280 pub strategy: RoutingStrategy,
282 pub predicted_quality: f64,
284 pub fallbacks: Vec<String>,
286 pub context_length: usize,
288 pub needs_compaction: bool,
290}
291
292pub struct AdaptiveRouter {
294 hw: HardwareInfo,
295 config: RoutingConfig,
296 pub circuit_breakers: Arc<Mutex<CircuitBreakerRegistry>>,
298}
299
300pub struct RouteRequest<'a> {
316 pub prompt: &'a str,
317 pub registry: &'a UnifiedRegistry,
318 pub tracker: &'a OutcomeTracker,
319 pub estimated_total_tokens: usize,
322 pub has_tools: bool,
323 pub has_vision: bool,
324 pub workload: RoutingWorkload,
325 pub intent: Option<&'a crate::intent::IntentHint>,
328}
329
330impl<'a> RouteRequest<'a> {
331 pub fn new(
335 prompt: &'a str,
336 registry: &'a UnifiedRegistry,
337 tracker: &'a OutcomeTracker,
338 ) -> Self {
339 Self {
340 prompt,
341 registry,
342 tracker,
343 estimated_total_tokens: 0,
344 has_tools: false,
345 has_vision: false,
346 workload: RoutingWorkload::Interactive,
347 intent: None,
348 }
349 }
350}
351
352impl AdaptiveRouter {
353 pub fn new(hw: HardwareInfo, config: RoutingConfig) -> Self {
354 let circuit_breakers = Arc::new(Mutex::new(
355 CircuitBreakerRegistry::new(3, 300), ));
357 Self {
358 hw,
359 config,
360 circuit_breakers,
361 }
362 }
363
364 pub fn with_default_config(hw: HardwareInfo) -> Self {
365 Self::new(hw, RoutingConfig::default())
366 }
367
368 pub fn config(&self) -> &RoutingConfig {
369 &self.config
370 }
371
372 pub fn set_config(&mut self, config: RoutingConfig) {
373 self.config = config;
374 }
375
376 pub fn route_with(&self, req: RouteRequest<'_>) -> AdaptiveRoutingDecision {
382 let workload = match req.intent {
388 Some(h) if h.prefer_fast => RoutingWorkload::Fastest,
389 Some(h) if h.prefer_local => RoutingWorkload::LocalPreferred,
390 _ => req.workload,
391 };
392 self.route_inner_with_intent(
393 req.prompt,
394 req.registry,
395 req.tracker,
396 req.has_tools,
397 req.has_vision,
398 req.estimated_total_tokens,
399 workload,
400 req.intent,
401 )
402 }
403
404 pub fn route(
407 &self,
408 prompt: &str,
409 registry: &UnifiedRegistry,
410 tracker: &OutcomeTracker,
411 ) -> AdaptiveRoutingDecision {
412 self.route_with(RouteRequest::new(prompt, registry, tracker))
413 }
414
415 pub fn route_editor(
421 &self,
422 prompt: &str,
423 registry: &UnifiedRegistry,
424 tracker: &OutcomeTracker,
425 ) -> AdaptiveRoutingDecision {
426 self.route_with(RouteRequest {
427 workload: RoutingWorkload::Background,
428 ..RouteRequest::new(prompt, registry, tracker)
429 })
430 }
431
432 pub fn route_with_tools(
434 &self,
435 prompt: &str,
436 registry: &UnifiedRegistry,
437 tracker: &OutcomeTracker,
438 ) -> AdaptiveRoutingDecision {
439 self.route_with(RouteRequest {
440 has_tools: true,
441 ..RouteRequest::new(prompt, registry, tracker)
442 })
443 }
444
445 pub fn route_with_vision(
447 &self,
448 prompt: &str,
449 registry: &UnifiedRegistry,
450 tracker: &OutcomeTracker,
451 has_tools: bool,
452 ) -> AdaptiveRoutingDecision {
453 self.route_with(RouteRequest {
454 has_tools,
455 has_vision: true,
456 ..RouteRequest::new(prompt, registry, tracker)
457 })
458 }
459
460 pub fn route_with_intent<'a>(
466 &self,
467 prompt: &'a str,
468 registry: &'a UnifiedRegistry,
469 tracker: &'a OutcomeTracker,
470 intent: &'a crate::intent::IntentHint,
471 ) -> AdaptiveRoutingDecision {
472 self.route_with(RouteRequest {
473 intent: Some(intent),
474 ..RouteRequest::new(prompt, registry, tracker)
475 })
476 }
477
478 pub fn route_context_aware(
481 &self,
482 prompt: &str,
483 estimated_total_tokens: usize,
484 registry: &UnifiedRegistry,
485 tracker: &OutcomeTracker,
486 has_tools: bool,
487 has_vision: bool,
488 workload: RoutingWorkload,
489 ) -> AdaptiveRoutingDecision {
490 self.route_with(RouteRequest {
491 estimated_total_tokens,
492 has_tools,
493 has_vision,
494 workload,
495 ..RouteRequest::new(prompt, registry, tracker)
496 })
497 }
498
499 pub fn route_context_aware_with_intent<'a>(
504 &self,
505 prompt: &'a str,
506 estimated_total_tokens: usize,
507 registry: &'a UnifiedRegistry,
508 tracker: &'a OutcomeTracker,
509 has_tools: bool,
510 has_vision: bool,
511 workload: RoutingWorkload,
512 intent: &'a crate::intent::IntentHint,
513 ) -> AdaptiveRoutingDecision {
514 self.route_with(RouteRequest {
515 estimated_total_tokens,
516 has_tools,
517 has_vision,
518 workload,
519 intent: Some(intent),
520 ..RouteRequest::new(prompt, registry, tracker)
521 })
522 }
523
524 fn route_inner_with_intent(
525 &self,
526 prompt: &str,
527 registry: &UnifiedRegistry,
528 tracker: &OutcomeTracker,
529 has_tools: bool,
530 has_vision: bool,
531 estimated_total_tokens: usize,
532 workload: RoutingWorkload,
533 intent: Option<&crate::intent::IntentHint>,
534 ) -> AdaptiveRoutingDecision {
535 let complexity = TaskComplexity::assess(prompt);
536 let task = intent
538 .and_then(|h| h.task)
539 .map(task_hint_to_inference_task)
540 .unwrap_or_else(|| complexity.inference_task());
541 let mut required_caps = complexity.required_capabilities();
542 if let Some(hint) = intent {
543 for cap in &hint.require {
544 if !required_caps.contains(cap) {
545 required_caps.push(*cap);
546 }
547 }
548 }
549 if has_vision {
550 required_caps.push(ModelCapability::Vision);
551 }
552 if has_tools {
553 required_caps.push(ModelCapability::ToolUse);
554 if Self::needs_multi_tool_call(prompt) {
557 required_caps.push(ModelCapability::MultiToolCall);
558 }
559 }
560
561 let mut candidates = self.filter_candidates(&required_caps, registry, tracker);
563
564 if candidates.is_empty() && required_caps.contains(&ModelCapability::MultiToolCall) {
567 required_caps.retain(|c| *c != ModelCapability::MultiToolCall);
568 candidates = self.filter_candidates(&required_caps, registry, tracker);
569 }
570
571 if candidates.is_empty() {
572 return self.cold_start_decision(complexity, task, registry, has_vision);
574 }
575
576 candidates = self.apply_quality_first_bootstrap_policy(
577 candidates, task, tracker, has_vision, has_tools, workload,
578 );
579
580 let (fits, needs_compaction_candidates) = if estimated_total_tokens > 0 {
583 let mut fits = Vec::new();
584 let mut tight = Vec::new();
585 for m in &candidates {
586 if m.context_length == 0 || m.context_length >= estimated_total_tokens {
587 fits.push(m.clone());
588 } else {
589 tight.push(m.clone());
590 }
591 }
592 (fits, tight)
593 } else {
594 (candidates.clone(), Vec::new())
595 };
596
597 let (scoring_candidates, compaction_needed) = if !fits.is_empty() {
599 (fits, false)
600 } else if !needs_compaction_candidates.is_empty() {
601 tracing::info!(
602 prompt_tokens = estimated_total_tokens,
603 candidates = needs_compaction_candidates.len(),
604 "no model fits full prompt — compaction will be needed"
605 );
606 (needs_compaction_candidates.clone(), true)
607 } else {
608 (candidates.clone(), false)
609 };
610
611 let scored = self.score_candidates_context_aware(
613 &scoring_candidates,
614 task,
615 tracker,
616 estimated_total_tokens,
617 workload,
618 );
619
620 let (selected_id, strategy) = self.select_with_thompson_sampling(&scored, tracker);
622
623 let mut fallbacks: Vec<String> = scored
625 .iter()
626 .filter(|(id, _)| *id != selected_id)
627 .map(|(id, _)| id.clone())
628 .collect();
629 if !compaction_needed {
631 for m in &needs_compaction_candidates {
632 if m.id != selected_id && !fallbacks.contains(&m.id) {
633 fallbacks.push(m.id.clone());
634 }
635 }
636 }
637
638 let predicted_quality = scored
639 .iter()
640 .find(|(id, _)| *id == selected_id)
641 .map(|(_, score)| *score)
642 .unwrap_or(0.5);
643
644 let selected_schema = registry
645 .get(&selected_id)
646 .or_else(|| registry.find_by_name(&selected_id));
647 let model_name = selected_schema
648 .map(|m| m.name.clone())
649 .unwrap_or_else(|| selected_id.clone());
650 let context_length = selected_schema.map(|m| m.context_length).unwrap_or(0);
651
652 let needs_compact = compaction_needed
653 || (estimated_total_tokens > 0
654 && context_length > 0
655 && estimated_total_tokens > context_length);
656
657 let compaction_note = if needs_compact {
658 format!(
659 " [compaction needed: {}→{}tok]",
660 estimated_total_tokens, context_length
661 )
662 } else {
663 String::new()
664 };
665
666 let reason = format!(
667 "{:?} task → {} via {:?} (quality: {:.2}, {} candidates){}",
668 complexity,
669 model_name,
670 strategy,
671 predicted_quality,
672 scoring_candidates.len(),
673 compaction_note,
674 );
675
676 AdaptiveRoutingDecision {
677 model_id: selected_id,
678 model_name,
679 task,
680 complexity,
681 reason,
682 strategy,
683 predicted_quality,
684 fallbacks,
685 context_length,
686 needs_compaction: needs_compact,
687 }
688 }
689
690 pub fn route_embedding(&self, registry: &UnifiedRegistry) -> String {
692 let embed_models = registry.query_by_capability(ModelCapability::Embed);
693 embed_models
694 .first()
695 .map(|m| m.name.clone())
696 .unwrap_or_else(|| "Qwen3-Embedding-0.6B".to_string())
697 }
698
699 pub fn route_small(&self, registry: &UnifiedRegistry) -> String {
701 let gen_models = registry.query_by_capability(ModelCapability::Generate);
702 gen_models
704 .iter()
705 .filter(|m| m.is_local())
706 .min_by_key(|m| m.size_mb())
707 .map(|m| m.name.clone())
708 .unwrap_or_else(|| "Qwen3-0.6B".to_string())
709 }
710
711 const LATENCY_CEILING_MS: f64 = 10000.0;
717 const _TPS_CEILING: f64 = 150.0;
719 const MOE_TPS_MULTIPLIER: f64 = 0.10;
721 const MLX_MOE_TPS_MULTIPLIER: f64 = 0.50;
723 const COST_CEILING_PER_1K: f64 = 0.1;
725 const LOCAL_BONUS: f64 = 0.15;
727 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
735 const HAS_GPU_BACKEND: bool = true;
736 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
737 const HAS_GPU_BACKEND: bool = false;
738 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
741 const MLX_BONUS: f64 = 0.10;
742 const SYSTEM_LLM_BONUS: f64 = 0.12;
753
754 fn filter_candidates(
758 &self,
759 required_caps: &[ModelCapability],
760 registry: &UnifiedRegistry,
761 tracker: &OutcomeTracker,
762 ) -> Vec<ModelSchema> {
763 registry
764 .list()
765 .into_iter()
766 .filter(|m| {
767 if !required_caps.iter().all(|c| m.has_capability(*c)) {
769 return false;
770 }
771 if !m.available {
773 return false;
774 }
775 if m.is_local() && m.size_mb() >= self.hw.max_model_mb {
777 return false;
778 }
779 if let Some(max) = self.config.max_latency_ms {
781 if let Some(p50) = m.performance.latency_p50_ms {
782 if p50 > max {
783 return false;
784 }
785 }
786 }
787 if let Some(max) = self.config.max_cost_usd {
789 if m.cost_per_1k_output() > max {
790 return false;
791 }
792 }
793 if !self.config.prefer_local && m.is_local() {
795 return false;
796 }
797 if tracker.is_excluded(&m.id) {
799 return false;
800 }
801 if let Ok(mut cb) = self.circuit_breakers.lock() {
803 if !cb.allow_request(&m.id) {
804 tracing::debug!(model = %m.id, "skipped by circuit breaker");
805 return false;
806 }
807 }
808 true
809 })
810 .cloned()
811 .collect()
812 }
813
814 fn apply_quality_first_bootstrap_policy(
815 &self,
816 candidates: Vec<ModelSchema>,
817 task: InferenceTask,
818 tracker: &OutcomeTracker,
819 has_vision: bool,
820 has_tools: bool,
821 workload: RoutingWorkload,
822 ) -> Vec<ModelSchema> {
823 if !self.config.quality_first_cold_start
824 || !workload.is_latency_sensitive()
825 || !self.is_quality_critical_bootstrap_task(task, has_vision, has_tools)
826 {
827 return candidates;
828 }
829
830 let trusted_remote: Vec<ModelSchema> = candidates
831 .iter()
832 .filter(|model| self.is_trusted_quality_remote(model))
833 .cloned()
834 .collect();
835
836 if trusted_remote.is_empty() {
837 return candidates;
838 }
839
840 let proven_local: Vec<ModelSchema> = candidates
841 .iter()
842 .filter(|model| {
843 model.is_local() && self.is_local_model_proven_for_task(model, task, tracker)
844 })
845 .cloned()
846 .collect();
847
848 if !proven_local.is_empty() {
849 return proven_local;
850 }
851
852 trusted_remote
853 }
854
855 fn score_candidates_context_aware(
859 &self,
860 candidates: &[ModelSchema],
861 task: InferenceTask,
862 tracker: &OutcomeTracker,
863 estimated_total_tokens: usize,
864 workload: RoutingWorkload,
865 ) -> Vec<(String, f64)> {
866 let mut scored: Vec<(String, f64)> = candidates
867 .iter()
868 .map(|m| {
869 let base_score = self.score_model(m, task, tracker, workload);
870 let headroom_bonus = if estimated_total_tokens > 0 && m.context_length > 0 {
873 let ratio = m.context_length as f64 / estimated_total_tokens as f64;
874 if ratio >= 1.0 {
875 (ratio.min(4.0) - 1.0) / 3.0 * 0.10 } else {
877 -0.15 }
879 } else {
880 0.0
881 };
882 (m.id.clone(), base_score + headroom_bonus)
883 })
884 .collect();
885
886 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
887 scored
888 }
889
890 fn score_model(
893 &self,
894 model: &ModelSchema,
895 task: InferenceTask,
896 tracker: &OutcomeTracker,
897 workload: RoutingWorkload,
898 ) -> f64 {
899 let profile = tracker.profile(&model.id);
900 let schema_quality = self.schema_quality_estimate(model);
901 let schema_latency = self.schema_latency_estimate(model);
902 let (quality_weight, latency_weight, cost_weight) = workload.weights();
903
904 let quality = match profile {
907 Some(p) if p.total_calls >= self.config.min_observations => p
908 .task_stats(task)
909 .map(|ts| ts.ema_quality)
910 .unwrap_or(p.ema_quality),
911 Some(p) if p.total_calls == 0 => p
912 .task_stats(task)
913 .map(|ts| ts.ema_quality)
914 .unwrap_or(p.ema_quality),
915 Some(p) if p.total_calls > 0 => {
916 let w = p.total_calls as f64 / self.config.min_observations as f64;
917 schema_quality * (1.0 - w) + p.ema_quality * w
918 }
919 _ => schema_quality,
920 };
921
922 let latency = match profile {
925 Some(p) if p.total_calls >= self.config.min_observations => {
926 let avg = p
927 .task_stats(task)
928 .filter(|ts| ts.calls > 0 || ts.avg_latency_ms > 0.0)
929 .map(|ts| ts.avg_latency_ms)
930 .unwrap_or_else(|| p.avg_latency_ms());
931 self.latency_ms_to_score(avg)
932 }
933 Some(p) if p.total_calls == 0 => p
934 .task_stats(task)
935 .filter(|ts| ts.avg_latency_ms > 0.0)
936 .map(|ts| self.latency_ms_to_score(ts.avg_latency_ms))
937 .unwrap_or(schema_latency),
938 Some(p) if p.total_calls > 0 => {
939 let observed = self.latency_ms_to_score(
940 p.task_stats(task)
941 .filter(|ts| ts.calls > 0 || ts.avg_latency_ms > 0.0)
942 .map(|ts| ts.avg_latency_ms)
943 .unwrap_or_else(|| p.avg_latency_ms()),
944 );
945 let w = p.total_calls as f64 / self.config.min_observations as f64;
946 schema_latency * (1.0 - w) + observed * w
947 }
948 _ => schema_latency,
949 };
950
951 let cost = if model.is_local() {
953 1.0
954 } else {
955 (1.0 - (model.cost_per_1k_output() / Self::COST_CEILING_PER_1K)).clamp(0.0, 1.0)
956 };
957
958 let local_bonus = if self.config.prefer_local && model.is_local() && Self::HAS_GPU_BACKEND {
967 Self::LOCAL_BONUS
968 } else {
969 if self.config.prefer_local && model.is_local() && !Self::HAS_GPU_BACKEND {
970 tracing::debug!(
971 model = %model.id,
972 "LOCAL_BONUS suppressed: no GPU backend on this host (Intel Mac or car_skip_mlx); cloud models will rank higher"
973 );
974 }
975 0.0
976 };
977 let workload_local_bonus = if model.is_local() {
978 workload.local_bonus()
979 } else {
980 0.0
981 };
982
983 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
985 let mlx_bonus = if model.is_mlx() { Self::MLX_BONUS } else { 0.0 };
986 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
987 let mlx_bonus = 0.0;
988
989 let vllm_mlx_bonus = if model.is_vllm_mlx() {
991 Self::LOCAL_BONUS + 0.05
992 } else {
993 0.0
994 };
995
996 let system_llm_bonus = if model.tags.iter().any(|t| t == "low_latency")
1004 && model.tags.iter().any(|t| t == "private")
1005 {
1006 Self::SYSTEM_LLM_BONUS
1007 } else {
1008 0.0
1009 };
1010
1011 quality_weight * quality
1012 + latency_weight * latency
1013 + cost_weight * cost
1014 + local_bonus
1015 + workload_local_bonus
1016 + mlx_bonus
1017 + vllm_mlx_bonus
1018 + system_llm_bonus
1019 }
1020
1021 fn latency_ms_to_score(&self, ms: f64) -> f64 {
1024 (1.0 - (ms / Self::LATENCY_CEILING_MS)).clamp(0.0, 1.0)
1025 }
1026
1027 fn tps_to_latency_ms(tps: f64) -> f64 {
1029 if tps <= 0.0 {
1030 return Self::LATENCY_CEILING_MS;
1031 }
1032 (200.0 / tps) * 1000.0
1034 }
1035
1036 fn needs_multi_tool_call(prompt: &str) -> bool {
1039 let lower = prompt.to_lowercase();
1040
1041 let has_numbered_list = {
1043 let mut count = 0u32;
1044 for i in 1..=5u32 {
1045 if lower.contains(&format!("{}) ", i)) || lower.contains(&format!("{}. ", i)) {
1046 count += 1;
1047 }
1048 }
1049 count >= 2
1050 };
1051
1052 let multi_keywords = [
1054 "multiple edits",
1055 "several changes",
1056 "three changes",
1057 "two changes",
1058 "all of the following",
1059 "each of these",
1060 "do both",
1061 "do all",
1062 "and also",
1063 "additionally",
1064 "as well as",
1065 "then also",
1066 ];
1067 let has_multi_keywords = multi_keywords.iter().any(|kw| lower.contains(kw));
1068
1069 let bullet_actions = lower.matches("- add ").count()
1071 + lower.matches("- update ").count()
1072 + lower.matches("- change ").count()
1073 + lower.matches("- remove ").count()
1074 + lower.matches("- fix ").count()
1075 + lower.matches("- edit ").count()
1076 + lower.matches("- implement ").count()
1077 + lower.matches("- create ").count();
1078 let has_bullet_list = bullet_actions >= 2;
1079
1080 has_numbered_list || has_multi_keywords || has_bullet_list
1081 }
1082
1083 fn schema_quality_estimate(&self, model: &ModelSchema) -> f64 {
1088 match model.size_mb() {
1089 0 => 0.5, s if s < 1000 => 0.4, s if s < 2000 => 0.5, s if s < 3000 => 0.6, s if s < 6000 => 0.7, _ => 0.75, }
1096 }
1097
1098 fn schema_latency_estimate(&self, model: &ModelSchema) -> f64 {
1103 let is_moe = model.tags.contains(&"moe".to_string());
1104
1105 if model.is_local() {
1106 if let Some(tps) = model.performance.tokens_per_second {
1107 let effective_tps = if is_moe {
1108 let multiplier = if model.is_mlx() {
1109 Self::MLX_MOE_TPS_MULTIPLIER
1110 } else {
1111 Self::MOE_TPS_MULTIPLIER
1112 };
1113 tps * multiplier
1114 } else {
1115 tps
1116 };
1117 let estimated_ms = Self::tps_to_latency_ms(effective_tps);
1118 return self.latency_ms_to_score(estimated_ms);
1119 }
1120 return 0.5; }
1122
1123 if let Some(p50) = model.performance.latency_p50_ms {
1125 return self.latency_ms_to_score(p50 as f64);
1126 }
1127 0.3 }
1129
1130 fn is_quality_critical_bootstrap_task(
1131 &self,
1132 task: InferenceTask,
1133 has_vision: bool,
1134 has_tools: bool,
1135 ) -> bool {
1136 has_vision
1137 || has_tools
1138 || matches!(
1139 task,
1140 InferenceTask::Generate | InferenceTask::Code | InferenceTask::Reasoning
1141 )
1142 }
1143
1144 fn is_trusted_quality_remote(&self, model: &ModelSchema) -> bool {
1145 model.is_remote()
1146 && matches!(model.provider.as_str(), "openai" | "anthropic" | "google")
1147 && !model.has_capability(ModelCapability::SpeechToText)
1148 && !model.has_capability(ModelCapability::TextToSpeech)
1149 }
1150
1151 fn is_local_model_proven_for_task(
1152 &self,
1153 model: &ModelSchema,
1154 task: InferenceTask,
1155 tracker: &OutcomeTracker,
1156 ) -> bool {
1157 let Some(profile) = tracker.profile(&model.id) else {
1158 return false;
1159 };
1160 if let Some(task_stats) = profile.task_stats(task) {
1161 if task_stats.calls >= self.config.bootstrap_min_task_observations
1162 && task_stats.ema_quality >= self.config.bootstrap_quality_floor
1163 {
1164 return true;
1165 }
1166 }
1167
1168 profile.total_calls >= self.config.bootstrap_min_task_observations
1169 && profile.ema_quality >= self.config.bootstrap_quality_floor
1170 }
1171
1172 fn select_with_thompson_sampling(
1182 &self,
1183 scored: &[(String, f64)],
1184 tracker: &OutcomeTracker,
1185 ) -> (String, RoutingStrategy) {
1186 if scored.is_empty() {
1187 return (String::new(), RoutingStrategy::SchemaBased);
1188 }
1189
1190 let mut rng = rand::rng();
1191 let mut best_sample = f64::NEG_INFINITY;
1192 let mut best_id = scored[0].0.clone();
1193 let mut best_strategy = RoutingStrategy::SchemaBased;
1194
1195 for (id, phase2_score) in scored {
1196 let profile = tracker.profile(id);
1197 let prior = self.config.prior_strength;
1198
1199 let prior_mean = phase2_score.clamp(0.0, 1.0);
1201
1202 let prior_alpha = prior * prior_mean;
1204 let prior_beta = prior * (1.0 - prior_mean);
1205
1206 let (obs_alpha, obs_beta) = match profile {
1208 Some(p) => (p.success_count as f64, p.fail_count as f64),
1209 None => (0.0, 0.0),
1210 };
1211
1212 let alpha = (prior_alpha + obs_alpha).max(0.01);
1214 let beta = (prior_beta + obs_beta).max(0.01);
1215
1216 let sample = sample_beta(&mut rng, alpha, beta);
1218
1219 if sample > best_sample {
1220 best_sample = sample;
1221 best_id = id.clone();
1222 best_strategy = match profile {
1223 Some(p) if p.total_calls >= self.config.min_observations => {
1224 RoutingStrategy::ProfileBased
1225 }
1226 Some(p) if p.total_calls > 0 => {
1227 RoutingStrategy::Exploration
1229 }
1230 _ => RoutingStrategy::SchemaBased,
1231 };
1232 }
1233 }
1234
1235 (best_id, best_strategy)
1236 }
1237
1238 fn cold_start_decision(
1240 &self,
1241 complexity: TaskComplexity,
1242 task: InferenceTask,
1243 registry: &UnifiedRegistry,
1244 has_vision: bool,
1245 ) -> AdaptiveRoutingDecision {
1246 if has_vision {
1247 if let Some(model) = registry
1248 .query_by_capability(ModelCapability::Vision)
1249 .into_iter()
1250 .find(|model| model.available && self.is_trusted_quality_remote(model))
1251 .or_else(|| {
1252 registry
1253 .query_by_capability(ModelCapability::Vision)
1254 .first()
1255 .copied()
1256 })
1257 {
1258 return AdaptiveRoutingDecision {
1259 model_id: model.id.clone(),
1260 model_name: model.name.clone(),
1261 task,
1262 complexity,
1263 reason: format!(
1264 "{:?} task → {} (cold start, vision fallback)",
1265 complexity, model.name
1266 ),
1267 strategy: RoutingStrategy::SchemaBased,
1268 predicted_quality: 0.5,
1269 fallbacks: vec![],
1270 context_length: model.context_length,
1271 needs_compaction: false,
1272 };
1273 }
1274 }
1275
1276 if self.config.quality_first_cold_start {
1277 let required_caps = complexity.required_capabilities();
1278 if let Some(model) = registry
1279 .list()
1280 .into_iter()
1281 .filter(|model| {
1282 model.available
1283 && required_caps.iter().all(|cap| model.has_capability(*cap))
1284 && self.is_trusted_quality_remote(model)
1285 })
1286 .max_by(|a, b| {
1287 self.schema_quality_estimate(a)
1288 .partial_cmp(&self.schema_quality_estimate(b))
1289 .unwrap_or(std::cmp::Ordering::Equal)
1290 })
1291 {
1292 return AdaptiveRoutingDecision {
1293 model_id: model.id.clone(),
1294 model_name: model.name.clone(),
1295 task,
1296 complexity,
1297 reason: format!(
1298 "{:?} task → {} (quality-first cold start)",
1299 complexity, model.name
1300 ),
1301 strategy: RoutingStrategy::SchemaBased,
1302 predicted_quality: self.schema_quality_estimate(model),
1303 fallbacks: vec![],
1304 context_length: model.context_length,
1305 needs_compaction: false,
1306 };
1307 }
1308 }
1309
1310 let model_name = match complexity {
1312 TaskComplexity::Simple => "Qwen3-0.6B",
1313 TaskComplexity::Medium => "Qwen3-1.7B",
1314 TaskComplexity::Code => "Qwen3-4B",
1315 TaskComplexity::Complex => &self.hw.recommended_model,
1316 };
1317
1318 let model_id = registry
1319 .find_by_name(model_name)
1320 .map(|m| m.id.clone())
1321 .unwrap_or_else(|| model_name.to_string());
1322
1323 let context_length = registry
1324 .find_by_name(model_name)
1325 .map(|m| m.context_length)
1326 .unwrap_or(0);
1327
1328 AdaptiveRoutingDecision {
1329 model_id,
1330 model_name: model_name.to_string(),
1331 task,
1332 complexity,
1333 reason: format!(
1334 "{:?} task → {} (cold start, no candidates)",
1335 complexity, model_name
1336 ),
1337 strategy: RoutingStrategy::SchemaBased,
1338 predicted_quality: 0.5,
1339 fallbacks: vec![],
1340 context_length,
1341 needs_compaction: false,
1342 }
1343 }
1344}
1345
1346fn task_hint_to_inference_task(hint: crate::intent::TaskHint) -> InferenceTask {
1352 use crate::intent::TaskHint;
1353 match hint {
1354 TaskHint::Chat => InferenceTask::Generate,
1355 TaskHint::Classify => InferenceTask::Classify,
1356 TaskHint::Reasoning => InferenceTask::Reasoning,
1357 TaskHint::Code => InferenceTask::Code,
1358 }
1359}
1360
1361fn sample_beta(rng: &mut impl Rng, alpha: f64, beta: f64) -> f64 {
1369 let x = sample_gamma(rng, alpha);
1370 let y = sample_gamma(rng, beta);
1371 if x + y == 0.0 {
1372 0.5 } else {
1374 x / (x + y)
1375 }
1376}
1377
1378fn sample_gamma(rng: &mut impl Rng, shape: f64) -> f64 {
1381 if shape < 1.0 {
1382 let u: f64 = rng.random();
1384 return sample_gamma(rng, shape + 1.0) * u.powf(1.0 / shape);
1385 }
1386
1387 let d = shape - 1.0 / 3.0;
1389 let c = 1.0 / (9.0 * d).sqrt();
1390
1391 loop {
1392 let x: f64 = loop {
1393 let n = sample_standard_normal(rng);
1394 if 1.0 + c * n > 0.0 {
1395 break n;
1396 }
1397 };
1398
1399 let v = (1.0 + c * x).powi(3);
1400 let u: f64 = rng.random();
1401
1402 if u < 1.0 - 0.0331 * x.powi(4) {
1403 return d * v;
1404 }
1405 if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
1406 return d * v;
1407 }
1408 }
1409}
1410
1411fn sample_standard_normal(rng: &mut impl Rng) -> f64 {
1413 let u1: f64 = rng.random();
1414 let u2: f64 = rng.random();
1415 (-2.0 * u1.max(1e-300).ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
1416}
1417
1418#[cfg(test)]
1419mod tests {
1420 use super::*;
1421 use crate::outcome::InferredOutcome;
1422
1423 fn test_hw() -> HardwareInfo {
1424 HardwareInfo {
1425 os: "macos".into(),
1426 arch: "aarch64".into(),
1427 cpu_cores: 10,
1428 total_ram_mb: 32768,
1429 gpu_backend: crate::hardware::GpuBackend::Metal,
1430 gpu_memory_mb: Some(28672),
1431 gpu_devices: Vec::new(),
1432 recommended_model: "Qwen3-8B".into(),
1433 recommended_context: 8192,
1434 max_model_mb: 18000, }
1436 }
1437
1438 fn test_registry() -> UnifiedRegistry {
1439 let tmp = std::path::PathBuf::from("/tmp/car-test-adaptive-router");
1440 unsafe {
1441 std::env::set_var("OPENAI_API_KEY", "test-openai-key");
1442 }
1443 for name in &[
1445 "Qwen3-0.6B",
1446 "Qwen3-1.7B",
1447 "Qwen3-4B",
1448 "Qwen3-8B",
1449 "Qwen3-Embedding-0.6B",
1450 ] {
1451 let dir = tmp.join(name);
1452 let _ = std::fs::create_dir_all(&dir);
1453 let _ = std::fs::write(dir.join("model.gguf"), b"fake");
1454 let _ = std::fs::write(dir.join("tokenizer.json"), b"{}");
1455 }
1456 let mut reg = UnifiedRegistry::new(tmp);
1457 reg.register(ModelSchema {
1458 id: "openai/gpt-5.4-mini:latest".into(),
1459 name: "gpt-5.4-mini".into(),
1460 provider: "openai".into(),
1461 family: "gpt-5.4".into(),
1462 version: "latest".into(),
1463 capabilities: vec![
1464 ModelCapability::Generate,
1465 ModelCapability::Code,
1466 ModelCapability::Reasoning,
1467 ModelCapability::ToolUse,
1468 ModelCapability::MultiToolCall,
1469 ModelCapability::Vision,
1470 ],
1471 context_length: 128_000,
1472 param_count: "api".into(),
1473 quantization: None,
1474 performance: Default::default(),
1475 cost: Default::default(),
1476 source: crate::schema::ModelSource::RemoteApi {
1477 endpoint: "https://api.openai.com/v1".into(),
1478 api_key_env: "OPENAI_API_KEY".into(),
1479 api_key_envs: vec![],
1480 api_version: None,
1481 protocol: crate::schema::ApiProtocol::OpenAiCompat,
1482 },
1483 tags: vec!["trusted-remote".into()],
1484 supported_params: vec![],
1485 public_benchmarks: vec![],
1486 available: true,
1487 });
1488 reg
1489 }
1490
1491 #[test]
1492 fn routes_simple_to_trusted_remote_during_cold_start() {
1493 let router = AdaptiveRouter::new(
1494 test_hw(),
1495 RoutingConfig {
1496 prior_strength: 100.0, ..Default::default()
1498 },
1499 );
1500 let reg = test_registry();
1501 let tracker = OutcomeTracker::new();
1502
1503 let decision = router.route("What is 2+2?", ®, &tracker);
1504 assert_eq!(decision.complexity, TaskComplexity::Simple);
1505 assert_eq!(decision.strategy, RoutingStrategy::SchemaBased);
1506 let schema = reg
1508 .find_by_name(&decision.model_name)
1509 .expect("selected model should exist in registry");
1510 assert!(
1511 !schema.is_local(),
1512 "simple task should route to trusted remote model during cold start"
1513 );
1514 assert!(matches!(
1515 schema.provider.as_str(),
1516 "openai" | "anthropic" | "google"
1517 ));
1518 }
1519
1520 #[test]
1521 fn routes_code_to_code_capable_remote_during_cold_start() {
1522 let router = AdaptiveRouter::new(
1523 test_hw(),
1524 RoutingConfig {
1525 prior_strength: 100.0, ..Default::default()
1527 },
1528 );
1529 let reg = test_registry();
1530 let tracker = OutcomeTracker::new();
1531
1532 let decision = router.route(
1533 "Fix this function:\n```rust\nfn main() {}\n```",
1534 ®,
1535 &tracker,
1536 );
1537 assert_eq!(decision.complexity, TaskComplexity::Code);
1538 assert_eq!(decision.task, InferenceTask::Code);
1539 let schema = reg
1541 .find_by_name(&decision.model_name)
1542 .expect("model should exist");
1543 assert!(
1544 schema.has_capability(ModelCapability::Code),
1545 "selected model must support Code"
1546 );
1547 assert!(!schema.is_local(), "should route to trusted remote model");
1548 }
1549
1550 #[test]
1551 fn routes_images_to_vision_capable_model() {
1552 let router = AdaptiveRouter::new(
1553 test_hw(),
1554 RoutingConfig {
1555 prior_strength: 100.0,
1556 ..Default::default()
1557 },
1558 );
1559 let mut reg = test_registry();
1560 let tracker = OutcomeTracker::new();
1561
1562 reg.register(ModelSchema {
1563 id: "mlx-vlm/qwen3-vl-2b:bf16".into(),
1564 name: "Qwen3-VL-2B-mlx-vlm".into(),
1565 provider: "qwen".into(),
1566 family: "qwen3-vl".into(),
1567 version: "bf16".into(),
1568 capabilities: vec![
1569 ModelCapability::Generate,
1570 ModelCapability::Vision,
1571 ModelCapability::Grounding,
1572 ],
1573 context_length: 262_144,
1574 param_count: "2B".into(),
1575 quantization: None,
1576 performance: Default::default(),
1577 cost: Default::default(),
1578 source: crate::schema::ModelSource::Mlx {
1579 hf_repo: "Qwen/Qwen3-VL-2B-Instruct".into(),
1580 hf_weight_file: None,
1581 },
1582 tags: vec!["vision".into(), "mlx-vlm-cli".into()],
1583 supported_params: vec![],
1584 public_benchmarks: vec![],
1585 available: true,
1586 });
1587
1588 let decision = router.route_with_vision("What is in this image?", ®, &tracker, false);
1589 let schema = reg
1590 .find_by_name(&decision.model_name)
1591 .expect("model should exist");
1592 assert!(
1593 schema.has_capability(ModelCapability::Vision),
1594 "selected model must support Vision"
1595 );
1596 }
1597
1598 #[test]
1599 fn profile_based_routing_favors_proven_model() {
1600 let router = AdaptiveRouter::new(
1601 test_hw(),
1602 RoutingConfig {
1603 prior_strength: 0.5, min_observations: 3,
1605 ..Default::default()
1606 },
1607 );
1608 let reg = test_registry();
1609 let mut tracker = OutcomeTracker::new();
1610
1611 let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
1613 for _ in 0..20 {
1614 let trace = tracker.record_start(qwen_8b_id, InferenceTask::Code, "test");
1615 tracker.record_complete(&trace, 500, 100, 50);
1616 tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1617 }
1618
1619 let mut wins = 0;
1621 for _ in 0..20 {
1622 let decision = router.route("Fix this bug in the parser", ®, &tracker);
1623 assert_eq!(decision.complexity, TaskComplexity::Code);
1624 if decision.model_id == qwen_8b_id {
1625 wins += 1;
1626 }
1627 }
1628 assert!(
1629 wins >= 12,
1630 "proven model won only {wins}/20 times (expected >= 12)"
1631 );
1632 }
1633
1634 #[test]
1635 fn proven_local_model_can_displace_bootstrap_remote() {
1636 let router = AdaptiveRouter::new(
1637 test_hw(),
1638 RoutingConfig {
1639 prior_strength: 100.0,
1640 bootstrap_min_task_observations: 6,
1641 bootstrap_quality_floor: 0.8,
1642 ..Default::default()
1643 },
1644 );
1645 let reg = test_registry();
1646 let mut tracker = OutcomeTracker::new();
1647
1648 let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
1649 for _ in 0..12 {
1650 let trace = tracker.record_start(qwen_8b_id, InferenceTask::Generate, "test");
1651 tracker.record_complete(&trace, 300, 50, 20);
1652 tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1653 }
1654
1655 let mut local_wins = 0;
1656 for _ in 0..20 {
1657 let decision = router.route("Summarize this design decision.", ®, &tracker);
1658 let schema = reg
1659 .get(&decision.model_id)
1660 .expect("selected model should exist");
1661 if schema.is_local() {
1662 local_wins += 1;
1663 }
1664 }
1665
1666 assert!(
1667 local_wins >= 12,
1668 "proven local model won only {local_wins}/20 times (expected >= 12)"
1669 );
1670 }
1671
1672 #[test]
1673 fn benchmark_prior_informs_background_routing() {
1674 let router = AdaptiveRouter::new(
1675 test_hw(),
1676 RoutingConfig {
1677 prior_strength: 100.0,
1678 ..Default::default()
1679 },
1680 );
1681 let reg = test_registry();
1682 let mut tracker = OutcomeTracker::new();
1683 let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1684 profile.ema_quality = 0.95;
1685 tracker.import_profiles(vec![profile]);
1686
1687 let decision = router.route_context_aware(
1688 "Write a Python fibonacci function.",
1689 128,
1690 ®,
1691 &tracker,
1692 false,
1693 false,
1694 RoutingWorkload::Background,
1695 );
1696
1697 let schema = reg
1698 .get(&decision.model_id)
1699 .expect("selected model should exist");
1700 assert!(
1701 schema.is_local(),
1702 "background routing should allow strong local benchmark priors to win"
1703 );
1704 }
1705
1706 #[test]
1707 fn task_specific_benchmark_prior_informs_cold_start_routing() {
1708 let router = AdaptiveRouter::new(
1709 test_hw(),
1710 RoutingConfig {
1711 prior_strength: 100.0,
1712 ..Default::default()
1713 },
1714 );
1715 let reg = test_registry();
1716 let mut tracker = OutcomeTracker::new();
1717 let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1718 profile.task_stats.insert(
1719 crate::outcome::InferenceTask::Code.to_string(),
1720 crate::outcome::TaskStats {
1721 ema_quality: 0.95,
1722 ..Default::default()
1723 },
1724 );
1725 tracker.import_profiles(vec![profile]);
1726
1727 let decision = router.route_context_aware(
1728 "Write a Python fibonacci function.",
1729 128,
1730 ®,
1731 &tracker,
1732 false,
1733 false,
1734 RoutingWorkload::Background,
1735 );
1736
1737 let schema = reg
1738 .get(&decision.model_id)
1739 .expect("selected model should exist");
1740 assert!(
1741 schema.is_local(),
1742 "background routing should use task-specific cold-start priors for local code models"
1743 );
1744 }
1745
1746 #[test]
1747 fn task_specific_latency_prior_affects_cold_start_score() {
1748 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1749 let reg = test_registry();
1750 let model = reg
1751 .get("qwen/qwen3-8b:q4_k_m")
1752 .expect("local test model should exist");
1753
1754 let mut fast_tracker = OutcomeTracker::new();
1755 let mut fast_profile = crate::outcome::ModelProfile::new(model.id.clone());
1756 fast_profile.task_stats.insert(
1757 crate::outcome::InferenceTask::Generate.to_string(),
1758 crate::outcome::TaskStats {
1759 ema_quality: 0.95,
1760 avg_latency_ms: 1200.0,
1761 ..Default::default()
1762 },
1763 );
1764 fast_tracker.import_profiles(vec![fast_profile]);
1765
1766 let mut slow_tracker = OutcomeTracker::new();
1767 let mut slow_profile = crate::outcome::ModelProfile::new(model.id.clone());
1768 slow_profile.task_stats.insert(
1769 crate::outcome::InferenceTask::Generate.to_string(),
1770 crate::outcome::TaskStats {
1771 ema_quality: 0.95,
1772 avg_latency_ms: 120_000.0,
1773 ..Default::default()
1774 },
1775 );
1776 slow_tracker.import_profiles(vec![slow_profile]);
1777
1778 let fast_score = router.score_model(
1779 model,
1780 InferenceTask::Generate,
1781 &fast_tracker,
1782 RoutingWorkload::Interactive,
1783 );
1784 let slow_score = router.score_model(
1785 model,
1786 InferenceTask::Generate,
1787 &slow_tracker,
1788 RoutingWorkload::Interactive,
1789 );
1790
1791 assert!(
1792 fast_score > slow_score,
1793 "faster task latency prior should improve cold-start score ({fast_score} <= {slow_score})"
1794 );
1795 }
1796
1797 #[test]
1798 fn interactive_workload_keeps_remote_bootstrap_bias() {
1799 let router = AdaptiveRouter::new(
1800 test_hw(),
1801 RoutingConfig {
1802 prior_strength: 100.0,
1803 ..Default::default()
1804 },
1805 );
1806 let reg = test_registry();
1807 let mut tracker = OutcomeTracker::new();
1808 let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1809 profile.ema_quality = 0.95;
1810 tracker.import_profiles(vec![profile]);
1811
1812 let decision = router.route_context_aware(
1813 "Write a Python fibonacci function.",
1814 128,
1815 ®,
1816 &tracker,
1817 false,
1818 false,
1819 RoutingWorkload::Interactive,
1820 );
1821
1822 let schema = reg
1823 .get(&decision.model_id)
1824 .expect("selected model should exist");
1825 assert!(
1826 !schema.is_local(),
1827 "interactive routing should still prefer trusted remote models during cold start"
1828 );
1829 }
1830
1831 #[test]
1832 fn fallback_chain_has_alternatives() {
1833 let router = AdaptiveRouter::new(
1834 test_hw(),
1835 RoutingConfig {
1836 prior_strength: 100.0, ..Default::default()
1838 },
1839 );
1840 let reg = test_registry();
1841 let tracker = OutcomeTracker::new();
1842
1843 let decision = router.route("Analyze the architecture trade-offs", ®, &tracker);
1844 assert!(!decision.fallbacks.is_empty());
1845 assert!(!decision.fallbacks.contains(&decision.model_id));
1847 }
1848
1849 #[test]
1850 fn latency_scoring_is_consistent() {
1851 let router = AdaptiveRouter::with_default_config(test_hw());
1853
1854 let schema_score = router.latency_ms_to_score(AdaptiveRouter::tps_to_latency_ms(25.0));
1856 let observed_score = router.latency_ms_to_score(8000.0);
1858 assert!(
1859 (schema_score - observed_score).abs() < 0.01,
1860 "schema ({schema_score}) and observed ({observed_score}) should match"
1861 );
1862 }
1863
1864 #[test]
1865 fn complexity_assessment() {
1866 assert_eq!(
1867 TaskComplexity::assess("What is the capital of France?"),
1868 TaskComplexity::Simple
1869 );
1870 assert_eq!(
1871 TaskComplexity::assess("Fix this broken test"),
1872 TaskComplexity::Code
1873 );
1874 assert_eq!(
1875 TaskComplexity::assess("Analyze the trade-offs between A and B"),
1876 TaskComplexity::Complex
1877 );
1878 }
1879
1880 #[test]
1881 fn beta_sampling_produces_valid_values() {
1882 let mut rng = rand::rng();
1883 for _ in 0..100 {
1885 let s = sample_beta(&mut rng, 2.0, 5.0);
1886 assert!(s >= 0.0 && s <= 1.0, "sample {s} out of [0,1] range");
1887 }
1888 let samples: Vec<f64> = (0..1000).map(|_| sample_beta(&mut rng, 1.0, 1.0)).collect();
1890 let mean = samples.iter().sum::<f64>() / samples.len() as f64;
1891 assert!(
1892 (mean - 0.5).abs() < 0.05,
1893 "Beta(1,1) mean {mean} should be ~0.5"
1894 );
1895 }
1896
1897 #[test]
1898 fn thompson_sampling_converges_to_best() {
1899 let router = AdaptiveRouter::new(
1901 test_hw(),
1902 RoutingConfig {
1903 prior_strength: 1.0, ..Default::default()
1905 },
1906 );
1907 let reg = test_registry();
1908 let mut tracker = OutcomeTracker::new();
1909
1910 let qwen_4b_id = "qwen/qwen3-4b:q4_k_m";
1912 for _ in 0..20 {
1913 let trace = tracker.record_start(qwen_4b_id, InferenceTask::Code, "test");
1914 tracker.record_complete(&trace, 500, 100, 50);
1915 tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1916 }
1917
1918 let mut wins = 0;
1920 for _ in 0..20 {
1921 let decision = router.route("Fix this parser bug", ®, &tracker);
1922 if decision.model_id == qwen_4b_id {
1923 wins += 1;
1924 }
1925 }
1926 assert!(
1927 wins >= 14,
1928 "strong model won only {wins}/20 times (expected >= 14)"
1929 );
1930 }
1931
1932 #[test]
1935 fn intent_require_filters_out_models_lacking_capability() {
1936 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1941 let reg = test_registry();
1942 let tracker = OutcomeTracker::new();
1943
1944 let intent = crate::intent::IntentHint {
1945 require: vec![ModelCapability::Vision],
1946 ..Default::default()
1947 };
1948 let decision = router.route_with_intent("hello", ®, &tracker, &intent);
1949
1950 assert_eq!(
1955 decision.strategy,
1956 RoutingStrategy::SchemaBased,
1957 "require=[vision] with no vision-capable candidates must drop to schema cold-start"
1958 );
1959 }
1960
1961 #[test]
1962 fn intent_default_does_not_override_task_or_caps() {
1963 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1970 let reg = test_registry();
1971 let tracker = OutcomeTracker::new();
1972
1973 let baseline = router.route("write a haiku", ®, &tracker);
1974 let with_default = router.route_with_intent(
1975 "write a haiku",
1976 ®,
1977 &tracker,
1978 &crate::intent::IntentHint::default(),
1979 );
1980
1981 assert_eq!(
1982 baseline.task, with_default.task,
1983 "default IntentHint must not change the prompt-derived task"
1984 );
1985 assert_eq!(
1986 baseline.complexity, with_default.complexity,
1987 "default IntentHint must not change the prompt-derived complexity"
1988 );
1989 }
1990
1991 #[test]
1992 fn intent_task_hint_overrides_prompt_complexity() {
1993 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1996 let reg = test_registry();
1997 let tracker = OutcomeTracker::new();
1998
1999 let hint = crate::intent::IntentHint {
2000 task: Some(crate::intent::TaskHint::Reasoning),
2001 ..Default::default()
2002 };
2003 let decision = router.route_with_intent("hi", ®, &tracker, &hint);
2004
2005 assert_eq!(
2006 decision.task,
2007 InferenceTask::Reasoning,
2008 "TaskHint::Reasoning should override the prompt-derived task"
2009 );
2010 }
2011}