1use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use uuid::Uuid;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
19#[serde(rename_all = "snake_case")]
20pub enum TaskType {
21 Chat,
23 Coding,
25 CodeReview,
27 Debugging,
29 Writing,
31 Creative,
33 Math,
35 Analysis,
37 Research,
39 Translation,
41 Summarization,
43 QuestionAnswering,
45 Vision,
47 Reasoning,
49 Quick,
51}
52
53impl TaskType {
54 pub fn complexity_weight(&self) -> u8 {
56 match self {
57 TaskType::Quick => 1,
58 TaskType::Chat => 2,
59 TaskType::QuestionAnswering => 3,
60 TaskType::Translation => 4,
61 TaskType::Summarization => 4,
62 TaskType::Writing => 5,
63 TaskType::Coding => 6,
64 TaskType::CodeReview => 6,
65 TaskType::Creative => 6,
66 TaskType::Analysis => 7,
67 TaskType::Debugging => 7,
68 TaskType::Research => 7,
69 TaskType::Math => 8,
70 TaskType::Vision => 8,
71 TaskType::Reasoning => 9,
72 }
73 }
74
75 pub fn detect(content: &str) -> Self {
77 let lower = content.to_lowercase();
78
79 if lower.contains("```") || lower.contains("code") || lower.contains("function")
81 || lower.contains("class") || lower.contains("implement") {
82 if lower.contains("review") || lower.contains("check") {
83 return TaskType::CodeReview;
84 }
85 if lower.contains("bug") || lower.contains("fix") || lower.contains("error")
86 || lower.contains("debug") {
87 return TaskType::Debugging;
88 }
89 return TaskType::Coding;
90 }
91
92 if lower.contains("calculate") || lower.contains("equation") || lower.contains("solve")
94 || lower.contains("math") || lower.contains("formula") {
95 return TaskType::Math;
96 }
97
98 if lower.contains("analyze") || lower.contains("analysis") || lower.contains("data")
100 || lower.contains("statistics") || lower.contains("trend") {
101 return TaskType::Analysis;
102 }
103
104 if lower.contains("research") || lower.contains("find out") || lower.contains("look up")
106 || lower.contains("search for") {
107 return TaskType::Research;
108 }
109
110 if lower.contains("write") || lower.contains("draft") || lower.contains("compose")
112 || lower.contains("edit") {
113 if lower.contains("creative") || lower.contains("story") || lower.contains("poem") {
114 return TaskType::Creative;
115 }
116 return TaskType::Writing;
117 }
118
119 if lower.contains("translate") || lower.contains("translation") {
121 return TaskType::Translation;
122 }
123
124 if lower.contains("summarize") || lower.contains("summary") || lower.contains("tldr") {
126 return TaskType::Summarization;
127 }
128
129 if lower.contains("why") || lower.contains("reason") || lower.contains("explain")
131 || lower.contains("logic") {
132 return TaskType::Reasoning;
133 }
134
135 if lower.contains("image") || lower.contains("picture") || lower.contains("photo")
137 || lower.contains("see") || lower.contains("look at") {
138 return TaskType::Vision;
139 }
140
141 if lower.ends_with('?') || lower.starts_with("what") || lower.starts_with("how")
143 || lower.starts_with("when") || lower.starts_with("where") {
144 return TaskType::QuestionAnswering;
145 }
146
147 if content.len() < 50 {
149 return TaskType::Quick;
150 }
151
152 TaskType::Chat
153 }
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct ModelCapabilities {
163 pub model_id: String,
165 pub provider: String,
167 pub name: String,
169 pub task_scores: HashMap<TaskType, f64>,
171 pub context_window: usize,
173 pub supports_vision: bool,
175 pub supports_functions: bool,
177 pub supports_streaming: bool,
179 pub cost_per_1k_input: f64,
181 pub cost_per_1k_output: f64,
183 pub avg_latency_ms: u32,
185 pub available: bool,
187}
188
189impl ModelCapabilities {
190 pub fn new(model_id: &str, provider: &str, name: &str) -> Self {
192 Self {
193 model_id: model_id.to_string(),
194 provider: provider.to_string(),
195 name: name.to_string(),
196 task_scores: HashMap::new(),
197 context_window: 4096,
198 supports_vision: false,
199 supports_functions: false,
200 supports_streaming: true,
201 cost_per_1k_input: 0.0,
202 cost_per_1k_output: 0.0,
203 avg_latency_ms: 1000,
204 available: true,
205 }
206 }
207
208 pub fn with_task_score(mut self, task: TaskType, score: f64) -> Self {
210 self.task_scores.insert(task, score.clamp(0.0, 1.0));
211 self
212 }
213
214 pub fn with_context_window(mut self, size: usize) -> Self {
216 self.context_window = size;
217 self
218 }
219
220 pub fn with_vision(mut self, supports: bool) -> Self {
222 self.supports_vision = supports;
223 self
224 }
225
226 pub fn with_functions(mut self, supports: bool) -> Self {
228 self.supports_functions = supports;
229 self
230 }
231
232 pub fn with_cost(mut self, input: f64, output: f64) -> Self {
234 self.cost_per_1k_input = input;
235 self.cost_per_1k_output = output;
236 self
237 }
238
239 pub fn with_latency(mut self, ms: u32) -> Self {
241 self.avg_latency_ms = ms;
242 self
243 }
244
245 pub fn score_for_task(&self, task: TaskType) -> f64 {
247 self.task_scores.get(&task).copied().unwrap_or(0.5)
248 }
249}
250
251#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
257#[serde(rename_all = "snake_case")]
258pub enum RoutingStrategy {
259 BestQuality,
261 LowestCost,
263 FastestResponse,
265 Balanced,
267 Custom,
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
273#[derive(Default)]
274pub struct RoutingConstraints {
275 pub max_cost: Option<f64>,
277 pub max_latency_ms: Option<u32>,
279 pub min_context_window: Option<usize>,
281 pub allowed_providers: Option<Vec<String>>,
283 pub blocked_providers: Vec<String>,
285 pub require_vision: bool,
287 pub require_functions: bool,
289}
290
291
292#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct RoutingConfig {
295 pub strategy: RoutingStrategy,
297 pub constraints: RoutingConstraints,
299 pub quality_weight: f64,
301 pub cost_weight: f64,
303 pub latency_weight: f64,
305 pub fallback_model: Option<String>,
307}
308
309impl Default for RoutingConfig {
310 fn default() -> Self {
311 Self {
312 strategy: RoutingStrategy::Balanced,
313 constraints: RoutingConstraints::default(),
314 quality_weight: 0.5,
315 cost_weight: 0.3,
316 latency_weight: 0.2,
317 fallback_model: None,
318 }
319 }
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
328pub struct RoutingRequest {
329 pub id: Uuid,
331 pub content: String,
333 pub context: Vec<String>,
335 pub estimated_tokens: usize,
337 pub config: RoutingConfig,
339 pub timestamp: DateTime<Utc>,
341}
342
343#[derive(Debug, Clone, Serialize, Deserialize)]
345pub struct RoutingDecision {
346 pub request_id: Uuid,
348 pub model_id: String,
350 pub provider: String,
352 pub task_type: TaskType,
354 pub confidence: f64,
356 pub estimated_cost: f64,
358 pub estimated_latency_ms: u32,
360 pub alternatives: Vec<ModelScore>,
362 pub reasoning: String,
364 pub decided_at: DateTime<Utc>,
366}
367
368#[derive(Debug, Clone, Serialize, Deserialize)]
370pub struct ModelScore {
371 pub model_id: String,
373 pub provider: String,
375 pub quality_score: f64,
377 pub cost_score: f64,
379 pub latency_score: f64,
381 pub total_score: f64,
383 pub rejection_reason: Option<String>,
385}
386
387pub struct ModelRouter {
393 models: Vec<ModelCapabilities>,
395 default_config: RoutingConfig,
397 history: Vec<RoutingDecision>,
399}
400
401impl ModelRouter {
402 pub fn new() -> Self {
404 Self {
405 models: Self::default_models(),
406 default_config: RoutingConfig::default(),
407 history: vec![],
408 }
409 }
410
411 pub fn with_models(models: Vec<ModelCapabilities>) -> Self {
413 Self {
414 models,
415 default_config: RoutingConfig::default(),
416 history: vec![],
417 }
418 }
419
420 pub fn add_model(&mut self, model: ModelCapabilities) {
422 self.models.push(model);
423 }
424
425 pub fn route(&mut self, request: &RoutingRequest) -> RoutingDecision {
427 let task_type = TaskType::detect(&request.content);
429
430 let mut scores: Vec<ModelScore> = self
432 .models
433 .iter()
434 .filter(|m| m.available)
435 .filter(|m| self.meets_constraints(m, &request.config.constraints))
436 .map(|m| self.score_model(m, task_type, request))
437 .collect();
438
439 scores.sort_by(|a, b| b.total_score.partial_cmp(&a.total_score).unwrap());
441
442 let selected = scores.first().cloned().unwrap_or_else(|| {
444 ModelScore {
446 model_id: request.config.fallback_model.clone()
447 .unwrap_or_else(|| "gpt-4o-mini".to_string()),
448 provider: "openai".to_string(),
449 quality_score: 0.5,
450 cost_score: 0.5,
451 latency_score: 0.5,
452 total_score: 0.5,
453 rejection_reason: None,
454 }
455 });
456
457 let decision = RoutingDecision {
458 request_id: request.id,
459 model_id: selected.model_id.clone(),
460 provider: selected.provider.clone(),
461 task_type,
462 confidence: selected.total_score,
463 estimated_cost: self.estimate_cost(&selected.model_id, request.estimated_tokens),
464 estimated_latency_ms: self.estimate_latency(&selected.model_id),
465 alternatives: scores.into_iter().skip(1).take(3).collect(),
466 reasoning: self.generate_reasoning(&selected, task_type),
467 decided_at: Utc::now(),
468 };
469
470 self.history.push(decision.clone());
472
473 decision
474 }
475
476 fn meets_constraints(&self, model: &ModelCapabilities, constraints: &RoutingConstraints) -> bool {
478 if let Some(max_cost) = constraints.max_cost {
480 if model.cost_per_1k_output > max_cost * 10.0 {
481 return false;
482 }
483 }
484
485 if let Some(max_latency) = constraints.max_latency_ms {
487 if model.avg_latency_ms > max_latency {
488 return false;
489 }
490 }
491
492 if let Some(min_context) = constraints.min_context_window {
494 if model.context_window < min_context {
495 return false;
496 }
497 }
498
499 if let Some(ref allowed) = constraints.allowed_providers {
501 if !allowed.contains(&model.provider) {
502 return false;
503 }
504 }
505
506 if constraints.blocked_providers.contains(&model.provider) {
508 return false;
509 }
510
511 if constraints.require_vision && !model.supports_vision {
513 return false;
514 }
515
516 if constraints.require_functions && !model.supports_functions {
518 return false;
519 }
520
521 true
522 }
523
524 fn score_model(&self, model: &ModelCapabilities, task: TaskType, request: &RoutingRequest) -> ModelScore {
526 let config = &request.config;
527
528 let quality_score = model.score_for_task(task);
530
531 let max_cost = 0.1; let cost_score = 1.0 - (model.cost_per_1k_output / max_cost).min(1.0);
534
535 let max_latency = 5000.0; let latency_score = 1.0 - (model.avg_latency_ms as f64 / max_latency).min(1.0);
538
539 let total_score = match config.strategy {
541 RoutingStrategy::BestQuality => quality_score,
542 RoutingStrategy::LowestCost => cost_score,
543 RoutingStrategy::FastestResponse => latency_score,
544 RoutingStrategy::Balanced => {
545 (quality_score + cost_score + latency_score) / 3.0
546 }
547 RoutingStrategy::Custom => {
548 config.quality_weight * quality_score
549 + config.cost_weight * cost_score
550 + config.latency_weight * latency_score
551 }
552 };
553
554 ModelScore {
555 model_id: model.model_id.clone(),
556 provider: model.provider.clone(),
557 quality_score,
558 cost_score,
559 latency_score,
560 total_score,
561 rejection_reason: None,
562 }
563 }
564
565 fn estimate_cost(&self, model_id: &str, tokens: usize) -> f64 {
567 self.models
568 .iter()
569 .find(|m| m.model_id == model_id)
570 .map(|m| (tokens as f64 / 1000.0) * (m.cost_per_1k_input + m.cost_per_1k_output))
571 .unwrap_or(0.0)
572 }
573
574 fn estimate_latency(&self, model_id: &str) -> u32 {
576 self.models
577 .iter()
578 .find(|m| m.model_id == model_id)
579 .map(|m| m.avg_latency_ms)
580 .unwrap_or(1000)
581 }
582
583 fn generate_reasoning(&self, selected: &ModelScore, task: TaskType) -> String {
585 format!(
586 "Selected {} for {:?} task. Quality: {:.0}%, Cost efficiency: {:.0}%, Speed: {:.0}%",
587 selected.model_id,
588 task,
589 selected.quality_score * 100.0,
590 selected.cost_score * 100.0,
591 selected.latency_score * 100.0
592 )
593 }
594
595 fn default_models() -> Vec<ModelCapabilities> {
597 vec![
598 ModelCapabilities::new("gpt-4o", "openai", "GPT-4o")
600 .with_context_window(128000)
601 .with_vision(true)
602 .with_functions(true)
603 .with_cost(0.005, 0.015)
604 .with_latency(800)
605 .with_task_score(TaskType::Coding, 0.95)
606 .with_task_score(TaskType::Reasoning, 0.95)
607 .with_task_score(TaskType::Vision, 0.90)
608 .with_task_score(TaskType::Writing, 0.90)
609 .with_task_score(TaskType::Analysis, 0.90),
610
611 ModelCapabilities::new("gpt-4o-mini", "openai", "GPT-4o Mini")
612 .with_context_window(128000)
613 .with_vision(true)
614 .with_functions(true)
615 .with_cost(0.00015, 0.0006)
616 .with_latency(500)
617 .with_task_score(TaskType::Chat, 0.85)
618 .with_task_score(TaskType::Quick, 0.90)
619 .with_task_score(TaskType::QuestionAnswering, 0.85)
620 .with_task_score(TaskType::Coding, 0.80),
621
622 ModelCapabilities::new("o1", "openai", "o1")
623 .with_context_window(200000)
624 .with_vision(true)
625 .with_functions(false)
626 .with_cost(0.015, 0.06)
627 .with_latency(3000)
628 .with_task_score(TaskType::Reasoning, 0.99)
629 .with_task_score(TaskType::Math, 0.98)
630 .with_task_score(TaskType::Coding, 0.97)
631 .with_task_score(TaskType::Analysis, 0.95),
632
633 ModelCapabilities::new("claude-sonnet-4-20250514", "anthropic", "Claude Sonnet 4")
635 .with_context_window(200000)
636 .with_vision(true)
637 .with_functions(true)
638 .with_cost(0.003, 0.015)
639 .with_latency(700)
640 .with_task_score(TaskType::Coding, 0.95)
641 .with_task_score(TaskType::Writing, 0.95)
642 .with_task_score(TaskType::Reasoning, 0.92)
643 .with_task_score(TaskType::Analysis, 0.90),
644
645 ModelCapabilities::new("claude-3-5-haiku-20241022", "anthropic", "Claude 3.5 Haiku")
646 .with_context_window(200000)
647 .with_vision(true)
648 .with_functions(true)
649 .with_cost(0.0008, 0.004)
650 .with_latency(400)
651 .with_task_score(TaskType::Chat, 0.85)
652 .with_task_score(TaskType::Quick, 0.90)
653 .with_task_score(TaskType::Coding, 0.80),
654
655 ModelCapabilities::new("gemini-2.5-flash", "google", "Gemini 2.5 Flash")
657 .with_context_window(1000000)
658 .with_vision(true)
659 .with_functions(true)
660 .with_cost(0.000075, 0.0003)
661 .with_latency(300)
662 .with_task_score(TaskType::Chat, 0.85)
663 .with_task_score(TaskType::Quick, 0.95)
664 .with_task_score(TaskType::Coding, 0.85)
665 .with_task_score(TaskType::Analysis, 0.85),
666
667 ModelCapabilities::new("gemini-2.5-pro", "google", "Gemini 2.5 Pro")
668 .with_context_window(1000000)
669 .with_vision(true)
670 .with_functions(true)
671 .with_cost(0.00125, 0.005)
672 .with_latency(600)
673 .with_task_score(TaskType::Coding, 0.92)
674 .with_task_score(TaskType::Reasoning, 0.90)
675 .with_task_score(TaskType::Analysis, 0.90)
676 .with_task_score(TaskType::Research, 0.90),
677
678 ModelCapabilities::new("llama3.3:70b", "ollama", "Llama 3.3 70B")
680 .with_context_window(128000)
681 .with_vision(false)
682 .with_functions(true)
683 .with_cost(0.0, 0.0)
684 .with_latency(2000)
685 .with_task_score(TaskType::Chat, 0.80)
686 .with_task_score(TaskType::Coding, 0.75)
687 .with_task_score(TaskType::Writing, 0.80),
688
689 ModelCapabilities::new("qwen2.5-coder:32b", "ollama", "Qwen 2.5 Coder 32B")
690 .with_context_window(32000)
691 .with_vision(false)
692 .with_functions(false)
693 .with_cost(0.0, 0.0)
694 .with_latency(1500)
695 .with_task_score(TaskType::Coding, 0.85)
696 .with_task_score(TaskType::CodeReview, 0.85)
697 .with_task_score(TaskType::Debugging, 0.80),
698 ]
699 }
700
701 pub fn stats(&self) -> RouterStats {
703 let mut task_counts: HashMap<TaskType, usize> = HashMap::new();
704 let mut model_counts: HashMap<String, usize> = HashMap::new();
705 let mut total_cost = 0.0;
706
707 for decision in &self.history {
708 *task_counts.entry(decision.task_type).or_insert(0) += 1;
709 *model_counts.entry(decision.model_id.clone()).or_insert(0) += 1;
710 total_cost += decision.estimated_cost;
711 }
712
713 RouterStats {
714 total_requests: self.history.len(),
715 task_distribution: task_counts,
716 model_distribution: model_counts,
717 total_estimated_cost: total_cost,
718 avg_confidence: self.history.iter().map(|d| d.confidence).sum::<f64>()
719 / self.history.len().max(1) as f64,
720 }
721 }
722}
723
724impl Default for ModelRouter {
725 fn default() -> Self {
726 Self::new()
727 }
728}
729
730#[derive(Debug, Clone, Serialize, Deserialize)]
732pub struct RouterStats {
733 pub total_requests: usize,
735 pub task_distribution: HashMap<TaskType, usize>,
737 pub model_distribution: HashMap<String, usize>,
739 pub total_estimated_cost: f64,
741 pub avg_confidence: f64,
743}
744
745#[cfg(test)]
746mod tests {
747 use super::*;
748
749 #[test]
750 fn test_task_detection() {
751 assert_eq!(TaskType::detect("Write a function to sort an array"), TaskType::Coding);
752 assert_eq!(TaskType::detect("Review this code for bugs"), TaskType::CodeReview);
753 assert_eq!(TaskType::detect("Calculate 2 + 2"), TaskType::Math);
754 assert_eq!(TaskType::detect("Translate this to Spanish"), TaskType::Translation);
755 assert_eq!(TaskType::detect("What is the weather?"), TaskType::QuestionAnswering);
756 assert_eq!(TaskType::detect("Hi"), TaskType::Quick);
757 }
758
759 #[test]
760 fn test_routing_decision() {
761 let mut router = ModelRouter::new();
762 let request = RoutingRequest {
763 id: Uuid::new_v4(),
764 content: "Write a Python function to parse JSON".to_string(),
765 context: vec![],
766 estimated_tokens: 500,
767 config: RoutingConfig::default(),
768 timestamp: Utc::now(),
769 };
770
771 let decision = router.route(&request);
772 assert_eq!(decision.task_type, TaskType::Coding);
773 assert!(decision.confidence > 0.0);
774 assert!(!decision.model_id.is_empty());
775 }
776
777 #[test]
778 fn test_constraints() {
779 let mut router = ModelRouter::new();
780 let mut config = RoutingConfig::default();
781 config.constraints.max_cost = Some(0.001);
782 config.constraints.allowed_providers = Some(vec!["google".to_string()]);
783
784 let request = RoutingRequest {
785 id: Uuid::new_v4(),
786 content: "Quick question".to_string(),
787 context: vec![],
788 estimated_tokens: 100,
789 config,
790 timestamp: Utc::now(),
791 };
792
793 let decision = router.route(&request);
794 assert_eq!(decision.provider, "google");
795 }
796}