1use rand::Rng;
13use serde::{Deserialize, Serialize};
14
15use std::sync::{Arc, Mutex};
16
17use crate::hardware::HardwareInfo;
18use crate::outcome::{InferenceTask, OutcomeTracker};
19use crate::registry::UnifiedRegistry;
20use crate::routing_ext::CircuitBreakerRegistry;
21use crate::schema::{ModelCapability, ModelSchema};
22use crate::tasks::RoutingWorkload;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case")]
27pub enum TaskComplexity {
28 Simple,
29 Medium,
30 Code,
31 Complex,
32}
33
34impl TaskComplexity {
35 pub fn assess(prompt: &str) -> Self {
42 let lower = prompt.to_lowercase();
43 let word_count = prompt.split_whitespace().count();
44 let estimated_tokens = (word_count as f64 * 1.3) as usize;
45
46 let has_code = Self::detect_code(prompt);
47
48 let repair_markers = [
49 "fix", "repair", "debug", "refactor", "broken", "failing", "error", "bug",
50 ];
51 let has_repair = repair_markers.iter().any(|m| lower.contains(m));
52
53 let reasoning_markers = [
54 "analyze",
55 "compare",
56 "explain why",
57 "step by step",
58 "think through",
59 "evaluate",
60 "trade-off",
61 "tradeoff",
62 "pros and cons",
63 "architecture",
64 "design",
65 "strategy",
66 "optimize",
67 "comprehensive",
68 ];
69 let has_reasoning = reasoning_markers.iter().any(|m| lower.contains(m));
70
71 let simple_patterns = [
72 "what is",
73 "who is",
74 "when did",
75 "where is",
76 "how many",
77 "yes or no",
78 "true or false",
79 "name the",
80 "list the",
81 "define ",
82 ];
83 let is_simple = simple_patterns.iter().any(|p| lower.contains(p));
84
85 if has_code || has_repair {
86 TaskComplexity::Code
87 } else if has_reasoning || estimated_tokens > 500 {
88 TaskComplexity::Complex
89 } else if is_simple || estimated_tokens < 30 {
90 TaskComplexity::Simple
91 } else {
92 TaskComplexity::Medium
93 }
94 }
95
96 fn detect_code(prompt: &str) -> bool {
102 #[cfg(feature = "ast")]
104 {
105 if let Some(is_code) = Self::detect_code_ast(prompt) {
106 return is_code;
107 }
108 }
109
110 let code_markers = [
112 "```",
113 "fn ",
114 "def ",
115 "class ",
116 "import ",
117 "require(",
118 "async fn",
119 "pub fn",
120 "function ",
121 "const ",
122 "let ",
123 "var ",
124 "#include",
125 "package ",
126 "impl ",
127 ];
128 code_markers.iter().any(|m| prompt.contains(m))
129 }
130
131 #[cfg(feature = "ast")]
135 fn detect_code_ast(prompt: &str) -> Option<bool> {
136 let mut blocks = Vec::new();
138 let mut rest = prompt;
139 while let Some(start) = rest.find("```") {
140 let after_fence = &rest[start + 3..];
141 let code_start = after_fence.find('\n').map(|i| i + 1).unwrap_or(0);
143 if let Some(end) = after_fence[code_start..].find("```") {
144 blocks.push(&after_fence[code_start..code_start + end]);
145 rest = &after_fence[code_start + end + 3..];
146 } else {
147 break;
148 }
149 }
150
151 if blocks.is_empty() {
152 return None; }
154
155 let languages = [
157 car_ast::Language::Rust,
158 car_ast::Language::Python,
159 car_ast::Language::TypeScript,
160 car_ast::Language::JavaScript,
161 car_ast::Language::Go,
162 ];
163
164 for block in &blocks {
165 let trimmed = block.trim();
166 if trimmed.is_empty() {
167 continue;
168 }
169
170 for lang in &languages {
171 if let Some(parsed) = car_ast::parse(trimmed, *lang) {
172 if !parsed.symbols.is_empty() {
174 return Some(true);
175 }
176 }
177 }
178 }
179
180 Some(false)
183 }
184
185 pub fn required_capabilities(&self) -> Vec<ModelCapability> {
187 match self {
188 TaskComplexity::Simple => vec![ModelCapability::Generate],
189 TaskComplexity::Medium => vec![ModelCapability::Generate],
190 TaskComplexity::Code => vec![ModelCapability::Code],
191 TaskComplexity::Complex => vec![ModelCapability::Reasoning],
192 }
193 }
194
195 pub fn inference_task(&self) -> InferenceTask {
197 match self {
198 TaskComplexity::Simple | TaskComplexity::Medium => InferenceTask::Generate,
199 TaskComplexity::Code => InferenceTask::Code,
200 TaskComplexity::Complex => InferenceTask::Reasoning,
201 }
202 }
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct RoutingConfig {
208 pub min_observations: u64,
210 pub quality_weight: f64,
212 pub latency_weight: f64,
213 pub cost_weight: f64,
214 pub max_latency_ms: Option<u64>,
216 pub max_cost_usd: Option<f64>,
218 pub prefer_local: bool,
220 pub prior_strength: f64,
224 pub quality_first_cold_start: bool,
227 pub bootstrap_min_task_observations: u64,
230 pub bootstrap_quality_floor: f64,
233}
234
235impl Default for RoutingConfig {
236 fn default() -> Self {
237 Self {
238 min_observations: 2,
239 quality_weight: 0.45,
240 latency_weight: 0.4,
241 cost_weight: 0.15,
242 max_latency_ms: None,
243 max_cost_usd: None,
244 prefer_local: true,
245 prior_strength: 2.0,
246 quality_first_cold_start: true,
247 bootstrap_min_task_observations: 8,
248 bootstrap_quality_floor: 0.8,
249 }
250 }
251}
252
253#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
255#[serde(rename_all = "snake_case")]
256pub enum RoutingStrategy {
257 SchemaBased,
259 ProfileBased,
261 Exploration,
263 Explicit,
265}
266
267#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct AdaptiveRoutingDecision {
270 pub model_id: String,
272 pub model_name: String,
274 pub task: InferenceTask,
276 pub complexity: TaskComplexity,
278 pub reason: String,
280 pub strategy: RoutingStrategy,
282 pub predicted_quality: f64,
284 pub fallbacks: Vec<String>,
286 pub context_length: usize,
288 pub needs_compaction: bool,
290}
291
292pub struct AdaptiveRouter {
294 hw: HardwareInfo,
295 config: RoutingConfig,
296 pub circuit_breakers: Arc<Mutex<CircuitBreakerRegistry>>,
298}
299
300pub struct RouteRequest<'a> {
316 pub prompt: &'a str,
317 pub registry: &'a UnifiedRegistry,
318 pub tracker: &'a OutcomeTracker,
319 pub estimated_total_tokens: usize,
322 pub has_tools: bool,
323 pub has_vision: bool,
324 pub workload: RoutingWorkload,
325 pub intent: Option<&'a crate::intent::IntentHint>,
328}
329
330impl<'a> RouteRequest<'a> {
331 pub fn new(
335 prompt: &'a str,
336 registry: &'a UnifiedRegistry,
337 tracker: &'a OutcomeTracker,
338 ) -> Self {
339 Self {
340 prompt,
341 registry,
342 tracker,
343 estimated_total_tokens: 0,
344 has_tools: false,
345 has_vision: false,
346 workload: RoutingWorkload::Interactive,
347 intent: None,
348 }
349 }
350}
351
352impl AdaptiveRouter {
353 pub fn new(hw: HardwareInfo, config: RoutingConfig) -> Self {
354 let circuit_breakers = Arc::new(Mutex::new(
355 CircuitBreakerRegistry::new(3, 300), ));
357 Self {
358 hw,
359 config,
360 circuit_breakers,
361 }
362 }
363
364 pub fn with_default_config(hw: HardwareInfo) -> Self {
365 Self::new(hw, RoutingConfig::default())
366 }
367
368 pub fn config(&self) -> &RoutingConfig {
369 &self.config
370 }
371
372 pub fn set_config(&mut self, config: RoutingConfig) {
373 self.config = config;
374 }
375
376 pub fn route_with(&self, req: RouteRequest<'_>) -> AdaptiveRoutingDecision {
382 let workload = match req.intent {
388 Some(h) if h.prefer_fast => RoutingWorkload::Fastest,
389 Some(h) if h.prefer_local => RoutingWorkload::LocalPreferred,
390 _ => req.workload,
391 };
392 self.route_inner_with_intent(
393 req.prompt,
394 req.registry,
395 req.tracker,
396 req.has_tools,
397 req.has_vision,
398 req.estimated_total_tokens,
399 workload,
400 req.intent,
401 )
402 }
403
404 pub fn route(
407 &self,
408 prompt: &str,
409 registry: &UnifiedRegistry,
410 tracker: &OutcomeTracker,
411 ) -> AdaptiveRoutingDecision {
412 self.route_with(RouteRequest::new(prompt, registry, tracker))
413 }
414
415 pub fn route_editor(
421 &self,
422 prompt: &str,
423 registry: &UnifiedRegistry,
424 tracker: &OutcomeTracker,
425 ) -> AdaptiveRoutingDecision {
426 self.route_with(RouteRequest {
427 workload: RoutingWorkload::Background,
428 ..RouteRequest::new(prompt, registry, tracker)
429 })
430 }
431
432 pub fn route_with_tools(
434 &self,
435 prompt: &str,
436 registry: &UnifiedRegistry,
437 tracker: &OutcomeTracker,
438 ) -> AdaptiveRoutingDecision {
439 self.route_with(RouteRequest {
440 has_tools: true,
441 ..RouteRequest::new(prompt, registry, tracker)
442 })
443 }
444
445 pub fn route_with_vision(
447 &self,
448 prompt: &str,
449 registry: &UnifiedRegistry,
450 tracker: &OutcomeTracker,
451 has_tools: bool,
452 ) -> AdaptiveRoutingDecision {
453 self.route_with(RouteRequest {
454 has_tools,
455 has_vision: true,
456 ..RouteRequest::new(prompt, registry, tracker)
457 })
458 }
459
460 pub fn route_with_intent<'a>(
466 &self,
467 prompt: &'a str,
468 registry: &'a UnifiedRegistry,
469 tracker: &'a OutcomeTracker,
470 intent: &'a crate::intent::IntentHint,
471 ) -> AdaptiveRoutingDecision {
472 self.route_with(RouteRequest {
473 intent: Some(intent),
474 ..RouteRequest::new(prompt, registry, tracker)
475 })
476 }
477
478 pub fn route_context_aware(
481 &self,
482 prompt: &str,
483 estimated_total_tokens: usize,
484 registry: &UnifiedRegistry,
485 tracker: &OutcomeTracker,
486 has_tools: bool,
487 has_vision: bool,
488 workload: RoutingWorkload,
489 ) -> AdaptiveRoutingDecision {
490 self.route_with(RouteRequest {
491 estimated_total_tokens,
492 has_tools,
493 has_vision,
494 workload,
495 ..RouteRequest::new(prompt, registry, tracker)
496 })
497 }
498
499 pub fn route_context_aware_with_intent<'a>(
504 &self,
505 prompt: &'a str,
506 estimated_total_tokens: usize,
507 registry: &'a UnifiedRegistry,
508 tracker: &'a OutcomeTracker,
509 has_tools: bool,
510 has_vision: bool,
511 workload: RoutingWorkload,
512 intent: &'a crate::intent::IntentHint,
513 ) -> AdaptiveRoutingDecision {
514 self.route_with(RouteRequest {
515 estimated_total_tokens,
516 has_tools,
517 has_vision,
518 workload,
519 intent: Some(intent),
520 ..RouteRequest::new(prompt, registry, tracker)
521 })
522 }
523
524 fn route_inner_with_intent(
525 &self,
526 prompt: &str,
527 registry: &UnifiedRegistry,
528 tracker: &OutcomeTracker,
529 has_tools: bool,
530 has_vision: bool,
531 estimated_total_tokens: usize,
532 workload: RoutingWorkload,
533 intent: Option<&crate::intent::IntentHint>,
534 ) -> AdaptiveRoutingDecision {
535 let complexity = TaskComplexity::assess(prompt);
536 let task = intent
538 .and_then(|h| h.task)
539 .map(task_hint_to_inference_task)
540 .unwrap_or_else(|| complexity.inference_task());
541 let mut required_caps = complexity.required_capabilities();
542 if let Some(hint) = intent {
543 for cap in &hint.require {
544 if !required_caps.contains(cap) {
545 required_caps.push(*cap);
546 }
547 }
548 }
549 if has_vision {
550 required_caps.push(ModelCapability::Vision);
551 }
552 if has_tools {
553 required_caps.push(ModelCapability::ToolUse);
554 if Self::needs_multi_tool_call(prompt) {
557 required_caps.push(ModelCapability::MultiToolCall);
558 }
559 }
560
561 let mut candidates = self.filter_candidates(&required_caps, registry, tracker);
563
564 if candidates.is_empty() && required_caps.contains(&ModelCapability::MultiToolCall) {
567 required_caps.retain(|c| *c != ModelCapability::MultiToolCall);
568 candidates = self.filter_candidates(&required_caps, registry, tracker);
569 }
570
571 if candidates.is_empty() {
572 return self.cold_start_decision(complexity, task, registry, has_vision);
574 }
575
576 candidates = self.apply_quality_first_bootstrap_policy(
577 candidates,
578 task,
579 tracker,
580 has_vision,
581 has_tools,
582 workload,
583 );
584
585 let (fits, needs_compaction_candidates) = if estimated_total_tokens > 0 {
588 let mut fits = Vec::new();
589 let mut tight = Vec::new();
590 for m in &candidates {
591 if m.context_length == 0 || m.context_length >= estimated_total_tokens {
592 fits.push(m.clone());
593 } else {
594 tight.push(m.clone());
595 }
596 }
597 (fits, tight)
598 } else {
599 (candidates.clone(), Vec::new())
600 };
601
602 let (scoring_candidates, compaction_needed) = if !fits.is_empty() {
604 (fits, false)
605 } else if !needs_compaction_candidates.is_empty() {
606 tracing::info!(
607 prompt_tokens = estimated_total_tokens,
608 candidates = needs_compaction_candidates.len(),
609 "no model fits full prompt — compaction will be needed"
610 );
611 (needs_compaction_candidates.clone(), true)
612 } else {
613 (candidates.clone(), false)
614 };
615
616 let scored = self.score_candidates_context_aware(
618 &scoring_candidates,
619 task,
620 tracker,
621 estimated_total_tokens,
622 workload,
623 );
624
625 let (selected_id, strategy) = self.select_with_thompson_sampling(&scored, tracker);
627
628 let mut fallbacks: Vec<String> = scored
630 .iter()
631 .filter(|(id, _)| *id != selected_id)
632 .map(|(id, _)| id.clone())
633 .collect();
634 if !compaction_needed {
636 for m in &needs_compaction_candidates {
637 if m.id != selected_id && !fallbacks.contains(&m.id) {
638 fallbacks.push(m.id.clone());
639 }
640 }
641 }
642
643 let predicted_quality = scored
644 .iter()
645 .find(|(id, _)| *id == selected_id)
646 .map(|(_, score)| *score)
647 .unwrap_or(0.5);
648
649 let selected_schema = registry
650 .get(&selected_id)
651 .or_else(|| registry.find_by_name(&selected_id));
652 let model_name = selected_schema
653 .map(|m| m.name.clone())
654 .unwrap_or_else(|| selected_id.clone());
655 let context_length = selected_schema.map(|m| m.context_length).unwrap_or(0);
656
657 let needs_compact = compaction_needed
658 || (estimated_total_tokens > 0
659 && context_length > 0
660 && estimated_total_tokens > context_length);
661
662 let compaction_note = if needs_compact {
663 format!(
664 " [compaction needed: {}→{}tok]",
665 estimated_total_tokens, context_length
666 )
667 } else {
668 String::new()
669 };
670
671 let reason = format!(
672 "{:?} task → {} via {:?} (quality: {:.2}, {} candidates){}",
673 complexity,
674 model_name,
675 strategy,
676 predicted_quality,
677 scoring_candidates.len(),
678 compaction_note,
679 );
680
681 AdaptiveRoutingDecision {
682 model_id: selected_id,
683 model_name,
684 task,
685 complexity,
686 reason,
687 strategy,
688 predicted_quality,
689 fallbacks,
690 context_length,
691 needs_compaction: needs_compact,
692 }
693 }
694
695 pub fn route_embedding(&self, registry: &UnifiedRegistry) -> String {
697 let embed_models = registry.query_by_capability(ModelCapability::Embed);
698 embed_models
699 .first()
700 .map(|m| m.name.clone())
701 .unwrap_or_else(|| "Qwen3-Embedding-0.6B".to_string())
702 }
703
704 pub fn route_small(&self, registry: &UnifiedRegistry) -> String {
706 let gen_models = registry.query_by_capability(ModelCapability::Generate);
707 gen_models
709 .iter()
710 .filter(|m| m.is_local())
711 .min_by_key(|m| m.size_mb())
712 .map(|m| m.name.clone())
713 .unwrap_or_else(|| "Qwen3-0.6B".to_string())
714 }
715
716 const LATENCY_CEILING_MS: f64 = 10000.0;
722 const _TPS_CEILING: f64 = 150.0;
724 const MOE_TPS_MULTIPLIER: f64 = 0.10;
726 const MLX_MOE_TPS_MULTIPLIER: f64 = 0.50;
728 const COST_CEILING_PER_1K: f64 = 0.1;
730 const LOCAL_BONUS: f64 = 0.15;
732 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
740 const HAS_GPU_BACKEND: bool = true;
741 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
742 const HAS_GPU_BACKEND: bool = false;
743 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
746 const MLX_BONUS: f64 = 0.10;
747 const SYSTEM_LLM_BONUS: f64 = 0.12;
758
759 fn filter_candidates(
763 &self,
764 required_caps: &[ModelCapability],
765 registry: &UnifiedRegistry,
766 tracker: &OutcomeTracker,
767 ) -> Vec<ModelSchema> {
768 registry
769 .list()
770 .into_iter()
771 .filter(|m| {
772 if !required_caps.iter().all(|c| m.has_capability(*c)) {
774 return false;
775 }
776 if !m.available {
778 return false;
779 }
780 if m.is_local() && m.size_mb() >= self.hw.max_model_mb {
782 return false;
783 }
784 if let Some(max) = self.config.max_latency_ms {
786 if let Some(p50) = m.performance.latency_p50_ms {
787 if p50 > max {
788 return false;
789 }
790 }
791 }
792 if let Some(max) = self.config.max_cost_usd {
794 if m.cost_per_1k_output() > max {
795 return false;
796 }
797 }
798 if !self.config.prefer_local && m.is_local() {
800 return false;
801 }
802 if tracker.is_excluded(&m.id) {
804 return false;
805 }
806 if let Ok(mut cb) = self.circuit_breakers.lock() {
808 if !cb.allow_request(&m.id) {
809 tracing::debug!(model = %m.id, "skipped by circuit breaker");
810 return false;
811 }
812 }
813 true
814 })
815 .cloned()
816 .collect()
817 }
818
819 fn apply_quality_first_bootstrap_policy(
820 &self,
821 candidates: Vec<ModelSchema>,
822 task: InferenceTask,
823 tracker: &OutcomeTracker,
824 has_vision: bool,
825 has_tools: bool,
826 workload: RoutingWorkload,
827 ) -> Vec<ModelSchema> {
828 if !self.config.quality_first_cold_start
829 || !workload.is_latency_sensitive()
830 || !self.is_quality_critical_bootstrap_task(task, has_vision, has_tools)
831 {
832 return candidates;
833 }
834
835 let trusted_remote: Vec<ModelSchema> = candidates
836 .iter()
837 .filter(|model| self.is_trusted_quality_remote(model))
838 .cloned()
839 .collect();
840
841 if trusted_remote.is_empty() {
842 return candidates;
843 }
844
845 let proven_local: Vec<ModelSchema> = candidates
846 .iter()
847 .filter(|model| model.is_local() && self.is_local_model_proven_for_task(model, task, tracker))
848 .cloned()
849 .collect();
850
851 if !proven_local.is_empty() {
852 return proven_local;
853 }
854
855 trusted_remote
856 }
857
858 fn score_candidates_context_aware(
862 &self,
863 candidates: &[ModelSchema],
864 task: InferenceTask,
865 tracker: &OutcomeTracker,
866 estimated_total_tokens: usize,
867 workload: RoutingWorkload,
868 ) -> Vec<(String, f64)> {
869 let mut scored: Vec<(String, f64)> = candidates
870 .iter()
871 .map(|m| {
872 let base_score = self.score_model(m, task, tracker, workload);
873 let headroom_bonus = if estimated_total_tokens > 0 && m.context_length > 0 {
876 let ratio = m.context_length as f64 / estimated_total_tokens as f64;
877 if ratio >= 1.0 {
878 (ratio.min(4.0) - 1.0) / 3.0 * 0.10 } else {
880 -0.15 }
882 } else {
883 0.0
884 };
885 (m.id.clone(), base_score + headroom_bonus)
886 })
887 .collect();
888
889 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
890 scored
891 }
892
893 fn score_model(
896 &self,
897 model: &ModelSchema,
898 task: InferenceTask,
899 tracker: &OutcomeTracker,
900 workload: RoutingWorkload,
901 ) -> f64 {
902 let profile = tracker.profile(&model.id);
903 let schema_quality = self.schema_quality_estimate(model);
904 let schema_latency = self.schema_latency_estimate(model);
905 let (quality_weight, latency_weight, cost_weight) = workload.weights();
906
907 let quality = match profile {
910 Some(p) if p.total_calls >= self.config.min_observations => p
911 .task_stats(task)
912 .map(|ts| ts.ema_quality)
913 .unwrap_or(p.ema_quality),
914 Some(p) if p.total_calls == 0 => p
915 .task_stats(task)
916 .map(|ts| ts.ema_quality)
917 .unwrap_or(p.ema_quality),
918 Some(p) if p.total_calls > 0 => {
919 let w = p.total_calls as f64 / self.config.min_observations as f64;
920 schema_quality * (1.0 - w) + p.ema_quality * w
921 }
922 _ => schema_quality,
923 };
924
925 let latency = match profile {
928 Some(p) if p.total_calls >= self.config.min_observations => {
929 let avg = p
930 .task_stats(task)
931 .filter(|ts| ts.calls > 0 || ts.avg_latency_ms > 0.0)
932 .map(|ts| ts.avg_latency_ms)
933 .unwrap_or_else(|| p.avg_latency_ms());
934 self.latency_ms_to_score(avg)
935 }
936 Some(p) if p.total_calls == 0 => p
937 .task_stats(task)
938 .filter(|ts| ts.avg_latency_ms > 0.0)
939 .map(|ts| self.latency_ms_to_score(ts.avg_latency_ms))
940 .unwrap_or(schema_latency),
941 Some(p) if p.total_calls > 0 => {
942 let observed = self.latency_ms_to_score(
943 p.task_stats(task)
944 .filter(|ts| ts.calls > 0 || ts.avg_latency_ms > 0.0)
945 .map(|ts| ts.avg_latency_ms)
946 .unwrap_or_else(|| p.avg_latency_ms()),
947 );
948 let w = p.total_calls as f64 / self.config.min_observations as f64;
949 schema_latency * (1.0 - w) + observed * w
950 }
951 _ => schema_latency,
952 };
953
954 let cost = if model.is_local() {
956 1.0
957 } else {
958 (1.0 - (model.cost_per_1k_output() / Self::COST_CEILING_PER_1K)).clamp(0.0, 1.0)
959 };
960
961 let local_bonus = if self.config.prefer_local && model.is_local() && Self::HAS_GPU_BACKEND {
970 Self::LOCAL_BONUS
971 } else {
972 if self.config.prefer_local && model.is_local() && !Self::HAS_GPU_BACKEND {
973 tracing::debug!(
974 model = %model.id,
975 "LOCAL_BONUS suppressed: no GPU backend on this host (Intel Mac or car_skip_mlx); cloud models will rank higher"
976 );
977 }
978 0.0
979 };
980 let workload_local_bonus = if model.is_local() {
981 workload.local_bonus()
982 } else {
983 0.0
984 };
985
986 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
988 let mlx_bonus = if model.is_mlx() { Self::MLX_BONUS } else { 0.0 };
989 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
990 let mlx_bonus = 0.0;
991
992 let vllm_mlx_bonus = if model.is_vllm_mlx() {
994 Self::LOCAL_BONUS + 0.05
995 } else {
996 0.0
997 };
998
999 let system_llm_bonus = if model.tags.iter().any(|t| t == "low_latency")
1007 && model.tags.iter().any(|t| t == "private")
1008 {
1009 Self::SYSTEM_LLM_BONUS
1010 } else {
1011 0.0
1012 };
1013
1014 quality_weight * quality
1015 + latency_weight * latency
1016 + cost_weight * cost
1017 + local_bonus
1018 + workload_local_bonus
1019 + mlx_bonus
1020 + vllm_mlx_bonus
1021 + system_llm_bonus
1022 }
1023
1024 fn latency_ms_to_score(&self, ms: f64) -> f64 {
1027 (1.0 - (ms / Self::LATENCY_CEILING_MS)).clamp(0.0, 1.0)
1028 }
1029
1030 fn tps_to_latency_ms(tps: f64) -> f64 {
1032 if tps <= 0.0 {
1033 return Self::LATENCY_CEILING_MS;
1034 }
1035 (200.0 / tps) * 1000.0
1037 }
1038
1039 fn needs_multi_tool_call(prompt: &str) -> bool {
1042 let lower = prompt.to_lowercase();
1043
1044 let has_numbered_list = {
1046 let mut count = 0u32;
1047 for i in 1..=5u32 {
1048 if lower.contains(&format!("{}) ", i)) || lower.contains(&format!("{}. ", i)) {
1049 count += 1;
1050 }
1051 }
1052 count >= 2
1053 };
1054
1055 let multi_keywords = [
1057 "multiple edits",
1058 "several changes",
1059 "three changes",
1060 "two changes",
1061 "all of the following",
1062 "each of these",
1063 "do both",
1064 "do all",
1065 "and also",
1066 "additionally",
1067 "as well as",
1068 "then also",
1069 ];
1070 let has_multi_keywords = multi_keywords.iter().any(|kw| lower.contains(kw));
1071
1072 let bullet_actions = lower.matches("- add ").count()
1074 + lower.matches("- update ").count()
1075 + lower.matches("- change ").count()
1076 + lower.matches("- remove ").count()
1077 + lower.matches("- fix ").count()
1078 + lower.matches("- edit ").count()
1079 + lower.matches("- implement ").count()
1080 + lower.matches("- create ").count();
1081 let has_bullet_list = bullet_actions >= 2;
1082
1083 has_numbered_list || has_multi_keywords || has_bullet_list
1084 }
1085
1086 fn schema_quality_estimate(&self, model: &ModelSchema) -> f64 {
1091 match model.size_mb() {
1092 0 => 0.5, s if s < 1000 => 0.4, s if s < 2000 => 0.5, s if s < 3000 => 0.6, s if s < 6000 => 0.7, _ => 0.75, }
1099 }
1100
1101 fn schema_latency_estimate(&self, model: &ModelSchema) -> f64 {
1106 let is_moe = model.tags.contains(&"moe".to_string());
1107
1108 if model.is_local() {
1109 if let Some(tps) = model.performance.tokens_per_second {
1110 let effective_tps = if is_moe {
1111 let multiplier = if model.is_mlx() {
1112 Self::MLX_MOE_TPS_MULTIPLIER
1113 } else {
1114 Self::MOE_TPS_MULTIPLIER
1115 };
1116 tps * multiplier
1117 } else {
1118 tps
1119 };
1120 let estimated_ms = Self::tps_to_latency_ms(effective_tps);
1121 return self.latency_ms_to_score(estimated_ms);
1122 }
1123 return 0.5; }
1125
1126 if let Some(p50) = model.performance.latency_p50_ms {
1128 return self.latency_ms_to_score(p50 as f64);
1129 }
1130 0.3 }
1132
1133 fn is_quality_critical_bootstrap_task(
1134 &self,
1135 task: InferenceTask,
1136 has_vision: bool,
1137 has_tools: bool,
1138 ) -> bool {
1139 has_vision
1140 || has_tools
1141 || matches!(
1142 task,
1143 InferenceTask::Generate | InferenceTask::Code | InferenceTask::Reasoning
1144 )
1145 }
1146
1147 fn is_trusted_quality_remote(&self, model: &ModelSchema) -> bool {
1148 model.is_remote()
1149 && matches!(
1150 model.provider.as_str(),
1151 "openai" | "anthropic" | "google"
1152 )
1153 && !model.has_capability(ModelCapability::SpeechToText)
1154 && !model.has_capability(ModelCapability::TextToSpeech)
1155 }
1156
1157 fn is_local_model_proven_for_task(
1158 &self,
1159 model: &ModelSchema,
1160 task: InferenceTask,
1161 tracker: &OutcomeTracker,
1162 ) -> bool {
1163 let Some(profile) = tracker.profile(&model.id) else {
1164 return false;
1165 };
1166 if let Some(task_stats) = profile.task_stats(task) {
1167 if task_stats.calls >= self.config.bootstrap_min_task_observations
1168 && task_stats.ema_quality >= self.config.bootstrap_quality_floor
1169 {
1170 return true;
1171 }
1172 }
1173
1174 profile.total_calls >= self.config.bootstrap_min_task_observations
1175 && profile.ema_quality >= self.config.bootstrap_quality_floor
1176 }
1177
1178 fn select_with_thompson_sampling(
1188 &self,
1189 scored: &[(String, f64)],
1190 tracker: &OutcomeTracker,
1191 ) -> (String, RoutingStrategy) {
1192 if scored.is_empty() {
1193 return (String::new(), RoutingStrategy::SchemaBased);
1194 }
1195
1196 let mut rng = rand::rng();
1197 let mut best_sample = f64::NEG_INFINITY;
1198 let mut best_id = scored[0].0.clone();
1199 let mut best_strategy = RoutingStrategy::SchemaBased;
1200
1201 for (id, phase2_score) in scored {
1202 let profile = tracker.profile(id);
1203 let prior = self.config.prior_strength;
1204
1205 let prior_mean = phase2_score.clamp(0.0, 1.0);
1207
1208 let prior_alpha = prior * prior_mean;
1210 let prior_beta = prior * (1.0 - prior_mean);
1211
1212 let (obs_alpha, obs_beta) = match profile {
1214 Some(p) => (p.success_count as f64, p.fail_count as f64),
1215 None => (0.0, 0.0),
1216 };
1217
1218 let alpha = (prior_alpha + obs_alpha).max(0.01);
1220 let beta = (prior_beta + obs_beta).max(0.01);
1221
1222 let sample = sample_beta(&mut rng, alpha, beta);
1224
1225 if sample > best_sample {
1226 best_sample = sample;
1227 best_id = id.clone();
1228 best_strategy = match profile {
1229 Some(p) if p.total_calls >= self.config.min_observations => {
1230 RoutingStrategy::ProfileBased
1231 }
1232 Some(p) if p.total_calls > 0 => {
1233 RoutingStrategy::Exploration
1235 }
1236 _ => RoutingStrategy::SchemaBased,
1237 };
1238 }
1239 }
1240
1241 (best_id, best_strategy)
1242 }
1243
1244 fn cold_start_decision(
1246 &self,
1247 complexity: TaskComplexity,
1248 task: InferenceTask,
1249 registry: &UnifiedRegistry,
1250 has_vision: bool,
1251 ) -> AdaptiveRoutingDecision {
1252 if has_vision {
1253 if let Some(model) = registry
1254 .query_by_capability(ModelCapability::Vision)
1255 .into_iter()
1256 .find(|model| model.available && self.is_trusted_quality_remote(model))
1257 .or_else(|| registry.query_by_capability(ModelCapability::Vision).first().copied())
1258 {
1259 return AdaptiveRoutingDecision {
1260 model_id: model.id.clone(),
1261 model_name: model.name.clone(),
1262 task,
1263 complexity,
1264 reason: format!(
1265 "{:?} task → {} (cold start, vision fallback)",
1266 complexity, model.name
1267 ),
1268 strategy: RoutingStrategy::SchemaBased,
1269 predicted_quality: 0.5,
1270 fallbacks: vec![],
1271 context_length: model.context_length,
1272 needs_compaction: false,
1273 };
1274 }
1275 }
1276
1277 if self.config.quality_first_cold_start {
1278 let required_caps = complexity.required_capabilities();
1279 if let Some(model) = registry
1280 .list()
1281 .into_iter()
1282 .filter(|model| {
1283 model.available
1284 && required_caps.iter().all(|cap| model.has_capability(*cap))
1285 && self.is_trusted_quality_remote(model)
1286 })
1287 .max_by(|a, b| {
1288 self.schema_quality_estimate(a)
1289 .partial_cmp(&self.schema_quality_estimate(b))
1290 .unwrap_or(std::cmp::Ordering::Equal)
1291 })
1292 {
1293 return AdaptiveRoutingDecision {
1294 model_id: model.id.clone(),
1295 model_name: model.name.clone(),
1296 task,
1297 complexity,
1298 reason: format!(
1299 "{:?} task → {} (quality-first cold start)",
1300 complexity, model.name
1301 ),
1302 strategy: RoutingStrategy::SchemaBased,
1303 predicted_quality: self.schema_quality_estimate(model),
1304 fallbacks: vec![],
1305 context_length: model.context_length,
1306 needs_compaction: false,
1307 };
1308 }
1309 }
1310
1311 let model_name = match complexity {
1313 TaskComplexity::Simple => "Qwen3-0.6B",
1314 TaskComplexity::Medium => "Qwen3-1.7B",
1315 TaskComplexity::Code => "Qwen3-4B",
1316 TaskComplexity::Complex => &self.hw.recommended_model,
1317 };
1318
1319 let model_id = registry
1320 .find_by_name(model_name)
1321 .map(|m| m.id.clone())
1322 .unwrap_or_else(|| model_name.to_string());
1323
1324 let context_length = registry
1325 .find_by_name(model_name)
1326 .map(|m| m.context_length)
1327 .unwrap_or(0);
1328
1329 AdaptiveRoutingDecision {
1330 model_id,
1331 model_name: model_name.to_string(),
1332 task,
1333 complexity,
1334 reason: format!(
1335 "{:?} task → {} (cold start, no candidates)",
1336 complexity, model_name
1337 ),
1338 strategy: RoutingStrategy::SchemaBased,
1339 predicted_quality: 0.5,
1340 fallbacks: vec![],
1341 context_length,
1342 needs_compaction: false,
1343 }
1344 }
1345}
1346
1347fn task_hint_to_inference_task(hint: crate::intent::TaskHint) -> InferenceTask {
1353 use crate::intent::TaskHint;
1354 match hint {
1355 TaskHint::Chat => InferenceTask::Generate,
1356 TaskHint::Classify => InferenceTask::Classify,
1357 TaskHint::Reasoning => InferenceTask::Reasoning,
1358 TaskHint::Code => InferenceTask::Code,
1359 }
1360}
1361
1362fn sample_beta(rng: &mut impl Rng, alpha: f64, beta: f64) -> f64 {
1370 let x = sample_gamma(rng, alpha);
1371 let y = sample_gamma(rng, beta);
1372 if x + y == 0.0 {
1373 0.5 } else {
1375 x / (x + y)
1376 }
1377}
1378
1379fn sample_gamma(rng: &mut impl Rng, shape: f64) -> f64 {
1382 if shape < 1.0 {
1383 let u: f64 = rng.random();
1385 return sample_gamma(rng, shape + 1.0) * u.powf(1.0 / shape);
1386 }
1387
1388 let d = shape - 1.0 / 3.0;
1390 let c = 1.0 / (9.0 * d).sqrt();
1391
1392 loop {
1393 let x: f64 = loop {
1394 let n = sample_standard_normal(rng);
1395 if 1.0 + c * n > 0.0 {
1396 break n;
1397 }
1398 };
1399
1400 let v = (1.0 + c * x).powi(3);
1401 let u: f64 = rng.random();
1402
1403 if u < 1.0 - 0.0331 * x.powi(4) {
1404 return d * v;
1405 }
1406 if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
1407 return d * v;
1408 }
1409 }
1410}
1411
1412fn sample_standard_normal(rng: &mut impl Rng) -> f64 {
1414 let u1: f64 = rng.random();
1415 let u2: f64 = rng.random();
1416 (-2.0 * u1.max(1e-300).ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
1417}
1418
1419#[cfg(test)]
1420mod tests {
1421 use super::*;
1422 use crate::outcome::InferredOutcome;
1423
1424 fn test_hw() -> HardwareInfo {
1425 HardwareInfo {
1426 os: "macos".into(),
1427 arch: "aarch64".into(),
1428 cpu_cores: 10,
1429 total_ram_mb: 32768,
1430 gpu_backend: crate::hardware::GpuBackend::Metal,
1431 gpu_memory_mb: Some(28672),
1432 gpu_devices: Vec::new(),
1433 recommended_model: "Qwen3-8B".into(),
1434 recommended_context: 8192,
1435 max_model_mb: 18000, }
1437 }
1438
1439 fn test_registry() -> UnifiedRegistry {
1440 let tmp = std::path::PathBuf::from("/tmp/car-test-adaptive-router");
1441 unsafe {
1442 std::env::set_var("OPENAI_API_KEY", "test-openai-key");
1443 }
1444 for name in &[
1446 "Qwen3-0.6B",
1447 "Qwen3-1.7B",
1448 "Qwen3-4B",
1449 "Qwen3-8B",
1450 "Qwen3-Embedding-0.6B",
1451 ] {
1452 let dir = tmp.join(name);
1453 let _ = std::fs::create_dir_all(&dir);
1454 let _ = std::fs::write(dir.join("model.gguf"), b"fake");
1455 let _ = std::fs::write(dir.join("tokenizer.json"), b"{}");
1456 }
1457 let mut reg = UnifiedRegistry::new(tmp);
1458 reg.register(ModelSchema {
1459 id: "openai/gpt-5.4-mini:latest".into(),
1460 name: "gpt-5.4-mini".into(),
1461 provider: "openai".into(),
1462 family: "gpt-5.4".into(),
1463 version: "latest".into(),
1464 capabilities: vec![
1465 ModelCapability::Generate,
1466 ModelCapability::Code,
1467 ModelCapability::Reasoning,
1468 ModelCapability::ToolUse,
1469 ModelCapability::MultiToolCall,
1470 ModelCapability::Vision,
1471 ],
1472 context_length: 128_000,
1473 param_count: "api".into(),
1474 quantization: None,
1475 performance: Default::default(),
1476 cost: Default::default(),
1477 source: crate::schema::ModelSource::RemoteApi {
1478 endpoint: "https://api.openai.com/v1".into(),
1479 api_key_env: "OPENAI_API_KEY".into(),
1480 api_key_envs: vec![],
1481 api_version: None,
1482 protocol: crate::schema::ApiProtocol::OpenAiCompat,
1483 },
1484 tags: vec!["trusted-remote".into()],
1485 supported_params: vec![],
1486 public_benchmarks: vec![],
1487 available: true,
1488 });
1489 reg
1490 }
1491
1492 #[test]
1493 fn routes_simple_to_trusted_remote_during_cold_start() {
1494 let router = AdaptiveRouter::new(
1495 test_hw(),
1496 RoutingConfig {
1497 prior_strength: 100.0, ..Default::default()
1499 },
1500 );
1501 let reg = test_registry();
1502 let tracker = OutcomeTracker::new();
1503
1504 let decision = router.route("What is 2+2?", ®, &tracker);
1505 assert_eq!(decision.complexity, TaskComplexity::Simple);
1506 assert_eq!(decision.strategy, RoutingStrategy::SchemaBased);
1507 let schema = reg
1509 .find_by_name(&decision.model_name)
1510 .expect("selected model should exist in registry");
1511 assert!(
1512 !schema.is_local(),
1513 "simple task should route to trusted remote model during cold start"
1514 );
1515 assert!(matches!(
1516 schema.provider.as_str(),
1517 "openai" | "anthropic" | "google"
1518 ));
1519 }
1520
1521 #[test]
1522 fn routes_code_to_code_capable_remote_during_cold_start() {
1523 let router = AdaptiveRouter::new(
1524 test_hw(),
1525 RoutingConfig {
1526 prior_strength: 100.0, ..Default::default()
1528 },
1529 );
1530 let reg = test_registry();
1531 let tracker = OutcomeTracker::new();
1532
1533 let decision = router.route(
1534 "Fix this function:\n```rust\nfn main() {}\n```",
1535 ®,
1536 &tracker,
1537 );
1538 assert_eq!(decision.complexity, TaskComplexity::Code);
1539 assert_eq!(decision.task, InferenceTask::Code);
1540 let schema = reg
1542 .find_by_name(&decision.model_name)
1543 .expect("model should exist");
1544 assert!(
1545 schema.has_capability(ModelCapability::Code),
1546 "selected model must support Code"
1547 );
1548 assert!(!schema.is_local(), "should route to trusted remote model");
1549 }
1550
1551 #[test]
1552 fn routes_images_to_vision_capable_model() {
1553 let router = AdaptiveRouter::new(
1554 test_hw(),
1555 RoutingConfig {
1556 prior_strength: 100.0,
1557 ..Default::default()
1558 },
1559 );
1560 let mut reg = test_registry();
1561 let tracker = OutcomeTracker::new();
1562
1563 reg.register(ModelSchema {
1564 id: "mlx-vlm/qwen3-vl-2b:bf16".into(),
1565 name: "Qwen3-VL-2B-mlx-vlm".into(),
1566 provider: "qwen".into(),
1567 family: "qwen3-vl".into(),
1568 version: "bf16".into(),
1569 capabilities: vec![
1570 ModelCapability::Generate,
1571 ModelCapability::Vision,
1572 ModelCapability::Grounding,
1573 ],
1574 context_length: 262_144,
1575 param_count: "2B".into(),
1576 quantization: None,
1577 performance: Default::default(),
1578 cost: Default::default(),
1579 source: crate::schema::ModelSource::Mlx {
1580 hf_repo: "Qwen/Qwen3-VL-2B-Instruct".into(),
1581 hf_weight_file: None,
1582 },
1583 tags: vec!["vision".into(), "mlx-vlm-cli".into()],
1584 supported_params: vec![],
1585 public_benchmarks: vec![],
1586 available: true,
1587 });
1588
1589 let decision = router.route_with_vision("What is in this image?", ®, &tracker, false);
1590 let schema = reg
1591 .find_by_name(&decision.model_name)
1592 .expect("model should exist");
1593 assert!(
1594 schema.has_capability(ModelCapability::Vision),
1595 "selected model must support Vision"
1596 );
1597 }
1598
1599 #[test]
1600 fn profile_based_routing_favors_proven_model() {
1601 let router = AdaptiveRouter::new(
1602 test_hw(),
1603 RoutingConfig {
1604 prior_strength: 0.5, min_observations: 3,
1606 ..Default::default()
1607 },
1608 );
1609 let reg = test_registry();
1610 let mut tracker = OutcomeTracker::new();
1611
1612 let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
1614 for _ in 0..20 {
1615 let trace = tracker.record_start(qwen_8b_id, InferenceTask::Code, "test");
1616 tracker.record_complete(&trace, 500, 100, 50);
1617 tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1618 }
1619
1620 let mut wins = 0;
1622 for _ in 0..20 {
1623 let decision = router.route("Fix this bug in the parser", ®, &tracker);
1624 assert_eq!(decision.complexity, TaskComplexity::Code);
1625 if decision.model_id == qwen_8b_id {
1626 wins += 1;
1627 }
1628 }
1629 assert!(
1630 wins >= 12,
1631 "proven model won only {wins}/20 times (expected >= 12)"
1632 );
1633 }
1634
1635 #[test]
1636 fn proven_local_model_can_displace_bootstrap_remote() {
1637 let router = AdaptiveRouter::new(
1638 test_hw(),
1639 RoutingConfig {
1640 prior_strength: 100.0,
1641 bootstrap_min_task_observations: 6,
1642 bootstrap_quality_floor: 0.8,
1643 ..Default::default()
1644 },
1645 );
1646 let reg = test_registry();
1647 let mut tracker = OutcomeTracker::new();
1648
1649 let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
1650 for _ in 0..12 {
1651 let trace = tracker.record_start(qwen_8b_id, InferenceTask::Generate, "test");
1652 tracker.record_complete(&trace, 300, 50, 20);
1653 tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1654 }
1655
1656 let mut local_wins = 0;
1657 for _ in 0..20 {
1658 let decision = router.route("Summarize this design decision.", ®, &tracker);
1659 let schema = reg.get(&decision.model_id).expect("selected model should exist");
1660 if schema.is_local() {
1661 local_wins += 1;
1662 }
1663 }
1664
1665 assert!(
1666 local_wins >= 12,
1667 "proven local model won only {local_wins}/20 times (expected >= 12)"
1668 );
1669 }
1670
1671 #[test]
1672 fn benchmark_prior_informs_background_routing() {
1673 let router = AdaptiveRouter::new(
1674 test_hw(),
1675 RoutingConfig {
1676 prior_strength: 100.0,
1677 ..Default::default()
1678 },
1679 );
1680 let reg = test_registry();
1681 let mut tracker = OutcomeTracker::new();
1682 let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1683 profile.ema_quality = 0.95;
1684 tracker.import_profiles(vec![profile]);
1685
1686 let decision = router.route_context_aware(
1687 "Write a Python fibonacci function.",
1688 128,
1689 ®,
1690 &tracker,
1691 false,
1692 false,
1693 RoutingWorkload::Background,
1694 );
1695
1696 let schema = reg.get(&decision.model_id).expect("selected model should exist");
1697 assert!(schema.is_local(), "background routing should allow strong local benchmark priors to win");
1698 }
1699
1700 #[test]
1701 fn task_specific_benchmark_prior_informs_cold_start_routing() {
1702 let router = AdaptiveRouter::new(
1703 test_hw(),
1704 RoutingConfig {
1705 prior_strength: 100.0,
1706 ..Default::default()
1707 },
1708 );
1709 let reg = test_registry();
1710 let mut tracker = OutcomeTracker::new();
1711 let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1712 profile.task_stats.insert(
1713 crate::outcome::InferenceTask::Code.to_string(),
1714 crate::outcome::TaskStats {
1715 ema_quality: 0.95,
1716 ..Default::default()
1717 },
1718 );
1719 tracker.import_profiles(vec![profile]);
1720
1721 let decision = router.route_context_aware(
1722 "Write a Python fibonacci function.",
1723 128,
1724 ®,
1725 &tracker,
1726 false,
1727 false,
1728 RoutingWorkload::Background,
1729 );
1730
1731 let schema = reg.get(&decision.model_id).expect("selected model should exist");
1732 assert!(
1733 schema.is_local(),
1734 "background routing should use task-specific cold-start priors for local code models"
1735 );
1736 }
1737
1738 #[test]
1739 fn task_specific_latency_prior_affects_cold_start_score() {
1740 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1741 let reg = test_registry();
1742 let model = reg
1743 .get("qwen/qwen3-8b:q4_k_m")
1744 .expect("local test model should exist");
1745
1746 let mut fast_tracker = OutcomeTracker::new();
1747 let mut fast_profile = crate::outcome::ModelProfile::new(model.id.clone());
1748 fast_profile.task_stats.insert(
1749 crate::outcome::InferenceTask::Generate.to_string(),
1750 crate::outcome::TaskStats {
1751 ema_quality: 0.95,
1752 avg_latency_ms: 1200.0,
1753 ..Default::default()
1754 },
1755 );
1756 fast_tracker.import_profiles(vec![fast_profile]);
1757
1758 let mut slow_tracker = OutcomeTracker::new();
1759 let mut slow_profile = crate::outcome::ModelProfile::new(model.id.clone());
1760 slow_profile.task_stats.insert(
1761 crate::outcome::InferenceTask::Generate.to_string(),
1762 crate::outcome::TaskStats {
1763 ema_quality: 0.95,
1764 avg_latency_ms: 120_000.0,
1765 ..Default::default()
1766 },
1767 );
1768 slow_tracker.import_profiles(vec![slow_profile]);
1769
1770 let fast_score = router.score_model(
1771 model,
1772 InferenceTask::Generate,
1773 &fast_tracker,
1774 RoutingWorkload::Interactive,
1775 );
1776 let slow_score = router.score_model(
1777 model,
1778 InferenceTask::Generate,
1779 &slow_tracker,
1780 RoutingWorkload::Interactive,
1781 );
1782
1783 assert!(
1784 fast_score > slow_score,
1785 "faster task latency prior should improve cold-start score ({fast_score} <= {slow_score})"
1786 );
1787 }
1788
1789 #[test]
1790 fn interactive_workload_keeps_remote_bootstrap_bias() {
1791 let router = AdaptiveRouter::new(
1792 test_hw(),
1793 RoutingConfig {
1794 prior_strength: 100.0,
1795 ..Default::default()
1796 },
1797 );
1798 let reg = test_registry();
1799 let mut tracker = OutcomeTracker::new();
1800 let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1801 profile.ema_quality = 0.95;
1802 tracker.import_profiles(vec![profile]);
1803
1804 let decision = router.route_context_aware(
1805 "Write a Python fibonacci function.",
1806 128,
1807 ®,
1808 &tracker,
1809 false,
1810 false,
1811 RoutingWorkload::Interactive,
1812 );
1813
1814 let schema = reg.get(&decision.model_id).expect("selected model should exist");
1815 assert!(
1816 !schema.is_local(),
1817 "interactive routing should still prefer trusted remote models during cold start"
1818 );
1819 }
1820
1821 #[test]
1822 fn fallback_chain_has_alternatives() {
1823 let router = AdaptiveRouter::new(
1824 test_hw(),
1825 RoutingConfig {
1826 prior_strength: 100.0, ..Default::default()
1828 },
1829 );
1830 let reg = test_registry();
1831 let tracker = OutcomeTracker::new();
1832
1833 let decision = router.route("Analyze the architecture trade-offs", ®, &tracker);
1834 assert!(!decision.fallbacks.is_empty());
1835 assert!(!decision.fallbacks.contains(&decision.model_id));
1837 }
1838
1839 #[test]
1840 fn latency_scoring_is_consistent() {
1841 let router = AdaptiveRouter::with_default_config(test_hw());
1843
1844 let schema_score = router.latency_ms_to_score(AdaptiveRouter::tps_to_latency_ms(25.0));
1846 let observed_score = router.latency_ms_to_score(8000.0);
1848 assert!(
1849 (schema_score - observed_score).abs() < 0.01,
1850 "schema ({schema_score}) and observed ({observed_score}) should match"
1851 );
1852 }
1853
1854 #[test]
1855 fn complexity_assessment() {
1856 assert_eq!(
1857 TaskComplexity::assess("What is the capital of France?"),
1858 TaskComplexity::Simple
1859 );
1860 assert_eq!(
1861 TaskComplexity::assess("Fix this broken test"),
1862 TaskComplexity::Code
1863 );
1864 assert_eq!(
1865 TaskComplexity::assess("Analyze the trade-offs between A and B"),
1866 TaskComplexity::Complex
1867 );
1868 }
1869
1870 #[test]
1871 fn beta_sampling_produces_valid_values() {
1872 let mut rng = rand::rng();
1873 for _ in 0..100 {
1875 let s = sample_beta(&mut rng, 2.0, 5.0);
1876 assert!(s >= 0.0 && s <= 1.0, "sample {s} out of [0,1] range");
1877 }
1878 let samples: Vec<f64> = (0..1000).map(|_| sample_beta(&mut rng, 1.0, 1.0)).collect();
1880 let mean = samples.iter().sum::<f64>() / samples.len() as f64;
1881 assert!(
1882 (mean - 0.5).abs() < 0.05,
1883 "Beta(1,1) mean {mean} should be ~0.5"
1884 );
1885 }
1886
1887 #[test]
1888 fn thompson_sampling_converges_to_best() {
1889 let router = AdaptiveRouter::new(
1891 test_hw(),
1892 RoutingConfig {
1893 prior_strength: 1.0, ..Default::default()
1895 },
1896 );
1897 let reg = test_registry();
1898 let mut tracker = OutcomeTracker::new();
1899
1900 let qwen_4b_id = "qwen/qwen3-4b:q4_k_m";
1902 for _ in 0..20 {
1903 let trace = tracker.record_start(qwen_4b_id, InferenceTask::Code, "test");
1904 tracker.record_complete(&trace, 500, 100, 50);
1905 tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1906 }
1907
1908 let mut wins = 0;
1910 for _ in 0..20 {
1911 let decision = router.route("Fix this parser bug", ®, &tracker);
1912 if decision.model_id == qwen_4b_id {
1913 wins += 1;
1914 }
1915 }
1916 assert!(
1917 wins >= 14,
1918 "strong model won only {wins}/20 times (expected >= 14)"
1919 );
1920 }
1921
1922 #[test]
1925 fn intent_require_filters_out_models_lacking_capability() {
1926 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1931 let reg = test_registry();
1932 let tracker = OutcomeTracker::new();
1933
1934 let intent = crate::intent::IntentHint {
1935 require: vec![ModelCapability::Vision],
1936 ..Default::default()
1937 };
1938 let decision = router.route_with_intent("hello", ®, &tracker, &intent);
1939
1940 assert_eq!(
1945 decision.strategy,
1946 RoutingStrategy::SchemaBased,
1947 "require=[vision] with no vision-capable candidates must drop to schema cold-start"
1948 );
1949 }
1950
1951 #[test]
1952 fn intent_default_does_not_override_task_or_caps() {
1953 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1960 let reg = test_registry();
1961 let tracker = OutcomeTracker::new();
1962
1963 let baseline = router.route("write a haiku", ®, &tracker);
1964 let with_default = router.route_with_intent(
1965 "write a haiku",
1966 ®,
1967 &tracker,
1968 &crate::intent::IntentHint::default(),
1969 );
1970
1971 assert_eq!(
1972 baseline.task, with_default.task,
1973 "default IntentHint must not change the prompt-derived task"
1974 );
1975 assert_eq!(
1976 baseline.complexity, with_default.complexity,
1977 "default IntentHint must not change the prompt-derived complexity"
1978 );
1979 }
1980
1981 #[test]
1982 fn intent_task_hint_overrides_prompt_complexity() {
1983 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1986 let reg = test_registry();
1987 let tracker = OutcomeTracker::new();
1988
1989 let hint = crate::intent::IntentHint {
1990 task: Some(crate::intent::TaskHint::Reasoning),
1991 ..Default::default()
1992 };
1993 let decision = router.route_with_intent("hi", ®, &tracker, &hint);
1994
1995 assert_eq!(
1996 decision.task,
1997 InferenceTask::Reasoning,
1998 "TaskHint::Reasoning should override the prompt-derived task"
1999 );
2000 }
2001}