1use rand::Rng;
13use serde::{Deserialize, Serialize};
14
15use std::sync::{Arc, Mutex};
16
17use crate::hardware::HardwareInfo;
18use crate::outcome::{InferenceTask, OutcomeTracker};
19use crate::registry::UnifiedRegistry;
20use crate::routing_ext::CircuitBreakerRegistry;
21use crate::schema::{ModelCapability, ModelSchema};
22use crate::tasks::RoutingWorkload;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case")]
27pub enum TaskComplexity {
28 Simple,
29 Medium,
30 Code,
31 Complex,
32}
33
34impl TaskComplexity {
35 pub fn assess(prompt: &str) -> Self {
42 let lower = prompt.to_lowercase();
43 let word_count = prompt.split_whitespace().count();
44 let estimated_tokens = (word_count as f64 * 1.3) as usize;
45
46 let has_code = Self::detect_code(prompt);
47
48 let repair_markers = [
49 "fix", "repair", "debug", "refactor", "broken", "failing", "error", "bug",
50 ];
51 let has_repair = repair_markers.iter().any(|m| lower.contains(m));
52
53 let reasoning_markers = [
54 "analyze",
55 "compare",
56 "explain why",
57 "step by step",
58 "think through",
59 "evaluate",
60 "trade-off",
61 "tradeoff",
62 "pros and cons",
63 "architecture",
64 "design",
65 "strategy",
66 "optimize",
67 "comprehensive",
68 ];
69 let has_reasoning = reasoning_markers.iter().any(|m| lower.contains(m));
70
71 let simple_patterns = [
72 "what is",
73 "who is",
74 "when did",
75 "where is",
76 "how many",
77 "yes or no",
78 "true or false",
79 "name the",
80 "list the",
81 "define ",
82 ];
83 let is_simple = simple_patterns.iter().any(|p| lower.contains(p));
84
85 if has_code || has_repair {
86 TaskComplexity::Code
87 } else if has_reasoning || estimated_tokens > 500 {
88 TaskComplexity::Complex
89 } else if is_simple || estimated_tokens < 30 {
90 TaskComplexity::Simple
91 } else {
92 TaskComplexity::Medium
93 }
94 }
95
96 fn detect_code(prompt: &str) -> bool {
102 #[cfg(feature = "ast")]
104 {
105 if let Some(is_code) = Self::detect_code_ast(prompt) {
106 return is_code;
107 }
108 }
109
110 let code_markers = [
112 "```",
113 "fn ",
114 "def ",
115 "class ",
116 "import ",
117 "require(",
118 "async fn",
119 "pub fn",
120 "function ",
121 "const ",
122 "let ",
123 "var ",
124 "#include",
125 "package ",
126 "impl ",
127 ];
128 code_markers.iter().any(|m| prompt.contains(m))
129 }
130
131 #[cfg(feature = "ast")]
135 fn detect_code_ast(prompt: &str) -> Option<bool> {
136 let mut blocks = Vec::new();
138 let mut rest = prompt;
139 while let Some(start) = rest.find("```") {
140 let after_fence = &rest[start + 3..];
141 let code_start = after_fence.find('\n').map(|i| i + 1).unwrap_or(0);
143 if let Some(end) = after_fence[code_start..].find("```") {
144 blocks.push(&after_fence[code_start..code_start + end]);
145 rest = &after_fence[code_start + end + 3..];
146 } else {
147 break;
148 }
149 }
150
151 if blocks.is_empty() {
152 return None; }
154
155 let languages = [
157 car_ast::Language::Rust,
158 car_ast::Language::Python,
159 car_ast::Language::TypeScript,
160 car_ast::Language::JavaScript,
161 car_ast::Language::Go,
162 ];
163
164 for block in &blocks {
165 let trimmed = block.trim();
166 if trimmed.is_empty() {
167 continue;
168 }
169
170 for lang in &languages {
171 if let Some(parsed) = car_ast::parse(trimmed, *lang) {
172 if !parsed.symbols.is_empty() {
174 return Some(true);
175 }
176 }
177 }
178 }
179
180 Some(false)
183 }
184
185 pub fn required_capabilities(&self) -> Vec<ModelCapability> {
187 match self {
188 TaskComplexity::Simple => vec![ModelCapability::Generate],
189 TaskComplexity::Medium => vec![ModelCapability::Generate],
190 TaskComplexity::Code => vec![ModelCapability::Code],
191 TaskComplexity::Complex => vec![ModelCapability::Reasoning],
192 }
193 }
194
195 pub fn inference_task(&self) -> InferenceTask {
197 match self {
198 TaskComplexity::Simple | TaskComplexity::Medium => InferenceTask::Generate,
199 TaskComplexity::Code => InferenceTask::Code,
200 TaskComplexity::Complex => InferenceTask::Reasoning,
201 }
202 }
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct RoutingConfig {
208 pub min_observations: u64,
210 pub quality_weight: f64,
212 pub latency_weight: f64,
213 pub cost_weight: f64,
214 pub max_latency_ms: Option<u64>,
216 pub max_cost_usd: Option<f64>,
218 pub prefer_local: bool,
220 pub prior_strength: f64,
224 pub quality_first_cold_start: bool,
227 pub bootstrap_min_task_observations: u64,
230 pub bootstrap_quality_floor: f64,
233}
234
235impl Default for RoutingConfig {
236 fn default() -> Self {
237 Self {
238 min_observations: 2,
239 quality_weight: 0.45,
240 latency_weight: 0.4,
241 cost_weight: 0.15,
242 max_latency_ms: None,
243 max_cost_usd: None,
244 prefer_local: true,
245 prior_strength: 2.0,
246 quality_first_cold_start: true,
247 bootstrap_min_task_observations: 8,
248 bootstrap_quality_floor: 0.8,
249 }
250 }
251}
252
253#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
255#[serde(rename_all = "snake_case")]
256pub enum RoutingStrategy {
257 SchemaBased,
259 ProfileBased,
261 Exploration,
263 Explicit,
265}
266
267#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct AdaptiveRoutingDecision {
270 pub model_id: String,
272 pub model_name: String,
274 pub task: InferenceTask,
276 pub complexity: TaskComplexity,
278 pub reason: String,
280 pub strategy: RoutingStrategy,
282 pub predicted_quality: f64,
284 pub fallbacks: Vec<String>,
286 pub context_length: usize,
288 pub needs_compaction: bool,
290}
291
292pub struct AdaptiveRouter {
294 hw: HardwareInfo,
295 config: RoutingConfig,
296 pub circuit_breakers: Arc<Mutex<CircuitBreakerRegistry>>,
298}
299
300pub struct RouteRequest<'a> {
316 pub prompt: &'a str,
317 pub registry: &'a UnifiedRegistry,
318 pub tracker: &'a OutcomeTracker,
319 pub estimated_total_tokens: usize,
322 pub has_tools: bool,
323 pub has_vision: bool,
324 pub workload: RoutingWorkload,
325 pub intent: Option<&'a crate::intent::IntentHint>,
328}
329
330impl<'a> RouteRequest<'a> {
331 pub fn new(
335 prompt: &'a str,
336 registry: &'a UnifiedRegistry,
337 tracker: &'a OutcomeTracker,
338 ) -> Self {
339 Self {
340 prompt,
341 registry,
342 tracker,
343 estimated_total_tokens: 0,
344 has_tools: false,
345 has_vision: false,
346 workload: RoutingWorkload::Interactive,
347 intent: None,
348 }
349 }
350}
351
352impl AdaptiveRouter {
353 pub fn new(hw: HardwareInfo, config: RoutingConfig) -> Self {
354 let circuit_breakers = Arc::new(Mutex::new(
355 CircuitBreakerRegistry::new(3, 300), ));
357 Self {
358 hw,
359 config,
360 circuit_breakers,
361 }
362 }
363
364 pub fn with_default_config(hw: HardwareInfo) -> Self {
365 Self::new(hw, RoutingConfig::default())
366 }
367
368 pub fn config(&self) -> &RoutingConfig {
369 &self.config
370 }
371
372 pub fn set_config(&mut self, config: RoutingConfig) {
373 self.config = config;
374 }
375
376 pub fn route_with(&self, req: RouteRequest<'_>) -> AdaptiveRoutingDecision {
382 let workload = match req.intent {
387 Some(h) if h.prefer_local => RoutingWorkload::LocalPreferred,
388 _ => req.workload,
389 };
390 self.route_inner_with_intent(
391 req.prompt,
392 req.registry,
393 req.tracker,
394 req.has_tools,
395 req.has_vision,
396 req.estimated_total_tokens,
397 workload,
398 req.intent,
399 )
400 }
401
402 pub fn route(
405 &self,
406 prompt: &str,
407 registry: &UnifiedRegistry,
408 tracker: &OutcomeTracker,
409 ) -> AdaptiveRoutingDecision {
410 self.route_with(RouteRequest::new(prompt, registry, tracker))
411 }
412
413 pub fn route_editor(
419 &self,
420 prompt: &str,
421 registry: &UnifiedRegistry,
422 tracker: &OutcomeTracker,
423 ) -> AdaptiveRoutingDecision {
424 self.route_with(RouteRequest {
425 workload: RoutingWorkload::Background,
426 ..RouteRequest::new(prompt, registry, tracker)
427 })
428 }
429
430 pub fn route_with_tools(
432 &self,
433 prompt: &str,
434 registry: &UnifiedRegistry,
435 tracker: &OutcomeTracker,
436 ) -> AdaptiveRoutingDecision {
437 self.route_with(RouteRequest {
438 has_tools: true,
439 ..RouteRequest::new(prompt, registry, tracker)
440 })
441 }
442
443 pub fn route_with_vision(
445 &self,
446 prompt: &str,
447 registry: &UnifiedRegistry,
448 tracker: &OutcomeTracker,
449 has_tools: bool,
450 ) -> AdaptiveRoutingDecision {
451 self.route_with(RouteRequest {
452 has_tools,
453 has_vision: true,
454 ..RouteRequest::new(prompt, registry, tracker)
455 })
456 }
457
458 pub fn route_with_intent<'a>(
464 &self,
465 prompt: &'a str,
466 registry: &'a UnifiedRegistry,
467 tracker: &'a OutcomeTracker,
468 intent: &'a crate::intent::IntentHint,
469 ) -> AdaptiveRoutingDecision {
470 self.route_with(RouteRequest {
471 intent: Some(intent),
472 ..RouteRequest::new(prompt, registry, tracker)
473 })
474 }
475
476 pub fn route_context_aware(
479 &self,
480 prompt: &str,
481 estimated_total_tokens: usize,
482 registry: &UnifiedRegistry,
483 tracker: &OutcomeTracker,
484 has_tools: bool,
485 has_vision: bool,
486 workload: RoutingWorkload,
487 ) -> AdaptiveRoutingDecision {
488 self.route_with(RouteRequest {
489 estimated_total_tokens,
490 has_tools,
491 has_vision,
492 workload,
493 ..RouteRequest::new(prompt, registry, tracker)
494 })
495 }
496
497 pub fn route_context_aware_with_intent<'a>(
502 &self,
503 prompt: &'a str,
504 estimated_total_tokens: usize,
505 registry: &'a UnifiedRegistry,
506 tracker: &'a OutcomeTracker,
507 has_tools: bool,
508 has_vision: bool,
509 workload: RoutingWorkload,
510 intent: &'a crate::intent::IntentHint,
511 ) -> AdaptiveRoutingDecision {
512 self.route_with(RouteRequest {
513 estimated_total_tokens,
514 has_tools,
515 has_vision,
516 workload,
517 intent: Some(intent),
518 ..RouteRequest::new(prompt, registry, tracker)
519 })
520 }
521
522 fn route_inner_with_intent(
523 &self,
524 prompt: &str,
525 registry: &UnifiedRegistry,
526 tracker: &OutcomeTracker,
527 has_tools: bool,
528 has_vision: bool,
529 estimated_total_tokens: usize,
530 workload: RoutingWorkload,
531 intent: Option<&crate::intent::IntentHint>,
532 ) -> AdaptiveRoutingDecision {
533 let complexity = TaskComplexity::assess(prompt);
534 let task = intent
536 .and_then(|h| h.task)
537 .map(task_hint_to_inference_task)
538 .unwrap_or_else(|| complexity.inference_task());
539 let mut required_caps = complexity.required_capabilities();
540 if let Some(hint) = intent {
541 for cap in &hint.require {
542 if !required_caps.contains(cap) {
543 required_caps.push(*cap);
544 }
545 }
546 }
547 if has_vision {
548 required_caps.push(ModelCapability::Vision);
549 }
550 if has_tools {
551 required_caps.push(ModelCapability::ToolUse);
552 if Self::needs_multi_tool_call(prompt) {
555 required_caps.push(ModelCapability::MultiToolCall);
556 }
557 }
558
559 let mut candidates = self.filter_candidates(&required_caps, registry, tracker);
561
562 if candidates.is_empty() && required_caps.contains(&ModelCapability::MultiToolCall) {
565 required_caps.retain(|c| *c != ModelCapability::MultiToolCall);
566 candidates = self.filter_candidates(&required_caps, registry, tracker);
567 }
568
569 if candidates.is_empty() {
570 return self.cold_start_decision(complexity, task, registry, has_vision);
572 }
573
574 candidates = self.apply_quality_first_bootstrap_policy(
575 candidates,
576 task,
577 tracker,
578 has_vision,
579 has_tools,
580 workload,
581 );
582
583 let (fits, needs_compaction_candidates) = if estimated_total_tokens > 0 {
586 let mut fits = Vec::new();
587 let mut tight = Vec::new();
588 for m in &candidates {
589 if m.context_length == 0 || m.context_length >= estimated_total_tokens {
590 fits.push(m.clone());
591 } else {
592 tight.push(m.clone());
593 }
594 }
595 (fits, tight)
596 } else {
597 (candidates.clone(), Vec::new())
598 };
599
600 let (scoring_candidates, compaction_needed) = if !fits.is_empty() {
602 (fits, false)
603 } else if !needs_compaction_candidates.is_empty() {
604 tracing::info!(
605 prompt_tokens = estimated_total_tokens,
606 candidates = needs_compaction_candidates.len(),
607 "no model fits full prompt — compaction will be needed"
608 );
609 (needs_compaction_candidates.clone(), true)
610 } else {
611 (candidates.clone(), false)
612 };
613
614 let scored = self.score_candidates_context_aware(
616 &scoring_candidates,
617 task,
618 tracker,
619 estimated_total_tokens,
620 workload,
621 );
622
623 let (selected_id, strategy) = self.select_with_thompson_sampling(&scored, tracker);
625
626 let mut fallbacks: Vec<String> = scored
628 .iter()
629 .filter(|(id, _)| *id != selected_id)
630 .map(|(id, _)| id.clone())
631 .collect();
632 if !compaction_needed {
634 for m in &needs_compaction_candidates {
635 if m.id != selected_id && !fallbacks.contains(&m.id) {
636 fallbacks.push(m.id.clone());
637 }
638 }
639 }
640
641 let predicted_quality = scored
642 .iter()
643 .find(|(id, _)| *id == selected_id)
644 .map(|(_, score)| *score)
645 .unwrap_or(0.5);
646
647 let selected_schema = registry
648 .get(&selected_id)
649 .or_else(|| registry.find_by_name(&selected_id));
650 let model_name = selected_schema
651 .map(|m| m.name.clone())
652 .unwrap_or_else(|| selected_id.clone());
653 let context_length = selected_schema.map(|m| m.context_length).unwrap_or(0);
654
655 let needs_compact = compaction_needed
656 || (estimated_total_tokens > 0
657 && context_length > 0
658 && estimated_total_tokens > context_length);
659
660 let compaction_note = if needs_compact {
661 format!(
662 " [compaction needed: {}→{}tok]",
663 estimated_total_tokens, context_length
664 )
665 } else {
666 String::new()
667 };
668
669 let reason = format!(
670 "{:?} task → {} via {:?} (quality: {:.2}, {} candidates){}",
671 complexity,
672 model_name,
673 strategy,
674 predicted_quality,
675 scoring_candidates.len(),
676 compaction_note,
677 );
678
679 AdaptiveRoutingDecision {
680 model_id: selected_id,
681 model_name,
682 task,
683 complexity,
684 reason,
685 strategy,
686 predicted_quality,
687 fallbacks,
688 context_length,
689 needs_compaction: needs_compact,
690 }
691 }
692
693 pub fn route_embedding(&self, registry: &UnifiedRegistry) -> String {
695 let embed_models = registry.query_by_capability(ModelCapability::Embed);
696 embed_models
697 .first()
698 .map(|m| m.name.clone())
699 .unwrap_or_else(|| "Qwen3-Embedding-0.6B".to_string())
700 }
701
702 pub fn route_small(&self, registry: &UnifiedRegistry) -> String {
704 let gen_models = registry.query_by_capability(ModelCapability::Generate);
705 gen_models
707 .iter()
708 .filter(|m| m.is_local())
709 .min_by_key(|m| m.size_mb())
710 .map(|m| m.name.clone())
711 .unwrap_or_else(|| "Qwen3-0.6B".to_string())
712 }
713
714 const LATENCY_CEILING_MS: f64 = 10000.0;
720 const _TPS_CEILING: f64 = 150.0;
722 const MOE_TPS_MULTIPLIER: f64 = 0.10;
724 const MLX_MOE_TPS_MULTIPLIER: f64 = 0.50;
726 const COST_CEILING_PER_1K: f64 = 0.1;
728 const LOCAL_BONUS: f64 = 0.15;
730 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
733 const MLX_BONUS: f64 = 0.10;
734 const SYSTEM_LLM_BONUS: f64 = 0.12;
745
746 fn filter_candidates(
750 &self,
751 required_caps: &[ModelCapability],
752 registry: &UnifiedRegistry,
753 tracker: &OutcomeTracker,
754 ) -> Vec<ModelSchema> {
755 registry
756 .list()
757 .into_iter()
758 .filter(|m| {
759 if !required_caps.iter().all(|c| m.has_capability(*c)) {
761 return false;
762 }
763 if !m.available {
765 return false;
766 }
767 if m.is_local() && m.size_mb() >= self.hw.max_model_mb {
769 return false;
770 }
771 if let Some(max) = self.config.max_latency_ms {
773 if let Some(p50) = m.performance.latency_p50_ms {
774 if p50 > max {
775 return false;
776 }
777 }
778 }
779 if let Some(max) = self.config.max_cost_usd {
781 if m.cost_per_1k_output() > max {
782 return false;
783 }
784 }
785 if !self.config.prefer_local && m.is_local() {
787 return false;
788 }
789 if tracker.is_excluded(&m.id) {
791 return false;
792 }
793 if let Ok(mut cb) = self.circuit_breakers.lock() {
795 if !cb.allow_request(&m.id) {
796 tracing::debug!(model = %m.id, "skipped by circuit breaker");
797 return false;
798 }
799 }
800 true
801 })
802 .cloned()
803 .collect()
804 }
805
806 fn apply_quality_first_bootstrap_policy(
807 &self,
808 candidates: Vec<ModelSchema>,
809 task: InferenceTask,
810 tracker: &OutcomeTracker,
811 has_vision: bool,
812 has_tools: bool,
813 workload: RoutingWorkload,
814 ) -> Vec<ModelSchema> {
815 if !self.config.quality_first_cold_start
816 || !workload.is_latency_sensitive()
817 || !self.is_quality_critical_bootstrap_task(task, has_vision, has_tools)
818 {
819 return candidates;
820 }
821
822 let trusted_remote: Vec<ModelSchema> = candidates
823 .iter()
824 .filter(|model| self.is_trusted_quality_remote(model))
825 .cloned()
826 .collect();
827
828 if trusted_remote.is_empty() {
829 return candidates;
830 }
831
832 let proven_local: Vec<ModelSchema> = candidates
833 .iter()
834 .filter(|model| model.is_local() && self.is_local_model_proven_for_task(model, task, tracker))
835 .cloned()
836 .collect();
837
838 if !proven_local.is_empty() {
839 return proven_local;
840 }
841
842 trusted_remote
843 }
844
845 fn score_candidates_context_aware(
849 &self,
850 candidates: &[ModelSchema],
851 task: InferenceTask,
852 tracker: &OutcomeTracker,
853 estimated_total_tokens: usize,
854 workload: RoutingWorkload,
855 ) -> Vec<(String, f64)> {
856 let mut scored: Vec<(String, f64)> = candidates
857 .iter()
858 .map(|m| {
859 let base_score = self.score_model(m, task, tracker, workload);
860 let headroom_bonus = if estimated_total_tokens > 0 && m.context_length > 0 {
863 let ratio = m.context_length as f64 / estimated_total_tokens as f64;
864 if ratio >= 1.0 {
865 (ratio.min(4.0) - 1.0) / 3.0 * 0.10 } else {
867 -0.15 }
869 } else {
870 0.0
871 };
872 (m.id.clone(), base_score + headroom_bonus)
873 })
874 .collect();
875
876 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
877 scored
878 }
879
880 fn score_model(
883 &self,
884 model: &ModelSchema,
885 task: InferenceTask,
886 tracker: &OutcomeTracker,
887 workload: RoutingWorkload,
888 ) -> f64 {
889 let profile = tracker.profile(&model.id);
890 let schema_quality = self.schema_quality_estimate(model);
891 let schema_latency = self.schema_latency_estimate(model);
892 let (quality_weight, latency_weight, cost_weight) = workload.weights();
893
894 let quality = match profile {
897 Some(p) if p.total_calls >= self.config.min_observations => p
898 .task_stats(task)
899 .map(|ts| ts.ema_quality)
900 .unwrap_or(p.ema_quality),
901 Some(p) if p.total_calls == 0 => p
902 .task_stats(task)
903 .map(|ts| ts.ema_quality)
904 .unwrap_or(p.ema_quality),
905 Some(p) if p.total_calls > 0 => {
906 let w = p.total_calls as f64 / self.config.min_observations as f64;
907 schema_quality * (1.0 - w) + p.ema_quality * w
908 }
909 _ => schema_quality,
910 };
911
912 let latency = match profile {
915 Some(p) if p.total_calls >= self.config.min_observations => {
916 let avg = p
917 .task_stats(task)
918 .filter(|ts| ts.calls > 0 || ts.avg_latency_ms > 0.0)
919 .map(|ts| ts.avg_latency_ms)
920 .unwrap_or_else(|| p.avg_latency_ms());
921 self.latency_ms_to_score(avg)
922 }
923 Some(p) if p.total_calls == 0 => p
924 .task_stats(task)
925 .filter(|ts| ts.avg_latency_ms > 0.0)
926 .map(|ts| self.latency_ms_to_score(ts.avg_latency_ms))
927 .unwrap_or(schema_latency),
928 Some(p) if p.total_calls > 0 => {
929 let observed = self.latency_ms_to_score(
930 p.task_stats(task)
931 .filter(|ts| ts.calls > 0 || ts.avg_latency_ms > 0.0)
932 .map(|ts| ts.avg_latency_ms)
933 .unwrap_or_else(|| p.avg_latency_ms()),
934 );
935 let w = p.total_calls as f64 / self.config.min_observations as f64;
936 schema_latency * (1.0 - w) + observed * w
937 }
938 _ => schema_latency,
939 };
940
941 let cost = if model.is_local() {
943 1.0
944 } else {
945 (1.0 - (model.cost_per_1k_output() / Self::COST_CEILING_PER_1K)).clamp(0.0, 1.0)
946 };
947
948 let local_bonus = if self.config.prefer_local && model.is_local() {
949 Self::LOCAL_BONUS
950 } else {
951 0.0
952 };
953 let workload_local_bonus = if model.is_local() {
954 workload.local_bonus()
955 } else {
956 0.0
957 };
958
959 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
961 let mlx_bonus = if model.is_mlx() { Self::MLX_BONUS } else { 0.0 };
962 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
963 let mlx_bonus = 0.0;
964
965 let vllm_mlx_bonus = if model.is_vllm_mlx() {
967 Self::LOCAL_BONUS + 0.05
968 } else {
969 0.0
970 };
971
972 let system_llm_bonus = if model.tags.iter().any(|t| t == "low_latency")
980 && model.tags.iter().any(|t| t == "private")
981 {
982 Self::SYSTEM_LLM_BONUS
983 } else {
984 0.0
985 };
986
987 quality_weight * quality
988 + latency_weight * latency
989 + cost_weight * cost
990 + local_bonus
991 + workload_local_bonus
992 + mlx_bonus
993 + vllm_mlx_bonus
994 + system_llm_bonus
995 }
996
997 fn latency_ms_to_score(&self, ms: f64) -> f64 {
1000 (1.0 - (ms / Self::LATENCY_CEILING_MS)).clamp(0.0, 1.0)
1001 }
1002
1003 fn tps_to_latency_ms(tps: f64) -> f64 {
1005 if tps <= 0.0 {
1006 return Self::LATENCY_CEILING_MS;
1007 }
1008 (200.0 / tps) * 1000.0
1010 }
1011
1012 fn needs_multi_tool_call(prompt: &str) -> bool {
1015 let lower = prompt.to_lowercase();
1016
1017 let has_numbered_list = {
1019 let mut count = 0u32;
1020 for i in 1..=5u32 {
1021 if lower.contains(&format!("{}) ", i)) || lower.contains(&format!("{}. ", i)) {
1022 count += 1;
1023 }
1024 }
1025 count >= 2
1026 };
1027
1028 let multi_keywords = [
1030 "multiple edits",
1031 "several changes",
1032 "three changes",
1033 "two changes",
1034 "all of the following",
1035 "each of these",
1036 "do both",
1037 "do all",
1038 "and also",
1039 "additionally",
1040 "as well as",
1041 "then also",
1042 ];
1043 let has_multi_keywords = multi_keywords.iter().any(|kw| lower.contains(kw));
1044
1045 let bullet_actions = lower.matches("- add ").count()
1047 + lower.matches("- update ").count()
1048 + lower.matches("- change ").count()
1049 + lower.matches("- remove ").count()
1050 + lower.matches("- fix ").count()
1051 + lower.matches("- edit ").count()
1052 + lower.matches("- implement ").count()
1053 + lower.matches("- create ").count();
1054 let has_bullet_list = bullet_actions >= 2;
1055
1056 has_numbered_list || has_multi_keywords || has_bullet_list
1057 }
1058
1059 fn schema_quality_estimate(&self, model: &ModelSchema) -> f64 {
1064 match model.size_mb() {
1065 0 => 0.5, s if s < 1000 => 0.4, s if s < 2000 => 0.5, s if s < 3000 => 0.6, s if s < 6000 => 0.7, _ => 0.75, }
1072 }
1073
1074 fn schema_latency_estimate(&self, model: &ModelSchema) -> f64 {
1079 let is_moe = model.tags.contains(&"moe".to_string());
1080
1081 if model.is_local() {
1082 if let Some(tps) = model.performance.tokens_per_second {
1083 let effective_tps = if is_moe {
1084 let multiplier = if model.is_mlx() {
1085 Self::MLX_MOE_TPS_MULTIPLIER
1086 } else {
1087 Self::MOE_TPS_MULTIPLIER
1088 };
1089 tps * multiplier
1090 } else {
1091 tps
1092 };
1093 let estimated_ms = Self::tps_to_latency_ms(effective_tps);
1094 return self.latency_ms_to_score(estimated_ms);
1095 }
1096 return 0.5; }
1098
1099 if let Some(p50) = model.performance.latency_p50_ms {
1101 return self.latency_ms_to_score(p50 as f64);
1102 }
1103 0.3 }
1105
1106 fn is_quality_critical_bootstrap_task(
1107 &self,
1108 task: InferenceTask,
1109 has_vision: bool,
1110 has_tools: bool,
1111 ) -> bool {
1112 has_vision
1113 || has_tools
1114 || matches!(
1115 task,
1116 InferenceTask::Generate | InferenceTask::Code | InferenceTask::Reasoning
1117 )
1118 }
1119
1120 fn is_trusted_quality_remote(&self, model: &ModelSchema) -> bool {
1121 model.is_remote()
1122 && matches!(
1123 model.provider.as_str(),
1124 "openai" | "anthropic" | "google"
1125 )
1126 && !model.has_capability(ModelCapability::SpeechToText)
1127 && !model.has_capability(ModelCapability::TextToSpeech)
1128 }
1129
1130 fn is_local_model_proven_for_task(
1131 &self,
1132 model: &ModelSchema,
1133 task: InferenceTask,
1134 tracker: &OutcomeTracker,
1135 ) -> bool {
1136 let Some(profile) = tracker.profile(&model.id) else {
1137 return false;
1138 };
1139 if let Some(task_stats) = profile.task_stats(task) {
1140 if task_stats.calls >= self.config.bootstrap_min_task_observations
1141 && task_stats.ema_quality >= self.config.bootstrap_quality_floor
1142 {
1143 return true;
1144 }
1145 }
1146
1147 profile.total_calls >= self.config.bootstrap_min_task_observations
1148 && profile.ema_quality >= self.config.bootstrap_quality_floor
1149 }
1150
1151 fn select_with_thompson_sampling(
1161 &self,
1162 scored: &[(String, f64)],
1163 tracker: &OutcomeTracker,
1164 ) -> (String, RoutingStrategy) {
1165 if scored.is_empty() {
1166 return (String::new(), RoutingStrategy::SchemaBased);
1167 }
1168
1169 let mut rng = rand::rng();
1170 let mut best_sample = f64::NEG_INFINITY;
1171 let mut best_id = scored[0].0.clone();
1172 let mut best_strategy = RoutingStrategy::SchemaBased;
1173
1174 for (id, phase2_score) in scored {
1175 let profile = tracker.profile(id);
1176 let prior = self.config.prior_strength;
1177
1178 let prior_mean = phase2_score.clamp(0.0, 1.0);
1180
1181 let prior_alpha = prior * prior_mean;
1183 let prior_beta = prior * (1.0 - prior_mean);
1184
1185 let (obs_alpha, obs_beta) = match profile {
1187 Some(p) => (p.success_count as f64, p.fail_count as f64),
1188 None => (0.0, 0.0),
1189 };
1190
1191 let alpha = (prior_alpha + obs_alpha).max(0.01);
1193 let beta = (prior_beta + obs_beta).max(0.01);
1194
1195 let sample = sample_beta(&mut rng, alpha, beta);
1197
1198 if sample > best_sample {
1199 best_sample = sample;
1200 best_id = id.clone();
1201 best_strategy = match profile {
1202 Some(p) if p.total_calls >= self.config.min_observations => {
1203 RoutingStrategy::ProfileBased
1204 }
1205 Some(p) if p.total_calls > 0 => {
1206 RoutingStrategy::Exploration
1208 }
1209 _ => RoutingStrategy::SchemaBased,
1210 };
1211 }
1212 }
1213
1214 (best_id, best_strategy)
1215 }
1216
1217 fn cold_start_decision(
1219 &self,
1220 complexity: TaskComplexity,
1221 task: InferenceTask,
1222 registry: &UnifiedRegistry,
1223 has_vision: bool,
1224 ) -> AdaptiveRoutingDecision {
1225 if has_vision {
1226 if let Some(model) = registry
1227 .query_by_capability(ModelCapability::Vision)
1228 .into_iter()
1229 .find(|model| model.available && self.is_trusted_quality_remote(model))
1230 .or_else(|| registry.query_by_capability(ModelCapability::Vision).first().copied())
1231 {
1232 return AdaptiveRoutingDecision {
1233 model_id: model.id.clone(),
1234 model_name: model.name.clone(),
1235 task,
1236 complexity,
1237 reason: format!(
1238 "{:?} task → {} (cold start, vision fallback)",
1239 complexity, model.name
1240 ),
1241 strategy: RoutingStrategy::SchemaBased,
1242 predicted_quality: 0.5,
1243 fallbacks: vec![],
1244 context_length: model.context_length,
1245 needs_compaction: false,
1246 };
1247 }
1248 }
1249
1250 if self.config.quality_first_cold_start {
1251 let required_caps = complexity.required_capabilities();
1252 if let Some(model) = registry
1253 .list()
1254 .into_iter()
1255 .filter(|model| {
1256 model.available
1257 && required_caps.iter().all(|cap| model.has_capability(*cap))
1258 && self.is_trusted_quality_remote(model)
1259 })
1260 .max_by(|a, b| {
1261 self.schema_quality_estimate(a)
1262 .partial_cmp(&self.schema_quality_estimate(b))
1263 .unwrap_or(std::cmp::Ordering::Equal)
1264 })
1265 {
1266 return AdaptiveRoutingDecision {
1267 model_id: model.id.clone(),
1268 model_name: model.name.clone(),
1269 task,
1270 complexity,
1271 reason: format!(
1272 "{:?} task → {} (quality-first cold start)",
1273 complexity, model.name
1274 ),
1275 strategy: RoutingStrategy::SchemaBased,
1276 predicted_quality: self.schema_quality_estimate(model),
1277 fallbacks: vec![],
1278 context_length: model.context_length,
1279 needs_compaction: false,
1280 };
1281 }
1282 }
1283
1284 let model_name = match complexity {
1286 TaskComplexity::Simple => "Qwen3-0.6B",
1287 TaskComplexity::Medium => "Qwen3-1.7B",
1288 TaskComplexity::Code => "Qwen3-4B",
1289 TaskComplexity::Complex => &self.hw.recommended_model,
1290 };
1291
1292 let model_id = registry
1293 .find_by_name(model_name)
1294 .map(|m| m.id.clone())
1295 .unwrap_or_else(|| model_name.to_string());
1296
1297 let context_length = registry
1298 .find_by_name(model_name)
1299 .map(|m| m.context_length)
1300 .unwrap_or(0);
1301
1302 AdaptiveRoutingDecision {
1303 model_id,
1304 model_name: model_name.to_string(),
1305 task,
1306 complexity,
1307 reason: format!(
1308 "{:?} task → {} (cold start, no candidates)",
1309 complexity, model_name
1310 ),
1311 strategy: RoutingStrategy::SchemaBased,
1312 predicted_quality: 0.5,
1313 fallbacks: vec![],
1314 context_length,
1315 needs_compaction: false,
1316 }
1317 }
1318}
1319
1320fn task_hint_to_inference_task(hint: crate::intent::TaskHint) -> InferenceTask {
1326 use crate::intent::TaskHint;
1327 match hint {
1328 TaskHint::Chat => InferenceTask::Generate,
1329 TaskHint::Classify => InferenceTask::Classify,
1330 TaskHint::Reasoning => InferenceTask::Reasoning,
1331 TaskHint::Code => InferenceTask::Code,
1332 }
1333}
1334
1335fn sample_beta(rng: &mut impl Rng, alpha: f64, beta: f64) -> f64 {
1343 let x = sample_gamma(rng, alpha);
1344 let y = sample_gamma(rng, beta);
1345 if x + y == 0.0 {
1346 0.5 } else {
1348 x / (x + y)
1349 }
1350}
1351
1352fn sample_gamma(rng: &mut impl Rng, shape: f64) -> f64 {
1355 if shape < 1.0 {
1356 let u: f64 = rng.random();
1358 return sample_gamma(rng, shape + 1.0) * u.powf(1.0 / shape);
1359 }
1360
1361 let d = shape - 1.0 / 3.0;
1363 let c = 1.0 / (9.0 * d).sqrt();
1364
1365 loop {
1366 let x: f64 = loop {
1367 let n = sample_standard_normal(rng);
1368 if 1.0 + c * n > 0.0 {
1369 break n;
1370 }
1371 };
1372
1373 let v = (1.0 + c * x).powi(3);
1374 let u: f64 = rng.random();
1375
1376 if u < 1.0 - 0.0331 * x.powi(4) {
1377 return d * v;
1378 }
1379 if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
1380 return d * v;
1381 }
1382 }
1383}
1384
1385fn sample_standard_normal(rng: &mut impl Rng) -> f64 {
1387 let u1: f64 = rng.random();
1388 let u2: f64 = rng.random();
1389 (-2.0 * u1.max(1e-300).ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
1390}
1391
1392#[cfg(test)]
1393mod tests {
1394 use super::*;
1395 use crate::outcome::InferredOutcome;
1396
1397 fn test_hw() -> HardwareInfo {
1398 HardwareInfo {
1399 os: "macos".into(),
1400 arch: "aarch64".into(),
1401 cpu_cores: 10,
1402 total_ram_mb: 32768,
1403 gpu_backend: crate::hardware::GpuBackend::Metal,
1404 gpu_memory_mb: Some(28672),
1405 gpu_devices: Vec::new(),
1406 recommended_model: "Qwen3-8B".into(),
1407 recommended_context: 8192,
1408 max_model_mb: 18000, }
1410 }
1411
1412 fn test_registry() -> UnifiedRegistry {
1413 let tmp = std::path::PathBuf::from("/tmp/car-test-adaptive-router");
1414 unsafe {
1415 std::env::set_var("OPENAI_API_KEY", "test-openai-key");
1416 }
1417 for name in &[
1419 "Qwen3-0.6B",
1420 "Qwen3-1.7B",
1421 "Qwen3-4B",
1422 "Qwen3-8B",
1423 "Qwen3-Embedding-0.6B",
1424 ] {
1425 let dir = tmp.join(name);
1426 let _ = std::fs::create_dir_all(&dir);
1427 let _ = std::fs::write(dir.join("model.gguf"), b"fake");
1428 let _ = std::fs::write(dir.join("tokenizer.json"), b"{}");
1429 }
1430 let mut reg = UnifiedRegistry::new(tmp);
1431 reg.register(ModelSchema {
1432 id: "openai/gpt-5.4-mini:latest".into(),
1433 name: "gpt-5.4-mini".into(),
1434 provider: "openai".into(),
1435 family: "gpt-5.4".into(),
1436 version: "latest".into(),
1437 capabilities: vec![
1438 ModelCapability::Generate,
1439 ModelCapability::Code,
1440 ModelCapability::Reasoning,
1441 ModelCapability::ToolUse,
1442 ModelCapability::MultiToolCall,
1443 ModelCapability::Vision,
1444 ],
1445 context_length: 128_000,
1446 param_count: "api".into(),
1447 quantization: None,
1448 performance: Default::default(),
1449 cost: Default::default(),
1450 source: crate::schema::ModelSource::RemoteApi {
1451 endpoint: "https://api.openai.com/v1".into(),
1452 api_key_env: "OPENAI_API_KEY".into(),
1453 api_key_envs: vec![],
1454 api_version: None,
1455 protocol: crate::schema::ApiProtocol::OpenAiCompat,
1456 },
1457 tags: vec!["trusted-remote".into()],
1458 supported_params: vec![],
1459 public_benchmarks: vec![],
1460 available: true,
1461 });
1462 reg
1463 }
1464
1465 #[test]
1466 fn routes_simple_to_trusted_remote_during_cold_start() {
1467 let router = AdaptiveRouter::new(
1468 test_hw(),
1469 RoutingConfig {
1470 prior_strength: 100.0, ..Default::default()
1472 },
1473 );
1474 let reg = test_registry();
1475 let tracker = OutcomeTracker::new();
1476
1477 let decision = router.route("What is 2+2?", ®, &tracker);
1478 assert_eq!(decision.complexity, TaskComplexity::Simple);
1479 assert_eq!(decision.strategy, RoutingStrategy::SchemaBased);
1480 let schema = reg
1482 .find_by_name(&decision.model_name)
1483 .expect("selected model should exist in registry");
1484 assert!(
1485 !schema.is_local(),
1486 "simple task should route to trusted remote model during cold start"
1487 );
1488 assert!(matches!(
1489 schema.provider.as_str(),
1490 "openai" | "anthropic" | "google"
1491 ));
1492 }
1493
1494 #[test]
1495 fn routes_code_to_code_capable_remote_during_cold_start() {
1496 let router = AdaptiveRouter::new(
1497 test_hw(),
1498 RoutingConfig {
1499 prior_strength: 100.0, ..Default::default()
1501 },
1502 );
1503 let reg = test_registry();
1504 let tracker = OutcomeTracker::new();
1505
1506 let decision = router.route(
1507 "Fix this function:\n```rust\nfn main() {}\n```",
1508 ®,
1509 &tracker,
1510 );
1511 assert_eq!(decision.complexity, TaskComplexity::Code);
1512 assert_eq!(decision.task, InferenceTask::Code);
1513 let schema = reg
1515 .find_by_name(&decision.model_name)
1516 .expect("model should exist");
1517 assert!(
1518 schema.has_capability(ModelCapability::Code),
1519 "selected model must support Code"
1520 );
1521 assert!(!schema.is_local(), "should route to trusted remote model");
1522 }
1523
1524 #[test]
1525 fn routes_images_to_vision_capable_model() {
1526 let router = AdaptiveRouter::new(
1527 test_hw(),
1528 RoutingConfig {
1529 prior_strength: 100.0,
1530 ..Default::default()
1531 },
1532 );
1533 let mut reg = test_registry();
1534 let tracker = OutcomeTracker::new();
1535
1536 reg.register(ModelSchema {
1537 id: "mlx-vlm/qwen3-vl-2b:bf16".into(),
1538 name: "Qwen3-VL-2B-mlx-vlm".into(),
1539 provider: "qwen".into(),
1540 family: "qwen3-vl".into(),
1541 version: "bf16".into(),
1542 capabilities: vec![
1543 ModelCapability::Generate,
1544 ModelCapability::Vision,
1545 ModelCapability::Grounding,
1546 ],
1547 context_length: 262_144,
1548 param_count: "2B".into(),
1549 quantization: None,
1550 performance: Default::default(),
1551 cost: Default::default(),
1552 source: crate::schema::ModelSource::Mlx {
1553 hf_repo: "Qwen/Qwen3-VL-2B-Instruct".into(),
1554 hf_weight_file: None,
1555 },
1556 tags: vec!["vision".into(), "mlx-vlm-cli".into()],
1557 supported_params: vec![],
1558 public_benchmarks: vec![],
1559 available: true,
1560 });
1561
1562 let decision = router.route_with_vision("What is in this image?", ®, &tracker, false);
1563 let schema = reg
1564 .find_by_name(&decision.model_name)
1565 .expect("model should exist");
1566 assert!(
1567 schema.has_capability(ModelCapability::Vision),
1568 "selected model must support Vision"
1569 );
1570 }
1571
1572 #[test]
1573 fn profile_based_routing_favors_proven_model() {
1574 let router = AdaptiveRouter::new(
1575 test_hw(),
1576 RoutingConfig {
1577 prior_strength: 0.5, min_observations: 3,
1579 ..Default::default()
1580 },
1581 );
1582 let reg = test_registry();
1583 let mut tracker = OutcomeTracker::new();
1584
1585 let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
1587 for _ in 0..20 {
1588 let trace = tracker.record_start(qwen_8b_id, InferenceTask::Code, "test");
1589 tracker.record_complete(&trace, 500, 100, 50);
1590 tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1591 }
1592
1593 let mut wins = 0;
1595 for _ in 0..20 {
1596 let decision = router.route("Fix this bug in the parser", ®, &tracker);
1597 assert_eq!(decision.complexity, TaskComplexity::Code);
1598 if decision.model_id == qwen_8b_id {
1599 wins += 1;
1600 }
1601 }
1602 assert!(
1603 wins >= 12,
1604 "proven model won only {wins}/20 times (expected >= 12)"
1605 );
1606 }
1607
1608 #[test]
1609 fn proven_local_model_can_displace_bootstrap_remote() {
1610 let router = AdaptiveRouter::new(
1611 test_hw(),
1612 RoutingConfig {
1613 prior_strength: 100.0,
1614 bootstrap_min_task_observations: 6,
1615 bootstrap_quality_floor: 0.8,
1616 ..Default::default()
1617 },
1618 );
1619 let reg = test_registry();
1620 let mut tracker = OutcomeTracker::new();
1621
1622 let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
1623 for _ in 0..12 {
1624 let trace = tracker.record_start(qwen_8b_id, InferenceTask::Generate, "test");
1625 tracker.record_complete(&trace, 300, 50, 20);
1626 tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1627 }
1628
1629 let mut local_wins = 0;
1630 for _ in 0..20 {
1631 let decision = router.route("Summarize this design decision.", ®, &tracker);
1632 let schema = reg.get(&decision.model_id).expect("selected model should exist");
1633 if schema.is_local() {
1634 local_wins += 1;
1635 }
1636 }
1637
1638 assert!(
1639 local_wins >= 12,
1640 "proven local model won only {local_wins}/20 times (expected >= 12)"
1641 );
1642 }
1643
1644 #[test]
1645 fn benchmark_prior_informs_background_routing() {
1646 let router = AdaptiveRouter::new(
1647 test_hw(),
1648 RoutingConfig {
1649 prior_strength: 100.0,
1650 ..Default::default()
1651 },
1652 );
1653 let reg = test_registry();
1654 let mut tracker = OutcomeTracker::new();
1655 let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1656 profile.ema_quality = 0.95;
1657 tracker.import_profiles(vec![profile]);
1658
1659 let decision = router.route_context_aware(
1660 "Write a Python fibonacci function.",
1661 128,
1662 ®,
1663 &tracker,
1664 false,
1665 false,
1666 RoutingWorkload::Background,
1667 );
1668
1669 let schema = reg.get(&decision.model_id).expect("selected model should exist");
1670 assert!(schema.is_local(), "background routing should allow strong local benchmark priors to win");
1671 }
1672
1673 #[test]
1674 fn task_specific_benchmark_prior_informs_cold_start_routing() {
1675 let router = AdaptiveRouter::new(
1676 test_hw(),
1677 RoutingConfig {
1678 prior_strength: 100.0,
1679 ..Default::default()
1680 },
1681 );
1682 let reg = test_registry();
1683 let mut tracker = OutcomeTracker::new();
1684 let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1685 profile.task_stats.insert(
1686 crate::outcome::InferenceTask::Code.to_string(),
1687 crate::outcome::TaskStats {
1688 ema_quality: 0.95,
1689 ..Default::default()
1690 },
1691 );
1692 tracker.import_profiles(vec![profile]);
1693
1694 let decision = router.route_context_aware(
1695 "Write a Python fibonacci function.",
1696 128,
1697 ®,
1698 &tracker,
1699 false,
1700 false,
1701 RoutingWorkload::Background,
1702 );
1703
1704 let schema = reg.get(&decision.model_id).expect("selected model should exist");
1705 assert!(
1706 schema.is_local(),
1707 "background routing should use task-specific cold-start priors for local code models"
1708 );
1709 }
1710
1711 #[test]
1712 fn task_specific_latency_prior_affects_cold_start_score() {
1713 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1714 let reg = test_registry();
1715 let model = reg
1716 .get("qwen/qwen3-8b:q4_k_m")
1717 .expect("local test model should exist");
1718
1719 let mut fast_tracker = OutcomeTracker::new();
1720 let mut fast_profile = crate::outcome::ModelProfile::new(model.id.clone());
1721 fast_profile.task_stats.insert(
1722 crate::outcome::InferenceTask::Generate.to_string(),
1723 crate::outcome::TaskStats {
1724 ema_quality: 0.95,
1725 avg_latency_ms: 1200.0,
1726 ..Default::default()
1727 },
1728 );
1729 fast_tracker.import_profiles(vec![fast_profile]);
1730
1731 let mut slow_tracker = OutcomeTracker::new();
1732 let mut slow_profile = crate::outcome::ModelProfile::new(model.id.clone());
1733 slow_profile.task_stats.insert(
1734 crate::outcome::InferenceTask::Generate.to_string(),
1735 crate::outcome::TaskStats {
1736 ema_quality: 0.95,
1737 avg_latency_ms: 120_000.0,
1738 ..Default::default()
1739 },
1740 );
1741 slow_tracker.import_profiles(vec![slow_profile]);
1742
1743 let fast_score = router.score_model(
1744 model,
1745 InferenceTask::Generate,
1746 &fast_tracker,
1747 RoutingWorkload::Interactive,
1748 );
1749 let slow_score = router.score_model(
1750 model,
1751 InferenceTask::Generate,
1752 &slow_tracker,
1753 RoutingWorkload::Interactive,
1754 );
1755
1756 assert!(
1757 fast_score > slow_score,
1758 "faster task latency prior should improve cold-start score ({fast_score} <= {slow_score})"
1759 );
1760 }
1761
1762 #[test]
1763 fn interactive_workload_keeps_remote_bootstrap_bias() {
1764 let router = AdaptiveRouter::new(
1765 test_hw(),
1766 RoutingConfig {
1767 prior_strength: 100.0,
1768 ..Default::default()
1769 },
1770 );
1771 let reg = test_registry();
1772 let mut tracker = OutcomeTracker::new();
1773 let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1774 profile.ema_quality = 0.95;
1775 tracker.import_profiles(vec![profile]);
1776
1777 let decision = router.route_context_aware(
1778 "Write a Python fibonacci function.",
1779 128,
1780 ®,
1781 &tracker,
1782 false,
1783 false,
1784 RoutingWorkload::Interactive,
1785 );
1786
1787 let schema = reg.get(&decision.model_id).expect("selected model should exist");
1788 assert!(
1789 !schema.is_local(),
1790 "interactive routing should still prefer trusted remote models during cold start"
1791 );
1792 }
1793
1794 #[test]
1795 fn fallback_chain_has_alternatives() {
1796 let router = AdaptiveRouter::new(
1797 test_hw(),
1798 RoutingConfig {
1799 prior_strength: 100.0, ..Default::default()
1801 },
1802 );
1803 let reg = test_registry();
1804 let tracker = OutcomeTracker::new();
1805
1806 let decision = router.route("Analyze the architecture trade-offs", ®, &tracker);
1807 assert!(!decision.fallbacks.is_empty());
1808 assert!(!decision.fallbacks.contains(&decision.model_id));
1810 }
1811
1812 #[test]
1813 fn latency_scoring_is_consistent() {
1814 let router = AdaptiveRouter::with_default_config(test_hw());
1816
1817 let schema_score = router.latency_ms_to_score(AdaptiveRouter::tps_to_latency_ms(25.0));
1819 let observed_score = router.latency_ms_to_score(8000.0);
1821 assert!(
1822 (schema_score - observed_score).abs() < 0.01,
1823 "schema ({schema_score}) and observed ({observed_score}) should match"
1824 );
1825 }
1826
1827 #[test]
1828 fn complexity_assessment() {
1829 assert_eq!(
1830 TaskComplexity::assess("What is the capital of France?"),
1831 TaskComplexity::Simple
1832 );
1833 assert_eq!(
1834 TaskComplexity::assess("Fix this broken test"),
1835 TaskComplexity::Code
1836 );
1837 assert_eq!(
1838 TaskComplexity::assess("Analyze the trade-offs between A and B"),
1839 TaskComplexity::Complex
1840 );
1841 }
1842
1843 #[test]
1844 fn beta_sampling_produces_valid_values() {
1845 let mut rng = rand::rng();
1846 for _ in 0..100 {
1848 let s = sample_beta(&mut rng, 2.0, 5.0);
1849 assert!(s >= 0.0 && s <= 1.0, "sample {s} out of [0,1] range");
1850 }
1851 let samples: Vec<f64> = (0..1000).map(|_| sample_beta(&mut rng, 1.0, 1.0)).collect();
1853 let mean = samples.iter().sum::<f64>() / samples.len() as f64;
1854 assert!(
1855 (mean - 0.5).abs() < 0.05,
1856 "Beta(1,1) mean {mean} should be ~0.5"
1857 );
1858 }
1859
1860 #[test]
1861 fn thompson_sampling_converges_to_best() {
1862 let router = AdaptiveRouter::new(
1864 test_hw(),
1865 RoutingConfig {
1866 prior_strength: 1.0, ..Default::default()
1868 },
1869 );
1870 let reg = test_registry();
1871 let mut tracker = OutcomeTracker::new();
1872
1873 let qwen_4b_id = "qwen/qwen3-4b:q4_k_m";
1875 for _ in 0..20 {
1876 let trace = tracker.record_start(qwen_4b_id, InferenceTask::Code, "test");
1877 tracker.record_complete(&trace, 500, 100, 50);
1878 tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1879 }
1880
1881 let mut wins = 0;
1883 for _ in 0..20 {
1884 let decision = router.route("Fix this parser bug", ®, &tracker);
1885 if decision.model_id == qwen_4b_id {
1886 wins += 1;
1887 }
1888 }
1889 assert!(
1890 wins >= 14,
1891 "strong model won only {wins}/20 times (expected >= 14)"
1892 );
1893 }
1894
1895 #[test]
1898 fn intent_require_filters_out_models_lacking_capability() {
1899 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1904 let reg = test_registry();
1905 let tracker = OutcomeTracker::new();
1906
1907 let intent = crate::intent::IntentHint {
1908 require: vec![ModelCapability::Vision],
1909 ..Default::default()
1910 };
1911 let decision = router.route_with_intent("hello", ®, &tracker, &intent);
1912
1913 assert_eq!(
1918 decision.strategy,
1919 RoutingStrategy::SchemaBased,
1920 "require=[vision] with no vision-capable candidates must drop to schema cold-start"
1921 );
1922 }
1923
1924 #[test]
1925 fn intent_default_does_not_override_task_or_caps() {
1926 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1933 let reg = test_registry();
1934 let tracker = OutcomeTracker::new();
1935
1936 let baseline = router.route("write a haiku", ®, &tracker);
1937 let with_default = router.route_with_intent(
1938 "write a haiku",
1939 ®,
1940 &tracker,
1941 &crate::intent::IntentHint::default(),
1942 );
1943
1944 assert_eq!(
1945 baseline.task, with_default.task,
1946 "default IntentHint must not change the prompt-derived task"
1947 );
1948 assert_eq!(
1949 baseline.complexity, with_default.complexity,
1950 "default IntentHint must not change the prompt-derived complexity"
1951 );
1952 }
1953
1954 #[test]
1955 fn intent_task_hint_overrides_prompt_complexity() {
1956 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1959 let reg = test_registry();
1960 let tracker = OutcomeTracker::new();
1961
1962 let hint = crate::intent::IntentHint {
1963 task: Some(crate::intent::TaskHint::Reasoning),
1964 ..Default::default()
1965 };
1966 let decision = router.route_with_intent("hi", ®, &tracker, &hint);
1967
1968 assert_eq!(
1969 decision.task,
1970 InferenceTask::Reasoning,
1971 "TaskHint::Reasoning should override the prompt-derived task"
1972 );
1973 }
1974}