1use rand::Rng;
11use serde::{Deserialize, Serialize};
12
13use crate::hardware::HardwareInfo;
14use crate::outcome::{InferenceTask, OutcomeTracker};
15use crate::registry::UnifiedRegistry;
16use crate::schema::{ModelCapability, ModelSchema};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(rename_all = "snake_case")]
21pub enum TaskComplexity {
22 Simple,
23 Medium,
24 Code,
25 Complex,
26}
27
28impl TaskComplexity {
29 pub fn assess(prompt: &str) -> Self {
36 let lower = prompt.to_lowercase();
37 let word_count = prompt.split_whitespace().count();
38 let estimated_tokens = (word_count as f64 * 1.3) as usize;
39
40 let has_code = Self::detect_code(prompt);
41
42 let repair_markers = [
43 "fix", "repair", "debug", "refactor", "broken",
44 "failing", "error", "bug",
45 ];
46 let has_repair = repair_markers.iter().any(|m| lower.contains(m));
47
48 let reasoning_markers = [
49 "analyze", "compare", "explain why", "step by step",
50 "think through", "evaluate", "trade-off", "tradeoff",
51 "pros and cons", "architecture", "design", "strategy",
52 "optimize", "comprehensive",
53 ];
54 let has_reasoning = reasoning_markers.iter().any(|m| lower.contains(m));
55
56 let simple_patterns = [
57 "what is", "who is", "when did", "where is",
58 "how many", "yes or no", "true or false", "name the",
59 "list the", "define ",
60 ];
61 let is_simple = simple_patterns.iter().any(|p| lower.contains(p));
62
63 if has_code || has_repair {
64 TaskComplexity::Code
65 } else if has_reasoning || estimated_tokens > 500 {
66 TaskComplexity::Complex
67 } else if is_simple || estimated_tokens < 30 {
68 TaskComplexity::Simple
69 } else {
70 TaskComplexity::Medium
71 }
72 }
73
74 fn detect_code(prompt: &str) -> bool {
80 #[cfg(feature = "ast")]
82 {
83 if let Some(is_code) = Self::detect_code_ast(prompt) {
84 return is_code;
85 }
86 }
87
88 let code_markers = [
90 "```", "fn ", "def ", "class ", "import ", "require(",
91 "async fn", "pub fn", "function ", "const ", "let ", "var ",
92 "#include", "package ", "impl ",
93 ];
94 code_markers.iter().any(|m| prompt.contains(m))
95 }
96
97 #[cfg(feature = "ast")]
101 fn detect_code_ast(prompt: &str) -> Option<bool> {
102 let mut blocks = Vec::new();
104 let mut rest = prompt;
105 while let Some(start) = rest.find("```") {
106 let after_fence = &rest[start + 3..];
107 let code_start = after_fence.find('\n').map(|i| i + 1).unwrap_or(0);
109 if let Some(end) = after_fence[code_start..].find("```") {
110 blocks.push(&after_fence[code_start..code_start + end]);
111 rest = &after_fence[code_start + end + 3..];
112 } else {
113 break;
114 }
115 }
116
117 if blocks.is_empty() {
118 return None; }
120
121 let languages = [
123 car_ast::Language::Rust,
124 car_ast::Language::Python,
125 car_ast::Language::TypeScript,
126 car_ast::Language::JavaScript,
127 car_ast::Language::Go,
128 ];
129
130 for block in &blocks {
131 let trimmed = block.trim();
132 if trimmed.is_empty() { continue; }
133
134 for lang in &languages {
135 if let Some(parsed) = car_ast::parse(trimmed, *lang) {
136 if !parsed.symbols.is_empty() {
138 return Some(true);
139 }
140 }
141 }
142 }
143
144 Some(false)
147 }
148
149 pub fn required_capabilities(&self) -> Vec<ModelCapability> {
151 match self {
152 TaskComplexity::Simple => vec![ModelCapability::Generate],
153 TaskComplexity::Medium => vec![ModelCapability::Generate],
154 TaskComplexity::Code => vec![ModelCapability::Code],
155 TaskComplexity::Complex => vec![ModelCapability::Reasoning],
156 }
157 }
158
159 pub fn inference_task(&self) -> InferenceTask {
161 match self {
162 TaskComplexity::Simple | TaskComplexity::Medium => InferenceTask::Generate,
163 TaskComplexity::Code => InferenceTask::Code,
164 TaskComplexity::Complex => InferenceTask::Reasoning,
165 }
166 }
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct RoutingConfig {
172 pub exploration_rate: f64,
174 pub min_observations: u64,
176 pub quality_weight: f64,
178 pub latency_weight: f64,
179 pub cost_weight: f64,
180 pub max_latency_ms: Option<u64>,
182 pub max_cost_usd: Option<f64>,
184 pub prefer_local: bool,
186}
187
188impl Default for RoutingConfig {
189 fn default() -> Self {
190 Self {
191 exploration_rate: 0.1,
192 min_observations: 5,
193 quality_weight: 0.45,
194 latency_weight: 0.4,
195 cost_weight: 0.15,
196 max_latency_ms: None,
197 max_cost_usd: None,
198 prefer_local: true,
199 }
200 }
201}
202
203#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
205#[serde(rename_all = "snake_case")]
206pub enum RoutingStrategy {
207 SchemaBased,
209 ProfileBased,
211 Exploration,
213 Explicit,
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct AdaptiveRoutingDecision {
220 pub model_id: String,
222 pub model_name: String,
224 pub task: InferenceTask,
226 pub complexity: TaskComplexity,
228 pub reason: String,
230 pub strategy: RoutingStrategy,
232 pub predicted_quality: f64,
234 pub fallbacks: Vec<String>,
236}
237
238pub struct AdaptiveRouter {
240 hw: HardwareInfo,
241 config: RoutingConfig,
242}
243
244impl AdaptiveRouter {
245 pub fn new(hw: HardwareInfo, config: RoutingConfig) -> Self {
246 Self { hw, config }
247 }
248
249 pub fn with_default_config(hw: HardwareInfo) -> Self {
250 Self::new(hw, RoutingConfig::default())
251 }
252
253 pub fn route(
255 &self,
256 prompt: &str,
257 registry: &UnifiedRegistry,
258 tracker: &OutcomeTracker,
259 ) -> AdaptiveRoutingDecision {
260 let complexity = TaskComplexity::assess(prompt);
261 let task = complexity.inference_task();
262 let required_caps = complexity.required_capabilities();
263
264 let candidates = self.filter_candidates(&required_caps, registry);
266
267 if candidates.is_empty() {
268 return self.cold_start_decision(complexity, task, registry);
270 }
271
272 let scored = self.score_candidates(&candidates, task, tracker);
274
275 let (selected_id, strategy) = self.select_with_exploration(&scored, tracker);
277
278 let fallbacks: Vec<String> = scored.iter()
280 .filter(|(id, _)| *id != selected_id)
281 .map(|(id, _)| id.clone())
282 .collect();
283
284 let predicted_quality = scored.iter()
285 .find(|(id, _)| *id == selected_id)
286 .map(|(_, score)| *score)
287 .unwrap_or(0.5);
288
289 let model_name = registry.get(&selected_id)
290 .or_else(|| registry.find_by_name(&selected_id))
291 .map(|m| m.name.clone())
292 .unwrap_or_else(|| selected_id.clone());
293
294 let reason = format!(
295 "{:?} task → {} via {:?} (quality: {:.2}, {} candidates)",
296 complexity, model_name, strategy, predicted_quality, candidates.len()
297 );
298
299 AdaptiveRoutingDecision {
300 model_id: selected_id,
301 model_name,
302 task,
303 complexity,
304 reason,
305 strategy,
306 predicted_quality,
307 fallbacks,
308 }
309 }
310
311 pub fn route_embedding(&self, registry: &UnifiedRegistry) -> String {
313 let embed_models = registry.query_by_capability(ModelCapability::Embed);
314 embed_models.first()
315 .map(|m| m.name.clone())
316 .unwrap_or_else(|| "Qwen3-Embedding-0.6B".to_string())
317 }
318
319 pub fn route_small(&self, registry: &UnifiedRegistry) -> String {
321 let gen_models = registry.query_by_capability(ModelCapability::Generate);
322 gen_models.iter()
324 .filter(|m| m.is_local())
325 .min_by_key(|m| m.size_mb())
326 .map(|m| m.name.clone())
327 .unwrap_or_else(|| "Qwen3-0.6B".to_string())
328 }
329
330 const LATENCY_CEILING_MS: f64 = 10000.0;
336 const TPS_CEILING: f64 = 150.0;
338 const MOE_TPS_MULTIPLIER: f64 = 0.10;
340 const COST_CEILING_PER_1K: f64 = 0.1;
342 const LOCAL_BONUS: f64 = 0.15;
344
345 fn filter_candidates(
349 &self,
350 required_caps: &[ModelCapability],
351 registry: &UnifiedRegistry,
352 ) -> Vec<ModelSchema> {
353 registry.list().into_iter()
354 .filter(|m| {
355 if !required_caps.iter().all(|c| m.has_capability(*c)) {
357 return false;
358 }
359 if !m.available {
361 return false;
362 }
363 if m.is_local() && m.size_mb() >= self.hw.max_model_mb {
365 return false;
366 }
367 if let Some(max) = self.config.max_latency_ms {
369 if let Some(p50) = m.performance.latency_p50_ms {
370 if p50 > max {
371 return false;
372 }
373 }
374 }
375 if let Some(max) = self.config.max_cost_usd {
377 if m.cost_per_1k_output() > max {
378 return false;
379 }
380 }
381 true
382 })
383 .cloned()
384 .collect()
385 }
386
387 fn score_candidates(
389 &self,
390 candidates: &[ModelSchema],
391 task: InferenceTask,
392 tracker: &OutcomeTracker,
393 ) -> Vec<(String, f64)> {
394 let mut scored: Vec<(String, f64)> = candidates.iter()
395 .map(|m| {
396 let score = self.score_model(m, task, tracker);
397 (m.id.clone(), score)
398 })
399 .collect();
400
401 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
402 scored
403 }
404
405 fn score_model(
408 &self,
409 model: &ModelSchema,
410 task: InferenceTask,
411 tracker: &OutcomeTracker,
412 ) -> f64 {
413 let profile = tracker.profile(&model.id);
414 let schema_quality = self.schema_quality_estimate(model);
415 let schema_latency = self.schema_latency_estimate(model);
416
417 let quality = match profile {
420 Some(p) if p.total_calls >= self.config.min_observations => {
421 p.task_stats(task).map(|ts| ts.ema_quality).unwrap_or(p.ema_quality)
422 }
423 Some(p) if p.total_calls > 0 => {
424 let w = p.total_calls as f64 / self.config.min_observations as f64;
425 schema_quality * (1.0 - w) + p.ema_quality * w
426 }
427 _ => schema_quality,
428 };
429
430 let latency = match profile {
433 Some(p) if p.total_calls >= self.config.min_observations => {
434 let avg = p.avg_latency_ms();
435 self.latency_ms_to_score(avg)
436 }
437 Some(p) if p.total_calls > 0 => {
438 let observed = self.latency_ms_to_score(p.avg_latency_ms());
439 let w = p.total_calls as f64 / self.config.min_observations as f64;
440 schema_latency * (1.0 - w) + observed * w
441 }
442 _ => schema_latency,
443 };
444
445 let cost = if model.is_local() {
447 1.0
448 } else {
449 (1.0 - (model.cost_per_1k_output() / Self::COST_CEILING_PER_1K)).clamp(0.0, 1.0)
450 };
451
452 let local_bonus = if self.config.prefer_local && model.is_local() {
453 Self::LOCAL_BONUS
454 } else {
455 0.0
456 };
457
458 self.config.quality_weight * quality
459 + self.config.latency_weight * latency
460 + self.config.cost_weight * cost
461 + local_bonus
462 }
463
464 fn latency_ms_to_score(&self, ms: f64) -> f64 {
467 (1.0 - (ms / Self::LATENCY_CEILING_MS)).clamp(0.0, 1.0)
468 }
469
470 fn tps_to_latency_ms(tps: f64) -> f64 {
472 if tps <= 0.0 { return Self::LATENCY_CEILING_MS; }
473 (200.0 / tps) * 1000.0
475 }
476
477 fn schema_quality_estimate(&self, model: &ModelSchema) -> f64 {
482 match model.size_mb() {
483 0 => 0.5, s if s < 1000 => 0.4, s if s < 2000 => 0.5, s if s < 3000 => 0.6, s if s < 6000 => 0.7, _ => 0.75, }
490 }
491
492 fn schema_latency_estimate(&self, model: &ModelSchema) -> f64 {
497 let is_moe = model.tags.contains(&"moe".to_string());
498
499 if model.is_local() {
500 if let Some(tps) = model.performance.tokens_per_second {
501 let effective_tps = if is_moe { tps * Self::MOE_TPS_MULTIPLIER } else { tps };
502 let estimated_ms = Self::tps_to_latency_ms(effective_tps);
503 return self.latency_ms_to_score(estimated_ms);
504 }
505 return 0.5; }
507
508 if let Some(p50) = model.performance.latency_p50_ms {
510 return self.latency_ms_to_score(p50 as f64);
511 }
512 0.3 }
514
515 fn select_with_exploration(
517 &self,
518 scored: &[(String, f64)],
519 tracker: &OutcomeTracker,
520 ) -> (String, RoutingStrategy) {
521 if scored.is_empty() {
522 return (String::new(), RoutingStrategy::SchemaBased);
523 }
524
525 let mut rng = rand::rng();
526
527 if rng.random::<f64>() < self.config.exploration_rate {
529 let under_tested: Vec<&str> = scored.iter()
531 .filter(|(id, _)| {
532 tracker.profile(id)
533 .map(|p| p.total_calls < self.config.min_observations)
534 .unwrap_or(true) })
536 .map(|(id, _)| id.as_str())
537 .collect();
538
539 if !under_tested.is_empty() {
540 let idx = rng.random_range(0..under_tested.len());
541 return (under_tested[idx].to_string(), RoutingStrategy::Exploration);
542 }
543 }
544
545 let best = &scored[0];
547 let strategy = if tracker.profile(&best.0)
548 .map(|p| p.total_calls >= self.config.min_observations)
549 .unwrap_or(false)
550 {
551 RoutingStrategy::ProfileBased
552 } else {
553 RoutingStrategy::SchemaBased
554 };
555
556 (best.0.clone(), strategy)
557 }
558
559 fn cold_start_decision(
561 &self,
562 complexity: TaskComplexity,
563 task: InferenceTask,
564 registry: &UnifiedRegistry,
565 ) -> AdaptiveRoutingDecision {
566 let model_name = match complexity {
568 TaskComplexity::Simple => "Qwen3-0.6B",
569 TaskComplexity::Medium => "Qwen3-1.7B",
570 TaskComplexity::Code => "Qwen3-4B",
571 TaskComplexity::Complex => &self.hw.recommended_model,
572 };
573
574 let model_id = registry.find_by_name(model_name)
575 .map(|m| m.id.clone())
576 .unwrap_or_else(|| model_name.to_string());
577
578 AdaptiveRoutingDecision {
579 model_id,
580 model_name: model_name.to_string(),
581 task,
582 complexity,
583 reason: format!("{:?} task → {} (cold start, no candidates)", complexity, model_name),
584 strategy: RoutingStrategy::SchemaBased,
585 predicted_quality: 0.5,
586 fallbacks: vec![],
587 }
588 }
589}
590
591#[cfg(test)]
592mod tests {
593 use super::*;
594 use crate::outcome::InferredOutcome;
595
596 fn test_hw() -> HardwareInfo {
597 HardwareInfo {
598 os: "macos".into(),
599 arch: "aarch64".into(),
600 cpu_cores: 10,
601 total_ram_mb: 32768,
602 gpu_backend: crate::hardware::GpuBackend::Metal,
603 gpu_memory_mb: Some(28672),
604 recommended_model: "Qwen3-8B".into(),
605 recommended_context: 8192,
606 max_model_mb: 18000, }
608 }
609
610 fn test_registry() -> UnifiedRegistry {
611 let tmp = std::path::PathBuf::from("/tmp/car-test-adaptive-router");
612 for name in &["Qwen3-0.6B", "Qwen3-1.7B", "Qwen3-4B", "Qwen3-8B", "Qwen3-Embedding-0.6B"] {
614 let dir = tmp.join(name);
615 let _ = std::fs::create_dir_all(&dir);
616 let _ = std::fs::write(dir.join("model.gguf"), b"fake");
617 let _ = std::fs::write(dir.join("tokenizer.json"), b"{}");
618 }
619 UnifiedRegistry::new(tmp)
620 }
621
622 #[test]
623 fn routes_simple_to_balanced_local() {
624 let router = AdaptiveRouter::new(test_hw(), RoutingConfig {
625 exploration_rate: 0.0,
626 ..Default::default()
627 });
628 let reg = test_registry();
629 let tracker = OutcomeTracker::new();
630
631 let decision = router.route("What is 2+2?", ®, &tracker);
632 assert_eq!(decision.complexity, TaskComplexity::Simple);
633 assert_eq!(decision.strategy, RoutingStrategy::SchemaBased);
634 let schema = reg.find_by_name(&decision.model_name);
637 assert!(schema.is_some(), "selected model should exist in registry");
638 assert!(schema.unwrap().is_local(), "simple task should route to local model");
639 }
640
641 #[test]
642 fn routes_code_to_code_capable_local() {
643 let router = AdaptiveRouter::new(test_hw(), RoutingConfig {
644 exploration_rate: 0.0,
645 ..Default::default()
646 });
647 let reg = test_registry();
648 let tracker = OutcomeTracker::new();
649
650 let decision = router.route("Fix this function:\n```rust\nfn main() {}\n```", ®, &tracker);
651 assert_eq!(decision.complexity, TaskComplexity::Code);
652 assert_eq!(decision.task, InferenceTask::Code);
653 assert_eq!(decision.model_name, "Qwen3-1.7B");
655 }
656
657 #[test]
658 fn profile_based_routing_selects_proven_model() {
659 let router = AdaptiveRouter::new(test_hw(), RoutingConfig {
660 exploration_rate: 0.0,
661 min_observations: 3,
662 ..Default::default()
663 });
664 let reg = test_registry();
665 let mut tracker = OutcomeTracker::new();
666
667 let qwen_8b_id = "qwen/qwen3-8b:q4_k_m";
669 for _ in 0..5 {
670 let trace = tracker.record_start(qwen_8b_id, InferenceTask::Code, "test");
671 tracker.record_complete(&trace, 500, 100, 50); tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.95 });
673 }
674
675 let decision = router.route("Fix this bug in the parser", ®, &tracker);
676 assert_eq!(decision.complexity, TaskComplexity::Code);
677 assert_eq!(decision.model_name, "Qwen3-8B");
680 assert_eq!(decision.strategy, RoutingStrategy::ProfileBased);
681 }
682
683 #[test]
684 fn fallback_chain_has_alternatives() {
685 let router = AdaptiveRouter::new(test_hw(), RoutingConfig {
686 exploration_rate: 0.0,
687 ..Default::default()
688 });
689 let reg = test_registry();
690 let tracker = OutcomeTracker::new();
691
692 let decision = router.route("Analyze the architecture trade-offs", ®, &tracker);
693 assert!(!decision.fallbacks.is_empty());
694 assert!(!decision.fallbacks.contains(&decision.model_id));
696 }
697
698 #[test]
699 fn latency_scoring_is_consistent() {
700 let router = AdaptiveRouter::with_default_config(test_hw());
702
703 let schema_score = router.latency_ms_to_score(AdaptiveRouter::tps_to_latency_ms(25.0));
705 let observed_score = router.latency_ms_to_score(8000.0);
707 assert!((schema_score - observed_score).abs() < 0.01,
708 "schema ({schema_score}) and observed ({observed_score}) should match");
709 }
710
711 #[test]
712 fn complexity_assessment() {
713 assert_eq!(TaskComplexity::assess("What is the capital of France?"), TaskComplexity::Simple);
714 assert_eq!(TaskComplexity::assess("Fix this broken test"), TaskComplexity::Code);
715 assert_eq!(TaskComplexity::assess("Analyze the trade-offs between A and B"), TaskComplexity::Complex);
716 }
717}