1use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use uuid::Uuid;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
19#[serde(rename_all = "snake_case")]
20pub enum TaskType {
21 Chat,
23 Coding,
25 CodeReview,
27 Debugging,
29 Writing,
31 Creative,
33 Math,
35 Analysis,
37 Research,
39 Translation,
41 Summarization,
43 QuestionAnswering,
45 Vision,
47 Reasoning,
49 Quick,
51}
52
53impl TaskType {
54 pub fn complexity_weight(&self) -> u8 {
56 match self {
57 TaskType::Quick => 1,
58 TaskType::Chat => 2,
59 TaskType::QuestionAnswering => 3,
60 TaskType::Translation => 4,
61 TaskType::Summarization => 4,
62 TaskType::Writing => 5,
63 TaskType::Coding => 6,
64 TaskType::CodeReview => 6,
65 TaskType::Creative => 6,
66 TaskType::Analysis => 7,
67 TaskType::Debugging => 7,
68 TaskType::Research => 7,
69 TaskType::Math => 8,
70 TaskType::Vision => 8,
71 TaskType::Reasoning => 9,
72 }
73 }
74
75 pub fn detect(content: &str) -> Self {
77 let lower = content.to_lowercase();
78
79 if lower.contains("```") || lower.contains("code") || lower.contains("function")
81 || lower.contains("class") || lower.contains("implement") {
82 if lower.contains("review") || lower.contains("check") {
83 return TaskType::CodeReview;
84 }
85 if lower.contains("bug") || lower.contains("fix") || lower.contains("error")
86 || lower.contains("debug") {
87 return TaskType::Debugging;
88 }
89 return TaskType::Coding;
90 }
91
92 if lower.contains("calculate") || lower.contains("equation") || lower.contains("solve")
94 || lower.contains("math") || lower.contains("formula") {
95 return TaskType::Math;
96 }
97
98 if lower.contains("analyze") || lower.contains("analysis") || lower.contains("data")
100 || lower.contains("statistics") || lower.contains("trend") {
101 return TaskType::Analysis;
102 }
103
104 if lower.contains("research") || lower.contains("find out") || lower.contains("look up")
106 || lower.contains("search for") {
107 return TaskType::Research;
108 }
109
110 if lower.contains("write") || lower.contains("draft") || lower.contains("compose")
112 || lower.contains("edit") {
113 if lower.contains("creative") || lower.contains("story") || lower.contains("poem") {
114 return TaskType::Creative;
115 }
116 return TaskType::Writing;
117 }
118
119 if lower.contains("translate") || lower.contains("translation") {
121 return TaskType::Translation;
122 }
123
124 if lower.contains("summarize") || lower.contains("summary") || lower.contains("tldr") {
126 return TaskType::Summarization;
127 }
128
129 if lower.contains("why") || lower.contains("reason") || lower.contains("explain")
131 || lower.contains("logic") {
132 return TaskType::Reasoning;
133 }
134
135 if lower.contains("image") || lower.contains("picture") || lower.contains("photo")
137 || lower.contains("see") || lower.contains("look at") {
138 return TaskType::Vision;
139 }
140
141 if lower.ends_with('?') || lower.starts_with("what") || lower.starts_with("how")
143 || lower.starts_with("when") || lower.starts_with("where") {
144 return TaskType::QuestionAnswering;
145 }
146
147 if content.len() < 50 {
149 return TaskType::Quick;
150 }
151
152 TaskType::Chat
153 }
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct ModelCapabilities {
163 pub model_id: String,
165 pub provider: String,
167 pub name: String,
169 pub task_scores: HashMap<TaskType, f64>,
171 pub context_window: usize,
173 pub supports_vision: bool,
175 pub supports_functions: bool,
177 pub supports_streaming: bool,
179 pub cost_per_1k_input: f64,
181 pub cost_per_1k_output: f64,
183 pub avg_latency_ms: u32,
185 pub available: bool,
187}
188
189impl ModelCapabilities {
190 pub fn new(model_id: &str, provider: &str, name: &str) -> Self {
192 Self {
193 model_id: model_id.to_string(),
194 provider: provider.to_string(),
195 name: name.to_string(),
196 task_scores: HashMap::new(),
197 context_window: 4096,
198 supports_vision: false,
199 supports_functions: false,
200 supports_streaming: true,
201 cost_per_1k_input: 0.0,
202 cost_per_1k_output: 0.0,
203 avg_latency_ms: 1000,
204 available: true,
205 }
206 }
207
208 pub fn with_task_score(mut self, task: TaskType, score: f64) -> Self {
210 self.task_scores.insert(task, score.clamp(0.0, 1.0));
211 self
212 }
213
214 pub fn with_context_window(mut self, size: usize) -> Self {
216 self.context_window = size;
217 self
218 }
219
220 pub fn with_vision(mut self, supports: bool) -> Self {
222 self.supports_vision = supports;
223 self
224 }
225
226 pub fn with_functions(mut self, supports: bool) -> Self {
228 self.supports_functions = supports;
229 self
230 }
231
232 pub fn with_cost(mut self, input: f64, output: f64) -> Self {
234 self.cost_per_1k_input = input;
235 self.cost_per_1k_output = output;
236 self
237 }
238
239 pub fn with_latency(mut self, ms: u32) -> Self {
241 self.avg_latency_ms = ms;
242 self
243 }
244
245 pub fn score_for_task(&self, task: TaskType) -> f64 {
247 self.task_scores.get(&task).copied().unwrap_or(0.5)
248 }
249}
250
251#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
257#[serde(rename_all = "snake_case")]
258pub enum RoutingStrategy {
259 BestQuality,
261 LowestCost,
263 FastestResponse,
265 Balanced,
267 Custom,
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
273pub struct RoutingConstraints {
274 pub max_cost: Option<f64>,
276 pub max_latency_ms: Option<u32>,
278 pub min_context_window: Option<usize>,
280 pub allowed_providers: Option<Vec<String>>,
282 pub blocked_providers: Vec<String>,
284 pub require_vision: bool,
286 pub require_functions: bool,
288}
289
290impl Default for RoutingConstraints {
291 fn default() -> Self {
292 Self {
293 max_cost: None,
294 max_latency_ms: None,
295 min_context_window: None,
296 allowed_providers: None,
297 blocked_providers: vec![],
298 require_vision: false,
299 require_functions: false,
300 }
301 }
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize)]
306pub struct RoutingConfig {
307 pub strategy: RoutingStrategy,
309 pub constraints: RoutingConstraints,
311 pub quality_weight: f64,
313 pub cost_weight: f64,
315 pub latency_weight: f64,
317 pub fallback_model: Option<String>,
319}
320
321impl Default for RoutingConfig {
322 fn default() -> Self {
323 Self {
324 strategy: RoutingStrategy::Balanced,
325 constraints: RoutingConstraints::default(),
326 quality_weight: 0.5,
327 cost_weight: 0.3,
328 latency_weight: 0.2,
329 fallback_model: None,
330 }
331 }
332}
333
334#[derive(Debug, Clone, Serialize, Deserialize)]
340pub struct RoutingRequest {
341 pub id: Uuid,
343 pub content: String,
345 pub context: Vec<String>,
347 pub estimated_tokens: usize,
349 pub config: RoutingConfig,
351 pub timestamp: DateTime<Utc>,
353}
354
355#[derive(Debug, Clone, Serialize, Deserialize)]
357pub struct RoutingDecision {
358 pub request_id: Uuid,
360 pub model_id: String,
362 pub provider: String,
364 pub task_type: TaskType,
366 pub confidence: f64,
368 pub estimated_cost: f64,
370 pub estimated_latency_ms: u32,
372 pub alternatives: Vec<ModelScore>,
374 pub reasoning: String,
376 pub decided_at: DateTime<Utc>,
378}
379
380#[derive(Debug, Clone, Serialize, Deserialize)]
382pub struct ModelScore {
383 pub model_id: String,
385 pub provider: String,
387 pub quality_score: f64,
389 pub cost_score: f64,
391 pub latency_score: f64,
393 pub total_score: f64,
395 pub rejection_reason: Option<String>,
397}
398
399pub struct ModelRouter {
405 models: Vec<ModelCapabilities>,
407 default_config: RoutingConfig,
409 history: Vec<RoutingDecision>,
411}
412
413impl ModelRouter {
414 pub fn new() -> Self {
416 Self {
417 models: Self::default_models(),
418 default_config: RoutingConfig::default(),
419 history: vec![],
420 }
421 }
422
423 pub fn with_models(models: Vec<ModelCapabilities>) -> Self {
425 Self {
426 models,
427 default_config: RoutingConfig::default(),
428 history: vec![],
429 }
430 }
431
432 pub fn add_model(&mut self, model: ModelCapabilities) {
434 self.models.push(model);
435 }
436
437 pub fn route(&mut self, request: &RoutingRequest) -> RoutingDecision {
439 let task_type = TaskType::detect(&request.content);
441
442 let mut scores: Vec<ModelScore> = self
444 .models
445 .iter()
446 .filter(|m| m.available)
447 .filter(|m| self.meets_constraints(m, &request.config.constraints))
448 .map(|m| self.score_model(m, task_type, request))
449 .collect();
450
451 scores.sort_by(|a, b| b.total_score.partial_cmp(&a.total_score).unwrap());
453
454 let selected = scores.first().cloned().unwrap_or_else(|| {
456 ModelScore {
458 model_id: request.config.fallback_model.clone()
459 .unwrap_or_else(|| "gpt-4o-mini".to_string()),
460 provider: "openai".to_string(),
461 quality_score: 0.5,
462 cost_score: 0.5,
463 latency_score: 0.5,
464 total_score: 0.5,
465 rejection_reason: None,
466 }
467 });
468
469 let decision = RoutingDecision {
470 request_id: request.id,
471 model_id: selected.model_id.clone(),
472 provider: selected.provider.clone(),
473 task_type,
474 confidence: selected.total_score,
475 estimated_cost: self.estimate_cost(&selected.model_id, request.estimated_tokens),
476 estimated_latency_ms: self.estimate_latency(&selected.model_id),
477 alternatives: scores.into_iter().skip(1).take(3).collect(),
478 reasoning: self.generate_reasoning(&selected, task_type),
479 decided_at: Utc::now(),
480 };
481
482 self.history.push(decision.clone());
484
485 decision
486 }
487
488 fn meets_constraints(&self, model: &ModelCapabilities, constraints: &RoutingConstraints) -> bool {
490 if let Some(max_cost) = constraints.max_cost {
492 if model.cost_per_1k_output > max_cost * 10.0 {
493 return false;
494 }
495 }
496
497 if let Some(max_latency) = constraints.max_latency_ms {
499 if model.avg_latency_ms > max_latency {
500 return false;
501 }
502 }
503
504 if let Some(min_context) = constraints.min_context_window {
506 if model.context_window < min_context {
507 return false;
508 }
509 }
510
511 if let Some(ref allowed) = constraints.allowed_providers {
513 if !allowed.contains(&model.provider) {
514 return false;
515 }
516 }
517
518 if constraints.blocked_providers.contains(&model.provider) {
520 return false;
521 }
522
523 if constraints.require_vision && !model.supports_vision {
525 return false;
526 }
527
528 if constraints.require_functions && !model.supports_functions {
530 return false;
531 }
532
533 true
534 }
535
536 fn score_model(&self, model: &ModelCapabilities, task: TaskType, request: &RoutingRequest) -> ModelScore {
538 let config = &request.config;
539
540 let quality_score = model.score_for_task(task);
542
543 let max_cost = 0.1; let cost_score = 1.0 - (model.cost_per_1k_output / max_cost).min(1.0);
546
547 let max_latency = 5000.0; let latency_score = 1.0 - (model.avg_latency_ms as f64 / max_latency).min(1.0);
550
551 let total_score = match config.strategy {
553 RoutingStrategy::BestQuality => quality_score,
554 RoutingStrategy::LowestCost => cost_score,
555 RoutingStrategy::FastestResponse => latency_score,
556 RoutingStrategy::Balanced => {
557 (quality_score + cost_score + latency_score) / 3.0
558 }
559 RoutingStrategy::Custom => {
560 config.quality_weight * quality_score
561 + config.cost_weight * cost_score
562 + config.latency_weight * latency_score
563 }
564 };
565
566 ModelScore {
567 model_id: model.model_id.clone(),
568 provider: model.provider.clone(),
569 quality_score,
570 cost_score,
571 latency_score,
572 total_score,
573 rejection_reason: None,
574 }
575 }
576
577 fn estimate_cost(&self, model_id: &str, tokens: usize) -> f64 {
579 self.models
580 .iter()
581 .find(|m| m.model_id == model_id)
582 .map(|m| (tokens as f64 / 1000.0) * (m.cost_per_1k_input + m.cost_per_1k_output))
583 .unwrap_or(0.0)
584 }
585
586 fn estimate_latency(&self, model_id: &str) -> u32 {
588 self.models
589 .iter()
590 .find(|m| m.model_id == model_id)
591 .map(|m| m.avg_latency_ms)
592 .unwrap_or(1000)
593 }
594
595 fn generate_reasoning(&self, selected: &ModelScore, task: TaskType) -> String {
597 format!(
598 "Selected {} for {:?} task. Quality: {:.0}%, Cost efficiency: {:.0}%, Speed: {:.0}%",
599 selected.model_id,
600 task,
601 selected.quality_score * 100.0,
602 selected.cost_score * 100.0,
603 selected.latency_score * 100.0
604 )
605 }
606
607 fn default_models() -> Vec<ModelCapabilities> {
609 vec![
610 ModelCapabilities::new("gpt-4o", "openai", "GPT-4o")
612 .with_context_window(128000)
613 .with_vision(true)
614 .with_functions(true)
615 .with_cost(0.005, 0.015)
616 .with_latency(800)
617 .with_task_score(TaskType::Coding, 0.95)
618 .with_task_score(TaskType::Reasoning, 0.95)
619 .with_task_score(TaskType::Vision, 0.90)
620 .with_task_score(TaskType::Writing, 0.90)
621 .with_task_score(TaskType::Analysis, 0.90),
622
623 ModelCapabilities::new("gpt-4o-mini", "openai", "GPT-4o Mini")
624 .with_context_window(128000)
625 .with_vision(true)
626 .with_functions(true)
627 .with_cost(0.00015, 0.0006)
628 .with_latency(500)
629 .with_task_score(TaskType::Chat, 0.85)
630 .with_task_score(TaskType::Quick, 0.90)
631 .with_task_score(TaskType::QuestionAnswering, 0.85)
632 .with_task_score(TaskType::Coding, 0.80),
633
634 ModelCapabilities::new("o1", "openai", "o1")
635 .with_context_window(200000)
636 .with_vision(true)
637 .with_functions(false)
638 .with_cost(0.015, 0.06)
639 .with_latency(3000)
640 .with_task_score(TaskType::Reasoning, 0.99)
641 .with_task_score(TaskType::Math, 0.98)
642 .with_task_score(TaskType::Coding, 0.97)
643 .with_task_score(TaskType::Analysis, 0.95),
644
645 ModelCapabilities::new("claude-sonnet-4-20250514", "anthropic", "Claude Sonnet 4")
647 .with_context_window(200000)
648 .with_vision(true)
649 .with_functions(true)
650 .with_cost(0.003, 0.015)
651 .with_latency(700)
652 .with_task_score(TaskType::Coding, 0.95)
653 .with_task_score(TaskType::Writing, 0.95)
654 .with_task_score(TaskType::Reasoning, 0.92)
655 .with_task_score(TaskType::Analysis, 0.90),
656
657 ModelCapabilities::new("claude-3-5-haiku-20241022", "anthropic", "Claude 3.5 Haiku")
658 .with_context_window(200000)
659 .with_vision(true)
660 .with_functions(true)
661 .with_cost(0.0008, 0.004)
662 .with_latency(400)
663 .with_task_score(TaskType::Chat, 0.85)
664 .with_task_score(TaskType::Quick, 0.90)
665 .with_task_score(TaskType::Coding, 0.80),
666
667 ModelCapabilities::new("gemini-2.5-flash", "google", "Gemini 2.5 Flash")
669 .with_context_window(1000000)
670 .with_vision(true)
671 .with_functions(true)
672 .with_cost(0.000075, 0.0003)
673 .with_latency(300)
674 .with_task_score(TaskType::Chat, 0.85)
675 .with_task_score(TaskType::Quick, 0.95)
676 .with_task_score(TaskType::Coding, 0.85)
677 .with_task_score(TaskType::Analysis, 0.85),
678
679 ModelCapabilities::new("gemini-2.5-pro", "google", "Gemini 2.5 Pro")
680 .with_context_window(1000000)
681 .with_vision(true)
682 .with_functions(true)
683 .with_cost(0.00125, 0.005)
684 .with_latency(600)
685 .with_task_score(TaskType::Coding, 0.92)
686 .with_task_score(TaskType::Reasoning, 0.90)
687 .with_task_score(TaskType::Analysis, 0.90)
688 .with_task_score(TaskType::Research, 0.90),
689
690 ModelCapabilities::new("llama3.3:70b", "ollama", "Llama 3.3 70B")
692 .with_context_window(128000)
693 .with_vision(false)
694 .with_functions(true)
695 .with_cost(0.0, 0.0)
696 .with_latency(2000)
697 .with_task_score(TaskType::Chat, 0.80)
698 .with_task_score(TaskType::Coding, 0.75)
699 .with_task_score(TaskType::Writing, 0.80),
700
701 ModelCapabilities::new("qwen2.5-coder:32b", "ollama", "Qwen 2.5 Coder 32B")
702 .with_context_window(32000)
703 .with_vision(false)
704 .with_functions(false)
705 .with_cost(0.0, 0.0)
706 .with_latency(1500)
707 .with_task_score(TaskType::Coding, 0.85)
708 .with_task_score(TaskType::CodeReview, 0.85)
709 .with_task_score(TaskType::Debugging, 0.80),
710 ]
711 }
712
713 pub fn stats(&self) -> RouterStats {
715 let mut task_counts: HashMap<TaskType, usize> = HashMap::new();
716 let mut model_counts: HashMap<String, usize> = HashMap::new();
717 let mut total_cost = 0.0;
718
719 for decision in &self.history {
720 *task_counts.entry(decision.task_type).or_insert(0) += 1;
721 *model_counts.entry(decision.model_id.clone()).or_insert(0) += 1;
722 total_cost += decision.estimated_cost;
723 }
724
725 RouterStats {
726 total_requests: self.history.len(),
727 task_distribution: task_counts,
728 model_distribution: model_counts,
729 total_estimated_cost: total_cost,
730 avg_confidence: self.history.iter().map(|d| d.confidence).sum::<f64>()
731 / self.history.len().max(1) as f64,
732 }
733 }
734}
735
736impl Default for ModelRouter {
737 fn default() -> Self {
738 Self::new()
739 }
740}
741
742#[derive(Debug, Clone, Serialize, Deserialize)]
744pub struct RouterStats {
745 pub total_requests: usize,
747 pub task_distribution: HashMap<TaskType, usize>,
749 pub model_distribution: HashMap<String, usize>,
751 pub total_estimated_cost: f64,
753 pub avg_confidence: f64,
755}
756
757#[cfg(test)]
758mod tests {
759 use super::*;
760
761 #[test]
762 fn test_task_detection() {
763 assert_eq!(TaskType::detect("Write a function to sort an array"), TaskType::Coding);
764 assert_eq!(TaskType::detect("Review this code for bugs"), TaskType::CodeReview);
765 assert_eq!(TaskType::detect("Calculate 2 + 2"), TaskType::Math);
766 assert_eq!(TaskType::detect("Translate this to Spanish"), TaskType::Translation);
767 assert_eq!(TaskType::detect("What is the weather?"), TaskType::QuestionAnswering);
768 assert_eq!(TaskType::detect("Hi"), TaskType::Quick);
769 }
770
771 #[test]
772 fn test_routing_decision() {
773 let mut router = ModelRouter::new();
774 let request = RoutingRequest {
775 id: Uuid::new_v4(),
776 content: "Write a Python function to parse JSON".to_string(),
777 context: vec![],
778 estimated_tokens: 500,
779 config: RoutingConfig::default(),
780 timestamp: Utc::now(),
781 };
782
783 let decision = router.route(&request);
784 assert_eq!(decision.task_type, TaskType::Coding);
785 assert!(decision.confidence > 0.0);
786 assert!(!decision.model_id.is_empty());
787 }
788
789 #[test]
790 fn test_constraints() {
791 let mut router = ModelRouter::new();
792 let mut config = RoutingConfig::default();
793 config.constraints.max_cost = Some(0.001);
794 config.constraints.allowed_providers = Some(vec!["google".to_string()]);
795
796 let request = RoutingRequest {
797 id: Uuid::new_v4(),
798 content: "Quick question".to_string(),
799 context: vec![],
800 estimated_tokens: 100,
801 config,
802 timestamp: Utc::now(),
803 };
804
805 let decision = router.route(&request);
806 assert_eq!(decision.provider, "google");
807 }
808}