1use rand::Rng;
13use serde::{Deserialize, Serialize};
14
15use std::sync::{Arc, Mutex};
16
17use crate::hardware::HardwareInfo;
18use crate::outcome::{InferenceTask, OutcomeTracker};
19use crate::registry::UnifiedRegistry;
20use crate::routing_ext::CircuitBreakerRegistry;
21use crate::schema::{ModelCapability, ModelSchema};
22use crate::tasks::RoutingWorkload;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case")]
27pub enum TaskComplexity {
28 Simple,
29 Medium,
30 Code,
31 Complex,
32}
33
34impl TaskComplexity {
35 pub fn assess(prompt: &str) -> Self {
42 let lower = prompt.to_lowercase();
43 let word_count = prompt.split_whitespace().count();
44 let estimated_tokens = (word_count as f64 * 1.3) as usize;
45
46 let has_code = Self::detect_code(prompt);
47
48 let repair_markers = [
49 "fix", "repair", "debug", "refactor", "broken", "failing", "error", "bug",
50 ];
51 let has_repair = repair_markers.iter().any(|m| lower.contains(m));
52
53 let reasoning_markers = [
54 "analyze",
55 "compare",
56 "explain why",
57 "step by step",
58 "think through",
59 "evaluate",
60 "trade-off",
61 "tradeoff",
62 "pros and cons",
63 "architecture",
64 "design",
65 "strategy",
66 "optimize",
67 "comprehensive",
68 ];
69 let has_reasoning = reasoning_markers.iter().any(|m| lower.contains(m));
70
71 let simple_patterns = [
72 "what is",
73 "who is",
74 "when did",
75 "where is",
76 "how many",
77 "yes or no",
78 "true or false",
79 "name the",
80 "list the",
81 "define ",
82 ];
83 let is_simple = simple_patterns.iter().any(|p| lower.contains(p));
84
85 if has_code || has_repair {
86 TaskComplexity::Code
87 } else if has_reasoning || estimated_tokens > 500 {
88 TaskComplexity::Complex
89 } else if is_simple || estimated_tokens < 30 {
90 TaskComplexity::Simple
91 } else {
92 TaskComplexity::Medium
93 }
94 }
95
96 fn detect_code(prompt: &str) -> bool {
102 #[cfg(feature = "ast")]
104 {
105 if let Some(is_code) = Self::detect_code_ast(prompt) {
106 return is_code;
107 }
108 }
109
110 let code_markers = [
112 "```",
113 "fn ",
114 "def ",
115 "class ",
116 "import ",
117 "require(",
118 "async fn",
119 "pub fn",
120 "function ",
121 "const ",
122 "let ",
123 "var ",
124 "#include",
125 "package ",
126 "impl ",
127 ];
128 code_markers.iter().any(|m| prompt.contains(m))
129 }
130
131 #[cfg(feature = "ast")]
135 fn detect_code_ast(prompt: &str) -> Option<bool> {
136 let mut blocks = Vec::new();
138 let mut rest = prompt;
139 while let Some(start) = rest.find("```") {
140 let after_fence = &rest[start + 3..];
141 let code_start = after_fence.find('\n').map(|i| i + 1).unwrap_or(0);
143 if let Some(end) = after_fence[code_start..].find("```") {
144 blocks.push(&after_fence[code_start..code_start + end]);
145 rest = &after_fence[code_start + end + 3..];
146 } else {
147 break;
148 }
149 }
150
151 if blocks.is_empty() {
152 return None; }
154
155 let languages = [
157 car_ast::Language::Rust,
158 car_ast::Language::Python,
159 car_ast::Language::TypeScript,
160 car_ast::Language::JavaScript,
161 car_ast::Language::Go,
162 ];
163
164 for block in &blocks {
165 let trimmed = block.trim();
166 if trimmed.is_empty() {
167 continue;
168 }
169
170 for lang in &languages {
171 if let Some(parsed) = car_ast::parse(trimmed, *lang) {
172 if !parsed.symbols.is_empty() {
174 return Some(true);
175 }
176 }
177 }
178 }
179
180 Some(false)
183 }
184
185 pub fn required_capabilities(&self) -> Vec<ModelCapability> {
187 match self {
188 TaskComplexity::Simple => vec![ModelCapability::Generate],
189 TaskComplexity::Medium => vec![ModelCapability::Generate],
190 TaskComplexity::Code => vec![ModelCapability::Code],
191 TaskComplexity::Complex => vec![ModelCapability::Reasoning],
192 }
193 }
194
195 pub fn inference_task(&self) -> InferenceTask {
197 match self {
198 TaskComplexity::Simple | TaskComplexity::Medium => InferenceTask::Generate,
199 TaskComplexity::Code => InferenceTask::Code,
200 TaskComplexity::Complex => InferenceTask::Reasoning,
201 }
202 }
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct RoutingConfig {
208 pub min_observations: u64,
210 pub quality_weight: f64,
212 pub latency_weight: f64,
213 pub cost_weight: f64,
214 pub max_latency_ms: Option<u64>,
216 pub max_cost_usd: Option<f64>,
218 pub prefer_local: bool,
220 pub prior_strength: f64,
224 pub quality_first_cold_start: bool,
227 pub bootstrap_min_task_observations: u64,
230 pub bootstrap_quality_floor: f64,
233}
234
235impl Default for RoutingConfig {
236 fn default() -> Self {
237 Self {
238 min_observations: 2,
239 quality_weight: 0.45,
240 latency_weight: 0.4,
241 cost_weight: 0.15,
242 max_latency_ms: None,
243 max_cost_usd: None,
244 prefer_local: true,
245 prior_strength: 2.0,
246 quality_first_cold_start: true,
247 bootstrap_min_task_observations: 8,
248 bootstrap_quality_floor: 0.8,
249 }
250 }
251}
252
253#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
255#[serde(rename_all = "snake_case")]
256pub enum RoutingStrategy {
257 SchemaBased,
259 ProfileBased,
261 Exploration,
263 Explicit,
265}
266
267#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct AdaptiveRoutingDecision {
270 pub model_id: String,
272 pub model_name: String,
274 pub task: InferenceTask,
276 pub complexity: TaskComplexity,
278 pub reason: String,
280 pub strategy: RoutingStrategy,
282 pub predicted_quality: f64,
284 pub fallbacks: Vec<String>,
286 pub context_length: usize,
288 pub needs_compaction: bool,
290}
291
292pub struct AdaptiveRouter {
294 hw: HardwareInfo,
295 config: RoutingConfig,
296 pub circuit_breakers: Arc<Mutex<CircuitBreakerRegistry>>,
298}
299
300pub struct RouteRequest<'a> {
316 pub prompt: &'a str,
317 pub registry: &'a UnifiedRegistry,
318 pub tracker: &'a OutcomeTracker,
319 pub estimated_total_tokens: usize,
322 pub has_tools: bool,
323 pub has_vision: bool,
324 pub workload: RoutingWorkload,
325 pub intent: Option<&'a crate::intent::IntentHint>,
328}
329
330impl<'a> RouteRequest<'a> {
331 pub fn new(
335 prompt: &'a str,
336 registry: &'a UnifiedRegistry,
337 tracker: &'a OutcomeTracker,
338 ) -> Self {
339 Self {
340 prompt,
341 registry,
342 tracker,
343 estimated_total_tokens: 0,
344 has_tools: false,
345 has_vision: false,
346 workload: RoutingWorkload::Interactive,
347 intent: None,
348 }
349 }
350}
351
352impl AdaptiveRouter {
353 pub fn new(hw: HardwareInfo, config: RoutingConfig) -> Self {
354 let circuit_breakers = Arc::new(Mutex::new(
355 CircuitBreakerRegistry::new(3, 300), ));
357 Self {
358 hw,
359 config,
360 circuit_breakers,
361 }
362 }
363
364 pub fn with_default_config(hw: HardwareInfo) -> Self {
365 Self::new(hw, RoutingConfig::default())
366 }
367
368 pub fn config(&self) -> &RoutingConfig {
369 &self.config
370 }
371
372 pub fn set_config(&mut self, config: RoutingConfig) {
373 self.config = config;
374 }
375
376 pub fn route_with(&self, req: RouteRequest<'_>) -> AdaptiveRoutingDecision {
382 let workload = match req.intent {
388 Some(h) if h.prefer_fast => RoutingWorkload::Fastest,
389 Some(h) if h.prefer_local => RoutingWorkload::LocalPreferred,
390 _ => req.workload,
391 };
392 self.route_inner_with_intent(
393 req.prompt,
394 req.registry,
395 req.tracker,
396 req.has_tools,
397 req.has_vision,
398 req.estimated_total_tokens,
399 workload,
400 req.intent,
401 )
402 }
403
404 pub fn route(
407 &self,
408 prompt: &str,
409 registry: &UnifiedRegistry,
410 tracker: &OutcomeTracker,
411 ) -> AdaptiveRoutingDecision {
412 self.route_with(RouteRequest::new(prompt, registry, tracker))
413 }
414
415 pub fn route_editor(
421 &self,
422 prompt: &str,
423 registry: &UnifiedRegistry,
424 tracker: &OutcomeTracker,
425 ) -> AdaptiveRoutingDecision {
426 self.route_with(RouteRequest {
427 workload: RoutingWorkload::Background,
428 ..RouteRequest::new(prompt, registry, tracker)
429 })
430 }
431
432 pub fn route_with_tools(
434 &self,
435 prompt: &str,
436 registry: &UnifiedRegistry,
437 tracker: &OutcomeTracker,
438 ) -> AdaptiveRoutingDecision {
439 self.route_with(RouteRequest {
440 has_tools: true,
441 ..RouteRequest::new(prompt, registry, tracker)
442 })
443 }
444
445 pub fn route_with_vision(
447 &self,
448 prompt: &str,
449 registry: &UnifiedRegistry,
450 tracker: &OutcomeTracker,
451 has_tools: bool,
452 ) -> AdaptiveRoutingDecision {
453 self.route_with(RouteRequest {
454 has_tools,
455 has_vision: true,
456 ..RouteRequest::new(prompt, registry, tracker)
457 })
458 }
459
460 pub fn route_with_intent<'a>(
466 &self,
467 prompt: &'a str,
468 registry: &'a UnifiedRegistry,
469 tracker: &'a OutcomeTracker,
470 intent: &'a crate::intent::IntentHint,
471 ) -> AdaptiveRoutingDecision {
472 self.route_with(RouteRequest {
473 intent: Some(intent),
474 ..RouteRequest::new(prompt, registry, tracker)
475 })
476 }
477
478 pub fn route_context_aware(
481 &self,
482 prompt: &str,
483 estimated_total_tokens: usize,
484 registry: &UnifiedRegistry,
485 tracker: &OutcomeTracker,
486 has_tools: bool,
487 has_vision: bool,
488 workload: RoutingWorkload,
489 ) -> AdaptiveRoutingDecision {
490 self.route_with(RouteRequest {
491 estimated_total_tokens,
492 has_tools,
493 has_vision,
494 workload,
495 ..RouteRequest::new(prompt, registry, tracker)
496 })
497 }
498
499 pub fn route_context_aware_with_intent<'a>(
504 &self,
505 prompt: &'a str,
506 estimated_total_tokens: usize,
507 registry: &'a UnifiedRegistry,
508 tracker: &'a OutcomeTracker,
509 has_tools: bool,
510 has_vision: bool,
511 workload: RoutingWorkload,
512 intent: &'a crate::intent::IntentHint,
513 ) -> AdaptiveRoutingDecision {
514 self.route_with(RouteRequest {
515 estimated_total_tokens,
516 has_tools,
517 has_vision,
518 workload,
519 intent: Some(intent),
520 ..RouteRequest::new(prompt, registry, tracker)
521 })
522 }
523
524 fn route_inner_with_intent(
525 &self,
526 prompt: &str,
527 registry: &UnifiedRegistry,
528 tracker: &OutcomeTracker,
529 has_tools: bool,
530 has_vision: bool,
531 estimated_total_tokens: usize,
532 workload: RoutingWorkload,
533 intent: Option<&crate::intent::IntentHint>,
534 ) -> AdaptiveRoutingDecision {
535 let complexity = TaskComplexity::assess(prompt);
536 let task = intent
538 .and_then(|h| h.task)
539 .map(task_hint_to_inference_task)
540 .unwrap_or_else(|| complexity.inference_task());
541 let mut required_caps = complexity.required_capabilities();
542 if let Some(hint) = intent {
543 for cap in &hint.require {
544 if !required_caps.contains(cap) {
545 required_caps.push(*cap);
546 }
547 }
548 }
549 if has_vision {
550 required_caps.push(ModelCapability::Vision);
551 }
552 if has_tools {
553 required_caps.push(ModelCapability::ToolUse);
554 if Self::needs_multi_tool_call(prompt) {
557 required_caps.push(ModelCapability::MultiToolCall);
558 }
559 }
560
561 let mut candidates = self.filter_candidates(&required_caps, registry, tracker);
563
564 if candidates.is_empty() && required_caps.contains(&ModelCapability::MultiToolCall) {
567 required_caps.retain(|c| *c != ModelCapability::MultiToolCall);
568 candidates = self.filter_candidates(&required_caps, registry, tracker);
569 }
570
571 if candidates.is_empty() {
572 return self.cold_start_decision(complexity, task, registry, has_vision);
574 }
575
576 candidates = self.apply_quality_first_bootstrap_policy(
577 candidates,
578 task,
579 tracker,
580 has_vision,
581 has_tools,
582 workload,
583 );
584
585 let (fits, needs_compaction_candidates) = if estimated_total_tokens > 0 {
588 let mut fits = Vec::new();
589 let mut tight = Vec::new();
590 for m in &candidates {
591 if m.context_length == 0 || m.context_length >= estimated_total_tokens {
592 fits.push(m.clone());
593 } else {
594 tight.push(m.clone());
595 }
596 }
597 (fits, tight)
598 } else {
599 (candidates.clone(), Vec::new())
600 };
601
602 let (scoring_candidates, compaction_needed) = if !fits.is_empty() {
604 (fits, false)
605 } else if !needs_compaction_candidates.is_empty() {
606 tracing::info!(
607 prompt_tokens = estimated_total_tokens,
608 candidates = needs_compaction_candidates.len(),
609 "no model fits full prompt — compaction will be needed"
610 );
611 (needs_compaction_candidates.clone(), true)
612 } else {
613 (candidates.clone(), false)
614 };
615
616 let scored = self.score_candidates_context_aware(
618 &scoring_candidates,
619 task,
620 tracker,
621 estimated_total_tokens,
622 workload,
623 );
624
625 let (selected_id, strategy) = self.select_with_thompson_sampling(&scored, tracker);
627
628 let mut fallbacks: Vec<String> = scored
630 .iter()
631 .filter(|(id, _)| *id != selected_id)
632 .map(|(id, _)| id.clone())
633 .collect();
634 if !compaction_needed {
636 for m in &needs_compaction_candidates {
637 if m.id != selected_id && !fallbacks.contains(&m.id) {
638 fallbacks.push(m.id.clone());
639 }
640 }
641 }
642
643 let predicted_quality = scored
644 .iter()
645 .find(|(id, _)| *id == selected_id)
646 .map(|(_, score)| *score)
647 .unwrap_or(0.5);
648
649 let selected_schema = registry
650 .get(&selected_id)
651 .or_else(|| registry.find_by_name(&selected_id));
652 let model_name = selected_schema
653 .map(|m| m.name.clone())
654 .unwrap_or_else(|| selected_id.clone());
655 let context_length = selected_schema.map(|m| m.context_length).unwrap_or(0);
656
657 let needs_compact = compaction_needed
658 || (estimated_total_tokens > 0
659 && context_length > 0
660 && estimated_total_tokens > context_length);
661
662 let compaction_note = if needs_compact {
663 format!(
664 " [compaction needed: {}→{}tok]",
665 estimated_total_tokens, context_length
666 )
667 } else {
668 String::new()
669 };
670
671 let reason = format!(
672 "{:?} task → {} via {:?} (quality: {:.2}, {} candidates){}",
673 complexity,
674 model_name,
675 strategy,
676 predicted_quality,
677 scoring_candidates.len(),
678 compaction_note,
679 );
680
681 AdaptiveRoutingDecision {
682 model_id: selected_id,
683 model_name,
684 task,
685 complexity,
686 reason,
687 strategy,
688 predicted_quality,
689 fallbacks,
690 context_length,
691 needs_compaction: needs_compact,
692 }
693 }
694
695 pub fn route_embedding(&self, registry: &UnifiedRegistry) -> String {
697 let embed_models = registry.query_by_capability(ModelCapability::Embed);
698 embed_models
699 .first()
700 .map(|m| m.name.clone())
701 .unwrap_or_else(|| "Qwen3-Embedding-0.6B".to_string())
702 }
703
704 pub fn route_small(&self, registry: &UnifiedRegistry) -> String {
706 let gen_models = registry.query_by_capability(ModelCapability::Generate);
707 gen_models
709 .iter()
710 .filter(|m| m.is_local())
711 .min_by_key(|m| m.size_mb())
712 .map(|m| m.name.clone())
713 .unwrap_or_else(|| "Qwen3-0.6B".to_string())
714 }
715
716 const LATENCY_CEILING_MS: f64 = 10000.0;
722 const _TPS_CEILING: f64 = 150.0;
724 const MOE_TPS_MULTIPLIER: f64 = 0.10;
726 const MLX_MOE_TPS_MULTIPLIER: f64 = 0.50;
728 const COST_CEILING_PER_1K: f64 = 0.1;
730 const LOCAL_BONUS: f64 = 0.15;
732 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
735 const MLX_BONUS: f64 = 0.10;
736 const SYSTEM_LLM_BONUS: f64 = 0.12;
747
748 fn filter_candidates(
752 &self,
753 required_caps: &[ModelCapability],
754 registry: &UnifiedRegistry,
755 tracker: &OutcomeTracker,
756 ) -> Vec<ModelSchema> {
757 registry
758 .list()
759 .into_iter()
760 .filter(|m| {
761 if !required_caps.iter().all(|c| m.has_capability(*c)) {
763 return false;
764 }
765 if !m.available {
767 return false;
768 }
769 if m.is_local() && m.size_mb() >= self.hw.max_model_mb {
771 return false;
772 }
773 if let Some(max) = self.config.max_latency_ms {
775 if let Some(p50) = m.performance.latency_p50_ms {
776 if p50 > max {
777 return false;
778 }
779 }
780 }
781 if let Some(max) = self.config.max_cost_usd {
783 if m.cost_per_1k_output() > max {
784 return false;
785 }
786 }
787 if !self.config.prefer_local && m.is_local() {
789 return false;
790 }
791 if tracker.is_excluded(&m.id) {
793 return false;
794 }
795 if let Ok(mut cb) = self.circuit_breakers.lock() {
797 if !cb.allow_request(&m.id) {
798 tracing::debug!(model = %m.id, "skipped by circuit breaker");
799 return false;
800 }
801 }
802 true
803 })
804 .cloned()
805 .collect()
806 }
807
808 fn apply_quality_first_bootstrap_policy(
809 &self,
810 candidates: Vec<ModelSchema>,
811 task: InferenceTask,
812 tracker: &OutcomeTracker,
813 has_vision: bool,
814 has_tools: bool,
815 workload: RoutingWorkload,
816 ) -> Vec<ModelSchema> {
817 if !self.config.quality_first_cold_start
818 || !workload.is_latency_sensitive()
819 || !self.is_quality_critical_bootstrap_task(task, has_vision, has_tools)
820 {
821 return candidates;
822 }
823
824 let trusted_remote: Vec<ModelSchema> = candidates
825 .iter()
826 .filter(|model| self.is_trusted_quality_remote(model))
827 .cloned()
828 .collect();
829
830 if trusted_remote.is_empty() {
831 return candidates;
832 }
833
834 let proven_local: Vec<ModelSchema> = candidates
835 .iter()
836 .filter(|model| model.is_local() && self.is_local_model_proven_for_task(model, task, tracker))
837 .cloned()
838 .collect();
839
840 if !proven_local.is_empty() {
841 return proven_local;
842 }
843
844 trusted_remote
845 }
846
847 fn score_candidates_context_aware(
851 &self,
852 candidates: &[ModelSchema],
853 task: InferenceTask,
854 tracker: &OutcomeTracker,
855 estimated_total_tokens: usize,
856 workload: RoutingWorkload,
857 ) -> Vec<(String, f64)> {
858 let mut scored: Vec<(String, f64)> = candidates
859 .iter()
860 .map(|m| {
861 let base_score = self.score_model(m, task, tracker, workload);
862 let headroom_bonus = if estimated_total_tokens > 0 && m.context_length > 0 {
865 let ratio = m.context_length as f64 / estimated_total_tokens as f64;
866 if ratio >= 1.0 {
867 (ratio.min(4.0) - 1.0) / 3.0 * 0.10 } else {
869 -0.15 }
871 } else {
872 0.0
873 };
874 (m.id.clone(), base_score + headroom_bonus)
875 })
876 .collect();
877
878 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
879 scored
880 }
881
882 fn score_model(
885 &self,
886 model: &ModelSchema,
887 task: InferenceTask,
888 tracker: &OutcomeTracker,
889 workload: RoutingWorkload,
890 ) -> f64 {
891 let profile = tracker.profile(&model.id);
892 let schema_quality = self.schema_quality_estimate(model);
893 let schema_latency = self.schema_latency_estimate(model);
894 let (quality_weight, latency_weight, cost_weight) = workload.weights();
895
896 let quality = match profile {
899 Some(p) if p.total_calls >= self.config.min_observations => p
900 .task_stats(task)
901 .map(|ts| ts.ema_quality)
902 .unwrap_or(p.ema_quality),
903 Some(p) if p.total_calls == 0 => p
904 .task_stats(task)
905 .map(|ts| ts.ema_quality)
906 .unwrap_or(p.ema_quality),
907 Some(p) if p.total_calls > 0 => {
908 let w = p.total_calls as f64 / self.config.min_observations as f64;
909 schema_quality * (1.0 - w) + p.ema_quality * w
910 }
911 _ => schema_quality,
912 };
913
914 let latency = match profile {
917 Some(p) if p.total_calls >= self.config.min_observations => {
918 let avg = p
919 .task_stats(task)
920 .filter(|ts| ts.calls > 0 || ts.avg_latency_ms > 0.0)
921 .map(|ts| ts.avg_latency_ms)
922 .unwrap_or_else(|| p.avg_latency_ms());
923 self.latency_ms_to_score(avg)
924 }
925 Some(p) if p.total_calls == 0 => p
926 .task_stats(task)
927 .filter(|ts| ts.avg_latency_ms > 0.0)
928 .map(|ts| self.latency_ms_to_score(ts.avg_latency_ms))
929 .unwrap_or(schema_latency),
930 Some(p) if p.total_calls > 0 => {
931 let observed = self.latency_ms_to_score(
932 p.task_stats(task)
933 .filter(|ts| ts.calls > 0 || ts.avg_latency_ms > 0.0)
934 .map(|ts| ts.avg_latency_ms)
935 .unwrap_or_else(|| p.avg_latency_ms()),
936 );
937 let w = p.total_calls as f64 / self.config.min_observations as f64;
938 schema_latency * (1.0 - w) + observed * w
939 }
940 _ => schema_latency,
941 };
942
943 let cost = if model.is_local() {
945 1.0
946 } else {
947 (1.0 - (model.cost_per_1k_output() / Self::COST_CEILING_PER_1K)).clamp(0.0, 1.0)
948 };
949
950 let local_bonus = if self.config.prefer_local && model.is_local() {
951 Self::LOCAL_BONUS
952 } else {
953 0.0
954 };
955 let workload_local_bonus = if model.is_local() {
956 workload.local_bonus()
957 } else {
958 0.0
959 };
960
961 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
963 let mlx_bonus = if model.is_mlx() { Self::MLX_BONUS } else { 0.0 };
964 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
965 let mlx_bonus = 0.0;
966
967 let vllm_mlx_bonus = if model.is_vllm_mlx() {
969 Self::LOCAL_BONUS + 0.05
970 } else {
971 0.0
972 };
973
974 let system_llm_bonus = if model.tags.iter().any(|t| t == "low_latency")
982 && model.tags.iter().any(|t| t == "private")
983 {
984 Self::SYSTEM_LLM_BONUS
985 } else {
986 0.0
987 };
988
989 quality_weight * quality
990 + latency_weight * latency
991 + cost_weight * cost
992 + local_bonus
993 + workload_local_bonus
994 + mlx_bonus
995 + vllm_mlx_bonus
996 + system_llm_bonus
997 }
998
999 fn latency_ms_to_score(&self, ms: f64) -> f64 {
1002 (1.0 - (ms / Self::LATENCY_CEILING_MS)).clamp(0.0, 1.0)
1003 }
1004
1005 fn tps_to_latency_ms(tps: f64) -> f64 {
1007 if tps <= 0.0 {
1008 return Self::LATENCY_CEILING_MS;
1009 }
1010 (200.0 / tps) * 1000.0
1012 }
1013
1014 fn needs_multi_tool_call(prompt: &str) -> bool {
1017 let lower = prompt.to_lowercase();
1018
1019 let has_numbered_list = {
1021 let mut count = 0u32;
1022 for i in 1..=5u32 {
1023 if lower.contains(&format!("{}) ", i)) || lower.contains(&format!("{}. ", i)) {
1024 count += 1;
1025 }
1026 }
1027 count >= 2
1028 };
1029
1030 let multi_keywords = [
1032 "multiple edits",
1033 "several changes",
1034 "three changes",
1035 "two changes",
1036 "all of the following",
1037 "each of these",
1038 "do both",
1039 "do all",
1040 "and also",
1041 "additionally",
1042 "as well as",
1043 "then also",
1044 ];
1045 let has_multi_keywords = multi_keywords.iter().any(|kw| lower.contains(kw));
1046
1047 let bullet_actions = lower.matches("- add ").count()
1049 + lower.matches("- update ").count()
1050 + lower.matches("- change ").count()
1051 + lower.matches("- remove ").count()
1052 + lower.matches("- fix ").count()
1053 + lower.matches("- edit ").count()
1054 + lower.matches("- implement ").count()
1055 + lower.matches("- create ").count();
1056 let has_bullet_list = bullet_actions >= 2;
1057
1058 has_numbered_list || has_multi_keywords || has_bullet_list
1059 }
1060
1061 fn schema_quality_estimate(&self, model: &ModelSchema) -> f64 {
1066 match model.size_mb() {
1067 0 => 0.5, s if s < 1000 => 0.4, s if s < 2000 => 0.5, s if s < 3000 => 0.6, s if s < 6000 => 0.7, _ => 0.75, }
1074 }
1075
1076 fn schema_latency_estimate(&self, model: &ModelSchema) -> f64 {
1081 let is_moe = model.tags.contains(&"moe".to_string());
1082
1083 if model.is_local() {
1084 if let Some(tps) = model.performance.tokens_per_second {
1085 let effective_tps = if is_moe {
1086 let multiplier = if model.is_mlx() {
1087 Self::MLX_MOE_TPS_MULTIPLIER
1088 } else {
1089 Self::MOE_TPS_MULTIPLIER
1090 };
1091 tps * multiplier
1092 } else {
1093 tps
1094 };
1095 let estimated_ms = Self::tps_to_latency_ms(effective_tps);
1096 return self.latency_ms_to_score(estimated_ms);
1097 }
1098 return 0.5; }
1100
1101 if let Some(p50) = model.performance.latency_p50_ms {
1103 return self.latency_ms_to_score(p50 as f64);
1104 }
1105 0.3 }
1107
1108 fn is_quality_critical_bootstrap_task(
1109 &self,
1110 task: InferenceTask,
1111 has_vision: bool,
1112 has_tools: bool,
1113 ) -> bool {
1114 has_vision
1115 || has_tools
1116 || matches!(
1117 task,
1118 InferenceTask::Generate | InferenceTask::Code | InferenceTask::Reasoning
1119 )
1120 }
1121
1122 fn is_trusted_quality_remote(&self, model: &ModelSchema) -> bool {
1123 model.is_remote()
1124 && matches!(
1125 model.provider.as_str(),
1126 "openai" | "anthropic" | "google"
1127 )
1128 && !model.has_capability(ModelCapability::SpeechToText)
1129 && !model.has_capability(ModelCapability::TextToSpeech)
1130 }
1131
1132 fn is_local_model_proven_for_task(
1133 &self,
1134 model: &ModelSchema,
1135 task: InferenceTask,
1136 tracker: &OutcomeTracker,
1137 ) -> bool {
1138 let Some(profile) = tracker.profile(&model.id) else {
1139 return false;
1140 };
1141 if let Some(task_stats) = profile.task_stats(task) {
1142 if task_stats.calls >= self.config.bootstrap_min_task_observations
1143 && task_stats.ema_quality >= self.config.bootstrap_quality_floor
1144 {
1145 return true;
1146 }
1147 }
1148
1149 profile.total_calls >= self.config.bootstrap_min_task_observations
1150 && profile.ema_quality >= self.config.bootstrap_quality_floor
1151 }
1152
1153 fn select_with_thompson_sampling(
1163 &self,
1164 scored: &[(String, f64)],
1165 tracker: &OutcomeTracker,
1166 ) -> (String, RoutingStrategy) {
1167 if scored.is_empty() {
1168 return (String::new(), RoutingStrategy::SchemaBased);
1169 }
1170
1171 let mut rng = rand::rng();
1172 let mut best_sample = f64::NEG_INFINITY;
1173 let mut best_id = scored[0].0.clone();
1174 let mut best_strategy = RoutingStrategy::SchemaBased;
1175
1176 for (id, phase2_score) in scored {
1177 let profile = tracker.profile(id);
1178 let prior = self.config.prior_strength;
1179
1180 let prior_mean = phase2_score.clamp(0.0, 1.0);
1182
1183 let prior_alpha = prior * prior_mean;
1185 let prior_beta = prior * (1.0 - prior_mean);
1186
1187 let (obs_alpha, obs_beta) = match profile {
1189 Some(p) => (p.success_count as f64, p.fail_count as f64),
1190 None => (0.0, 0.0),
1191 };
1192
1193 let alpha = (prior_alpha + obs_alpha).max(0.01);
1195 let beta = (prior_beta + obs_beta).max(0.01);
1196
1197 let sample = sample_beta(&mut rng, alpha, beta);
1199
1200 if sample > best_sample {
1201 best_sample = sample;
1202 best_id = id.clone();
1203 best_strategy = match profile {
1204 Some(p) if p.total_calls >= self.config.min_observations => {
1205 RoutingStrategy::ProfileBased
1206 }
1207 Some(p) if p.total_calls > 0 => {
1208 RoutingStrategy::Exploration
1210 }
1211 _ => RoutingStrategy::SchemaBased,
1212 };
1213 }
1214 }
1215
1216 (best_id, best_strategy)
1217 }
1218
1219 fn cold_start_decision(
1221 &self,
1222 complexity: TaskComplexity,
1223 task: InferenceTask,
1224 registry: &UnifiedRegistry,
1225 has_vision: bool,
1226 ) -> AdaptiveRoutingDecision {
1227 if has_vision {
1228 if let Some(model) = registry
1229 .query_by_capability(ModelCapability::Vision)
1230 .into_iter()
1231 .find(|model| model.available && self.is_trusted_quality_remote(model))
1232 .or_else(|| registry.query_by_capability(ModelCapability::Vision).first().copied())
1233 {
1234 return AdaptiveRoutingDecision {
1235 model_id: model.id.clone(),
1236 model_name: model.name.clone(),
1237 task,
1238 complexity,
1239 reason: format!(
1240 "{:?} task → {} (cold start, vision fallback)",
1241 complexity, model.name
1242 ),
1243 strategy: RoutingStrategy::SchemaBased,
1244 predicted_quality: 0.5,
1245 fallbacks: vec![],
1246 context_length: model.context_length,
1247 needs_compaction: false,
1248 };
1249 }
1250 }
1251
1252 if self.config.quality_first_cold_start {
1253 let required_caps = complexity.required_capabilities();
1254 if let Some(model) = registry
1255 .list()
1256 .into_iter()
1257 .filter(|model| {
1258 model.available
1259 && required_caps.iter().all(|cap| model.has_capability(*cap))
1260 && self.is_trusted_quality_remote(model)
1261 })
1262 .max_by(|a, b| {
1263 self.schema_quality_estimate(a)
1264 .partial_cmp(&self.schema_quality_estimate(b))
1265 .unwrap_or(std::cmp::Ordering::Equal)
1266 })
1267 {
1268 return AdaptiveRoutingDecision {
1269 model_id: model.id.clone(),
1270 model_name: model.name.clone(),
1271 task,
1272 complexity,
1273 reason: format!(
1274 "{:?} task → {} (quality-first cold start)",
1275 complexity, model.name
1276 ),
1277 strategy: RoutingStrategy::SchemaBased,
1278 predicted_quality: self.schema_quality_estimate(model),
1279 fallbacks: vec![],
1280 context_length: model.context_length,
1281 needs_compaction: false,
1282 };
1283 }
1284 }
1285
1286 let model_name = match complexity {
1288 TaskComplexity::Simple => "Qwen3-0.6B",
1289 TaskComplexity::Medium => "Qwen3-1.7B",
1290 TaskComplexity::Code => "Qwen3-4B",
1291 TaskComplexity::Complex => &self.hw.recommended_model,
1292 };
1293
1294 let model_id = registry
1295 .find_by_name(model_name)
1296 .map(|m| m.id.clone())
1297 .unwrap_or_else(|| model_name.to_string());
1298
1299 let context_length = registry
1300 .find_by_name(model_name)
1301 .map(|m| m.context_length)
1302 .unwrap_or(0);
1303
1304 AdaptiveRoutingDecision {
1305 model_id,
1306 model_name: model_name.to_string(),
1307 task,
1308 complexity,
1309 reason: format!(
1310 "{:?} task → {} (cold start, no candidates)",
1311 complexity, model_name
1312 ),
1313 strategy: RoutingStrategy::SchemaBased,
1314 predicted_quality: 0.5,
1315 fallbacks: vec![],
1316 context_length,
1317 needs_compaction: false,
1318 }
1319 }
1320}
1321
1322fn task_hint_to_inference_task(hint: crate::intent::TaskHint) -> InferenceTask {
1328 use crate::intent::TaskHint;
1329 match hint {
1330 TaskHint::Chat => InferenceTask::Generate,
1331 TaskHint::Classify => InferenceTask::Classify,
1332 TaskHint::Reasoning => InferenceTask::Reasoning,
1333 TaskHint::Code => InferenceTask::Code,
1334 }
1335}
1336
1337fn sample_beta(rng: &mut impl Rng, alpha: f64, beta: f64) -> f64 {
1345 let x = sample_gamma(rng, alpha);
1346 let y = sample_gamma(rng, beta);
1347 if x + y == 0.0 {
1348 0.5 } else {
1350 x / (x + y)
1351 }
1352}
1353
1354fn sample_gamma(rng: &mut impl Rng, shape: f64) -> f64 {
1357 if shape < 1.0 {
1358 let u: f64 = rng.random();
1360 return sample_gamma(rng, shape + 1.0) * u.powf(1.0 / shape);
1361 }
1362
1363 let d = shape - 1.0 / 3.0;
1365 let c = 1.0 / (9.0 * d).sqrt();
1366
1367 loop {
1368 let x: f64 = loop {
1369 let n = sample_standard_normal(rng);
1370 if 1.0 + c * n > 0.0 {
1371 break n;
1372 }
1373 };
1374
1375 let v = (1.0 + c * x).powi(3);
1376 let u: f64 = rng.random();
1377
1378 if u < 1.0 - 0.0331 * x.powi(4) {
1379 return d * v;
1380 }
1381 if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
1382 return d * v;
1383 }
1384 }
1385}
1386
1387fn sample_standard_normal(rng: &mut impl Rng) -> f64 {
1389 let u1: f64 = rng.random();
1390 let u2: f64 = rng.random();
1391 (-2.0 * u1.max(1e-300).ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
1392}
1393
1394#[cfg(test)]
1395mod tests {
1396 use super::*;
1397 use crate::outcome::InferredOutcome;
1398
1399 fn test_hw() -> HardwareInfo {
1400 HardwareInfo {
1401 os: "macos".into(),
1402 arch: "aarch64".into(),
1403 cpu_cores: 10,
1404 total_ram_mb: 32768,
1405 gpu_backend: crate::hardware::GpuBackend::Metal,
1406 gpu_memory_mb: Some(28672),
1407 gpu_devices: Vec::new(),
1408 recommended_model: "Qwen3-8B".into(),
1409 recommended_context: 8192,
1410 max_model_mb: 18000, }
1412 }
1413
1414 fn test_registry() -> UnifiedRegistry {
1415 let tmp = std::path::PathBuf::from("/tmp/car-test-adaptive-router");
1416 unsafe {
1417 std::env::set_var("OPENAI_API_KEY", "test-openai-key");
1418 }
1419 for name in &[
1421 "Qwen3-0.6B",
1422 "Qwen3-1.7B",
1423 "Qwen3-4B",
1424 "Qwen3-8B",
1425 "Qwen3-Embedding-0.6B",
1426 ] {
1427 let dir = tmp.join(name);
1428 let _ = std::fs::create_dir_all(&dir);
1429 let _ = std::fs::write(dir.join("model.gguf"), b"fake");
1430 let _ = std::fs::write(dir.join("tokenizer.json"), b"{}");
1431 }
1432 let mut reg = UnifiedRegistry::new(tmp);
1433 reg.register(ModelSchema {
1434 id: "openai/gpt-5.4-mini:latest".into(),
1435 name: "gpt-5.4-mini".into(),
1436 provider: "openai".into(),
1437 family: "gpt-5.4".into(),
1438 version: "latest".into(),
1439 capabilities: vec![
1440 ModelCapability::Generate,
1441 ModelCapability::Code,
1442 ModelCapability::Reasoning,
1443 ModelCapability::ToolUse,
1444 ModelCapability::MultiToolCall,
1445 ModelCapability::Vision,
1446 ],
1447 context_length: 128_000,
1448 param_count: "api".into(),
1449 quantization: None,
1450 performance: Default::default(),
1451 cost: Default::default(),
1452 source: crate::schema::ModelSource::RemoteApi {
1453 endpoint: "https://api.openai.com/v1".into(),
1454 api_key_env: "OPENAI_API_KEY".into(),
1455 api_key_envs: vec![],
1456 api_version: None,
1457 protocol: crate::schema::ApiProtocol::OpenAiCompat,
1458 },
1459 tags: vec!["trusted-remote".into()],
1460 supported_params: vec![],
1461 public_benchmarks: vec![],
1462 available: true,
1463 });
1464 reg
1465 }
1466
1467 #[test]
1468 fn routes_simple_to_trusted_remote_during_cold_start() {
1469 let router = AdaptiveRouter::new(
1470 test_hw(),
1471 RoutingConfig {
1472 prior_strength: 100.0, ..Default::default()
1474 },
1475 );
1476 let reg = test_registry();
1477 let tracker = OutcomeTracker::new();
1478
1479 let decision = router.route("What is 2+2?", ®, &tracker);
1480 assert_eq!(decision.complexity, TaskComplexity::Simple);
1481 assert_eq!(decision.strategy, RoutingStrategy::SchemaBased);
1482 let schema = reg
1484 .find_by_name(&decision.model_name)
1485 .expect("selected model should exist in registry");
1486 assert!(
1487 !schema.is_local(),
1488 "simple task should route to trusted remote model during cold start"
1489 );
1490 assert!(matches!(
1491 schema.provider.as_str(),
1492 "openai" | "anthropic" | "google"
1493 ));
1494 }
1495
1496 #[test]
1497 fn routes_code_to_code_capable_remote_during_cold_start() {
1498 let router = AdaptiveRouter::new(
1499 test_hw(),
1500 RoutingConfig {
1501 prior_strength: 100.0, ..Default::default()
1503 },
1504 );
1505 let reg = test_registry();
1506 let tracker = OutcomeTracker::new();
1507
1508 let decision = router.route(
1509 "Fix this function:\n```rust\nfn main() {}\n```",
1510 ®,
1511 &tracker,
1512 );
1513 assert_eq!(decision.complexity, TaskComplexity::Code);
1514 assert_eq!(decision.task, InferenceTask::Code);
1515 let schema = reg
1517 .find_by_name(&decision.model_name)
1518 .expect("model should exist");
1519 assert!(
1520 schema.has_capability(ModelCapability::Code),
1521 "selected model must support Code"
1522 );
1523 assert!(!schema.is_local(), "should route to trusted remote model");
1524 }
1525
1526 #[test]
1527 fn routes_images_to_vision_capable_model() {
1528 let router = AdaptiveRouter::new(
1529 test_hw(),
1530 RoutingConfig {
1531 prior_strength: 100.0,
1532 ..Default::default()
1533 },
1534 );
1535 let mut reg = test_registry();
1536 let tracker = OutcomeTracker::new();
1537
1538 reg.register(ModelSchema {
1539 id: "mlx-vlm/qwen3-vl-2b:bf16".into(),
1540 name: "Qwen3-VL-2B-mlx-vlm".into(),
1541 provider: "qwen".into(),
1542 family: "qwen3-vl".into(),
1543 version: "bf16".into(),
1544 capabilities: vec![
1545 ModelCapability::Generate,
1546 ModelCapability::Vision,
1547 ModelCapability::Grounding,
1548 ],
1549 context_length: 262_144,
1550 param_count: "2B".into(),
1551 quantization: None,
1552 performance: Default::default(),
1553 cost: Default::default(),
1554 source: crate::schema::ModelSource::Mlx {
1555 hf_repo: "Qwen/Qwen3-VL-2B-Instruct".into(),
1556 hf_weight_file: None,
1557 },
1558 tags: vec!["vision".into(), "mlx-vlm-cli".into()],
1559 supported_params: vec![],
1560 public_benchmarks: vec![],
1561 available: true,
1562 });
1563
1564 let decision = router.route_with_vision("What is in this image?", ®, &tracker, false);
1565 let schema = reg
1566 .find_by_name(&decision.model_name)
1567 .expect("model should exist");
1568 assert!(
1569 schema.has_capability(ModelCapability::Vision),
1570 "selected model must support Vision"
1571 );
1572 }
1573
1574 #[test]
1575 fn profile_based_routing_favors_proven_model() {
1576 let router = AdaptiveRouter::new(
1577 test_hw(),
1578 RoutingConfig {
1579 prior_strength: 0.5, min_observations: 3,
1581 ..Default::default()
1582 },
1583 );
1584 let reg = test_registry();
1585 let mut tracker = OutcomeTracker::new();
1586
1587 let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
1589 for _ in 0..20 {
1590 let trace = tracker.record_start(qwen_8b_id, InferenceTask::Code, "test");
1591 tracker.record_complete(&trace, 500, 100, 50);
1592 tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1593 }
1594
1595 let mut wins = 0;
1597 for _ in 0..20 {
1598 let decision = router.route("Fix this bug in the parser", ®, &tracker);
1599 assert_eq!(decision.complexity, TaskComplexity::Code);
1600 if decision.model_id == qwen_8b_id {
1601 wins += 1;
1602 }
1603 }
1604 assert!(
1605 wins >= 12,
1606 "proven model won only {wins}/20 times (expected >= 12)"
1607 );
1608 }
1609
1610 #[test]
1611 fn proven_local_model_can_displace_bootstrap_remote() {
1612 let router = AdaptiveRouter::new(
1613 test_hw(),
1614 RoutingConfig {
1615 prior_strength: 100.0,
1616 bootstrap_min_task_observations: 6,
1617 bootstrap_quality_floor: 0.8,
1618 ..Default::default()
1619 },
1620 );
1621 let reg = test_registry();
1622 let mut tracker = OutcomeTracker::new();
1623
1624 let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
1625 for _ in 0..12 {
1626 let trace = tracker.record_start(qwen_8b_id, InferenceTask::Generate, "test");
1627 tracker.record_complete(&trace, 300, 50, 20);
1628 tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1629 }
1630
1631 let mut local_wins = 0;
1632 for _ in 0..20 {
1633 let decision = router.route("Summarize this design decision.", ®, &tracker);
1634 let schema = reg.get(&decision.model_id).expect("selected model should exist");
1635 if schema.is_local() {
1636 local_wins += 1;
1637 }
1638 }
1639
1640 assert!(
1641 local_wins >= 12,
1642 "proven local model won only {local_wins}/20 times (expected >= 12)"
1643 );
1644 }
1645
1646 #[test]
1647 fn benchmark_prior_informs_background_routing() {
1648 let router = AdaptiveRouter::new(
1649 test_hw(),
1650 RoutingConfig {
1651 prior_strength: 100.0,
1652 ..Default::default()
1653 },
1654 );
1655 let reg = test_registry();
1656 let mut tracker = OutcomeTracker::new();
1657 let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1658 profile.ema_quality = 0.95;
1659 tracker.import_profiles(vec![profile]);
1660
1661 let decision = router.route_context_aware(
1662 "Write a Python fibonacci function.",
1663 128,
1664 ®,
1665 &tracker,
1666 false,
1667 false,
1668 RoutingWorkload::Background,
1669 );
1670
1671 let schema = reg.get(&decision.model_id).expect("selected model should exist");
1672 assert!(schema.is_local(), "background routing should allow strong local benchmark priors to win");
1673 }
1674
1675 #[test]
1676 fn task_specific_benchmark_prior_informs_cold_start_routing() {
1677 let router = AdaptiveRouter::new(
1678 test_hw(),
1679 RoutingConfig {
1680 prior_strength: 100.0,
1681 ..Default::default()
1682 },
1683 );
1684 let reg = test_registry();
1685 let mut tracker = OutcomeTracker::new();
1686 let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1687 profile.task_stats.insert(
1688 crate::outcome::InferenceTask::Code.to_string(),
1689 crate::outcome::TaskStats {
1690 ema_quality: 0.95,
1691 ..Default::default()
1692 },
1693 );
1694 tracker.import_profiles(vec![profile]);
1695
1696 let decision = router.route_context_aware(
1697 "Write a Python fibonacci function.",
1698 128,
1699 ®,
1700 &tracker,
1701 false,
1702 false,
1703 RoutingWorkload::Background,
1704 );
1705
1706 let schema = reg.get(&decision.model_id).expect("selected model should exist");
1707 assert!(
1708 schema.is_local(),
1709 "background routing should use task-specific cold-start priors for local code models"
1710 );
1711 }
1712
1713 #[test]
1714 fn task_specific_latency_prior_affects_cold_start_score() {
1715 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1716 let reg = test_registry();
1717 let model = reg
1718 .get("qwen/qwen3-8b:q4_k_m")
1719 .expect("local test model should exist");
1720
1721 let mut fast_tracker = OutcomeTracker::new();
1722 let mut fast_profile = crate::outcome::ModelProfile::new(model.id.clone());
1723 fast_profile.task_stats.insert(
1724 crate::outcome::InferenceTask::Generate.to_string(),
1725 crate::outcome::TaskStats {
1726 ema_quality: 0.95,
1727 avg_latency_ms: 1200.0,
1728 ..Default::default()
1729 },
1730 );
1731 fast_tracker.import_profiles(vec![fast_profile]);
1732
1733 let mut slow_tracker = OutcomeTracker::new();
1734 let mut slow_profile = crate::outcome::ModelProfile::new(model.id.clone());
1735 slow_profile.task_stats.insert(
1736 crate::outcome::InferenceTask::Generate.to_string(),
1737 crate::outcome::TaskStats {
1738 ema_quality: 0.95,
1739 avg_latency_ms: 120_000.0,
1740 ..Default::default()
1741 },
1742 );
1743 slow_tracker.import_profiles(vec![slow_profile]);
1744
1745 let fast_score = router.score_model(
1746 model,
1747 InferenceTask::Generate,
1748 &fast_tracker,
1749 RoutingWorkload::Interactive,
1750 );
1751 let slow_score = router.score_model(
1752 model,
1753 InferenceTask::Generate,
1754 &slow_tracker,
1755 RoutingWorkload::Interactive,
1756 );
1757
1758 assert!(
1759 fast_score > slow_score,
1760 "faster task latency prior should improve cold-start score ({fast_score} <= {slow_score})"
1761 );
1762 }
1763
1764 #[test]
1765 fn interactive_workload_keeps_remote_bootstrap_bias() {
1766 let router = AdaptiveRouter::new(
1767 test_hw(),
1768 RoutingConfig {
1769 prior_strength: 100.0,
1770 ..Default::default()
1771 },
1772 );
1773 let reg = test_registry();
1774 let mut tracker = OutcomeTracker::new();
1775 let mut profile = crate::outcome::ModelProfile::new("qwen/qwen3-8b:q4_k_m".into());
1776 profile.ema_quality = 0.95;
1777 tracker.import_profiles(vec![profile]);
1778
1779 let decision = router.route_context_aware(
1780 "Write a Python fibonacci function.",
1781 128,
1782 ®,
1783 &tracker,
1784 false,
1785 false,
1786 RoutingWorkload::Interactive,
1787 );
1788
1789 let schema = reg.get(&decision.model_id).expect("selected model should exist");
1790 assert!(
1791 !schema.is_local(),
1792 "interactive routing should still prefer trusted remote models during cold start"
1793 );
1794 }
1795
1796 #[test]
1797 fn fallback_chain_has_alternatives() {
1798 let router = AdaptiveRouter::new(
1799 test_hw(),
1800 RoutingConfig {
1801 prior_strength: 100.0, ..Default::default()
1803 },
1804 );
1805 let reg = test_registry();
1806 let tracker = OutcomeTracker::new();
1807
1808 let decision = router.route("Analyze the architecture trade-offs", ®, &tracker);
1809 assert!(!decision.fallbacks.is_empty());
1810 assert!(!decision.fallbacks.contains(&decision.model_id));
1812 }
1813
1814 #[test]
1815 fn latency_scoring_is_consistent() {
1816 let router = AdaptiveRouter::with_default_config(test_hw());
1818
1819 let schema_score = router.latency_ms_to_score(AdaptiveRouter::tps_to_latency_ms(25.0));
1821 let observed_score = router.latency_ms_to_score(8000.0);
1823 assert!(
1824 (schema_score - observed_score).abs() < 0.01,
1825 "schema ({schema_score}) and observed ({observed_score}) should match"
1826 );
1827 }
1828
1829 #[test]
1830 fn complexity_assessment() {
1831 assert_eq!(
1832 TaskComplexity::assess("What is the capital of France?"),
1833 TaskComplexity::Simple
1834 );
1835 assert_eq!(
1836 TaskComplexity::assess("Fix this broken test"),
1837 TaskComplexity::Code
1838 );
1839 assert_eq!(
1840 TaskComplexity::assess("Analyze the trade-offs between A and B"),
1841 TaskComplexity::Complex
1842 );
1843 }
1844
1845 #[test]
1846 fn beta_sampling_produces_valid_values() {
1847 let mut rng = rand::rng();
1848 for _ in 0..100 {
1850 let s = sample_beta(&mut rng, 2.0, 5.0);
1851 assert!(s >= 0.0 && s <= 1.0, "sample {s} out of [0,1] range");
1852 }
1853 let samples: Vec<f64> = (0..1000).map(|_| sample_beta(&mut rng, 1.0, 1.0)).collect();
1855 let mean = samples.iter().sum::<f64>() / samples.len() as f64;
1856 assert!(
1857 (mean - 0.5).abs() < 0.05,
1858 "Beta(1,1) mean {mean} should be ~0.5"
1859 );
1860 }
1861
1862 #[test]
1863 fn thompson_sampling_converges_to_best() {
1864 let router = AdaptiveRouter::new(
1866 test_hw(),
1867 RoutingConfig {
1868 prior_strength: 1.0, ..Default::default()
1870 },
1871 );
1872 let reg = test_registry();
1873 let mut tracker = OutcomeTracker::new();
1874
1875 let qwen_4b_id = "qwen/qwen3-4b:q4_k_m";
1877 for _ in 0..20 {
1878 let trace = tracker.record_start(qwen_4b_id, InferenceTask::Code, "test");
1879 tracker.record_complete(&trace, 500, 100, 50);
1880 tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
1881 }
1882
1883 let mut wins = 0;
1885 for _ in 0..20 {
1886 let decision = router.route("Fix this parser bug", ®, &tracker);
1887 if decision.model_id == qwen_4b_id {
1888 wins += 1;
1889 }
1890 }
1891 assert!(
1892 wins >= 14,
1893 "strong model won only {wins}/20 times (expected >= 14)"
1894 );
1895 }
1896
1897 #[test]
1900 fn intent_require_filters_out_models_lacking_capability() {
1901 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1906 let reg = test_registry();
1907 let tracker = OutcomeTracker::new();
1908
1909 let intent = crate::intent::IntentHint {
1910 require: vec![ModelCapability::Vision],
1911 ..Default::default()
1912 };
1913 let decision = router.route_with_intent("hello", ®, &tracker, &intent);
1914
1915 assert_eq!(
1920 decision.strategy,
1921 RoutingStrategy::SchemaBased,
1922 "require=[vision] with no vision-capable candidates must drop to schema cold-start"
1923 );
1924 }
1925
1926 #[test]
1927 fn intent_default_does_not_override_task_or_caps() {
1928 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1935 let reg = test_registry();
1936 let tracker = OutcomeTracker::new();
1937
1938 let baseline = router.route("write a haiku", ®, &tracker);
1939 let with_default = router.route_with_intent(
1940 "write a haiku",
1941 ®,
1942 &tracker,
1943 &crate::intent::IntentHint::default(),
1944 );
1945
1946 assert_eq!(
1947 baseline.task, with_default.task,
1948 "default IntentHint must not change the prompt-derived task"
1949 );
1950 assert_eq!(
1951 baseline.complexity, with_default.complexity,
1952 "default IntentHint must not change the prompt-derived complexity"
1953 );
1954 }
1955
1956 #[test]
1957 fn intent_task_hint_overrides_prompt_complexity() {
1958 let router = AdaptiveRouter::new(test_hw(), RoutingConfig::default());
1961 let reg = test_registry();
1962 let tracker = OutcomeTracker::new();
1963
1964 let hint = crate::intent::IntentHint {
1965 task: Some(crate::intent::TaskHint::Reasoning),
1966 ..Default::default()
1967 };
1968 let decision = router.route_with_intent("hi", ®, &tracker, &hint);
1969
1970 assert_eq!(
1971 decision.task,
1972 InferenceTask::Reasoning,
1973 "TaskHint::Reasoning should override the prompt-derived task"
1974 );
1975 }
1976}