1use rand::Rng;
13use serde::{Deserialize, Serialize};
14
15use std::sync::{Arc, Mutex};
16
17use crate::hardware::HardwareInfo;
18use crate::outcome::{InferenceTask, OutcomeTracker};
19use crate::registry::UnifiedRegistry;
20use crate::routing_ext::CircuitBreakerRegistry;
21use crate::schema::{ModelCapability, ModelSchema};
22use crate::tasks::RoutingWorkload;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case")]
27pub enum TaskComplexity {
28 Simple,
29 Medium,
30 Code,
31 Complex,
32}
33
34impl TaskComplexity {
35 pub fn assess(prompt: &str) -> Self {
42 let lower = prompt.to_lowercase();
43 let word_count = prompt.split_whitespace().count();
44 let estimated_tokens = (word_count as f64 * 1.3) as usize;
45
46 let has_code = Self::detect_code(prompt);
47
48 let repair_markers = [
49 "fix", "repair", "debug", "refactor", "broken", "failing", "error", "bug",
50 ];
51 let has_repair = repair_markers.iter().any(|m| lower.contains(m));
52
53 let reasoning_markers = [
54 "analyze",
55 "compare",
56 "explain why",
57 "step by step",
58 "think through",
59 "evaluate",
60 "trade-off",
61 "tradeoff",
62 "pros and cons",
63 "architecture",
64 "design",
65 "strategy",
66 "optimize",
67 "comprehensive",
68 ];
69 let has_reasoning = reasoning_markers.iter().any(|m| lower.contains(m));
70
71 let simple_patterns = [
72 "what is",
73 "who is",
74 "when did",
75 "where is",
76 "how many",
77 "yes or no",
78 "true or false",
79 "name the",
80 "list the",
81 "define ",
82 ];
83 let is_simple = simple_patterns.iter().any(|p| lower.contains(p));
84
85 if has_code || has_repair {
86 TaskComplexity::Code
87 } else if has_reasoning || estimated_tokens > 500 {
88 TaskComplexity::Complex
89 } else if is_simple || estimated_tokens < 30 {
90 TaskComplexity::Simple
91 } else {
92 TaskComplexity::Medium
93 }
94 }
95
96 fn detect_code(prompt: &str) -> bool {
102 #[cfg(feature = "ast")]
104 {
105 if let Some(is_code) = Self::detect_code_ast(prompt) {
106 return is_code;
107 }
108 }
109
110 let code_markers = [
112 "```",
113 "fn ",
114 "def ",
115 "class ",
116 "import ",
117 "require(",
118 "async fn",
119 "pub fn",
120 "function ",
121 "const ",
122 "let ",
123 "var ",
124 "#include",
125 "package ",
126 "impl ",
127 ];
128 code_markers.iter().any(|m| prompt.contains(m))
129 }
130
131 #[cfg(feature = "ast")]
135 fn detect_code_ast(prompt: &str) -> Option<bool> {
136 let mut blocks = Vec::new();
138 let mut rest = prompt;
139 while let Some(start) = rest.find("```") {
140 let after_fence = &rest[start + 3..];
141 let code_start = after_fence.find('\n').map(|i| i + 1).unwrap_or(0);
143 if let Some(end) = after_fence[code_start..].find("```") {
144 blocks.push(&after_fence[code_start..code_start + end]);
145 rest = &after_fence[code_start + end + 3..];
146 } else {
147 break;
148 }
149 }
150
151 if blocks.is_empty() {
152 return None; }
154
155 let languages = [
157 car_ast::Language::Rust,
158 car_ast::Language::Python,
159 car_ast::Language::TypeScript,
160 car_ast::Language::JavaScript,
161 car_ast::Language::Go,
162 ];
163
164 for block in &blocks {
165 let trimmed = block.trim();
166 if trimmed.is_empty() {
167 continue;
168 }
169
170 for lang in &languages {
171 if let Some(parsed) = car_ast::parse(trimmed, *lang) {
172 if !parsed.symbols.is_empty() {
174 return Some(true);
175 }
176 }
177 }
178 }
179
180 Some(false)
183 }
184
185 pub fn required_capabilities(&self) -> Vec<ModelCapability> {
187 match self {
188 TaskComplexity::Simple => vec![ModelCapability::Generate],
189 TaskComplexity::Medium => vec![ModelCapability::Generate],
190 TaskComplexity::Code => vec![ModelCapability::Code],
191 TaskComplexity::Complex => vec![ModelCapability::Reasoning],
192 }
193 }
194
195 pub fn inference_task(&self) -> InferenceTask {
197 match self {
198 TaskComplexity::Simple | TaskComplexity::Medium => InferenceTask::Generate,
199 TaskComplexity::Code => InferenceTask::Code,
200 TaskComplexity::Complex => InferenceTask::Reasoning,
201 }
202 }
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct RoutingConfig {
208 pub min_observations: u64,
210 pub quality_weight: f64,
212 pub latency_weight: f64,
213 pub cost_weight: f64,
214 pub max_latency_ms: Option<u64>,
216 pub max_cost_usd: Option<f64>,
218 pub prefer_local: bool,
220 pub prior_strength: f64,
224 pub quality_first_cold_start: bool,
227 pub bootstrap_min_task_observations: u64,
230 pub bootstrap_quality_floor: f64,
233}
234
235impl Default for RoutingConfig {
236 fn default() -> Self {
237 Self {
238 min_observations: 2,
239 quality_weight: 0.45,
240 latency_weight: 0.4,
241 cost_weight: 0.15,
242 max_latency_ms: None,
243 max_cost_usd: None,
244 prefer_local: true,
245 prior_strength: 2.0,
246 quality_first_cold_start: true,
247 bootstrap_min_task_observations: 8,
248 bootstrap_quality_floor: 0.8,
249 }
250 }
251}
252
253#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
255#[serde(rename_all = "snake_case")]
256pub enum RoutingStrategy {
257 SchemaBased,
259 ProfileBased,
261 Exploration,
263 Explicit,
265}
266
267#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct AdaptiveRoutingDecision {
270 pub model_id: String,
272 pub model_name: String,
274 pub task: InferenceTask,
276 pub complexity: TaskComplexity,
278 pub reason: String,
280 pub strategy: RoutingStrategy,
282 pub predicted_quality: f64,
284 pub fallbacks: Vec<String>,
286 pub context_length: usize,
288 pub needs_compaction: bool,
290}
291
292pub struct AdaptiveRouter {
294 hw: HardwareInfo,
295 config: RoutingConfig,
296 pub circuit_breakers: Arc<Mutex<CircuitBreakerRegistry>>,
298}
299
300pub struct RouteRequest<'a> {
316 pub prompt: &'a str,
317 pub registry: &'a UnifiedRegistry,
318 pub tracker: &'a OutcomeTracker,
319 pub estimated_total_tokens: usize,
322 pub has_tools: bool,
323 pub has_vision: bool,
324 pub workload: RoutingWorkload,
325 pub intent: Option<&'a crate::intent::IntentHint>,
328}
329
330impl<'a> RouteRequest<'a> {
331 pub fn new(
335 prompt: &'a str,
336 registry: &'a UnifiedRegistry,
337 tracker: &'a OutcomeTracker,
338 ) -> Self {
339 Self {
340 prompt,
341 registry,
342 tracker,
343 estimated_total_tokens: 0,
344 has_tools: false,
345 has_vision: false,
346 workload: RoutingWorkload::Interactive,
347 intent: None,
348 }
349 }
350}
351
352impl AdaptiveRouter {
353 pub fn new(hw: HardwareInfo, config: RoutingConfig) -> Self {
354 let circuit_breakers = Arc::new(Mutex::new(
355 CircuitBreakerRegistry::new(3, 300), ));
357 Self {
358 hw,
359 config,
360 circuit_breakers,
361 }
362 }
363
364 pub fn with_default_config(hw: HardwareInfo) -> Self {
365 Self::new(hw, RoutingConfig::default())
366 }
367
368 pub fn config(&self) -> &RoutingConfig {
369 &self.config
370 }
371
372 pub fn set_config(&mut self, config: RoutingConfig) {
373 self.config = config;
374 }
375
376 pub fn route_with(&self, req: RouteRequest<'_>) -> AdaptiveRoutingDecision {
382 let workload = match req.intent {
388 Some(h) if h.prefer_fast => RoutingWorkload::Fastest,
389 Some(h) if h.prefer_local => RoutingWorkload::LocalPreferred,
390 _ => req.workload,
391 };
392 self.route_inner_with_intent(
393 req.prompt,
394 req.registry,
395 req.tracker,
396 req.has_tools,
397 req.has_vision,
398 req.estimated_total_tokens,
399 workload,
400 req.intent,
401 )
402 }
403
404 pub fn route(
407 &self,
408 prompt: &str,
409 registry: &UnifiedRegistry,
410 tracker: &OutcomeTracker,
411 ) -> AdaptiveRoutingDecision {
412 self.route_with(RouteRequest::new(prompt, registry, tracker))
413 }
414
415 pub fn route_editor(
421 &self,
422 prompt: &str,
423 registry: &UnifiedRegistry,
424 tracker: &OutcomeTracker,
425 ) -> AdaptiveRoutingDecision {
426 self.route_with(RouteRequest {
427 workload: RoutingWorkload::Background,
428 ..RouteRequest::new(prompt, registry, tracker)
429 })
430 }
431
432 pub fn route_with_tools(
434 &self,
435 prompt: &str,
436 registry: &UnifiedRegistry,
437 tracker: &OutcomeTracker,
438 ) -> AdaptiveRoutingDecision {
439 self.route_with(RouteRequest {
440 has_tools: true,
441 ..RouteRequest::new(prompt, registry, tracker)
442 })
443 }
444
445 pub fn route_with_vision(
447 &self,
448 prompt: &str,
449 registry: &UnifiedRegistry,
450 tracker: &OutcomeTracker,
451 has_tools: bool,
452 ) -> AdaptiveRoutingDecision {
453 self.route_with(RouteRequest {
454 has_tools,
455 has_vision: true,
456 ..RouteRequest::new(prompt, registry, tracker)
457 })
458 }
459
460 pub fn route_with_intent<'a>(
466 &self,
467 prompt: &'a str,
468 registry: &'a UnifiedRegistry,
469 tracker: &'a OutcomeTracker,
470 intent: &'a crate::intent::IntentHint,
471 ) -> AdaptiveRoutingDecision {
472 self.route_with(RouteRequest {
473 intent: Some(intent),
474 ..RouteRequest::new(prompt, registry, tracker)
475 })
476 }
477
478 pub fn route_context_aware(
481 &self,
482 prompt: &str,
483 estimated_total_tokens: usize,
484 registry: &UnifiedRegistry,
485 tracker: &OutcomeTracker,
486 has_tools: bool,
487 has_vision: bool,
488 workload: RoutingWorkload,
489 ) -> AdaptiveRoutingDecision {
490 self.route_with(RouteRequest {
491 estimated_total_tokens,
492 has_tools,
493 has_vision,
494 workload,
495 ..RouteRequest::new(prompt, registry, tracker)
496 })
497 }
498
499 pub fn route_context_aware_with_intent<'a>(
504 &self,
505 prompt: &'a str,
506 estimated_total_tokens: usize,
507 registry: &'a UnifiedRegistry,
508 tracker: &'a OutcomeTracker,
509 has_tools: bool,
510 has_vision: bool,
511 workload: RoutingWorkload,
512 intent: &'a crate::intent::IntentHint,
513 ) -> AdaptiveRoutingDecision {
514 self.route_with(RouteRequest {
515 estimated_total_tokens,
516 has_tools,
517 has_vision,
518 workload,
519 intent: Some(intent),
520 ..RouteRequest::new(prompt, registry, tracker)
521 })
522 }
523
524 fn route_inner_with_intent(
525 &self,
526 prompt: &str,
527 registry: &UnifiedRegistry,
528 tracker: &OutcomeTracker,
529 has_tools: bool,
530 has_vision: bool,
531 estimated_total_tokens: usize,
532 workload: RoutingWorkload,
533 intent: Option<&crate::intent::IntentHint>,
534 ) -> AdaptiveRoutingDecision {
535 let complexity = TaskComplexity::assess(prompt);
536 let task = intent
538 .and_then(|h| h.task)
539 .map(task_hint_to_inference_task)
540 .unwrap_or_else(|| complexity.inference_task());
541 let mut required_caps = complexity.required_capabilities();
542 if let Some(hint) = intent {
543 for cap in &hint.require {
544 if !required_caps.contains(cap) {
545 required_caps.push(*cap);
546 }
547 }
548 }
549 if has_vision {
550 required_caps.push(ModelCapability::Vision);
551 }
552 if has_tools {
553 required_caps.push(ModelCapability::ToolUse);
554 if Self::needs_multi_tool_call(prompt) {
557 required_caps.push(ModelCapability::MultiToolCall);
558 }
559 }
560
561 let mut candidates = self.filter_candidates(&required_caps, registry, tracker, has_vision);
563
564 if candidates.is_empty() && required_caps.contains(&ModelCapability::MultiToolCall) {
567 required_caps.retain(|c| *c != ModelCapability::MultiToolCall);
568 candidates = self.filter_candidates(&required_caps, registry, tracker, has_vision);
569 }
570
571 if candidates.is_empty() {
572 return self.cold_start_decision(complexity, task, registry, has_vision);
574 }
575
576 candidates = self.apply_quality_first_bootstrap_policy(
577 candidates, task, tracker, has_vision, has_tools, workload,
578 );
579
580 let (fits, needs_compaction_candidates) = if estimated_total_tokens > 0 {
583 let mut fits = Vec::new();
584 let mut tight = Vec::new();
585 for m in &candidates {
586 if m.context_length == 0 || m.context_length >= estimated_total_tokens {
587 fits.push(m.clone());
588 } else {
589 tight.push(m.clone());
590 }
591 }
592 (fits, tight)
593 } else {
594 (candidates.clone(), Vec::new())
595 };
596
597 let (scoring_candidates, compaction_needed) = if !fits.is_empty() {
599 (fits, false)
600 } else if !needs_compaction_candidates.is_empty() {
601 tracing::info!(
602 prompt_tokens = estimated_total_tokens,
603 candidates = needs_compaction_candidates.len(),
604 "no model fits full prompt — compaction will be needed"
605 );
606 (needs_compaction_candidates.clone(), true)
607 } else {
608 (candidates.clone(), false)
609 };
610
611 let scored = self.score_candidates_context_aware(
613 &scoring_candidates,
614 task,
615 tracker,
616 estimated_total_tokens,
617 workload,
618 );
619
620 let (selected_id, strategy) = self.select_with_thompson_sampling(&scored, tracker);
622
623 let mut fallbacks: Vec<String> = scored
625 .iter()
626 .filter(|(id, _)| *id != selected_id)
627 .map(|(id, _)| id.clone())
628 .collect();
629 if !compaction_needed {
631 for m in &needs_compaction_candidates {
632 if m.id != selected_id && !fallbacks.contains(&m.id) {
633 fallbacks.push(m.id.clone());
634 }
635 }
636 }
637
638 let predicted_quality = scored
639 .iter()
640 .find(|(id, _)| *id == selected_id)
641 .map(|(_, score)| *score)
642 .unwrap_or(0.5);
643
644 let selected_schema = registry
645 .get(&selected_id)
646 .or_else(|| registry.find_by_name(&selected_id));
647 let model_name = selected_schema
648 .map(|m| m.name.clone())
649 .unwrap_or_else(|| selected_id.clone());
650 let context_length = selected_schema.map(|m| m.context_length).unwrap_or(0);
651
652 let needs_compact = compaction_needed
653 || (estimated_total_tokens > 0
654 && context_length > 0
655 && estimated_total_tokens > context_length);
656
657 let compaction_note = if needs_compact {
658 format!(
659 " [compaction needed: {}→{}tok]",
660 estimated_total_tokens, context_length
661 )
662 } else {
663 String::new()
664 };
665
666 let reason = format!(
667 "{:?} task → {} via {:?} (quality: {:.2}, {} candidates){}",
668 complexity,
669 model_name,
670 strategy,
671 predicted_quality,
672 scoring_candidates.len(),
673 compaction_note,
674 );
675
676 AdaptiveRoutingDecision {
677 model_id: selected_id,
678 model_name,
679 task,
680 complexity,
681 reason,
682 strategy,
683 predicted_quality,
684 fallbacks,
685 context_length,
686 needs_compaction: needs_compact,
687 }
688 }
689
690 pub fn route_embedding(&self, registry: &UnifiedRegistry) -> String {
692 let embed_models = registry.query_by_capability(ModelCapability::Embed);
693 embed_models
694 .first()
695 .map(|m| m.name.clone())
696 .unwrap_or_else(|| "Qwen3-Embedding-0.6B".to_string())
697 }
698
699 pub fn route_small(&self, registry: &UnifiedRegistry) -> String {
701 let gen_models = registry.query_by_capability(ModelCapability::Generate);
702 gen_models
704 .iter()
705 .filter(|m| m.is_local())
706 .min_by_key(|m| m.size_mb())
707 .map(|m| m.name.clone())
708 .unwrap_or_else(|| "Qwen3-0.6B".to_string())
709 }
710
711 const LATENCY_CEILING_MS: f64 = 10000.0;
717 const _TPS_CEILING: f64 = 150.0;
719 const MOE_TPS_MULTIPLIER: f64 = 0.10;
721 const MLX_MOE_TPS_MULTIPLIER: f64 = 0.50;
723 const COST_CEILING_PER_1K: f64 = 0.1;
725 const LOCAL_BONUS: f64 = 0.15;
727 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
735 const HAS_GPU_BACKEND: bool = true;
736 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
737 const HAS_GPU_BACKEND: bool = false;
738 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
741 const MLX_BONUS: f64 = 0.10;
742 const SYSTEM_LLM_BONUS: f64 = 0.12;
753
754 fn filter_candidates(
758 &self,
759 required_caps: &[ModelCapability],
760 registry: &UnifiedRegistry,
761 tracker: &OutcomeTracker,
762 has_vision: bool,
763 ) -> Vec<ModelSchema> {
764 registry
765 .list()
766 .into_iter()
767 .filter(|m| {
768 if !required_caps.iter().all(|c| m.has_capability(*c)) {
770 return false;
771 }
772 if !has_vision && m.tags.iter().any(|t| t == "mlx-vlm-cli") {
782 return false;
783 }
784 if !m.available {
786 return false;
787 }
788 if m.is_local() && m.size_mb() >= self.hw.max_model_mb {
790 return false;
791 }
792 if let Some(max) = self.config.max_latency_ms {
794 if let Some(p50) = m.performance.latency_p50_ms {
795 if p50 > max {
796 return false;
797 }
798 }
799 }
800 if let Some(max) = self.config.max_cost_usd {
802 if m.cost_per_1k_output() > max {
803 return false;
804 }
805 }
806 if !self.config.prefer_local && m.is_local() {
808 return false;
809 }
810 if tracker.is_excluded(&m.id) {
812 return false;
813 }
814 if let Ok(mut cb) = self.circuit_breakers.lock() {
816 if !cb.allow_request(&m.id) {
817 tracing::debug!(model = %m.id, "skipped by circuit breaker");
818 return false;
819 }
820 }
821 true
822 })
823 .cloned()
824 .collect()
825 }
826
827 fn apply_quality_first_bootstrap_policy(
828 &self,
829 candidates: Vec<ModelSchema>,
830 task: InferenceTask,
831 tracker: &OutcomeTracker,
832 has_vision: bool,
833 has_tools: bool,
834 workload: RoutingWorkload,
835 ) -> Vec<ModelSchema> {
836 if !self.config.quality_first_cold_start
837 || !workload.is_latency_sensitive()
838 || !self.is_quality_critical_bootstrap_task(task, has_vision, has_tools)
839 {
840 return candidates;
841 }
842
843 let trusted_remote: Vec<ModelSchema> = candidates
844 .iter()
845 .filter(|model| self.is_trusted_quality_remote(model))
846 .cloned()
847 .collect();
848
849 if trusted_remote.is_empty() {
850 return candidates;
851 }
852
853 let proven_local: Vec<ModelSchema> = candidates
854 .iter()
855 .filter(|model| {
856 model.is_local() && self.is_local_model_proven_for_task(model, task, tracker)
857 })
858 .cloned()
859 .collect();
860
861 if !proven_local.is_empty() {
862 return proven_local;
863 }
864
865 trusted_remote
866 }
867
868 fn score_candidates_context_aware(
872 &self,
873 candidates: &[ModelSchema],
874 task: InferenceTask,
875 tracker: &OutcomeTracker,
876 estimated_total_tokens: usize,
877 workload: RoutingWorkload,
878 ) -> Vec<(String, f64)> {
879 let mut scored: Vec<(String, f64)> = candidates
880 .iter()
881 .map(|m| {
882 let base_score = self.score_model(m, task, tracker, workload);
883 let headroom_bonus = if estimated_total_tokens > 0 && m.context_length > 0 {
886 let ratio = m.context_length as f64 / estimated_total_tokens as f64;
887 if ratio >= 1.0 {
888 (ratio.min(4.0) - 1.0) / 3.0 * 0.10 } else {
890 -0.15 }
892 } else {
893 0.0
894 };
895 (m.id.clone(), base_score + headroom_bonus)
896 })
897 .collect();
898
899 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
900 scored
901 }
902
903 fn task_aware_weights(
919 task: InferenceTask,
920 workload: RoutingWorkload,
921 ) -> (f64, f64, f64) {
922 match (task, workload) {
923 (
924 InferenceTask::Code,
925 RoutingWorkload::Interactive | RoutingWorkload::LocalPreferred,
926 ) => (0.62, 0.28, 0.10),
927 _ => workload.weights(),
928 }
929 }
930
931 fn score_model(
932 &self,
933 model: &ModelSchema,
934 task: InferenceTask,
935 tracker: &OutcomeTracker,
936 workload: RoutingWorkload,
937 ) -> f64 {
938 let profile = tracker.profile(&model.id);
939 let schema_quality = self.schema_quality_estimate(model);
940 let schema_latency = self.schema_latency_estimate(model);
941 let (quality_weight, latency_weight, cost_weight) =
942 Self::task_aware_weights(task, workload);
943
944 let quality = match profile {
947 Some(p) if p.total_calls >= self.config.min_observations => p
948 .task_stats(task)
949 .map(|ts| ts.ema_quality)
950 .unwrap_or(p.ema_quality),
951 Some(p) if p.total_calls == 0 => p
952 .task_stats(task)
953 .map(|ts| ts.ema_quality)
954 .unwrap_or(p.ema_quality),
955 Some(p) if p.total_calls > 0 => {
956 let w = p.total_calls as f64 / self.config.min_observations as f64;
957 schema_quality * (1.0 - w) + p.ema_quality * w
958 }
959 _ => schema_quality,
960 };
961
962 let latency = match profile {
965 Some(p) if p.total_calls >= self.config.min_observations => {
966 let avg = p
967 .task_stats(task)
968 .filter(|ts| ts.calls > 0 || ts.avg_latency_ms > 0.0)
969 .map(|ts| ts.avg_latency_ms)
970 .unwrap_or_else(|| p.avg_latency_ms());
971 self.latency_ms_to_score(avg)
972 }
973 Some(p) if p.total_calls == 0 => p
974 .task_stats(task)
975 .filter(|ts| ts.avg_latency_ms > 0.0)
976 .map(|ts| self.latency_ms_to_score(ts.avg_latency_ms))
977 .unwrap_or(schema_latency),
978 Some(p) if p.total_calls > 0 => {
979 let observed = self.latency_ms_to_score(
980 p.task_stats(task)
981 .filter(|ts| ts.calls > 0 || ts.avg_latency_ms > 0.0)
982 .map(|ts| ts.avg_latency_ms)
983 .unwrap_or_else(|| p.avg_latency_ms()),
984 );
985 let w = p.total_calls as f64 / self.config.min_observations as f64;
986 schema_latency * (1.0 - w) + observed * w
987 }
988 _ => schema_latency,
989 };
990
991 let cost = if model.is_local() {
993 1.0
994 } else {
995 (1.0 - (model.cost_per_1k_output() / Self::COST_CEILING_PER_1K)).clamp(0.0, 1.0)
996 };
997
998 let local_bonus = if self.config.prefer_local && model.is_local() && Self::HAS_GPU_BACKEND {
1007 Self::LOCAL_BONUS
1008 } else {
1009 if self.config.prefer_local && model.is_local() && !Self::HAS_GPU_BACKEND {
1010 tracing::debug!(
1011 model = %model.id,
1012 "LOCAL_BONUS suppressed: no GPU backend on this host (Intel Mac or car_skip_mlx); cloud models will rank higher"
1013 );
1014 }
1015 0.0
1016 };
1017 let workload_local_bonus = if model.is_local() {
1018 workload.local_bonus()
1019 } else {
1020 0.0
1021 };
1022
1023 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1025 let mlx_bonus = if model.is_mlx() { Self::MLX_BONUS } else { 0.0 };
1026 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1027 let mlx_bonus = 0.0;
1028
1029 let vllm_mlx_bonus = if model.is_vllm_mlx() {
1031 Self::LOCAL_BONUS + 0.05
1032 } else {
1033 0.0
1034 };
1035
1036 let system_llm_bonus = if model.tags.iter().any(|t| t == "low_latency")
1044 && model.tags.iter().any(|t| t == "private")
1045 {
1046 Self::SYSTEM_LLM_BONUS
1047 } else {
1048 0.0
1049 };
1050
1051 quality_weight * quality
1052 + latency_weight * latency
1053 + cost_weight * cost
1054 + local_bonus
1055 + workload_local_bonus
1056 + mlx_bonus
1057 + vllm_mlx_bonus
1058 + system_llm_bonus
1059 }
1060
1061 fn latency_ms_to_score(&self, ms: f64) -> f64 {
1064 (1.0 - (ms / Self::LATENCY_CEILING_MS)).clamp(0.0, 1.0)
1065 }
1066
1067 fn tps_to_latency_ms(tps: f64) -> f64 {
1069 if tps <= 0.0 {
1070 return Self::LATENCY_CEILING_MS;
1071 }
1072 (200.0 / tps) * 1000.0
1074 }
1075
1076 fn needs_multi_tool_call(prompt: &str) -> bool {
1079 let lower = prompt.to_lowercase();
1080
1081 let has_numbered_list = {
1083 let mut count = 0u32;
1084 for i in 1..=5u32 {
1085 if lower.contains(&format!("{}) ", i)) || lower.contains(&format!("{}. ", i)) {
1086 count += 1;
1087 }
1088 }
1089 count >= 2
1090 };
1091
1092 let multi_keywords = [
1094 "multiple edits",
1095 "several changes",
1096 "three changes",
1097 "two changes",
1098 "all of the following",
1099 "each of these",
1100 "do both",
1101 "do all",
1102 "and also",
1103 "additionally",
1104 "as well as",
1105 "then also",
1106 ];
1107 let has_multi_keywords = multi_keywords.iter().any(|kw| lower.contains(kw));
1108
1109 let bullet_actions = lower.matches("- add ").count()
1111 + lower.matches("- update ").count()
1112 + lower.matches("- change ").count()
1113 + lower.matches("- remove ").count()
1114 + lower.matches("- fix ").count()
1115 + lower.matches("- edit ").count()
1116 + lower.matches("- implement ").count()
1117 + lower.matches("- create ").count();
1118 let has_bullet_list = bullet_actions >= 2;
1119
1120 has_numbered_list || has_multi_keywords || has_bullet_list
1121 }
1122
1123 fn schema_quality_estimate(&self, model: &ModelSchema) -> f64 {
1128 match model.size_mb() {
1129 0 => 0.5, s if s < 1000 => 0.4, s if s < 2000 => 0.5, s if s < 3000 => 0.6, s if s < 6000 => 0.7, _ => 0.75, }
1136 }
1137
1138 fn schema_latency_estimate(&self, model: &ModelSchema) -> f64 {
1143 let is_moe = model.tags.contains(&"moe".to_string());
1144
1145 if model.is_local() {
1146 if let Some(tps) = model.performance.tokens_per_second {
1147 let effective_tps = if is_moe {
1148 let multiplier = if model.is_mlx() {
1149 Self::MLX_MOE_TPS_MULTIPLIER
1150 } else {
1151 Self::MOE_TPS_MULTIPLIER
1152 };
1153 tps * multiplier
1154 } else {
1155 tps
1156 };
1157 let estimated_ms = Self::tps_to_latency_ms(effective_tps);
1158 return self.latency_ms_to_score(estimated_ms);
1159 }
1160 return 0.5; }
1162
1163 if let Some(p50) = model.performance.latency_p50_ms {
1165 return self.latency_ms_to_score(p50 as f64);
1166 }
1167 0.3 }
1169
1170 fn is_quality_critical_bootstrap_task(
1171 &self,
1172 task: InferenceTask,
1173 has_vision: bool,
1174 has_tools: bool,
1175 ) -> bool {
1176 has_vision
1177 || has_tools
1178 || matches!(
1179 task,
1180 InferenceTask::Generate | InferenceTask::Code | InferenceTask::Reasoning
1181 )
1182 }
1183
1184 fn is_trusted_quality_remote(&self, model: &ModelSchema) -> bool {
1185 model.is_remote()
1186 && matches!(model.provider.as_str(), "openai" | "anthropic" | "google")
1187 && !model.has_capability(ModelCapability::SpeechToText)
1188 && !model.has_capability(ModelCapability::TextToSpeech)
1189 }
1190
1191 fn is_local_model_proven_for_task(
1192 &self,
1193 model: &ModelSchema,
1194 task: InferenceTask,
1195 tracker: &OutcomeTracker,
1196 ) -> bool {
1197 let Some(profile) = tracker.profile(&model.id) else {
1198 return false;
1199 };
1200 if let Some(task_stats) = profile.task_stats(task) {
1201 if task_stats.calls >= self.config.bootstrap_min_task_observations
1202 && task_stats.ema_quality >= self.config.bootstrap_quality_floor
1203 {
1204 return true;
1205 }
1206 }
1207
1208 profile.total_calls >= self.config.bootstrap_min_task_observations
1209 && profile.ema_quality >= self.config.bootstrap_quality_floor
1210 }
1211
1212 fn select_with_thompson_sampling(
1222 &self,
1223 scored: &[(String, f64)],
1224 tracker: &OutcomeTracker,
1225 ) -> (String, RoutingStrategy) {
1226 if scored.is_empty() {
1227 return (String::new(), RoutingStrategy::SchemaBased);
1228 }
1229
1230 let mut rng = rand::rng();
1231 let mut best_sample = f64::NEG_INFINITY;
1232 let mut best_id = scored[0].0.clone();
1233 let mut best_strategy = RoutingStrategy::SchemaBased;
1234
1235 for (id, phase2_score) in scored {
1236 let profile = tracker.profile(id);
1237 let prior = self.config.prior_strength;
1238
1239 let prior_mean = phase2_score.clamp(0.0, 1.0);
1241
1242 let prior_alpha = prior * prior_mean;
1244 let prior_beta = prior * (1.0 - prior_mean);
1245
1246 let (obs_alpha, obs_beta) = match profile {
1248 Some(p) => (p.success_count as f64, p.fail_count as f64),
1249 None => (0.0, 0.0),
1250 };
1251
1252 let alpha = (prior_alpha + obs_alpha).max(0.01);
1254 let beta = (prior_beta + obs_beta).max(0.01);
1255
1256 let sample = sample_beta(&mut rng, alpha, beta);
1258
1259 if sample > best_sample {
1260 best_sample = sample;
1261 best_id = id.clone();
1262 best_strategy = match profile {
1263 Some(p) if p.total_calls >= self.config.min_observations => {
1264 RoutingStrategy::ProfileBased
1265 }
1266 Some(p) if p.total_calls > 0 => {
1267 RoutingStrategy::Exploration
1269 }
1270 _ => RoutingStrategy::SchemaBased,
1271 };
1272 }
1273 }
1274
1275 (best_id, best_strategy)
1276 }
1277
1278 fn cold_start_decision(
1280 &self,
1281 complexity: TaskComplexity,
1282 task: InferenceTask,
1283 registry: &UnifiedRegistry,
1284 has_vision: bool,
1285 ) -> AdaptiveRoutingDecision {
1286 if has_vision {
1287 if let Some(model) = registry
1288 .query_by_capability(ModelCapability::Vision)
1289 .into_iter()
1290 .find(|model| model.available && self.is_trusted_quality_remote(model))
1291 .or_else(|| {
1292 registry
1293 .query_by_capability(ModelCapability::Vision)
1294 .first()
1295 .copied()
1296 })
1297 {
1298 return AdaptiveRoutingDecision {
1299 model_id: model.id.clone(),
1300 model_name: model.name.clone(),
1301 task,
1302 complexity,
1303 reason: format!(
1304 "{:?} task → {} (cold start, vision fallback)",
1305 complexity, model.name
1306 ),
1307 strategy: RoutingStrategy::SchemaBased,
1308 predicted_quality: 0.5,
1309 fallbacks: vec![],
1310 context_length: model.context_length,
1311 needs_compaction: false,
1312 };
1313 }
1314 }
1315
1316 if self.config.quality_first_cold_start {
1317 let required_caps = complexity.required_capabilities();
1318 if let Some(model) = registry
1319 .list()
1320 .into_iter()
1321 .filter(|model| {
1322 model.available
1323 && required_caps.iter().all(|cap| model.has_capability(*cap))
1324 && self.is_trusted_quality_remote(model)
1325 })
1326 .max_by(|a, b| {
1327 self.schema_quality_estimate(a)
1328 .partial_cmp(&self.schema_quality_estimate(b))
1329 .unwrap_or(std::cmp::Ordering::Equal)
1330 })
1331 {
1332 return AdaptiveRoutingDecision {
1333 model_id: model.id.clone(),
1334 model_name: model.name.clone(),
1335 task,
1336 complexity,
1337 reason: format!(
1338 "{:?} task → {} (quality-first cold start)",
1339 complexity, model.name
1340 ),
1341 strategy: RoutingStrategy::SchemaBased,
1342 predicted_quality: self.schema_quality_estimate(model),
1343 fallbacks: vec![],
1344 context_length: model.context_length,
1345 needs_compaction: false,
1346 };
1347 }
1348 }
1349
1350 let model_name = match complexity {
1352 TaskComplexity::Simple => "Qwen3-0.6B",
1353 TaskComplexity::Medium => "Qwen3-1.7B",
1354 TaskComplexity::Code => "Qwen3-4B",
1355 TaskComplexity::Complex => &self.hw.recommended_model,
1356 };
1357
1358 let model_id = registry
1359 .find_by_name(model_name)
1360 .map(|m| m.id.clone())
1361 .unwrap_or_else(|| model_name.to_string());
1362
1363 let context_length = registry
1364 .find_by_name(model_name)
1365 .map(|m| m.context_length)
1366 .unwrap_or(0);
1367
1368 AdaptiveRoutingDecision {
1369 model_id,
1370 model_name: model_name.to_string(),
1371 task,
1372 complexity,
1373 reason: format!(
1374 "{:?} task → {} (cold start, no candidates)",
1375 complexity, model_name
1376 ),
1377 strategy: RoutingStrategy::SchemaBased,
1378 predicted_quality: 0.5,
1379 fallbacks: vec![],
1380 context_length,
1381 needs_compaction: false,
1382 }
1383 }
1384}
1385
1386fn task_hint_to_inference_task(hint: crate::intent::TaskHint) -> InferenceTask {
1392 use crate::intent::TaskHint;
1393 match hint {
1394 TaskHint::Chat => InferenceTask::Generate,
1395 TaskHint::Classify => InferenceTask::Classify,
1396 TaskHint::Reasoning => InferenceTask::Reasoning,
1397 TaskHint::Code => InferenceTask::Code,
1398 }
1399}
1400
1401fn sample_beta(rng: &mut impl Rng, alpha: f64, beta: f64) -> f64 {
1409 let x = sample_gamma(rng, alpha);
1410 let y = sample_gamma(rng, beta);
1411 if x + y == 0.0 {
1412 0.5 } else {
1414 x / (x + y)
1415 }
1416}
1417
1418fn sample_gamma(rng: &mut impl Rng, shape: f64) -> f64 {
1421 if shape < 1.0 {
1422 let u: f64 = rng.random();
1424 return sample_gamma(rng, shape + 1.0) * u.powf(1.0 / shape);
1425 }
1426
1427 let d = shape - 1.0 / 3.0;
1429 let c = 1.0 / (9.0 * d).sqrt();
1430
1431 loop {
1432 let x: f64 = loop {
1433 let n = sample_standard_normal(rng);
1434 if 1.0 + c * n > 0.0 {
1435 break n;
1436 }
1437 };
1438
1439 let v = (1.0 + c * x).powi(3);
1440 let u: f64 = rng.random();
1441
1442 if u < 1.0 - 0.0331 * x.powi(4) {
1443 return d * v;
1444 }
1445 if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
1446 return d * v;
1447 }
1448 }
1449}
1450
1451fn sample_standard_normal(rng: &mut impl Rng) -> f64 {
1453 let u1: f64 = rng.random();
1454 let u2: f64 = rng.random();
1455 (-2.0 * u1.max(1e-300).ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
1456}
1457
1458#[cfg(test)]
1459mod tests {
1460 use super::*;
1461 use crate::outcome::InferredOutcome;
1462
1463 fn test_hw() -> HardwareInfo {
1464 HardwareInfo {
1465 os: "macos".into(),
1466 arch: "aarch64".into(),
1467 cpu_cores: 10,
1468 total_ram_mb: 32768,
1469 gpu_backend: crate::hardware::GpuBackend::Metal,
1470 gpu_memory_mb: Some(28672),
1471 gpu_devices: Vec::new(),
1472 recommended_model: "Qwen3-8B".into(),
1473 recommended_context: 8192,
1474 max_model_mb: 18000, }
1476 }
1477
1478 fn test_registry() -> UnifiedRegistry {
1479 let tmp = std::path::PathBuf::from("/tmp/car-test-adaptive-router");
1480 unsafe {
1481 std::env::set_var("OPENAI_API_KEY", "test-openai-key");
1482 }
1483 for name in &[
1485 "Qwen3-0.6B",
1486 "Qwen3-1.7B",
1487 "Qwen3-4B",
1488 "Qwen3-8B",
1489 "Qwen3-Embedding-0.6B",
1490 ] {
1491 let dir = tmp.join(name);
1492 let _ = std::fs::create_dir_all(&dir);
1493 let _ = std::fs::write(dir.join("model.gguf"), b"fake");
1494 let _ = std::fs::write(dir.join("tokenizer.json"), b"{}");
1495 }
1496 let mut reg = UnifiedRegistry::new(tmp);
1497 reg.register(ModelSchema {
1498 id: "openai/gpt-5.4-mini:latest".into(),
1499 name: "gpt-5.4-mini".into(),
1500 provider: "openai".into(),
1501 family: "gpt-5.4".into(),
1502 version: "latest".into(),
1503 capabilities: vec![
1504 ModelCapability::Generate,
1505 ModelCapability::Code,
1506 ModelCapability::Reasoning,
1507 ModelCapability::ToolUse,
1508 ModelCapability::MultiToolCall,
1509 ModelCapability::Vision,
1510 ],
1511 context_length: 128_000,
1512 param_count: "api".into(),
1513 quantization: None,
1514 performance: Default::default(),
1515 cost: Default::default(),
1516 source: crate::schema::ModelSource::RemoteApi {
1517 endpoint: "https://api.openai.com/v1".into(),
1518 api_key_env: "OPENAI_API_KEY".into(),
1519 api_key_envs: vec![],
1520 api_version: None,
1521 protocol: crate::schema::ApiProtocol::OpenAiCompat,
1522 },
1523 tags: vec!["trusted-remote".into()],
1524 supported_params: vec![],
1525 public_benchmarks: vec![],
1526 trust_tier: crate::schema::TrustTier::Curated,
1527 deprecated: false,
1528 available: true,
1529 });
1530 reg
1531 }
1532
1533 #[test]
1534 fn routes_simple_to_trusted_remote_during_cold_start() {
1535 let router = AdaptiveRouter::new(
1536 test_hw(),
1537 RoutingConfig {
1538 prior_strength: 100.0, ..Default::default()
1540 },
1541 );
1542 let reg = test_registry();
1543 let tracker = OutcomeTracker::new();
1544
1545 let decision = router.route("What is 2+2?", ®, &tracker);
1546 assert_eq!(decision.complexity, TaskComplexity::Simple);
1547 assert_eq!(decision.strategy, RoutingStrategy::SchemaBased);
1548 let schema = reg
1550 .find_by_name(&decision.model_name)
1551 .expect("selected model should exist in registry");
1552 assert!(
1553 !schema.is_local(),
1554 "simple task should route to trusted remote model during cold start"
1555 );
1556 assert!(matches!(
1557 schema.provider.as_str(),
1558 "openai" | "anthropic" | "google"
1559 ));
1560 }
1561
1562 #[test]
1563 fn routes_code_to_code_capable_remote_during_cold_start() {
1564 let router = AdaptiveRouter::new(
1565 test_hw(),
1566 RoutingConfig {
1567 prior_strength: 100.0, ..Default::default()
1569 },
1570 );
1571 let reg = test_registry();
1572 let tracker = OutcomeTracker::new();
1573
1574 let decision = router.route(
1575 "Fix this function:\n```rust\nfn main() {}\n```",
1576 ®,
1577 &tracker,
1578 );
1579 assert_eq!(decision.complexity, TaskComplexity::Code);
1580 assert_eq!(decision.task, InferenceTask::Code);
1581 let schema = reg
1583 .find_by_name(&decision.model_name)
1584 .expect("model should exist");
1585 assert!(
1586 schema.has_capability(ModelCapability::Code),
1587 "selected model must support Code"
1588 );
1589 assert!(!schema.is_local(), "should route to trusted remote model");
1590 }
1591
1592 #[test]
1593 fn routes_images_to_vision_capable_model() {
1594 let router = AdaptiveRouter::new(
1595 test_hw(),
1596 RoutingConfig {
1597 prior_strength: 100.0,
1598 ..Default::default()
1599 },
1600 );
1601 let mut reg = test_registry();
1602 let tracker = OutcomeTracker::new();
1603
1604 reg.register(ModelSchema {
1605 id: "mlx-vlm/qwen3-vl-2b:bf16".into(),
1606 name: "Qwen3-VL-2B-mlx-vlm".into(),
1607 provider: "qwen".into(),
1608 family: "qwen3-vl".into(),
1609 version: "bf16".into(),
1610 capabilities: vec![
1611 ModelCapability::Generate,
1612 ModelCapability::Vision,
1613 ModelCapability::Grounding,
1614 ],
1615 context_length: 262_144,
1616 param_count: "2B".into(),
1617 quantization: None,
1618 performance: Default::default(),
1619 cost: Default::default(),
1620 source: crate::schema::ModelSource::Mlx {
1621 hf_repo: "Qwen/Qwen3-VL-2B-Instruct".into(),
1622 hf_weight_file: None,
1623 },
1624 tags: vec!["vision".into(), "mlx-vlm-cli".into()],
1625 supported_params: vec![],
1626 public_benchmarks: vec![],
1627 trust_tier: crate::schema::TrustTier::Curated,
1628 deprecated: false,
1629 available: true,
1630 });
1631
1632 let decision = router.route_with_vision("What is in this image?", ®, &tracker, false);
1633 let schema = reg
1634 .find_by_name(&decision.model_name)
1635 .expect("model should exist");
1636 assert!(
1637 schema.has_capability(ModelCapability::Vision),
1638 "selected model must support Vision"
1639 );
1640 }
1641
1642 #[test]
1643 fn profile_based_routing_favors_proven_model() {
1644 let router = AdaptiveRouter::new(
1645 test_hw(),
1646 RoutingConfig {
1647 prior_strength: 0.5, min_observations: 3,
1649 ..Default::default()
1650 },
1651 );
1652 let reg = test_registry();
1653 let mut tracker = OutcomeTracker::new();
1654
1655 let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
1657 for _ in 0..20 {
1658 let trace = tracker.record_start(qwen_8b_id, InferenceTask::Code, "test");
1659 tracker.record_complete(&trace, 500, 100, 50);
1660 tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1661 }
1662
1663 let mut wins = 0;
1665 for _ in 0..20 {
1666 let decision = router.route("Fix this bug in the parser", ®, &tracker);
1667 assert_eq!(decision.complexity, TaskComplexity::Code);
1668 if decision.model_id == qwen_8b_id {
1669 wins += 1;
1670 }
1671 }
1672 assert!(
1673 wins >= 12,
1674 "proven model won only {wins}/20 times (expected >= 12)"
1675 );
1676 }
1677
1678 #[test]
1679 fn proven_local_model_can_displace_bootstrap_remote() {
1680 let router = AdaptiveRouter::new(
1681 test_hw(),
1682 RoutingConfig {
1683 prior_strength: 100.0,
1684 bootstrap_min_task_observations: 6,
1685 bootstrap_quality_floor: 0.8,
1686 ..Default::default()
1687 },
1688 );
1689 let reg = test_registry();
1690 let mut tracker = OutcomeTracker::new();
1691
1692 let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
1693 for _ in 0..12 {
1694 let trace = tracker.record_start(qwen_8b_id, InferenceTask::Generate, "test");
1695 tracker.record_complete(&trace, 300, 50, 20);
1696 tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1697 }
1698
1699 let mut local_wins = 0;
1700 for _ in 0..20 {
1701 let decision = router.route("Summarize this design decision.", ®, &tracker);
1702 let schema = reg
1703 .get(&decision.model_id)
1704 .expect("selected model should exist");
1705 if schema.is_local() {
1706 local_wins += 1;
1707 }
1708 }
1709
1710 assert!(
1711 local_wins >= 12,
1712 "proven local model won only {local_wins}/20 times (expected >= 12)"
1713 );
1714 }
1715
1716 #[test]
1717 fn benchmark_prior_informs_background_routing() {
1718 let router = AdaptiveRouter::new(
1719 test_hw(),
1720 RoutingConfig {
1721 prior_strength: 100.0,
1722 ..Default::default()
1723 },
1724 );
1725 let reg = test_registry();
1726 let mut tracker = OutcomeTracker::new();
1727 let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1728 profile.ema_quality = 0.95;
1729 tracker.import_profiles(vec![profile]);
1730
1731 let decision = router.route_context_aware(
1732 "Write a Python fibonacci function.",
1733 128,
1734 ®,
1735 &tracker,
1736 false,
1737 false,
1738 RoutingWorkload::Background,
1739 );
1740
1741 let schema = reg
1742 .get(&decision.model_id)
1743 .expect("selected model should exist");
1744 assert!(
1745 schema.is_local(),
1746 "background routing should allow strong local benchmark priors to win"
1747 );
1748 }
1749
1750 #[test]
1751 fn task_specific_benchmark_prior_informs_cold_start_routing() {
1752 let router = AdaptiveRouter::new(
1753 test_hw(),
1754 RoutingConfig {
1755 prior_strength: 100.0,
1756 ..Default::default()
1757 },
1758 );
1759 let reg = test_registry();
1760 let mut tracker = OutcomeTracker::new();
1761 let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1762 profile.task_stats.insert(
1763 crate::outcome::InferenceTask::Code.to_string(),
1764 crate::outcome::TaskStats {
1765 ema_quality: 0.95,
1766 ..Default::default()
1767 },
1768 );
1769 tracker.import_profiles(vec![profile]);
1770
1771 let decision = router.route_context_aware(
1772 "Write a Python fibonacci function.",
1773 128,
1774 ®,
1775 &tracker,
1776 false,
1777 false,
1778 RoutingWorkload::Background,
1779 );
1780
1781 let schema = reg
1782 .get(&decision.model_id)
1783 .expect("selected model should exist");
1784 assert!(
1785 schema.is_local(),
1786 "background routing should use task-specific cold-start priors for local code models"
1787 );
1788 }
1789
1790 #[test]
1791 fn task_specific_latency_prior_affects_cold_start_score() {
1792 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1793 let reg = test_registry();
1794 let model = reg
1795 .get("qwen/qwen3-8b:q4_k_m")
1796 .expect("local test model should exist");
1797
1798 let mut fast_tracker = OutcomeTracker::new();
1799 let mut fast_profile = crate::outcome::ModelProfile::new(model.id.clone());
1800 fast_profile.task_stats.insert(
1801 crate::outcome::InferenceTask::Generate.to_string(),
1802 crate::outcome::TaskStats {
1803 ema_quality: 0.95,
1804 avg_latency_ms: 1200.0,
1805 ..Default::default()
1806 },
1807 );
1808 fast_tracker.import_profiles(vec![fast_profile]);
1809
1810 let mut slow_tracker = OutcomeTracker::new();
1811 let mut slow_profile = crate::outcome::ModelProfile::new(model.id.clone());
1812 slow_profile.task_stats.insert(
1813 crate::outcome::InferenceTask::Generate.to_string(),
1814 crate::outcome::TaskStats {
1815 ema_quality: 0.95,
1816 avg_latency_ms: 120_000.0,
1817 ..Default::default()
1818 },
1819 );
1820 slow_tracker.import_profiles(vec![slow_profile]);
1821
1822 let fast_score = router.score_model(
1823 model,
1824 InferenceTask::Generate,
1825 &fast_tracker,
1826 RoutingWorkload::Interactive,
1827 );
1828 let slow_score = router.score_model(
1829 model,
1830 InferenceTask::Generate,
1831 &slow_tracker,
1832 RoutingWorkload::Interactive,
1833 );
1834
1835 assert!(
1836 fast_score > slow_score,
1837 "faster task latency prior should improve cold-start score ({fast_score} <= {slow_score})"
1838 );
1839 }
1840
1841 #[test]
1842 fn interactive_workload_keeps_remote_bootstrap_bias() {
1843 let router = AdaptiveRouter::new(
1844 test_hw(),
1845 RoutingConfig {
1846 prior_strength: 100.0,
1847 ..Default::default()
1848 },
1849 );
1850 let reg = test_registry();
1851 let mut tracker = OutcomeTracker::new();
1852 let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1853 profile.ema_quality = 0.95;
1854 tracker.import_profiles(vec![profile]);
1855
1856 let decision = router.route_context_aware(
1857 "Write a Python fibonacci function.",
1858 128,
1859 ®,
1860 &tracker,
1861 false,
1862 false,
1863 RoutingWorkload::Interactive,
1864 );
1865
1866 let schema = reg
1867 .get(&decision.model_id)
1868 .expect("selected model should exist");
1869 assert!(
1870 !schema.is_local(),
1871 "interactive routing should still prefer trusted remote models during cold start"
1872 );
1873 }
1874
1875 #[test]
1876 fn code_interactive_weights_prioritise_quality_over_speed_over_cost() {
1877 let (q, lat, cost) = AdaptiveRouter::task_aware_weights(
1883 InferenceTask::Code,
1884 RoutingWorkload::Interactive,
1885 );
1886 assert!(q > lat && lat > cost, "expected quality > speed > cost, got ({q}, {lat}, {cost})");
1887 assert!((q + lat + cost - 1.0).abs() < 1e-9, "weights must sum to 1.0");
1888 let (default_q, default_lat, default_cost) = RoutingWorkload::Interactive.weights();
1889 assert!(
1890 q > default_q && lat < default_lat && cost < default_cost,
1891 "code weighting must lift quality and cut latency/cost vs the generic interactive profile"
1892 );
1893
1894 assert_eq!(
1896 AdaptiveRouter::task_aware_weights(
1897 InferenceTask::Code,
1898 RoutingWorkload::LocalPreferred,
1899 ),
1900 (q, lat, cost),
1901 );
1902
1903 assert_eq!(
1906 AdaptiveRouter::task_aware_weights(InferenceTask::Code, RoutingWorkload::Fastest),
1907 RoutingWorkload::Fastest.weights(),
1908 );
1909 assert_eq!(
1911 AdaptiveRouter::task_aware_weights(InferenceTask::Code, RoutingWorkload::Background),
1912 RoutingWorkload::Background.weights(),
1913 );
1914 assert_eq!(
1916 AdaptiveRouter::task_aware_weights(
1917 InferenceTask::Generate,
1918 RoutingWorkload::Interactive,
1919 ),
1920 RoutingWorkload::Interactive.weights(),
1921 );
1922 }
1923
1924 #[test]
1925 fn fallback_chain_has_alternatives() {
1926 let router = AdaptiveRouter::new(
1927 test_hw(),
1928 RoutingConfig {
1929 prior_strength: 100.0, ..Default::default()
1931 },
1932 );
1933 let reg = test_registry();
1934 let tracker = OutcomeTracker::new();
1935
1936 let decision = router.route("Analyze the architecture trade-offs", ®, &tracker);
1937 assert!(!decision.fallbacks.is_empty());
1938 assert!(!decision.fallbacks.contains(&decision.model_id));
1940 }
1941
1942 #[test]
1943 fn latency_scoring_is_consistent() {
1944 let router = AdaptiveRouter::with_default_config(test_hw());
1946
1947 let schema_score = router.latency_ms_to_score(AdaptiveRouter::tps_to_latency_ms(25.0));
1949 let observed_score = router.latency_ms_to_score(8000.0);
1951 assert!(
1952 (schema_score - observed_score).abs() < 0.01,
1953 "schema ({schema_score}) and observed ({observed_score}) should match"
1954 );
1955 }
1956
1957 #[test]
1958 fn complexity_assessment() {
1959 assert_eq!(
1960 TaskComplexity::assess("What is the capital of France?"),
1961 TaskComplexity::Simple
1962 );
1963 assert_eq!(
1964 TaskComplexity::assess("Fix this broken test"),
1965 TaskComplexity::Code
1966 );
1967 assert_eq!(
1968 TaskComplexity::assess("Analyze the trade-offs between A and B"),
1969 TaskComplexity::Complex
1970 );
1971 }
1972
1973 #[test]
1974 fn beta_sampling_produces_valid_values() {
1975 let mut rng = rand::rng();
1976 for _ in 0..100 {
1978 let s = sample_beta(&mut rng, 2.0, 5.0);
1979 assert!(s >= 0.0 && s <= 1.0, "sample {s} out of [0,1] range");
1980 }
1981 let samples: Vec<f64> = (0..1000).map(|_| sample_beta(&mut rng, 1.0, 1.0)).collect();
1983 let mean = samples.iter().sum::<f64>() / samples.len() as f64;
1984 assert!(
1985 (mean - 0.5).abs() < 0.05,
1986 "Beta(1,1) mean {mean} should be ~0.5"
1987 );
1988 }
1989
1990 #[test]
1991 fn thompson_sampling_converges_to_best() {
1992 let router = AdaptiveRouter::new(
1994 test_hw(),
1995 RoutingConfig {
1996 prior_strength: 1.0, ..Default::default()
1998 },
1999 );
2000 let reg = test_registry();
2001 let mut tracker = OutcomeTracker::new();
2002
2003 let qwen_4b_id = "qwen/qwen3-4b:q4_k_m";
2005 for _ in 0..20 {
2006 let trace = tracker.record_start(qwen_4b_id, InferenceTask::Code, "test");
2007 tracker.record_complete(&trace, 500, 100, 50);
2008 tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
2009 }
2010
2011 let mut wins = 0;
2013 for _ in 0..20 {
2014 let decision = router.route("Fix this parser bug", ®, &tracker);
2015 if decision.model_id == qwen_4b_id {
2016 wins += 1;
2017 }
2018 }
2019 assert!(
2020 wins >= 14,
2021 "strong model won only {wins}/20 times (expected >= 14)"
2022 );
2023 }
2024
2025 #[test]
2028 fn intent_require_filters_out_models_lacking_capability() {
2029 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
2034 let reg = test_registry();
2035 let tracker = OutcomeTracker::new();
2036
2037 let intent = crate::intent::IntentHint {
2038 require: vec![ModelCapability::Vision],
2039 ..Default::default()
2040 };
2041 let decision = router.route_with_intent("hello", ®, &tracker, &intent);
2042
2043 assert_eq!(
2048 decision.strategy,
2049 RoutingStrategy::SchemaBased,
2050 "require=[vision] with no vision-capable candidates must drop to schema cold-start"
2051 );
2052 }
2053
2054 #[test]
2055 fn intent_default_does_not_override_task_or_caps() {
2056 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
2063 let reg = test_registry();
2064 let tracker = OutcomeTracker::new();
2065
2066 let baseline = router.route("write a haiku", ®, &tracker);
2067 let with_default = router.route_with_intent(
2068 "write a haiku",
2069 ®,
2070 &tracker,
2071 &crate::intent::IntentHint::default(),
2072 );
2073
2074 assert_eq!(
2075 baseline.task, with_default.task,
2076 "default IntentHint must not change the prompt-derived task"
2077 );
2078 assert_eq!(
2079 baseline.complexity, with_default.complexity,
2080 "default IntentHint must not change the prompt-derived complexity"
2081 );
2082 }
2083
2084 #[test]
2085 fn intent_task_hint_overrides_prompt_complexity() {
2086 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
2089 let reg = test_registry();
2090 let tracker = OutcomeTracker::new();
2091
2092 let hint = crate::intent::IntentHint {
2093 task: Some(crate::intent::TaskHint::Reasoning),
2094 ..Default::default()
2095 };
2096 let decision = router.route_with_intent("hi", ®, &tracker, &hint);
2097
2098 assert_eq!(
2099 decision.task,
2100 InferenceTask::Reasoning,
2101 "TaskHint::Reasoning should override the prompt-derived task"
2102 );
2103 }
2104}