1use rand::Rng;
13use serde::{Deserialize, Serialize};
14
15use std::sync::{Arc, Mutex};
16
17use crate::hardware::HardwareInfo;
18use crate::outcome::{InferenceTask, OutcomeTracker};
19use crate::registry::UnifiedRegistry;
20use crate::routing_ext::CircuitBreakerRegistry;
21use crate::schema::{ModelCapability, ModelSchema};
22use crate::tasks::RoutingWorkload;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case")]
27pub enum TaskComplexity {
28 Simple,
29 Medium,
30 Code,
31 Complex,
32}
33
34impl TaskComplexity {
35 pub fn assess(prompt: &str) -> Self {
42 let lower = prompt.to_lowercase();
43 let word_count = prompt.split_whitespace().count();
44 let estimated_tokens = (word_count as f64 * 1.3) as usize;
45
46 let has_code = Self::detect_code(prompt);
47
48 let repair_markers = [
49 "fix", "repair", "debug", "refactor", "broken", "failing", "error", "bug",
50 ];
51 let has_repair = repair_markers.iter().any(|m| lower.contains(m));
52
53 let reasoning_markers = [
54 "analyze",
55 "compare",
56 "explain why",
57 "step by step",
58 "think through",
59 "evaluate",
60 "trade-off",
61 "tradeoff",
62 "pros and cons",
63 "architecture",
64 "design",
65 "strategy",
66 "optimize",
67 "comprehensive",
68 ];
69 let has_reasoning = reasoning_markers.iter().any(|m| lower.contains(m));
70
71 let simple_patterns = [
72 "what is",
73 "who is",
74 "when did",
75 "where is",
76 "how many",
77 "yes or no",
78 "true or false",
79 "name the",
80 "list the",
81 "define ",
82 ];
83 let is_simple = simple_patterns.iter().any(|p| lower.contains(p));
84
85 if has_code || has_repair {
86 TaskComplexity::Code
87 } else if has_reasoning || estimated_tokens > 500 {
88 TaskComplexity::Complex
89 } else if is_simple || estimated_tokens < 30 {
90 TaskComplexity::Simple
91 } else {
92 TaskComplexity::Medium
93 }
94 }
95
96 fn detect_code(prompt: &str) -> bool {
102 #[cfg(feature = "ast")]
104 {
105 if let Some(is_code) = Self::detect_code_ast(prompt) {
106 return is_code;
107 }
108 }
109
110 let code_markers = [
112 "```",
113 "fn ",
114 "def ",
115 "class ",
116 "import ",
117 "require(",
118 "async fn",
119 "pub fn",
120 "function ",
121 "const ",
122 "let ",
123 "var ",
124 "#include",
125 "package ",
126 "impl ",
127 ];
128 code_markers.iter().any(|m| prompt.contains(m))
129 }
130
131 #[cfg(feature = "ast")]
135 fn detect_code_ast(prompt: &str) -> Option<bool> {
136 let mut blocks = Vec::new();
138 let mut rest = prompt;
139 while let Some(start) = rest.find("```") {
140 let after_fence = &rest[start + 3..];
141 let code_start = after_fence.find('\n').map(|i| i + 1).unwrap_or(0);
143 if let Some(end) = after_fence[code_start..].find("```") {
144 blocks.push(&after_fence[code_start..code_start + end]);
145 rest = &after_fence[code_start + end + 3..];
146 } else {
147 break;
148 }
149 }
150
151 if blocks.is_empty() {
152 return None; }
154
155 let languages = [
157 car_ast::Language::Rust,
158 car_ast::Language::Python,
159 car_ast::Language::TypeScript,
160 car_ast::Language::JavaScript,
161 car_ast::Language::Go,
162 ];
163
164 for block in &blocks {
165 let trimmed = block.trim();
166 if trimmed.is_empty() {
167 continue;
168 }
169
170 for lang in &languages {
171 if let Some(parsed) = car_ast::parse(trimmed, *lang) {
172 if !parsed.symbols.is_empty() {
174 return Some(true);
175 }
176 }
177 }
178 }
179
180 Some(false)
183 }
184
185 pub fn required_capabilities(&self) -> Vec<ModelCapability> {
187 match self {
188 TaskComplexity::Simple => vec![ModelCapability::Generate],
189 TaskComplexity::Medium => vec![ModelCapability::Generate],
190 TaskComplexity::Code => vec![ModelCapability::Code],
191 TaskComplexity::Complex => vec![ModelCapability::Reasoning],
192 }
193 }
194
195 pub fn inference_task(&self) -> InferenceTask {
197 match self {
198 TaskComplexity::Simple | TaskComplexity::Medium => InferenceTask::Generate,
199 TaskComplexity::Code => InferenceTask::Code,
200 TaskComplexity::Complex => InferenceTask::Reasoning,
201 }
202 }
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct RoutingConfig {
208 pub min_observations: u64,
210 pub quality_weight: f64,
212 pub latency_weight: f64,
213 pub cost_weight: f64,
214 pub max_latency_ms: Option<u64>,
216 pub max_cost_usd: Option<f64>,
218 pub prefer_local: bool,
220 pub prior_strength: f64,
224 pub quality_first_cold_start: bool,
227 pub bootstrap_min_task_observations: u64,
230 pub bootstrap_quality_floor: f64,
233}
234
235impl Default for RoutingConfig {
236 fn default() -> Self {
237 Self {
238 min_observations: 2,
239 quality_weight: 0.45,
240 latency_weight: 0.4,
241 cost_weight: 0.15,
242 max_latency_ms: None,
243 max_cost_usd: None,
244 prefer_local: true,
245 prior_strength: 2.0,
246 quality_first_cold_start: true,
247 bootstrap_min_task_observations: 8,
248 bootstrap_quality_floor: 0.8,
249 }
250 }
251}
252
253#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
255#[serde(rename_all = "snake_case")]
256pub enum RoutingStrategy {
257 SchemaBased,
259 ProfileBased,
261 Exploration,
263 Explicit,
265}
266
267#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct AdaptiveRoutingDecision {
270 pub model_id: String,
272 pub model_name: String,
274 pub task: InferenceTask,
276 pub complexity: TaskComplexity,
278 pub reason: String,
280 pub strategy: RoutingStrategy,
282 pub predicted_quality: f64,
284 pub fallbacks: Vec<String>,
286 pub context_length: usize,
288 pub needs_compaction: bool,
290}
291
292pub struct AdaptiveRouter {
294 hw: HardwareInfo,
295 config: RoutingConfig,
296 pub circuit_breakers: Arc<Mutex<CircuitBreakerRegistry>>,
298}
299
300pub struct RouteRequest<'a> {
316 pub prompt: &'a str,
317 pub registry: &'a UnifiedRegistry,
318 pub tracker: &'a OutcomeTracker,
319 pub estimated_total_tokens: usize,
322 pub has_tools: bool,
323 pub has_vision: bool,
324 pub workload: RoutingWorkload,
325 pub intent: Option<&'a crate::intent::IntentHint>,
328}
329
330impl<'a> RouteRequest<'a> {
331 pub fn new(
335 prompt: &'a str,
336 registry: &'a UnifiedRegistry,
337 tracker: &'a OutcomeTracker,
338 ) -> Self {
339 Self {
340 prompt,
341 registry,
342 tracker,
343 estimated_total_tokens: 0,
344 has_tools: false,
345 has_vision: false,
346 workload: RoutingWorkload::Interactive,
347 intent: None,
348 }
349 }
350}
351
352impl AdaptiveRouter {
353 pub fn new(hw: HardwareInfo, config: RoutingConfig) -> Self {
354 let circuit_breakers = Arc::new(Mutex::new(
355 CircuitBreakerRegistry::new(3, 300), ));
357 Self {
358 hw,
359 config,
360 circuit_breakers,
361 }
362 }
363
364 pub fn with_default_config(hw: HardwareInfo) -> Self {
365 Self::new(hw, RoutingConfig::default())
366 }
367
368 pub fn config(&self) -> &RoutingConfig {
369 &self.config
370 }
371
372 pub fn set_config(&mut self, config: RoutingConfig) {
373 self.config = config;
374 }
375
376 pub fn route_with(&self, req: RouteRequest<'_>) -> AdaptiveRoutingDecision {
382 let workload = match req.intent {
388 Some(h) if h.prefer_fast => RoutingWorkload::Fastest,
389 Some(h) if h.prefer_local => RoutingWorkload::LocalPreferred,
390 _ => req.workload,
391 };
392 self.route_inner_with_intent(
393 req.prompt,
394 req.registry,
395 req.tracker,
396 req.has_tools,
397 req.has_vision,
398 req.estimated_total_tokens,
399 workload,
400 req.intent,
401 )
402 }
403
404 pub fn route(
407 &self,
408 prompt: &str,
409 registry: &UnifiedRegistry,
410 tracker: &OutcomeTracker,
411 ) -> AdaptiveRoutingDecision {
412 self.route_with(RouteRequest::new(prompt, registry, tracker))
413 }
414
415 pub fn route_editor(
421 &self,
422 prompt: &str,
423 registry: &UnifiedRegistry,
424 tracker: &OutcomeTracker,
425 ) -> AdaptiveRoutingDecision {
426 self.route_with(RouteRequest {
427 workload: RoutingWorkload::Background,
428 ..RouteRequest::new(prompt, registry, tracker)
429 })
430 }
431
432 pub fn route_with_tools(
434 &self,
435 prompt: &str,
436 registry: &UnifiedRegistry,
437 tracker: &OutcomeTracker,
438 ) -> AdaptiveRoutingDecision {
439 self.route_with(RouteRequest {
440 has_tools: true,
441 ..RouteRequest::new(prompt, registry, tracker)
442 })
443 }
444
445 pub fn route_with_vision(
447 &self,
448 prompt: &str,
449 registry: &UnifiedRegistry,
450 tracker: &OutcomeTracker,
451 has_tools: bool,
452 ) -> AdaptiveRoutingDecision {
453 self.route_with(RouteRequest {
454 has_tools,
455 has_vision: true,
456 ..RouteRequest::new(prompt, registry, tracker)
457 })
458 }
459
460 pub fn route_with_intent<'a>(
466 &self,
467 prompt: &'a str,
468 registry: &'a UnifiedRegistry,
469 tracker: &'a OutcomeTracker,
470 intent: &'a crate::intent::IntentHint,
471 ) -> AdaptiveRoutingDecision {
472 self.route_with(RouteRequest {
473 intent: Some(intent),
474 ..RouteRequest::new(prompt, registry, tracker)
475 })
476 }
477
478 pub fn route_context_aware(
481 &self,
482 prompt: &str,
483 estimated_total_tokens: usize,
484 registry: &UnifiedRegistry,
485 tracker: &OutcomeTracker,
486 has_tools: bool,
487 has_vision: bool,
488 workload: RoutingWorkload,
489 ) -> AdaptiveRoutingDecision {
490 self.route_with(RouteRequest {
491 estimated_total_tokens,
492 has_tools,
493 has_vision,
494 workload,
495 ..RouteRequest::new(prompt, registry, tracker)
496 })
497 }
498
499 pub fn route_context_aware_with_intent<'a>(
504 &self,
505 prompt: &'a str,
506 estimated_total_tokens: usize,
507 registry: &'a UnifiedRegistry,
508 tracker: &'a OutcomeTracker,
509 has_tools: bool,
510 has_vision: bool,
511 workload: RoutingWorkload,
512 intent: &'a crate::intent::IntentHint,
513 ) -> AdaptiveRoutingDecision {
514 self.route_with(RouteRequest {
515 estimated_total_tokens,
516 has_tools,
517 has_vision,
518 workload,
519 intent: Some(intent),
520 ..RouteRequest::new(prompt, registry, tracker)
521 })
522 }
523
524 fn route_inner_with_intent(
525 &self,
526 prompt: &str,
527 registry: &UnifiedRegistry,
528 tracker: &OutcomeTracker,
529 has_tools: bool,
530 has_vision: bool,
531 estimated_total_tokens: usize,
532 workload: RoutingWorkload,
533 intent: Option<&crate::intent::IntentHint>,
534 ) -> AdaptiveRoutingDecision {
535 let complexity = TaskComplexity::assess(prompt);
536 let task = intent
538 .and_then(|h| h.task)
539 .map(task_hint_to_inference_task)
540 .unwrap_or_else(|| complexity.inference_task());
541 let mut required_caps = complexity.required_capabilities();
542 if let Some(hint) = intent {
543 for cap in &hint.require {
544 if !required_caps.contains(cap) {
545 required_caps.push(*cap);
546 }
547 }
548 }
549 if has_vision {
550 required_caps.push(ModelCapability::Vision);
551 }
552 if has_tools {
553 required_caps.push(ModelCapability::ToolUse);
554 if Self::needs_multi_tool_call(prompt) {
557 required_caps.push(ModelCapability::MultiToolCall);
558 }
559 }
560
561 let mut candidates = self.filter_candidates(&required_caps, registry, tracker, has_vision);
563
564 if candidates.is_empty() && required_caps.contains(&ModelCapability::MultiToolCall) {
567 required_caps.retain(|c| *c != ModelCapability::MultiToolCall);
568 candidates = self.filter_candidates(&required_caps, registry, tracker, has_vision);
569 }
570
571 if candidates.is_empty() {
572 return self.cold_start_decision(complexity, task, registry, has_vision);
574 }
575
576 candidates = self.apply_quality_first_bootstrap_policy(
577 candidates, task, tracker, has_vision, has_tools, workload,
578 );
579
580 let (fits, needs_compaction_candidates) = if estimated_total_tokens > 0 {
583 let mut fits = Vec::new();
584 let mut tight = Vec::new();
585 for m in &candidates {
586 if m.context_length == 0 || m.context_length >= estimated_total_tokens {
587 fits.push(m.clone());
588 } else {
589 tight.push(m.clone());
590 }
591 }
592 (fits, tight)
593 } else {
594 (candidates.clone(), Vec::new())
595 };
596
597 let (scoring_candidates, compaction_needed) = if !fits.is_empty() {
599 (fits, false)
600 } else if !needs_compaction_candidates.is_empty() {
601 tracing::info!(
602 prompt_tokens = estimated_total_tokens,
603 candidates = needs_compaction_candidates.len(),
604 "no model fits full prompt — compaction will be needed"
605 );
606 (needs_compaction_candidates.clone(), true)
607 } else {
608 (candidates.clone(), false)
609 };
610
611 let scored = self.score_candidates_context_aware(
613 &scoring_candidates,
614 task,
615 tracker,
616 estimated_total_tokens,
617 workload,
618 );
619
620 let (selected_id, strategy) = self.select_with_thompson_sampling(&scored, tracker);
622
623 let mut fallbacks: Vec<String> = scored
625 .iter()
626 .filter(|(id, _)| *id != selected_id)
627 .map(|(id, _)| id.clone())
628 .collect();
629 if !compaction_needed {
631 for m in &needs_compaction_candidates {
632 if m.id != selected_id && !fallbacks.contains(&m.id) {
633 fallbacks.push(m.id.clone());
634 }
635 }
636 }
637
638 let predicted_quality = scored
639 .iter()
640 .find(|(id, _)| *id == selected_id)
641 .map(|(_, score)| *score)
642 .unwrap_or(0.5);
643
644 let selected_schema = registry
645 .get(&selected_id)
646 .or_else(|| registry.find_by_name(&selected_id));
647 let model_name = selected_schema
648 .map(|m| m.name.clone())
649 .unwrap_or_else(|| selected_id.clone());
650 let context_length = selected_schema.map(|m| m.context_length).unwrap_or(0);
651
652 let needs_compact = compaction_needed
653 || (estimated_total_tokens > 0
654 && context_length > 0
655 && estimated_total_tokens > context_length);
656
657 let compaction_note = if needs_compact {
658 format!(
659 " [compaction needed: {}→{}tok]",
660 estimated_total_tokens, context_length
661 )
662 } else {
663 String::new()
664 };
665
666 let reason = format!(
667 "{:?} task → {} via {:?} (quality: {:.2}, {} candidates){}",
668 complexity,
669 model_name,
670 strategy,
671 predicted_quality,
672 scoring_candidates.len(),
673 compaction_note,
674 );
675
676 AdaptiveRoutingDecision {
677 model_id: selected_id,
678 model_name,
679 task,
680 complexity,
681 reason,
682 strategy,
683 predicted_quality,
684 fallbacks,
685 context_length,
686 needs_compaction: needs_compact,
687 }
688 }
689
690 pub fn route_embedding(&self, registry: &UnifiedRegistry) -> String {
692 let embed_models = registry.query_by_capability(ModelCapability::Embed);
693 embed_models
694 .first()
695 .map(|m| m.name.clone())
696 .unwrap_or_else(|| "Qwen3-Embedding-0.6B".to_string())
697 }
698
699 pub fn route_small(&self, registry: &UnifiedRegistry) -> String {
701 let gen_models = registry.query_by_capability(ModelCapability::Generate);
702 gen_models
704 .iter()
705 .filter(|m| m.is_local())
706 .min_by_key(|m| m.size_mb())
707 .map(|m| m.name.clone())
708 .unwrap_or_else(|| "Qwen3-0.6B".to_string())
709 }
710
711 const LATENCY_CEILING_MS: f64 = 10000.0;
717 const _TPS_CEILING: f64 = 150.0;
719 const MOE_TPS_MULTIPLIER: f64 = 0.10;
721 const MLX_MOE_TPS_MULTIPLIER: f64 = 0.50;
723 const COST_CEILING_PER_1K: f64 = 0.1;
725 const LOCAL_BONUS: f64 = 0.15;
727 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
735 const HAS_GPU_BACKEND: bool = true;
736 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
737 const HAS_GPU_BACKEND: bool = false;
738 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
741 const MLX_BONUS: f64 = 0.10;
742 const SYSTEM_LLM_BONUS: f64 = 0.12;
753
754 fn filter_candidates(
758 &self,
759 required_caps: &[ModelCapability],
760 registry: &UnifiedRegistry,
761 tracker: &OutcomeTracker,
762 has_vision: bool,
763 ) -> Vec<ModelSchema> {
764 registry
765 .list()
766 .into_iter()
767 .filter(|m| {
768 if !required_caps.iter().all(|c| m.has_capability(*c)) {
770 return false;
771 }
772 if !has_vision && m.tags.iter().any(|t| t == "mlx-vlm-cli") {
782 return false;
783 }
784 if !m.available {
786 return false;
787 }
788 if m.is_local() && m.size_mb() >= self.hw.max_model_mb {
790 return false;
791 }
792 if let Some(max) = self.config.max_latency_ms {
794 if let Some(p50) = m.performance.latency_p50_ms {
795 if p50 > max {
796 return false;
797 }
798 }
799 }
800 if let Some(max) = self.config.max_cost_usd {
802 if m.cost_per_1k_output() > max {
803 return false;
804 }
805 }
806 if !self.config.prefer_local && m.is_local() {
808 return false;
809 }
810 if tracker.is_excluded(&m.id) {
812 return false;
813 }
814 if let Ok(mut cb) = self.circuit_breakers.lock() {
816 if !cb.allow_request(&m.id) {
817 tracing::debug!(model = %m.id, "skipped by circuit breaker");
818 return false;
819 }
820 }
821 true
822 })
823 .cloned()
824 .collect()
825 }
826
827 fn apply_quality_first_bootstrap_policy(
828 &self,
829 candidates: Vec<ModelSchema>,
830 task: InferenceTask,
831 tracker: &OutcomeTracker,
832 has_vision: bool,
833 has_tools: bool,
834 workload: RoutingWorkload,
835 ) -> Vec<ModelSchema> {
836 if !self.config.quality_first_cold_start
837 || !workload.is_latency_sensitive()
838 || !self.is_quality_critical_bootstrap_task(task, has_vision, has_tools)
839 {
840 return candidates;
841 }
842
843 let trusted_remote: Vec<ModelSchema> = candidates
844 .iter()
845 .filter(|model| self.is_trusted_quality_remote(model))
846 .cloned()
847 .collect();
848
849 if trusted_remote.is_empty() {
850 return candidates;
851 }
852
853 let proven_local: Vec<ModelSchema> = candidates
854 .iter()
855 .filter(|model| {
856 model.is_local() && self.is_local_model_proven_for_task(model, task, tracker)
857 })
858 .cloned()
859 .collect();
860
861 if !proven_local.is_empty() {
862 return proven_local;
863 }
864
865 trusted_remote
866 }
867
868 fn score_candidates_context_aware(
872 &self,
873 candidates: &[ModelSchema],
874 task: InferenceTask,
875 tracker: &OutcomeTracker,
876 estimated_total_tokens: usize,
877 workload: RoutingWorkload,
878 ) -> Vec<(String, f64)> {
879 let mut scored: Vec<(String, f64)> = candidates
880 .iter()
881 .map(|m| {
882 let base_score = self.score_model(m, task, tracker, workload);
883 let headroom_bonus = if estimated_total_tokens > 0 && m.context_length > 0 {
886 let ratio = m.context_length as f64 / estimated_total_tokens as f64;
887 if ratio >= 1.0 {
888 (ratio.min(4.0) - 1.0) / 3.0 * 0.10 } else {
890 -0.15 }
892 } else {
893 0.0
894 };
895 (m.id.clone(), base_score + headroom_bonus)
896 })
897 .collect();
898
899 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
900 scored
901 }
902
903 fn task_aware_weights(
919 task: InferenceTask,
920 workload: RoutingWorkload,
921 ) -> (f64, f64, f64) {
922 match (task, workload) {
923 (
924 InferenceTask::Code,
925 RoutingWorkload::Interactive | RoutingWorkload::LocalPreferred,
926 ) => (0.62, 0.28, 0.10),
927 _ => workload.weights(),
928 }
929 }
930
931 fn score_model(
932 &self,
933 model: &ModelSchema,
934 task: InferenceTask,
935 tracker: &OutcomeTracker,
936 workload: RoutingWorkload,
937 ) -> f64 {
938 let profile = tracker.profile(&model.id);
939 let schema_quality = self.schema_quality_estimate(model);
940 let schema_latency = self.schema_latency_estimate(model);
941 let (quality_weight, latency_weight, cost_weight) =
942 Self::task_aware_weights(task, workload);
943
944 let quality = match profile {
947 Some(p) if p.total_calls >= self.config.min_observations => p
948 .task_stats(task)
949 .map(|ts| ts.ema_quality)
950 .unwrap_or(p.ema_quality),
951 Some(p) if p.total_calls == 0 => p
952 .task_stats(task)
953 .map(|ts| ts.ema_quality)
954 .unwrap_or(p.ema_quality),
955 Some(p) if p.total_calls > 0 => {
956 let w = p.total_calls as f64 / self.config.min_observations as f64;
957 schema_quality * (1.0 - w) + p.ema_quality * w
958 }
959 _ => schema_quality,
960 };
961
962 let latency = match profile {
965 Some(p) if p.total_calls >= self.config.min_observations => {
966 let avg = p
967 .task_stats(task)
968 .filter(|ts| ts.calls > 0 || ts.avg_latency_ms > 0.0)
969 .map(|ts| ts.avg_latency_ms)
970 .unwrap_or_else(|| p.avg_latency_ms());
971 self.latency_ms_to_score(avg)
972 }
973 Some(p) if p.total_calls == 0 => p
974 .task_stats(task)
975 .filter(|ts| ts.avg_latency_ms > 0.0)
976 .map(|ts| self.latency_ms_to_score(ts.avg_latency_ms))
977 .unwrap_or(schema_latency),
978 Some(p) if p.total_calls > 0 => {
979 let observed = self.latency_ms_to_score(
980 p.task_stats(task)
981 .filter(|ts| ts.calls > 0 || ts.avg_latency_ms > 0.0)
982 .map(|ts| ts.avg_latency_ms)
983 .unwrap_or_else(|| p.avg_latency_ms()),
984 );
985 let w = p.total_calls as f64 / self.config.min_observations as f64;
986 schema_latency * (1.0 - w) + observed * w
987 }
988 _ => schema_latency,
989 };
990
991 let cost = if model.is_local() {
993 1.0
994 } else {
995 (1.0 - (model.cost_per_1k_output() / Self::COST_CEILING_PER_1K)).clamp(0.0, 1.0)
996 };
997
998 let local_bonus = if self.config.prefer_local && model.is_local() && Self::HAS_GPU_BACKEND {
1007 Self::LOCAL_BONUS
1008 } else {
1009 if self.config.prefer_local && model.is_local() && !Self::HAS_GPU_BACKEND {
1010 tracing::debug!(
1011 model = %model.id,
1012 "LOCAL_BONUS suppressed: no GPU backend on this host (Intel Mac or car_skip_mlx); cloud models will rank higher"
1013 );
1014 }
1015 0.0
1016 };
1017 let workload_local_bonus = if model.is_local() {
1018 workload.local_bonus()
1019 } else {
1020 0.0
1021 };
1022
1023 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1025 let mlx_bonus = if model.is_mlx() { Self::MLX_BONUS } else { 0.0 };
1026 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1027 let mlx_bonus = 0.0;
1028
1029 let vllm_mlx_bonus = if model.is_vllm_mlx() {
1031 Self::LOCAL_BONUS + 0.05
1032 } else {
1033 0.0
1034 };
1035
1036 let system_llm_bonus = if model.tags.iter().any(|t| t == "low_latency")
1044 && model.tags.iter().any(|t| t == "private")
1045 {
1046 Self::SYSTEM_LLM_BONUS
1047 } else {
1048 0.0
1049 };
1050
1051 quality_weight * quality
1052 + latency_weight * latency
1053 + cost_weight * cost
1054 + local_bonus
1055 + workload_local_bonus
1056 + mlx_bonus
1057 + vllm_mlx_bonus
1058 + system_llm_bonus
1059 }
1060
1061 fn latency_ms_to_score(&self, ms: f64) -> f64 {
1064 (1.0 - (ms / Self::LATENCY_CEILING_MS)).clamp(0.0, 1.0)
1065 }
1066
1067 fn tps_to_latency_ms(tps: f64) -> f64 {
1069 if tps <= 0.0 {
1070 return Self::LATENCY_CEILING_MS;
1071 }
1072 (200.0 / tps) * 1000.0
1074 }
1075
1076 fn needs_multi_tool_call(prompt: &str) -> bool {
1079 let lower = prompt.to_lowercase();
1080
1081 let has_numbered_list = {
1083 let mut count = 0u32;
1084 for i in 1..=5u32 {
1085 if lower.contains(&format!("{}) ", i)) || lower.contains(&format!("{}. ", i)) {
1086 count += 1;
1087 }
1088 }
1089 count >= 2
1090 };
1091
1092 let multi_keywords = [
1094 "multiple edits",
1095 "several changes",
1096 "three changes",
1097 "two changes",
1098 "all of the following",
1099 "each of these",
1100 "do both",
1101 "do all",
1102 "and also",
1103 "additionally",
1104 "as well as",
1105 "then also",
1106 ];
1107 let has_multi_keywords = multi_keywords.iter().any(|kw| lower.contains(kw));
1108
1109 let bullet_actions = lower.matches("- add ").count()
1111 + lower.matches("- update ").count()
1112 + lower.matches("- change ").count()
1113 + lower.matches("- remove ").count()
1114 + lower.matches("- fix ").count()
1115 + lower.matches("- edit ").count()
1116 + lower.matches("- implement ").count()
1117 + lower.matches("- create ").count();
1118 let has_bullet_list = bullet_actions >= 2;
1119
1120 has_numbered_list || has_multi_keywords || has_bullet_list
1121 }
1122
1123 fn schema_quality_estimate(&self, model: &ModelSchema) -> f64 {
1128 match model.size_mb() {
1129 0 => 0.5, s if s < 1000 => 0.4, s if s < 2000 => 0.5, s if s < 3000 => 0.6, s if s < 6000 => 0.7, _ => 0.75, }
1136 }
1137
1138 fn schema_latency_estimate(&self, model: &ModelSchema) -> f64 {
1143 let is_moe = model.tags.contains(&"moe".to_string());
1144
1145 if model.is_local() {
1146 if let Some(tps) = model.performance.tokens_per_second {
1147 let effective_tps = if is_moe {
1148 let multiplier = if model.is_mlx() {
1149 Self::MLX_MOE_TPS_MULTIPLIER
1150 } else {
1151 Self::MOE_TPS_MULTIPLIER
1152 };
1153 tps * multiplier
1154 } else {
1155 tps
1156 };
1157 let estimated_ms = Self::tps_to_latency_ms(effective_tps);
1158 return self.latency_ms_to_score(estimated_ms);
1159 }
1160 return 0.5; }
1162
1163 if let Some(p50) = model.performance.latency_p50_ms {
1165 return self.latency_ms_to_score(p50 as f64);
1166 }
1167 0.3 }
1169
1170 fn is_quality_critical_bootstrap_task(
1171 &self,
1172 task: InferenceTask,
1173 has_vision: bool,
1174 has_tools: bool,
1175 ) -> bool {
1176 has_vision
1177 || has_tools
1178 || matches!(
1179 task,
1180 InferenceTask::Generate | InferenceTask::Code | InferenceTask::Reasoning
1181 )
1182 }
1183
1184 fn is_trusted_quality_remote(&self, model: &ModelSchema) -> bool {
1185 model.is_remote()
1186 && matches!(model.provider.as_str(), "openai" | "anthropic" | "google")
1187 && !model.has_capability(ModelCapability::SpeechToText)
1188 && !model.has_capability(ModelCapability::TextToSpeech)
1189 }
1190
1191 fn is_local_model_proven_for_task(
1192 &self,
1193 model: &ModelSchema,
1194 task: InferenceTask,
1195 tracker: &OutcomeTracker,
1196 ) -> bool {
1197 let Some(profile) = tracker.profile(&model.id) else {
1198 return false;
1199 };
1200 if let Some(task_stats) = profile.task_stats(task) {
1201 if task_stats.calls >= self.config.bootstrap_min_task_observations
1202 && task_stats.ema_quality >= self.config.bootstrap_quality_floor
1203 {
1204 return true;
1205 }
1206 }
1207
1208 profile.total_calls >= self.config.bootstrap_min_task_observations
1209 && profile.ema_quality >= self.config.bootstrap_quality_floor
1210 }
1211
1212 fn select_with_thompson_sampling(
1222 &self,
1223 scored: &[(String, f64)],
1224 tracker: &OutcomeTracker,
1225 ) -> (String, RoutingStrategy) {
1226 if scored.is_empty() {
1227 return (String::new(), RoutingStrategy::SchemaBased);
1228 }
1229
1230 let mut rng = rand::rng();
1231 let mut best_sample = f64::NEG_INFINITY;
1232 let mut best_id = scored[0].0.clone();
1233 let mut best_strategy = RoutingStrategy::SchemaBased;
1234
1235 for (id, phase2_score) in scored {
1236 let profile = tracker.profile(id);
1237 let prior = self.config.prior_strength;
1238
1239 let prior_mean = phase2_score.clamp(0.0, 1.0);
1241
1242 let prior_alpha = prior * prior_mean;
1244 let prior_beta = prior * (1.0 - prior_mean);
1245
1246 let (obs_alpha, obs_beta) = match profile {
1248 Some(p) => (p.success_count as f64, p.fail_count as f64),
1249 None => (0.0, 0.0),
1250 };
1251
1252 let alpha = (prior_alpha + obs_alpha).max(0.01);
1254 let beta = (prior_beta + obs_beta).max(0.01);
1255
1256 let sample = sample_beta(&mut rng, alpha, beta);
1258
1259 if sample > best_sample {
1260 best_sample = sample;
1261 best_id = id.clone();
1262 best_strategy = match profile {
1263 Some(p) if p.total_calls >= self.config.min_observations => {
1264 RoutingStrategy::ProfileBased
1265 }
1266 Some(p) if p.total_calls > 0 => {
1267 RoutingStrategy::Exploration
1269 }
1270 _ => RoutingStrategy::SchemaBased,
1271 };
1272 }
1273 }
1274
1275 (best_id, best_strategy)
1276 }
1277
1278 fn cold_start_decision(
1280 &self,
1281 complexity: TaskComplexity,
1282 task: InferenceTask,
1283 registry: &UnifiedRegistry,
1284 has_vision: bool,
1285 ) -> AdaptiveRoutingDecision {
1286 if has_vision {
1287 if let Some(model) = registry
1288 .query_by_capability(ModelCapability::Vision)
1289 .into_iter()
1290 .find(|model| model.available && self.is_trusted_quality_remote(model))
1291 .or_else(|| {
1292 registry
1293 .query_by_capability(ModelCapability::Vision)
1294 .first()
1295 .copied()
1296 })
1297 {
1298 return AdaptiveRoutingDecision {
1299 model_id: model.id.clone(),
1300 model_name: model.name.clone(),
1301 task,
1302 complexity,
1303 reason: format!(
1304 "{:?} task → {} (cold start, vision fallback)",
1305 complexity, model.name
1306 ),
1307 strategy: RoutingStrategy::SchemaBased,
1308 predicted_quality: 0.5,
1309 fallbacks: vec![],
1310 context_length: model.context_length,
1311 needs_compaction: false,
1312 };
1313 }
1314 }
1315
1316 if self.config.quality_first_cold_start {
1317 let required_caps = complexity.required_capabilities();
1318 if let Some(model) = registry
1319 .list()
1320 .into_iter()
1321 .filter(|model| {
1322 model.available
1323 && required_caps.iter().all(|cap| model.has_capability(*cap))
1324 && self.is_trusted_quality_remote(model)
1325 })
1326 .max_by(|a, b| {
1327 self.schema_quality_estimate(a)
1328 .partial_cmp(&self.schema_quality_estimate(b))
1329 .unwrap_or(std::cmp::Ordering::Equal)
1330 })
1331 {
1332 return AdaptiveRoutingDecision {
1333 model_id: model.id.clone(),
1334 model_name: model.name.clone(),
1335 task,
1336 complexity,
1337 reason: format!(
1338 "{:?} task → {} (quality-first cold start)",
1339 complexity, model.name
1340 ),
1341 strategy: RoutingStrategy::SchemaBased,
1342 predicted_quality: self.schema_quality_estimate(model),
1343 fallbacks: vec![],
1344 context_length: model.context_length,
1345 needs_compaction: false,
1346 };
1347 }
1348 }
1349
1350 let model_name = match complexity {
1352 TaskComplexity::Simple => "Qwen3-0.6B",
1353 TaskComplexity::Medium => "Qwen3-1.7B",
1354 TaskComplexity::Code => "Qwen3-4B",
1355 TaskComplexity::Complex => &self.hw.recommended_model,
1356 };
1357
1358 let model_id = registry
1359 .find_by_name(model_name)
1360 .map(|m| m.id.clone())
1361 .unwrap_or_else(|| model_name.to_string());
1362
1363 let context_length = registry
1364 .find_by_name(model_name)
1365 .map(|m| m.context_length)
1366 .unwrap_or(0);
1367
1368 AdaptiveRoutingDecision {
1369 model_id,
1370 model_name: model_name.to_string(),
1371 task,
1372 complexity,
1373 reason: format!(
1374 "{:?} task → {} (cold start, no candidates)",
1375 complexity, model_name
1376 ),
1377 strategy: RoutingStrategy::SchemaBased,
1378 predicted_quality: 0.5,
1379 fallbacks: vec![],
1380 context_length,
1381 needs_compaction: false,
1382 }
1383 }
1384}
1385
1386fn task_hint_to_inference_task(hint: crate::intent::TaskHint) -> InferenceTask {
1392 use crate::intent::TaskHint;
1393 match hint {
1394 TaskHint::Chat => InferenceTask::Generate,
1395 TaskHint::Classify => InferenceTask::Classify,
1396 TaskHint::Reasoning => InferenceTask::Reasoning,
1397 TaskHint::Code => InferenceTask::Code,
1398 }
1399}
1400
1401fn sample_beta(rng: &mut impl Rng, alpha: f64, beta: f64) -> f64 {
1409 let x = sample_gamma(rng, alpha);
1410 let y = sample_gamma(rng, beta);
1411 if x + y == 0.0 {
1412 0.5 } else {
1414 x / (x + y)
1415 }
1416}
1417
1418fn sample_gamma(rng: &mut impl Rng, shape: f64) -> f64 {
1421 if shape < 1.0 {
1422 let u: f64 = rng.random();
1424 return sample_gamma(rng, shape + 1.0) * u.powf(1.0 / shape);
1425 }
1426
1427 let d = shape - 1.0 / 3.0;
1429 let c = 1.0 / (9.0 * d).sqrt();
1430
1431 loop {
1432 let x: f64 = loop {
1433 let n = sample_standard_normal(rng);
1434 if 1.0 + c * n > 0.0 {
1435 break n;
1436 }
1437 };
1438
1439 let v = (1.0 + c * x).powi(3);
1440 let u: f64 = rng.random();
1441
1442 if u < 1.0 - 0.0331 * x.powi(4) {
1443 return d * v;
1444 }
1445 if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
1446 return d * v;
1447 }
1448 }
1449}
1450
1451fn sample_standard_normal(rng: &mut impl Rng) -> f64 {
1453 let u1: f64 = rng.random();
1454 let u2: f64 = rng.random();
1455 (-2.0 * u1.max(1e-300).ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
1456}
1457
1458#[cfg(test)]
1459mod tests {
1460 use super::*;
1461 use crate::outcome::InferredOutcome;
1462
1463 fn test_hw() -> HardwareInfo {
1464 HardwareInfo {
1465 os: "macos".into(),
1466 arch: "aarch64".into(),
1467 cpu_cores: 10,
1468 total_ram_mb: 32768,
1469 gpu_backend: crate::hardware::GpuBackend::Metal,
1470 gpu_memory_mb: Some(28672),
1471 gpu_devices: Vec::new(),
1472 recommended_model: "Qwen3-8B".into(),
1473 recommended_context: 8192,
1474 max_model_mb: 18000, }
1476 }
1477
1478 fn test_registry() -> UnifiedRegistry {
1479 let tmp = std::path::PathBuf::from("/tmp/car-test-adaptive-router");
1480 unsafe {
1481 std::env::set_var("OPENAI_API_KEY", "test-openai-key");
1482 }
1483 for name in &[
1485 "Qwen3-0.6B",
1486 "Qwen3-1.7B",
1487 "Qwen3-4B",
1488 "Qwen3-8B",
1489 "Qwen3-Embedding-0.6B",
1490 ] {
1491 let dir = tmp.join(name);
1492 let _ = std::fs::create_dir_all(&dir);
1493 let _ = std::fs::write(dir.join("model.gguf"), b"fake");
1494 let _ = std::fs::write(dir.join("tokenizer.json"), b"{}");
1495 }
1496 let mut reg = UnifiedRegistry::new(tmp);
1497 reg.register(ModelSchema {
1498 id: "openai/gpt-5.4-mini:latest".into(),
1499 name: "gpt-5.4-mini".into(),
1500 provider: "openai".into(),
1501 family: "gpt-5.4".into(),
1502 version: "latest".into(),
1503 capabilities: vec![
1504 ModelCapability::Generate,
1505 ModelCapability::Code,
1506 ModelCapability::Reasoning,
1507 ModelCapability::ToolUse,
1508 ModelCapability::MultiToolCall,
1509 ModelCapability::Vision,
1510 ],
1511 context_length: 128_000,
1512 param_count: "api".into(),
1513 quantization: None,
1514 performance: Default::default(),
1515 cost: Default::default(),
1516 source: crate::schema::ModelSource::RemoteApi {
1517 endpoint: "https://api.openai.com/v1".into(),
1518 api_key_env: "OPENAI_API_KEY".into(),
1519 api_key_envs: vec![],
1520 api_version: None,
1521 protocol: crate::schema::ApiProtocol::OpenAiCompat,
1522 },
1523 tags: vec!["trusted-remote".into()],
1524 supported_params: vec![],
1525 public_benchmarks: vec![],
1526 available: true,
1527 });
1528 reg
1529 }
1530
1531 #[test]
1532 fn routes_simple_to_trusted_remote_during_cold_start() {
1533 let router = AdaptiveRouter::new(
1534 test_hw(),
1535 RoutingConfig {
1536 prior_strength: 100.0, ..Default::default()
1538 },
1539 );
1540 let reg = test_registry();
1541 let tracker = OutcomeTracker::new();
1542
1543 let decision = router.route("What is 2+2?", ®, &tracker);
1544 assert_eq!(decision.complexity, TaskComplexity::Simple);
1545 assert_eq!(decision.strategy, RoutingStrategy::SchemaBased);
1546 let schema = reg
1548 .find_by_name(&decision.model_name)
1549 .expect("selected model should exist in registry");
1550 assert!(
1551 !schema.is_local(),
1552 "simple task should route to trusted remote model during cold start"
1553 );
1554 assert!(matches!(
1555 schema.provider.as_str(),
1556 "openai" | "anthropic" | "google"
1557 ));
1558 }
1559
1560 #[test]
1561 fn routes_code_to_code_capable_remote_during_cold_start() {
1562 let router = AdaptiveRouter::new(
1563 test_hw(),
1564 RoutingConfig {
1565 prior_strength: 100.0, ..Default::default()
1567 },
1568 );
1569 let reg = test_registry();
1570 let tracker = OutcomeTracker::new();
1571
1572 let decision = router.route(
1573 "Fix this function:\n```rust\nfn main() {}\n```",
1574 ®,
1575 &tracker,
1576 );
1577 assert_eq!(decision.complexity, TaskComplexity::Code);
1578 assert_eq!(decision.task, InferenceTask::Code);
1579 let schema = reg
1581 .find_by_name(&decision.model_name)
1582 .expect("model should exist");
1583 assert!(
1584 schema.has_capability(ModelCapability::Code),
1585 "selected model must support Code"
1586 );
1587 assert!(!schema.is_local(), "should route to trusted remote model");
1588 }
1589
1590 #[test]
1591 fn routes_images_to_vision_capable_model() {
1592 let router = AdaptiveRouter::new(
1593 test_hw(),
1594 RoutingConfig {
1595 prior_strength: 100.0,
1596 ..Default::default()
1597 },
1598 );
1599 let mut reg = test_registry();
1600 let tracker = OutcomeTracker::new();
1601
1602 reg.register(ModelSchema {
1603 id: "mlx-vlm/qwen3-vl-2b:bf16".into(),
1604 name: "Qwen3-VL-2B-mlx-vlm".into(),
1605 provider: "qwen".into(),
1606 family: "qwen3-vl".into(),
1607 version: "bf16".into(),
1608 capabilities: vec![
1609 ModelCapability::Generate,
1610 ModelCapability::Vision,
1611 ModelCapability::Grounding,
1612 ],
1613 context_length: 262_144,
1614 param_count: "2B".into(),
1615 quantization: None,
1616 performance: Default::default(),
1617 cost: Default::default(),
1618 source: crate::schema::ModelSource::Mlx {
1619 hf_repo: "Qwen/Qwen3-VL-2B-Instruct".into(),
1620 hf_weight_file: None,
1621 },
1622 tags: vec!["vision".into(), "mlx-vlm-cli".into()],
1623 supported_params: vec![],
1624 public_benchmarks: vec![],
1625 available: true,
1626 });
1627
1628 let decision = router.route_with_vision("What is in this image?", ®, &tracker, false);
1629 let schema = reg
1630 .find_by_name(&decision.model_name)
1631 .expect("model should exist");
1632 assert!(
1633 schema.has_capability(ModelCapability::Vision),
1634 "selected model must support Vision"
1635 );
1636 }
1637
1638 #[test]
1639 fn profile_based_routing_favors_proven_model() {
1640 let router = AdaptiveRouter::new(
1641 test_hw(),
1642 RoutingConfig {
1643 prior_strength: 0.5, min_observations: 3,
1645 ..Default::default()
1646 },
1647 );
1648 let reg = test_registry();
1649 let mut tracker = OutcomeTracker::new();
1650
1651 let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
1653 for _ in 0..20 {
1654 let trace = tracker.record_start(qwen_8b_id, InferenceTask::Code, "test");
1655 tracker.record_complete(&trace, 500, 100, 50);
1656 tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1657 }
1658
1659 let mut wins = 0;
1661 for _ in 0..20 {
1662 let decision = router.route("Fix this bug in the parser", ®, &tracker);
1663 assert_eq!(decision.complexity, TaskComplexity::Code);
1664 if decision.model_id == qwen_8b_id {
1665 wins += 1;
1666 }
1667 }
1668 assert!(
1669 wins >= 12,
1670 "proven model won only {wins}/20 times (expected >= 12)"
1671 );
1672 }
1673
1674 #[test]
1675 fn proven_local_model_can_displace_bootstrap_remote() {
1676 let router = AdaptiveRouter::new(
1677 test_hw(),
1678 RoutingConfig {
1679 prior_strength: 100.0,
1680 bootstrap_min_task_observations: 6,
1681 bootstrap_quality_floor: 0.8,
1682 ..Default::default()
1683 },
1684 );
1685 let reg = test_registry();
1686 let mut tracker = OutcomeTracker::new();
1687
1688 let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
1689 for _ in 0..12 {
1690 let trace = tracker.record_start(qwen_8b_id, InferenceTask::Generate, "test");
1691 tracker.record_complete(&trace, 300, 50, 20);
1692 tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1693 }
1694
1695 let mut local_wins = 0;
1696 for _ in 0..20 {
1697 let decision = router.route("Summarize this design decision.", ®, &tracker);
1698 let schema = reg
1699 .get(&decision.model_id)
1700 .expect("selected model should exist");
1701 if schema.is_local() {
1702 local_wins += 1;
1703 }
1704 }
1705
1706 assert!(
1707 local_wins >= 12,
1708 "proven local model won only {local_wins}/20 times (expected >= 12)"
1709 );
1710 }
1711
1712 #[test]
1713 fn benchmark_prior_informs_background_routing() {
1714 let router = AdaptiveRouter::new(
1715 test_hw(),
1716 RoutingConfig {
1717 prior_strength: 100.0,
1718 ..Default::default()
1719 },
1720 );
1721 let reg = test_registry();
1722 let mut tracker = OutcomeTracker::new();
1723 let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1724 profile.ema_quality = 0.95;
1725 tracker.import_profiles(vec![profile]);
1726
1727 let decision = router.route_context_aware(
1728 "Write a Python fibonacci function.",
1729 128,
1730 ®,
1731 &tracker,
1732 false,
1733 false,
1734 RoutingWorkload::Background,
1735 );
1736
1737 let schema = reg
1738 .get(&decision.model_id)
1739 .expect("selected model should exist");
1740 assert!(
1741 schema.is_local(),
1742 "background routing should allow strong local benchmark priors to win"
1743 );
1744 }
1745
1746 #[test]
1747 fn task_specific_benchmark_prior_informs_cold_start_routing() {
1748 let router = AdaptiveRouter::new(
1749 test_hw(),
1750 RoutingConfig {
1751 prior_strength: 100.0,
1752 ..Default::default()
1753 },
1754 );
1755 let reg = test_registry();
1756 let mut tracker = OutcomeTracker::new();
1757 let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1758 profile.task_stats.insert(
1759 crate::outcome::InferenceTask::Code.to_string(),
1760 crate::outcome::TaskStats {
1761 ema_quality: 0.95,
1762 ..Default::default()
1763 },
1764 );
1765 tracker.import_profiles(vec![profile]);
1766
1767 let decision = router.route_context_aware(
1768 "Write a Python fibonacci function.",
1769 128,
1770 ®,
1771 &tracker,
1772 false,
1773 false,
1774 RoutingWorkload::Background,
1775 );
1776
1777 let schema = reg
1778 .get(&decision.model_id)
1779 .expect("selected model should exist");
1780 assert!(
1781 schema.is_local(),
1782 "background routing should use task-specific cold-start priors for local code models"
1783 );
1784 }
1785
1786 #[test]
1787 fn task_specific_latency_prior_affects_cold_start_score() {
1788 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1789 let reg = test_registry();
1790 let model = reg
1791 .get("qwen/qwen3-8b:q4_k_m")
1792 .expect("local test model should exist");
1793
1794 let mut fast_tracker = OutcomeTracker::new();
1795 let mut fast_profile = crate::outcome::ModelProfile::new(model.id.clone());
1796 fast_profile.task_stats.insert(
1797 crate::outcome::InferenceTask::Generate.to_string(),
1798 crate::outcome::TaskStats {
1799 ema_quality: 0.95,
1800 avg_latency_ms: 1200.0,
1801 ..Default::default()
1802 },
1803 );
1804 fast_tracker.import_profiles(vec![fast_profile]);
1805
1806 let mut slow_tracker = OutcomeTracker::new();
1807 let mut slow_profile = crate::outcome::ModelProfile::new(model.id.clone());
1808 slow_profile.task_stats.insert(
1809 crate::outcome::InferenceTask::Generate.to_string(),
1810 crate::outcome::TaskStats {
1811 ema_quality: 0.95,
1812 avg_latency_ms: 120_000.0,
1813 ..Default::default()
1814 },
1815 );
1816 slow_tracker.import_profiles(vec![slow_profile]);
1817
1818 let fast_score = router.score_model(
1819 model,
1820 InferenceTask::Generate,
1821 &fast_tracker,
1822 RoutingWorkload::Interactive,
1823 );
1824 let slow_score = router.score_model(
1825 model,
1826 InferenceTask::Generate,
1827 &slow_tracker,
1828 RoutingWorkload::Interactive,
1829 );
1830
1831 assert!(
1832 fast_score > slow_score,
1833 "faster task latency prior should improve cold-start score ({fast_score} <= {slow_score})"
1834 );
1835 }
1836
1837 #[test]
1838 fn interactive_workload_keeps_remote_bootstrap_bias() {
1839 let router = AdaptiveRouter::new(
1840 test_hw(),
1841 RoutingConfig {
1842 prior_strength: 100.0,
1843 ..Default::default()
1844 },
1845 );
1846 let reg = test_registry();
1847 let mut tracker = OutcomeTracker::new();
1848 let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1849 profile.ema_quality = 0.95;
1850 tracker.import_profiles(vec![profile]);
1851
1852 let decision = router.route_context_aware(
1853 "Write a Python fibonacci function.",
1854 128,
1855 ®,
1856 &tracker,
1857 false,
1858 false,
1859 RoutingWorkload::Interactive,
1860 );
1861
1862 let schema = reg
1863 .get(&decision.model_id)
1864 .expect("selected model should exist");
1865 assert!(
1866 !schema.is_local(),
1867 "interactive routing should still prefer trusted remote models during cold start"
1868 );
1869 }
1870
1871 #[test]
1872 fn code_interactive_weights_prioritise_quality_over_speed_over_cost() {
1873 let (q, lat, cost) = AdaptiveRouter::task_aware_weights(
1879 InferenceTask::Code,
1880 RoutingWorkload::Interactive,
1881 );
1882 assert!(q > lat && lat > cost, "expected quality > speed > cost, got ({q}, {lat}, {cost})");
1883 assert!((q + lat + cost - 1.0).abs() < 1e-9, "weights must sum to 1.0");
1884 let (default_q, default_lat, default_cost) = RoutingWorkload::Interactive.weights();
1885 assert!(
1886 q > default_q && lat < default_lat && cost < default_cost,
1887 "code weighting must lift quality and cut latency/cost vs the generic interactive profile"
1888 );
1889
1890 assert_eq!(
1892 AdaptiveRouter::task_aware_weights(
1893 InferenceTask::Code,
1894 RoutingWorkload::LocalPreferred,
1895 ),
1896 (q, lat, cost),
1897 );
1898
1899 assert_eq!(
1902 AdaptiveRouter::task_aware_weights(InferenceTask::Code, RoutingWorkload::Fastest),
1903 RoutingWorkload::Fastest.weights(),
1904 );
1905 assert_eq!(
1907 AdaptiveRouter::task_aware_weights(InferenceTask::Code, RoutingWorkload::Background),
1908 RoutingWorkload::Background.weights(),
1909 );
1910 assert_eq!(
1912 AdaptiveRouter::task_aware_weights(
1913 InferenceTask::Generate,
1914 RoutingWorkload::Interactive,
1915 ),
1916 RoutingWorkload::Interactive.weights(),
1917 );
1918 }
1919
1920 #[test]
1921 fn fallback_chain_has_alternatives() {
1922 let router = AdaptiveRouter::new(
1923 test_hw(),
1924 RoutingConfig {
1925 prior_strength: 100.0, ..Default::default()
1927 },
1928 );
1929 let reg = test_registry();
1930 let tracker = OutcomeTracker::new();
1931
1932 let decision = router.route("Analyze the architecture trade-offs", ®, &tracker);
1933 assert!(!decision.fallbacks.is_empty());
1934 assert!(!decision.fallbacks.contains(&decision.model_id));
1936 }
1937
1938 #[test]
1939 fn latency_scoring_is_consistent() {
1940 let router = AdaptiveRouter::with_default_config(test_hw());
1942
1943 let schema_score = router.latency_ms_to_score(AdaptiveRouter::tps_to_latency_ms(25.0));
1945 let observed_score = router.latency_ms_to_score(8000.0);
1947 assert!(
1948 (schema_score - observed_score).abs() < 0.01,
1949 "schema ({schema_score}) and observed ({observed_score}) should match"
1950 );
1951 }
1952
1953 #[test]
1954 fn complexity_assessment() {
1955 assert_eq!(
1956 TaskComplexity::assess("What is the capital of France?"),
1957 TaskComplexity::Simple
1958 );
1959 assert_eq!(
1960 TaskComplexity::assess("Fix this broken test"),
1961 TaskComplexity::Code
1962 );
1963 assert_eq!(
1964 TaskComplexity::assess("Analyze the trade-offs between A and B"),
1965 TaskComplexity::Complex
1966 );
1967 }
1968
1969 #[test]
1970 fn beta_sampling_produces_valid_values() {
1971 let mut rng = rand::rng();
1972 for _ in 0..100 {
1974 let s = sample_beta(&mut rng, 2.0, 5.0);
1975 assert!(s >= 0.0 && s <= 1.0, "sample {s} out of [0,1] range");
1976 }
1977 let samples: Vec<f64> = (0..1000).map(|_| sample_beta(&mut rng, 1.0, 1.0)).collect();
1979 let mean = samples.iter().sum::<f64>() / samples.len() as f64;
1980 assert!(
1981 (mean - 0.5).abs() < 0.05,
1982 "Beta(1,1) mean {mean} should be ~0.5"
1983 );
1984 }
1985
1986 #[test]
1987 fn thompson_sampling_converges_to_best() {
1988 let router = AdaptiveRouter::new(
1990 test_hw(),
1991 RoutingConfig {
1992 prior_strength: 1.0, ..Default::default()
1994 },
1995 );
1996 let reg = test_registry();
1997 let mut tracker = OutcomeTracker::new();
1998
1999 let qwen_4b_id = "qwen/qwen3-4b:q4_k_m";
2001 for _ in 0..20 {
2002 let trace = tracker.record_start(qwen_4b_id, InferenceTask::Code, "test");
2003 tracker.record_complete(&trace, 500, 100, 50);
2004 tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
2005 }
2006
2007 let mut wins = 0;
2009 for _ in 0..20 {
2010 let decision = router.route("Fix this parser bug", ®, &tracker);
2011 if decision.model_id == qwen_4b_id {
2012 wins += 1;
2013 }
2014 }
2015 assert!(
2016 wins >= 14,
2017 "strong model won only {wins}/20 times (expected >= 14)"
2018 );
2019 }
2020
2021 #[test]
2024 fn intent_require_filters_out_models_lacking_capability() {
2025 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
2030 let reg = test_registry();
2031 let tracker = OutcomeTracker::new();
2032
2033 let intent = crate::intent::IntentHint {
2034 require: vec![ModelCapability::Vision],
2035 ..Default::default()
2036 };
2037 let decision = router.route_with_intent("hello", ®, &tracker, &intent);
2038
2039 assert_eq!(
2044 decision.strategy,
2045 RoutingStrategy::SchemaBased,
2046 "require=[vision] with no vision-capable candidates must drop to schema cold-start"
2047 );
2048 }
2049
2050 #[test]
2051 fn intent_default_does_not_override_task_or_caps() {
2052 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
2059 let reg = test_registry();
2060 let tracker = OutcomeTracker::new();
2061
2062 let baseline = router.route("write a haiku", ®, &tracker);
2063 let with_default = router.route_with_intent(
2064 "write a haiku",
2065 ®,
2066 &tracker,
2067 &crate::intent::IntentHint::default(),
2068 );
2069
2070 assert_eq!(
2071 baseline.task, with_default.task,
2072 "default IntentHint must not change the prompt-derived task"
2073 );
2074 assert_eq!(
2075 baseline.complexity, with_default.complexity,
2076 "default IntentHint must not change the prompt-derived complexity"
2077 );
2078 }
2079
2080 #[test]
2081 fn intent_task_hint_overrides_prompt_complexity() {
2082 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
2085 let reg = test_registry();
2086 let tracker = OutcomeTracker::new();
2087
2088 let hint = crate::intent::IntentHint {
2089 task: Some(crate::intent::TaskHint::Reasoning),
2090 ..Default::default()
2091 };
2092 let decision = router.route_with_intent("hi", ®, &tracker, &hint);
2093
2094 assert_eq!(
2095 decision.task,
2096 InferenceTask::Reasoning,
2097 "TaskHint::Reasoning should override the prompt-derived task"
2098 );
2099 }
2100}