Skip to main content

chasm_cli/routing/
model_router.rs

1// Copyright (c) 2024-2027 Nervosys LLC
2// SPDX-License-Identifier: Apache-2.0
3//! Multi-model conversation routing
4//!
5//! Routes conversations to optimal models based on task type, complexity,
6//! cost constraints, and performance requirements.
7
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use uuid::Uuid;
12
13// ============================================================================
14// Task Classification
15// ============================================================================
16
17/// Task type detected from conversation
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
19#[serde(rename_all = "snake_case")]
20pub enum TaskType {
21    /// General conversation/chat
22    Chat,
23    /// Code generation or assistance
24    Coding,
25    /// Code review and analysis
26    CodeReview,
27    /// Bug fixing and debugging
28    Debugging,
29    /// Writing and editing text
30    Writing,
31    /// Creative writing and brainstorming
32    Creative,
33    /// Mathematical reasoning
34    Math,
35    /// Data analysis
36    Analysis,
37    /// Research and information retrieval
38    Research,
39    /// Translation between languages
40    Translation,
41    /// Summarization of content
42    Summarization,
43    /// Question answering
44    QuestionAnswering,
45    /// Image understanding (multi-modal)
46    Vision,
47    /// Complex reasoning tasks
48    Reasoning,
49    /// Simple/quick queries
50    Quick,
51}
52
53impl TaskType {
54    /// Get complexity weight (1-10)
55    pub fn complexity_weight(&self) -> u8 {
56        match self {
57            TaskType::Quick => 1,
58            TaskType::Chat => 2,
59            TaskType::QuestionAnswering => 3,
60            TaskType::Translation => 4,
61            TaskType::Summarization => 4,
62            TaskType::Writing => 5,
63            TaskType::Coding => 6,
64            TaskType::CodeReview => 6,
65            TaskType::Creative => 6,
66            TaskType::Analysis => 7,
67            TaskType::Debugging => 7,
68            TaskType::Research => 7,
69            TaskType::Math => 8,
70            TaskType::Vision => 8,
71            TaskType::Reasoning => 9,
72        }
73    }
74
75    /// Detect task type from message content
76    pub fn detect(content: &str) -> Self {
77        let lower = content.to_lowercase();
78
79        // Code-related keywords
80        if lower.contains("```") || lower.contains("code") || lower.contains("function") 
81            || lower.contains("class") || lower.contains("implement") {
82            if lower.contains("review") || lower.contains("check") {
83                return TaskType::CodeReview;
84            }
85            if lower.contains("bug") || lower.contains("fix") || lower.contains("error") 
86                || lower.contains("debug") {
87                return TaskType::Debugging;
88            }
89            return TaskType::Coding;
90        }
91
92        // Math keywords
93        if lower.contains("calculate") || lower.contains("equation") || lower.contains("solve")
94            || lower.contains("math") || lower.contains("formula") {
95            return TaskType::Math;
96        }
97
98        // Analysis keywords
99        if lower.contains("analyze") || lower.contains("analysis") || lower.contains("data")
100            || lower.contains("statistics") || lower.contains("trend") {
101            return TaskType::Analysis;
102        }
103
104        // Research keywords
105        if lower.contains("research") || lower.contains("find out") || lower.contains("look up")
106            || lower.contains("search for") {
107            return TaskType::Research;
108        }
109
110        // Writing keywords
111        if lower.contains("write") || lower.contains("draft") || lower.contains("compose")
112            || lower.contains("edit") {
113            if lower.contains("creative") || lower.contains("story") || lower.contains("poem") {
114                return TaskType::Creative;
115            }
116            return TaskType::Writing;
117        }
118
119        // Translation
120        if lower.contains("translate") || lower.contains("translation") {
121            return TaskType::Translation;
122        }
123
124        // Summarization
125        if lower.contains("summarize") || lower.contains("summary") || lower.contains("tldr") {
126            return TaskType::Summarization;
127        }
128
129        // Reasoning
130        if lower.contains("why") || lower.contains("reason") || lower.contains("explain")
131            || lower.contains("logic") {
132            return TaskType::Reasoning;
133        }
134
135        // Image/vision
136        if lower.contains("image") || lower.contains("picture") || lower.contains("photo")
137            || lower.contains("see") || lower.contains("look at") {
138            return TaskType::Vision;
139        }
140
141        // Question answering
142        if lower.ends_with('?') || lower.starts_with("what") || lower.starts_with("how")
143            || lower.starts_with("when") || lower.starts_with("where") {
144            return TaskType::QuestionAnswering;
145        }
146
147        // Short messages are quick queries
148        if content.len() < 50 {
149            return TaskType::Quick;
150        }
151
152        TaskType::Chat
153    }
154}
155
156// ============================================================================
157// Model Capabilities
158// ============================================================================
159
160/// Model capability profile
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct ModelCapabilities {
163    /// Model identifier
164    pub model_id: String,
165    /// Provider (e.g., "openai", "anthropic", "google")
166    pub provider: String,
167    /// Display name
168    pub name: String,
169    /// Task type scores (0.0 - 1.0)
170    pub task_scores: HashMap<TaskType, f64>,
171    /// Context window size
172    pub context_window: usize,
173    /// Whether supports vision/images
174    pub supports_vision: bool,
175    /// Whether supports function calling
176    pub supports_functions: bool,
177    /// Whether supports streaming
178    pub supports_streaming: bool,
179    /// Cost per 1K input tokens (USD)
180    pub cost_per_1k_input: f64,
181    /// Cost per 1K output tokens (USD)
182    pub cost_per_1k_output: f64,
183    /// Average latency in ms
184    pub avg_latency_ms: u32,
185    /// Whether model is available
186    pub available: bool,
187}
188
189impl ModelCapabilities {
190    /// Create a new model capabilities profile
191    pub fn new(model_id: &str, provider: &str, name: &str) -> Self {
192        Self {
193            model_id: model_id.to_string(),
194            provider: provider.to_string(),
195            name: name.to_string(),
196            task_scores: HashMap::new(),
197            context_window: 4096,
198            supports_vision: false,
199            supports_functions: false,
200            supports_streaming: true,
201            cost_per_1k_input: 0.0,
202            cost_per_1k_output: 0.0,
203            avg_latency_ms: 1000,
204            available: true,
205        }
206    }
207
208    /// Set task score
209    pub fn with_task_score(mut self, task: TaskType, score: f64) -> Self {
210        self.task_scores.insert(task, score.clamp(0.0, 1.0));
211        self
212    }
213
214    /// Set context window
215    pub fn with_context_window(mut self, size: usize) -> Self {
216        self.context_window = size;
217        self
218    }
219
220    /// Set vision support
221    pub fn with_vision(mut self, supports: bool) -> Self {
222        self.supports_vision = supports;
223        self
224    }
225
226    /// Set function calling support
227    pub fn with_functions(mut self, supports: bool) -> Self {
228        self.supports_functions = supports;
229        self
230    }
231
232    /// Set cost
233    pub fn with_cost(mut self, input: f64, output: f64) -> Self {
234        self.cost_per_1k_input = input;
235        self.cost_per_1k_output = output;
236        self
237    }
238
239    /// Set latency
240    pub fn with_latency(mut self, ms: u32) -> Self {
241        self.avg_latency_ms = ms;
242        self
243    }
244
245    /// Get score for a task type
246    pub fn score_for_task(&self, task: TaskType) -> f64 {
247        self.task_scores.get(&task).copied().unwrap_or(0.5)
248    }
249}
250
251// ============================================================================
252// Routing Configuration
253// ============================================================================
254
255/// Routing strategy
256#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
257#[serde(rename_all = "snake_case")]
258pub enum RoutingStrategy {
259    /// Best quality regardless of cost
260    BestQuality,
261    /// Lowest cost that meets quality threshold
262    LowestCost,
263    /// Fastest response time
264    FastestResponse,
265    /// Balance quality and cost
266    Balanced,
267    /// Custom weighted scoring
268    Custom,
269}
270
271/// Routing constraints
272#[derive(Debug, Clone, Serialize, Deserialize)]
273pub struct RoutingConstraints {
274    /// Maximum cost per request (USD)
275    pub max_cost: Option<f64>,
276    /// Maximum latency (ms)
277    pub max_latency_ms: Option<u32>,
278    /// Minimum context window required
279    pub min_context_window: Option<usize>,
280    /// Required providers (whitelist)
281    pub allowed_providers: Option<Vec<String>>,
282    /// Blocked providers (blacklist)
283    pub blocked_providers: Vec<String>,
284    /// Require vision support
285    pub require_vision: bool,
286    /// Require function calling
287    pub require_functions: bool,
288}
289
290impl Default for RoutingConstraints {
291    fn default() -> Self {
292        Self {
293            max_cost: None,
294            max_latency_ms: None,
295            min_context_window: None,
296            allowed_providers: None,
297            blocked_providers: vec![],
298            require_vision: false,
299            require_functions: false,
300        }
301    }
302}
303
304/// Routing configuration
305#[derive(Debug, Clone, Serialize, Deserialize)]
306pub struct RoutingConfig {
307    /// Routing strategy
308    pub strategy: RoutingStrategy,
309    /// Constraints
310    pub constraints: RoutingConstraints,
311    /// Quality weight (0.0 - 1.0)
312    pub quality_weight: f64,
313    /// Cost weight (0.0 - 1.0)
314    pub cost_weight: f64,
315    /// Latency weight (0.0 - 1.0)
316    pub latency_weight: f64,
317    /// Fallback model if routing fails
318    pub fallback_model: Option<String>,
319}
320
321impl Default for RoutingConfig {
322    fn default() -> Self {
323        Self {
324            strategy: RoutingStrategy::Balanced,
325            constraints: RoutingConstraints::default(),
326            quality_weight: 0.5,
327            cost_weight: 0.3,
328            latency_weight: 0.2,
329            fallback_model: None,
330        }
331    }
332}
333
334// ============================================================================
335// Routing Request/Response
336// ============================================================================
337
338/// Routing request
339#[derive(Debug, Clone, Serialize, Deserialize)]
340pub struct RoutingRequest {
341    /// Request ID
342    pub id: Uuid,
343    /// Message content to route
344    pub content: String,
345    /// Conversation context (previous messages)
346    pub context: Vec<String>,
347    /// Estimated token count
348    pub estimated_tokens: usize,
349    /// User preferences
350    pub config: RoutingConfig,
351    /// Timestamp
352    pub timestamp: DateTime<Utc>,
353}
354
355/// Routing decision
356#[derive(Debug, Clone, Serialize, Deserialize)]
357pub struct RoutingDecision {
358    /// Request ID
359    pub request_id: Uuid,
360    /// Selected model
361    pub model_id: String,
362    /// Provider
363    pub provider: String,
364    /// Detected task type
365    pub task_type: TaskType,
366    /// Confidence score (0.0 - 1.0)
367    pub confidence: f64,
368    /// Estimated cost
369    pub estimated_cost: f64,
370    /// Estimated latency
371    pub estimated_latency_ms: u32,
372    /// Alternative models considered
373    pub alternatives: Vec<ModelScore>,
374    /// Reasoning for selection
375    pub reasoning: String,
376    /// Decision timestamp
377    pub decided_at: DateTime<Utc>,
378}
379
380/// Model score during routing
381#[derive(Debug, Clone, Serialize, Deserialize)]
382pub struct ModelScore {
383    /// Model ID
384    pub model_id: String,
385    /// Provider
386    pub provider: String,
387    /// Quality score
388    pub quality_score: f64,
389    /// Cost score
390    pub cost_score: f64,
391    /// Latency score
392    pub latency_score: f64,
393    /// Total weighted score
394    pub total_score: f64,
395    /// Why not selected (if applicable)
396    pub rejection_reason: Option<String>,
397}
398
399// ============================================================================
400// Model Router
401// ============================================================================
402
403/// Multi-model conversation router
404pub struct ModelRouter {
405    /// Available models
406    models: Vec<ModelCapabilities>,
407    /// Default configuration
408    default_config: RoutingConfig,
409    /// Routing history for learning
410    history: Vec<RoutingDecision>,
411}
412
413impl ModelRouter {
414    /// Create a new model router
415    pub fn new() -> Self {
416        Self {
417            models: Self::default_models(),
418            default_config: RoutingConfig::default(),
419            history: vec![],
420        }
421    }
422
423    /// Create router with custom models
424    pub fn with_models(models: Vec<ModelCapabilities>) -> Self {
425        Self {
426            models,
427            default_config: RoutingConfig::default(),
428            history: vec![],
429        }
430    }
431
432    /// Add a model
433    pub fn add_model(&mut self, model: ModelCapabilities) {
434        self.models.push(model);
435    }
436
437    /// Route a request to the optimal model
438    pub fn route(&mut self, request: &RoutingRequest) -> RoutingDecision {
439        // Detect task type
440        let task_type = TaskType::detect(&request.content);
441
442        // Score all models
443        let mut scores: Vec<ModelScore> = self
444            .models
445            .iter()
446            .filter(|m| m.available)
447            .filter(|m| self.meets_constraints(m, &request.config.constraints))
448            .map(|m| self.score_model(m, task_type, request))
449            .collect();
450
451        // Sort by total score (descending)
452        scores.sort_by(|a, b| b.total_score.partial_cmp(&a.total_score).unwrap());
453
454        // Select best model
455        let selected = scores.first().cloned().unwrap_or_else(|| {
456            // Fallback
457            ModelScore {
458                model_id: request.config.fallback_model.clone()
459                    .unwrap_or_else(|| "gpt-4o-mini".to_string()),
460                provider: "openai".to_string(),
461                quality_score: 0.5,
462                cost_score: 0.5,
463                latency_score: 0.5,
464                total_score: 0.5,
465                rejection_reason: None,
466            }
467        });
468
469        let decision = RoutingDecision {
470            request_id: request.id,
471            model_id: selected.model_id.clone(),
472            provider: selected.provider.clone(),
473            task_type,
474            confidence: selected.total_score,
475            estimated_cost: self.estimate_cost(&selected.model_id, request.estimated_tokens),
476            estimated_latency_ms: self.estimate_latency(&selected.model_id),
477            alternatives: scores.into_iter().skip(1).take(3).collect(),
478            reasoning: self.generate_reasoning(&selected, task_type),
479            decided_at: Utc::now(),
480        };
481
482        // Store in history
483        self.history.push(decision.clone());
484
485        decision
486    }
487
488    /// Check if model meets constraints
489    fn meets_constraints(&self, model: &ModelCapabilities, constraints: &RoutingConstraints) -> bool {
490        // Check cost
491        if let Some(max_cost) = constraints.max_cost {
492            if model.cost_per_1k_output > max_cost * 10.0 {
493                return false;
494            }
495        }
496
497        // Check latency
498        if let Some(max_latency) = constraints.max_latency_ms {
499            if model.avg_latency_ms > max_latency {
500                return false;
501            }
502        }
503
504        // Check context window
505        if let Some(min_context) = constraints.min_context_window {
506            if model.context_window < min_context {
507                return false;
508            }
509        }
510
511        // Check allowed providers
512        if let Some(ref allowed) = constraints.allowed_providers {
513            if !allowed.contains(&model.provider) {
514                return false;
515            }
516        }
517
518        // Check blocked providers
519        if constraints.blocked_providers.contains(&model.provider) {
520            return false;
521        }
522
523        // Check vision requirement
524        if constraints.require_vision && !model.supports_vision {
525            return false;
526        }
527
528        // Check function requirement
529        if constraints.require_functions && !model.supports_functions {
530            return false;
531        }
532
533        true
534    }
535
536    /// Score a model for routing
537    fn score_model(&self, model: &ModelCapabilities, task: TaskType, request: &RoutingRequest) -> ModelScore {
538        let config = &request.config;
539
540        // Quality score based on task
541        let quality_score = model.score_for_task(task);
542
543        // Cost score (inverse - lower cost = higher score)
544        let max_cost = 0.1; // $0.10 per 1K tokens as baseline
545        let cost_score = 1.0 - (model.cost_per_1k_output / max_cost).min(1.0);
546
547        // Latency score (inverse - lower latency = higher score)
548        let max_latency = 5000.0; // 5 seconds as baseline
549        let latency_score = 1.0 - (model.avg_latency_ms as f64 / max_latency).min(1.0);
550
551        // Calculate total based on strategy
552        let total_score = match config.strategy {
553            RoutingStrategy::BestQuality => quality_score,
554            RoutingStrategy::LowestCost => cost_score,
555            RoutingStrategy::FastestResponse => latency_score,
556            RoutingStrategy::Balanced => {
557                (quality_score + cost_score + latency_score) / 3.0
558            }
559            RoutingStrategy::Custom => {
560                config.quality_weight * quality_score
561                    + config.cost_weight * cost_score
562                    + config.latency_weight * latency_score
563            }
564        };
565
566        ModelScore {
567            model_id: model.model_id.clone(),
568            provider: model.provider.clone(),
569            quality_score,
570            cost_score,
571            latency_score,
572            total_score,
573            rejection_reason: None,
574        }
575    }
576
577    /// Estimate cost for a request
578    fn estimate_cost(&self, model_id: &str, tokens: usize) -> f64 {
579        self.models
580            .iter()
581            .find(|m| m.model_id == model_id)
582            .map(|m| (tokens as f64 / 1000.0) * (m.cost_per_1k_input + m.cost_per_1k_output))
583            .unwrap_or(0.0)
584    }
585
586    /// Estimate latency for a model
587    fn estimate_latency(&self, model_id: &str) -> u32 {
588        self.models
589            .iter()
590            .find(|m| m.model_id == model_id)
591            .map(|m| m.avg_latency_ms)
592            .unwrap_or(1000)
593    }
594
595    /// Generate reasoning for the selection
596    fn generate_reasoning(&self, selected: &ModelScore, task: TaskType) -> String {
597        format!(
598            "Selected {} for {:?} task. Quality: {:.0}%, Cost efficiency: {:.0}%, Speed: {:.0}%",
599            selected.model_id,
600            task,
601            selected.quality_score * 100.0,
602            selected.cost_score * 100.0,
603            selected.latency_score * 100.0
604        )
605    }
606
607    /// Get default model profiles
608    fn default_models() -> Vec<ModelCapabilities> {
609        vec![
610            // OpenAI models
611            ModelCapabilities::new("gpt-4o", "openai", "GPT-4o")
612                .with_context_window(128000)
613                .with_vision(true)
614                .with_functions(true)
615                .with_cost(0.005, 0.015)
616                .with_latency(800)
617                .with_task_score(TaskType::Coding, 0.95)
618                .with_task_score(TaskType::Reasoning, 0.95)
619                .with_task_score(TaskType::Vision, 0.90)
620                .with_task_score(TaskType::Writing, 0.90)
621                .with_task_score(TaskType::Analysis, 0.90),
622
623            ModelCapabilities::new("gpt-4o-mini", "openai", "GPT-4o Mini")
624                .with_context_window(128000)
625                .with_vision(true)
626                .with_functions(true)
627                .with_cost(0.00015, 0.0006)
628                .with_latency(500)
629                .with_task_score(TaskType::Chat, 0.85)
630                .with_task_score(TaskType::Quick, 0.90)
631                .with_task_score(TaskType::QuestionAnswering, 0.85)
632                .with_task_score(TaskType::Coding, 0.80),
633
634            ModelCapabilities::new("o1", "openai", "o1")
635                .with_context_window(200000)
636                .with_vision(true)
637                .with_functions(false)
638                .with_cost(0.015, 0.06)
639                .with_latency(3000)
640                .with_task_score(TaskType::Reasoning, 0.99)
641                .with_task_score(TaskType::Math, 0.98)
642                .with_task_score(TaskType::Coding, 0.97)
643                .with_task_score(TaskType::Analysis, 0.95),
644
645            // Anthropic models
646            ModelCapabilities::new("claude-sonnet-4-20250514", "anthropic", "Claude Sonnet 4")
647                .with_context_window(200000)
648                .with_vision(true)
649                .with_functions(true)
650                .with_cost(0.003, 0.015)
651                .with_latency(700)
652                .with_task_score(TaskType::Coding, 0.95)
653                .with_task_score(TaskType::Writing, 0.95)
654                .with_task_score(TaskType::Reasoning, 0.92)
655                .with_task_score(TaskType::Analysis, 0.90),
656
657            ModelCapabilities::new("claude-3-5-haiku-20241022", "anthropic", "Claude 3.5 Haiku")
658                .with_context_window(200000)
659                .with_vision(true)
660                .with_functions(true)
661                .with_cost(0.0008, 0.004)
662                .with_latency(400)
663                .with_task_score(TaskType::Chat, 0.85)
664                .with_task_score(TaskType::Quick, 0.90)
665                .with_task_score(TaskType::Coding, 0.80),
666
667            // Google models
668            ModelCapabilities::new("gemini-2.5-flash", "google", "Gemini 2.5 Flash")
669                .with_context_window(1000000)
670                .with_vision(true)
671                .with_functions(true)
672                .with_cost(0.000075, 0.0003)
673                .with_latency(300)
674                .with_task_score(TaskType::Chat, 0.85)
675                .with_task_score(TaskType::Quick, 0.95)
676                .with_task_score(TaskType::Coding, 0.85)
677                .with_task_score(TaskType::Analysis, 0.85),
678
679            ModelCapabilities::new("gemini-2.5-pro", "google", "Gemini 2.5 Pro")
680                .with_context_window(1000000)
681                .with_vision(true)
682                .with_functions(true)
683                .with_cost(0.00125, 0.005)
684                .with_latency(600)
685                .with_task_score(TaskType::Coding, 0.92)
686                .with_task_score(TaskType::Reasoning, 0.90)
687                .with_task_score(TaskType::Analysis, 0.90)
688                .with_task_score(TaskType::Research, 0.90),
689
690            // Local models (Ollama)
691            ModelCapabilities::new("llama3.3:70b", "ollama", "Llama 3.3 70B")
692                .with_context_window(128000)
693                .with_vision(false)
694                .with_functions(true)
695                .with_cost(0.0, 0.0)
696                .with_latency(2000)
697                .with_task_score(TaskType::Chat, 0.80)
698                .with_task_score(TaskType::Coding, 0.75)
699                .with_task_score(TaskType::Writing, 0.80),
700
701            ModelCapabilities::new("qwen2.5-coder:32b", "ollama", "Qwen 2.5 Coder 32B")
702                .with_context_window(32000)
703                .with_vision(false)
704                .with_functions(false)
705                .with_cost(0.0, 0.0)
706                .with_latency(1500)
707                .with_task_score(TaskType::Coding, 0.85)
708                .with_task_score(TaskType::CodeReview, 0.85)
709                .with_task_score(TaskType::Debugging, 0.80),
710        ]
711    }
712
713    /// Get routing statistics
714    pub fn stats(&self) -> RouterStats {
715        let mut task_counts: HashMap<TaskType, usize> = HashMap::new();
716        let mut model_counts: HashMap<String, usize> = HashMap::new();
717        let mut total_cost = 0.0;
718
719        for decision in &self.history {
720            *task_counts.entry(decision.task_type).or_insert(0) += 1;
721            *model_counts.entry(decision.model_id.clone()).or_insert(0) += 1;
722            total_cost += decision.estimated_cost;
723        }
724
725        RouterStats {
726            total_requests: self.history.len(),
727            task_distribution: task_counts,
728            model_distribution: model_counts,
729            total_estimated_cost: total_cost,
730            avg_confidence: self.history.iter().map(|d| d.confidence).sum::<f64>()
731                / self.history.len().max(1) as f64,
732        }
733    }
734}
735
736impl Default for ModelRouter {
737    fn default() -> Self {
738        Self::new()
739    }
740}
741
742/// Router statistics
743#[derive(Debug, Clone, Serialize, Deserialize)]
744pub struct RouterStats {
745    /// Total routing requests
746    pub total_requests: usize,
747    /// Task type distribution
748    pub task_distribution: HashMap<TaskType, usize>,
749    /// Model usage distribution
750    pub model_distribution: HashMap<String, usize>,
751    /// Total estimated cost
752    pub total_estimated_cost: f64,
753    /// Average confidence score
754    pub avg_confidence: f64,
755}
756
757#[cfg(test)]
758mod tests {
759    use super::*;
760
761    #[test]
762    fn test_task_detection() {
763        assert_eq!(TaskType::detect("Write a function to sort an array"), TaskType::Coding);
764        assert_eq!(TaskType::detect("Review this code for bugs"), TaskType::CodeReview);
765        assert_eq!(TaskType::detect("Calculate 2 + 2"), TaskType::Math);
766        assert_eq!(TaskType::detect("Translate this to Spanish"), TaskType::Translation);
767        assert_eq!(TaskType::detect("What is the weather?"), TaskType::QuestionAnswering);
768        assert_eq!(TaskType::detect("Hi"), TaskType::Quick);
769    }
770
771    #[test]
772    fn test_routing_decision() {
773        let mut router = ModelRouter::new();
774        let request = RoutingRequest {
775            id: Uuid::new_v4(),
776            content: "Write a Python function to parse JSON".to_string(),
777            context: vec![],
778            estimated_tokens: 500,
779            config: RoutingConfig::default(),
780            timestamp: Utc::now(),
781        };
782
783        let decision = router.route(&request);
784        assert_eq!(decision.task_type, TaskType::Coding);
785        assert!(decision.confidence > 0.0);
786        assert!(!decision.model_id.is_empty());
787    }
788
789    #[test]
790    fn test_constraints() {
791        let mut router = ModelRouter::new();
792        let mut config = RoutingConfig::default();
793        config.constraints.max_cost = Some(0.001);
794        config.constraints.allowed_providers = Some(vec!["google".to_string()]);
795
796        let request = RoutingRequest {
797            id: Uuid::new_v4(),
798            content: "Quick question".to_string(),
799            context: vec![],
800            estimated_tokens: 100,
801            config,
802            timestamp: Utc::now(),
803        };
804
805        let decision = router.route(&request);
806        assert_eq!(decision.provider, "google");
807    }
808}