1use std::collections::HashMap;
8use std::sync::atomic::{AtomicU32, AtomicU64};
9use std::sync::Arc;
10
11pub mod error;
12pub mod reconciler;
13pub mod requirements;
14pub mod scoring;
15pub mod strategies; pub use error::RoutingError;
18pub use requirements::RequestRequirements;
19pub use scoring::{score_backend, ScoringWeights};
20pub use strategies::RoutingStrategy;
21
22use crate::agent::quality::QualityMetricsStore;
23use crate::agent::tokenizer::TokenizerRegistry;
24use crate::config::{BudgetConfig, PolicyMatcher, QualityConfig};
25use crate::registry::{Backend, BackendStatus, Registry};
26use crate::routing::reconciler::budget::BudgetMetrics;
27use dashmap::DashMap;
28use reconciler::budget::BudgetReconciler;
29use reconciler::decision::RoutingDecision;
30use reconciler::intent::RoutingIntent;
31use reconciler::privacy::PrivacyReconciler;
32use reconciler::quality::QualityReconciler;
33use reconciler::request_analyzer::RequestAnalyzer;
34use reconciler::scheduler::SchedulerReconciler;
35use reconciler::tier::TierReconciler;
36use reconciler::ReconcilerPipeline;
37
38#[derive(Debug)]
40pub struct RoutingResult {
41 pub backend: Arc<Backend>,
43 pub actual_model: String,
45 pub fallback_used: bool,
47 pub route_reason: String,
50 pub cost_estimated: Option<f64>,
53 pub budget_status: reconciler::intent::BudgetStatus,
55 pub budget_utilization: Option<f64>,
57 pub budget_remaining: Option<f64>,
59}
60
61#[allow(dead_code)] pub struct Router {
64 registry: Arc<Registry>,
66
67 strategy: RoutingStrategy,
69
70 weights: ScoringWeights,
72
73 aliases: HashMap<String, String>,
75
76 fallbacks: HashMap<String, Vec<String>>,
78
79 round_robin_counter: Arc<AtomicU64>,
81
82 policy_matcher: PolicyMatcher,
84
85 budget_config: BudgetConfig,
87
88 budget_state: Arc<DashMap<String, BudgetMetrics>>,
90
91 tokenizer_registry: Arc<TokenizerRegistry>,
93
94 quality_store: Arc<QualityMetricsStore>,
96
97 quality_config: QualityConfig,
99
100 queue_enabled: bool,
102}
103
104impl Router {
105 pub fn new(
107 registry: Arc<Registry>,
108 strategy: RoutingStrategy,
109 weights: ScoringWeights,
110 ) -> Self {
111 let tokenizer_registry =
112 Arc::new(TokenizerRegistry::new().expect("Failed to initialize tokenizer registry"));
113 let quality_config = QualityConfig::default();
114 let quality_store = Arc::new(QualityMetricsStore::new(quality_config.clone()));
115 Self {
116 registry,
117 strategy,
118 weights,
119 aliases: HashMap::new(),
120 fallbacks: HashMap::new(),
121 round_robin_counter: Arc::new(AtomicU64::new(0)),
122 policy_matcher: PolicyMatcher::default(),
123 budget_config: BudgetConfig::default(),
124 budget_state: Arc::new(DashMap::new()),
125 tokenizer_registry,
126 quality_store,
127 quality_config,
128 queue_enabled: false,
129 }
130 }
131
132 pub fn with_aliases_and_fallbacks(
134 registry: Arc<Registry>,
135 strategy: RoutingStrategy,
136 weights: ScoringWeights,
137 aliases: HashMap<String, String>,
138 fallbacks: HashMap<String, Vec<String>>,
139 ) -> Self {
140 let tokenizer_registry =
141 Arc::new(TokenizerRegistry::new().expect("Failed to initialize tokenizer registry"));
142 let quality_config = QualityConfig::default();
143 let quality_store = Arc::new(QualityMetricsStore::new(quality_config.clone()));
144 Self {
145 registry,
146 strategy,
147 weights,
148 aliases,
149 fallbacks,
150 round_robin_counter: Arc::new(AtomicU64::new(0)),
151 policy_matcher: PolicyMatcher::default(),
152 budget_config: BudgetConfig::default(),
153 budget_state: Arc::new(DashMap::new()),
154 tokenizer_registry,
155 quality_store,
156 quality_config,
157 queue_enabled: false,
158 }
159 }
160
161 pub fn with_aliases_fallbacks_and_policies(
163 registry: Arc<Registry>,
164 strategy: RoutingStrategy,
165 weights: ScoringWeights,
166 aliases: HashMap<String, String>,
167 fallbacks: HashMap<String, Vec<String>>,
168 policy_matcher: PolicyMatcher,
169 quality_config: QualityConfig,
170 ) -> Self {
171 let tokenizer_registry =
172 Arc::new(TokenizerRegistry::new().expect("Failed to initialize tokenizer registry"));
173 let quality_store = Arc::new(QualityMetricsStore::new(quality_config.clone()));
174 Self {
175 registry,
176 strategy,
177 weights,
178 aliases,
179 fallbacks,
180 round_robin_counter: Arc::new(AtomicU64::new(0)),
181 policy_matcher,
182 budget_config: BudgetConfig::default(),
183 budget_state: Arc::new(DashMap::new()),
184 tokenizer_registry,
185 quality_store,
186 quality_config,
187 queue_enabled: false,
188 }
189 }
190
191 #[allow(clippy::too_many_arguments)]
193 pub fn with_full_config(
194 registry: Arc<Registry>,
195 strategy: RoutingStrategy,
196 weights: ScoringWeights,
197 aliases: HashMap<String, String>,
198 fallbacks: HashMap<String, Vec<String>>,
199 policy_matcher: PolicyMatcher,
200 budget_config: BudgetConfig,
201 budget_state: Arc<DashMap<String, BudgetMetrics>>,
202 ) -> Self {
203 let tokenizer_registry =
204 Arc::new(TokenizerRegistry::new().expect("Failed to initialize tokenizer registry"));
205 let quality_config = QualityConfig::default();
206 let quality_store = Arc::new(QualityMetricsStore::new(quality_config.clone()));
207 Self {
208 registry,
209 strategy,
210 weights,
211 aliases,
212 fallbacks,
213 round_robin_counter: Arc::new(AtomicU64::new(0)),
214 policy_matcher,
215 budget_config,
216 budget_state,
217 tokenizer_registry,
218 quality_store,
219 quality_config,
220 queue_enabled: false,
221 }
222 }
223
224 fn resolve_alias(&self, model: &str) -> String {
226 let mut current = model.to_string();
227 let mut depth = 0;
228 const MAX_DEPTH: usize = 3;
229
230 while depth < MAX_DEPTH {
231 match self.aliases.get(¤t) {
232 Some(target) => {
233 tracing::debug!(
234 from = %current,
235 to = %target,
236 depth = depth + 1,
237 "Resolved alias"
238 );
239 current = target.clone();
240 depth += 1;
241 }
242 None => break,
243 }
244 }
245
246 if depth > 0 {
247 tracing::debug!(
248 original = %model,
249 resolved = %current,
250 chain_depth = depth,
251 "Alias resolution complete"
252 );
253 }
254
255 current
256 }
257
258 fn get_fallbacks(&self, model: &str) -> Vec<String> {
260 self.fallbacks.get(model).cloned().unwrap_or_default()
261 }
262
263 fn build_pipeline(&self, model_aliases: HashMap<String, String>) -> ReconcilerPipeline {
267 let analyzer = RequestAnalyzer::new(model_aliases, Arc::clone(&self.registry));
268 let privacy =
269 PrivacyReconciler::new(Arc::clone(&self.registry), self.policy_matcher.clone());
270 let budget = BudgetReconciler::new(
271 Arc::clone(&self.registry),
272 self.budget_config.clone(),
273 Arc::clone(&self.tokenizer_registry),
274 Arc::clone(&self.budget_state),
275 );
276 let tier = TierReconciler::new(Arc::clone(&self.registry), self.policy_matcher.clone());
277 let quality =
278 QualityReconciler::new(Arc::clone(&self.quality_store), self.quality_config.clone());
279 let scheduler = SchedulerReconciler::new(
280 Arc::clone(&self.registry),
281 self.strategy,
282 self.weights,
283 Arc::clone(&self.round_robin_counter),
284 Arc::clone(&self.quality_store),
285 self.quality_config.clone(),
286 );
287 ReconcilerPipeline::with_queue(
288 vec![
289 Box::new(analyzer),
290 Box::new(privacy),
291 Box::new(budget),
292 Box::new(tier),
293 Box::new(quality),
294 Box::new(scheduler),
295 ],
296 self.queue_enabled,
297 )
298 }
299
300 fn run_pipeline_for_model(
302 &self,
303 requirements: &RequestRequirements,
304 model: &str,
305 tier_enforcement_mode: Option<crate::routing::reconciler::intent::TierEnforcementMode>,
306 ) -> Result<RoutingDecision, RoutingError> {
307 let mut intent = RoutingIntent::new(
310 format!(
311 "req-{}",
312 std::time::SystemTime::now()
313 .duration_since(std::time::UNIX_EPOCH)
314 .unwrap_or_default()
315 .as_nanos()
316 ),
317 model.to_string(),
318 model.to_string(), requirements.clone(),
320 vec![], );
322
323 if let Some(mode) = tier_enforcement_mode {
325 intent.tier_enforcement_mode = mode;
326 }
327
328 let mut pipeline = self.build_pipeline(HashMap::new());
330 pipeline.execute(&mut intent)
331 }
332
333 pub fn select_backend(
345 &self,
346 requirements: &RequestRequirements,
347 tier_enforcement_mode: Option<crate::routing::reconciler::intent::TierEnforcementMode>,
348 ) -> Result<RoutingResult, RoutingError> {
349 let model = self.resolve_alias(&requirements.model);
351
352 let all_backends = self.registry.get_backends_for_model(&model);
354 let model_exists = !all_backends.is_empty();
355
356 let decision = self.run_pipeline_for_model(requirements, &model, tier_enforcement_mode)?;
358
359 if let RoutingDecision::Route {
360 agent_id,
361 model: resolved_model,
362 reason,
363 cost_estimate,
364 } = decision
365 {
366 let backend = self.registry.get_backend(&agent_id).ok_or_else(|| {
367 RoutingError::NoHealthyBackend {
368 model: model.clone(),
369 }
370 })?;
371
372 tracing::debug!(
373 backend = %backend.name,
374 backend_type = ?backend.backend_type,
375 model = %resolved_model,
376 route_reason = %reason,
377 "routing decision made"
378 );
379
380 let (budget_status, budget_utilization, budget_remaining) =
382 self.get_budget_status_and_utilization();
383
384 return Ok(RoutingResult {
385 backend: Arc::new(backend),
386 actual_model: resolved_model,
387 fallback_used: false,
388 route_reason: reason,
389 cost_estimated: Some(cost_estimate.cost_usd),
390 budget_status,
391 budget_utilization,
392 budget_remaining,
393 });
394 }
395
396 if let RoutingDecision::Queue {
398 reason,
399 estimated_wait_ms,
400 ..
401 } = &decision
402 {
403 return Err(RoutingError::Queue {
404 reason: reason.clone(),
405 estimated_wait_ms: *estimated_wait_ms,
406 });
407 }
408
409 let fallbacks = self.get_fallbacks(&model);
411 for fallback_model in &fallbacks {
412 let decision =
413 self.run_pipeline_for_model(requirements, fallback_model, tier_enforcement_mode)?;
414
415 if let RoutingDecision::Route {
416 agent_id,
417 cost_estimate,
418 reason,
419 ..
420 } = decision
421 {
422 let backend = self.registry.get_backend(&agent_id).ok_or_else(|| {
423 RoutingError::NoHealthyBackend {
424 model: fallback_model.clone(),
425 }
426 })?;
427
428 let route_reason = format!("fallback:{}:{}", model, reason);
429
430 tracing::warn!(
431 requested_model = %model,
432 fallback_model = %fallback_model,
433 backend = %backend.name,
434 "Using fallback model"
435 );
436
437 let (budget_status, budget_utilization, budget_remaining) =
439 self.get_budget_status_and_utilization();
440
441 return Ok(RoutingResult {
442 backend: Arc::new(backend),
443 actual_model: fallback_model.clone(),
444 fallback_used: true,
445 route_reason,
446 cost_estimated: Some(cost_estimate.cost_usd),
447 budget_status,
448 budget_utilization,
449 budget_remaining,
450 });
451 }
452 }
453
454 if !fallbacks.is_empty() {
456 let mut chain = vec![model.clone()];
457 chain.extend(fallbacks);
458 Err(RoutingError::FallbackChainExhausted { chain })
459 } else if model_exists {
460 Err(RoutingError::NoHealthyBackend {
461 model: model.clone(),
462 })
463 } else {
464 Err(RoutingError::ModelNotFound {
465 model: requirements.model.clone(),
466 })
467 }
468 }
469
470 #[allow(dead_code)] fn select_smart(&self, candidates: &[Backend]) -> Backend {
473 let best = candidates
474 .iter()
475 .max_by_key(|backend| {
476 let priority = backend.priority as u32;
477 let pending = backend
478 .pending_requests
479 .load(std::sync::atomic::Ordering::Relaxed);
480 let latency = backend
481 .avg_latency_ms
482 .load(std::sync::atomic::Ordering::Relaxed);
483 score_backend(priority, pending, latency, &self.weights)
484 })
485 .unwrap();
486
487 Backend {
489 id: best.id.clone(),
490 name: best.name.clone(),
491 url: best.url.clone(),
492 backend_type: best.backend_type,
493 status: best.status,
494 last_health_check: best.last_health_check,
495 last_error: best.last_error.clone(),
496 models: best.models.clone(),
497 priority: best.priority,
498 pending_requests: AtomicU32::new(
499 best.pending_requests
500 .load(std::sync::atomic::Ordering::Relaxed),
501 ),
502 total_requests: AtomicU64::new(
503 best.total_requests
504 .load(std::sync::atomic::Ordering::Relaxed),
505 ),
506 avg_latency_ms: AtomicU32::new(
507 best.avg_latency_ms
508 .load(std::sync::atomic::Ordering::Relaxed),
509 ),
510 discovery_source: best.discovery_source,
511 metadata: best.metadata.clone(),
512 }
513 }
514
515 #[allow(dead_code)] fn select_priority_only(&self, candidates: &[Backend]) -> Backend {
518 let best = candidates
519 .iter()
520 .min_by_key(|backend| backend.priority)
521 .unwrap();
522
523 Backend {
525 id: best.id.clone(),
526 name: best.name.clone(),
527 url: best.url.clone(),
528 backend_type: best.backend_type,
529 status: best.status,
530 last_health_check: best.last_health_check,
531 last_error: best.last_error.clone(),
532 models: best.models.clone(),
533 priority: best.priority,
534 pending_requests: AtomicU32::new(
535 best.pending_requests
536 .load(std::sync::atomic::Ordering::Relaxed),
537 ),
538 total_requests: AtomicU64::new(
539 best.total_requests
540 .load(std::sync::atomic::Ordering::Relaxed),
541 ),
542 avg_latency_ms: AtomicU32::new(
543 best.avg_latency_ms
544 .load(std::sync::atomic::Ordering::Relaxed),
545 ),
546 discovery_source: best.discovery_source,
547 metadata: best.metadata.clone(),
548 }
549 }
550
551 #[allow(dead_code)] fn select_random(&self, candidates: &[Backend]) -> Backend {
554 use std::collections::hash_map::RandomState;
555 use std::hash::BuildHasher;
556
557 let random_state = RandomState::new();
559 let random_value = random_state.hash_one(std::time::SystemTime::now());
560 let index = (random_value as usize) % candidates.len();
561 let best = &candidates[index];
562
563 Backend {
565 id: best.id.clone(),
566 name: best.name.clone(),
567 url: best.url.clone(),
568 backend_type: best.backend_type,
569 status: best.status,
570 last_health_check: best.last_health_check,
571 last_error: best.last_error.clone(),
572 models: best.models.clone(),
573 priority: best.priority,
574 pending_requests: AtomicU32::new(
575 best.pending_requests
576 .load(std::sync::atomic::Ordering::Relaxed),
577 ),
578 total_requests: AtomicU64::new(
579 best.total_requests
580 .load(std::sync::atomic::Ordering::Relaxed),
581 ),
582 avg_latency_ms: AtomicU32::new(
583 best.avg_latency_ms
584 .load(std::sync::atomic::Ordering::Relaxed),
585 ),
586 discovery_source: best.discovery_source,
587 metadata: best.metadata.clone(),
588 }
589 }
590
591 #[allow(dead_code)] fn filter_candidates(&self, model: &str, requirements: &RequestRequirements) -> Vec<Backend> {
594 let mut candidates = self.registry.get_backends_for_model(model);
596
597 candidates.retain(|backend| backend.status == BackendStatus::Healthy);
599
600 candidates.retain(|backend| {
602 if let Some(model_info) = backend.models.iter().find(|m| m.id == model) {
604 if requirements.needs_vision && !model_info.supports_vision {
606 return false;
607 }
608
609 if requirements.needs_tools && !model_info.supports_tools {
611 return false;
612 }
613
614 if requirements.needs_json_mode && !model_info.supports_json_mode {
616 return false;
617 }
618
619 if requirements.estimated_tokens > model_info.context_length {
621 return false;
622 }
623
624 true
625 } else {
626 false
628 }
629 });
630
631 candidates
632 }
633
634 pub fn budget_config(&self) -> &BudgetConfig {
636 &self.budget_config
637 }
638
639 pub fn budget_state(&self) -> &Arc<DashMap<String, BudgetMetrics>> {
641 &self.budget_state
642 }
643
644 pub fn quality_store(&self) -> &Arc<QualityMetricsStore> {
646 &self.quality_store
647 }
648
649 pub fn set_queue_enabled(&mut self, enabled: bool) {
651 self.queue_enabled = enabled;
652 }
653
654 fn get_budget_status_and_utilization(
662 &self,
663 ) -> (reconciler::intent::BudgetStatus, Option<f64>, Option<f64>) {
664 use reconciler::budget::GLOBAL_BUDGET_KEY;
665 use reconciler::intent::BudgetStatus;
666
667 let monthly_limit = match self.budget_config.monthly_limit_usd {
668 Some(limit) if limit > 0.0 => limit,
669 _ => return (BudgetStatus::Normal, None, None),
670 };
671
672 let current_spending = self
673 .budget_state
674 .get(GLOBAL_BUDGET_KEY)
675 .map(|m| m.current_month_spending)
676 .unwrap_or(0.0);
677
678 let utilization_percent = (current_spending / monthly_limit) * 100.0;
679 let remaining = (monthly_limit - current_spending).max(0.0);
680 let soft_threshold = self.budget_config.soft_limit_percent;
681
682 let status = if utilization_percent >= 100.0 {
683 BudgetStatus::HardLimit
684 } else if utilization_percent >= soft_threshold {
685 BudgetStatus::SoftLimit
686 } else {
687 BudgetStatus::Normal
688 };
689
690 (status, Some(utilization_percent), Some(remaining))
691 }
692}
693
694#[cfg(test)]
695mod tests {
696 use super::*;
697
698 #[test]
699 fn routing_strategy_default_is_smart() {
700 assert_eq!(RoutingStrategy::default(), RoutingStrategy::Smart);
701 }
702
703 #[test]
704 fn routing_strategy_from_str() {
705 assert_eq!(
706 "smart".parse::<RoutingStrategy>().unwrap(),
707 RoutingStrategy::Smart
708 );
709 assert_eq!(
710 "round_robin".parse::<RoutingStrategy>().unwrap(),
711 RoutingStrategy::RoundRobin
712 );
713 assert_eq!(
714 "priority_only".parse::<RoutingStrategy>().unwrap(),
715 RoutingStrategy::PriorityOnly
716 );
717 assert_eq!(
718 "random".parse::<RoutingStrategy>().unwrap(),
719 RoutingStrategy::Random
720 );
721 }
722
723 #[test]
724 fn routing_strategy_from_str_case_insensitive() {
725 assert_eq!(
726 "Smart".parse::<RoutingStrategy>().unwrap(),
727 RoutingStrategy::Smart
728 );
729 assert_eq!(
730 "ROUND_ROBIN".parse::<RoutingStrategy>().unwrap(),
731 RoutingStrategy::RoundRobin
732 );
733 }
734
735 #[test]
736 fn routing_strategy_from_str_invalid() {
737 assert!("invalid".parse::<RoutingStrategy>().is_err());
738 }
739
740 #[test]
741 fn routing_strategy_display() {
742 assert_eq!(RoutingStrategy::Smart.to_string(), "smart");
743 assert_eq!(RoutingStrategy::RoundRobin.to_string(), "round_robin");
744 assert_eq!(RoutingStrategy::PriorityOnly.to_string(), "priority_only");
745 assert_eq!(RoutingStrategy::Random.to_string(), "random");
746 }
747
748 #[test]
749 fn routing_strategy_display_roundtrips() {
750 for strategy in &[
751 RoutingStrategy::Smart,
752 RoutingStrategy::RoundRobin,
753 RoutingStrategy::PriorityOnly,
754 RoutingStrategy::Random,
755 ] {
756 let s = strategy.to_string();
757 let parsed: RoutingStrategy = s.parse().unwrap();
758 assert_eq!(*strategy, parsed);
759 }
760 }
761}
762
763#[cfg(test)]
764mod filter_tests {
765 use super::*;
766 use crate::registry::{Backend, BackendStatus, BackendType, DiscoverySource, Model};
767 use chrono::Utc;
768 use std::collections::HashMap;
769 use std::sync::atomic::{AtomicU32, AtomicU64};
770
771 fn create_test_backend(
772 id: &str,
773 name: &str,
774 status: BackendStatus,
775 models: Vec<Model>,
776 ) -> Backend {
777 Backend {
778 id: id.to_string(),
779 name: name.to_string(),
780 url: format!("http://{}", name),
781 backend_type: BackendType::Ollama,
782 status,
783 last_health_check: Utc::now(),
784 last_error: None,
785 models,
786 priority: 1,
787 pending_requests: AtomicU32::new(0),
788 total_requests: AtomicU64::new(0),
789 avg_latency_ms: AtomicU32::new(50),
790 discovery_source: DiscoverySource::Static,
791 metadata: HashMap::new(),
792 }
793 }
794
795 fn create_test_model(
796 id: &str,
797 context_length: u32,
798 supports_vision: bool,
799 supports_tools: bool,
800 ) -> Model {
801 Model {
802 id: id.to_string(),
803 name: id.to_string(),
804 context_length,
805 supports_vision,
806 supports_tools,
807 supports_json_mode: false,
808 max_output_tokens: None,
809 }
810 }
811
812 fn create_test_router(backends: Vec<Backend>) -> Router {
813 let registry = Arc::new(Registry::new());
814 for backend in backends {
815 registry.add_backend(backend).unwrap();
816 }
817
818 Router::new(registry, RoutingStrategy::Smart, ScoringWeights::default())
819 }
820
821 #[test]
822 fn filters_by_model_name() {
823 let backends = vec![
824 create_test_backend(
825 "backend_a",
826 "Backend A",
827 BackendStatus::Healthy,
828 vec![create_test_model("llama3:8b", 4096, false, false)],
829 ),
830 create_test_backend(
831 "backend_b",
832 "Backend B",
833 BackendStatus::Healthy,
834 vec![create_test_model("mistral:7b", 4096, false, false)],
835 ),
836 ];
837
838 let router = create_test_router(backends);
839 let requirements = RequestRequirements {
840 model: "llama3:8b".to_string(),
841 estimated_tokens: 100,
842 needs_vision: false,
843 needs_tools: false,
844 needs_json_mode: false,
845 prefers_streaming: false,
846 };
847
848 let candidates = router.filter_candidates("llama3:8b", &requirements);
849 assert_eq!(candidates.len(), 1);
850 assert_eq!(candidates[0].name, "Backend A");
851 }
852
853 #[test]
854 fn filters_out_unhealthy_backends() {
855 let backends = vec![
856 create_test_backend(
857 "backend_a",
858 "Backend A",
859 BackendStatus::Healthy,
860 vec![create_test_model("llama3:8b", 4096, false, false)],
861 ),
862 create_test_backend(
863 "backend_b",
864 "Backend B",
865 BackendStatus::Unhealthy,
866 vec![create_test_model("llama3:8b", 4096, false, false)],
867 ),
868 ];
869
870 let router = create_test_router(backends);
871 let requirements = RequestRequirements {
872 model: "llama3:8b".to_string(),
873 estimated_tokens: 100,
874 needs_vision: false,
875 needs_tools: false,
876 needs_json_mode: false,
877 prefers_streaming: false,
878 };
879
880 let candidates = router.filter_candidates("llama3:8b", &requirements);
881 assert_eq!(candidates.len(), 1);
882 assert_eq!(candidates[0].name, "Backend A");
883 }
884
885 #[test]
886 fn filters_by_vision_capability() {
887 let backends = vec![
888 create_test_backend(
889 "backend_a",
890 "Backend A",
891 BackendStatus::Healthy,
892 vec![create_test_model("llama3:8b", 4096, false, false)],
893 ),
894 create_test_backend(
895 "backend_b",
896 "Backend B",
897 BackendStatus::Healthy,
898 vec![create_test_model("llama3:8b", 4096, true, false)],
899 ),
900 ];
901
902 let router = create_test_router(backends);
903 let requirements = RequestRequirements {
904 model: "llama3:8b".to_string(),
905 estimated_tokens: 100,
906 needs_vision: true,
907 needs_tools: false,
908 needs_json_mode: false,
909 prefers_streaming: false,
910 };
911
912 let candidates = router.filter_candidates("llama3:8b", &requirements);
913 assert_eq!(candidates.len(), 1);
914 assert!(candidates[0].models[0].supports_vision);
915 }
916
917 #[test]
918 fn filters_by_context_length() {
919 let backends = vec![
920 create_test_backend(
921 "backend_a",
922 "Backend A",
923 BackendStatus::Healthy,
924 vec![create_test_model("llama3:8b", 4096, false, false)],
925 ),
926 create_test_backend(
927 "backend_b",
928 "Backend B",
929 BackendStatus::Healthy,
930 vec![create_test_model("llama3:8b", 128000, false, false)],
931 ),
932 ];
933
934 let router = create_test_router(backends);
935 let requirements = RequestRequirements {
936 model: "llama3:8b".to_string(),
937 estimated_tokens: 10000,
938 needs_vision: false,
939 needs_tools: false,
940 needs_json_mode: false,
941 prefers_streaming: false,
942 };
943
944 let candidates = router.filter_candidates("llama3:8b", &requirements);
945 assert_eq!(candidates.len(), 1);
946 assert!(candidates[0].models[0].context_length >= 10000);
947 }
948
949 #[test]
950 fn returns_empty_when_no_match() {
951 let backends = vec![create_test_backend(
952 "backend_a",
953 "Backend A",
954 BackendStatus::Healthy,
955 vec![create_test_model("llama3:8b", 4096, false, false)],
956 )];
957
958 let router = create_test_router(backends);
959 let requirements = RequestRequirements {
960 model: "nonexistent".to_string(),
961 estimated_tokens: 100,
962 needs_vision: false,
963 needs_tools: false,
964 needs_json_mode: false,
965 prefers_streaming: false,
966 };
967
968 let candidates = router.filter_candidates("nonexistent", &requirements);
969 assert!(candidates.is_empty());
970 }
971
972 #[test]
973 fn filters_by_tools_capability() {
974 let backends = vec![
975 create_test_backend(
976 "backend_a",
977 "Backend A",
978 BackendStatus::Healthy,
979 vec![create_test_model("llama3:8b", 4096, false, false)],
980 ),
981 create_test_backend(
982 "backend_b",
983 "Backend B",
984 BackendStatus::Healthy,
985 vec![create_test_model("llama3:8b", 4096, false, true)],
986 ),
987 ];
988
989 let router = create_test_router(backends);
990 let requirements = RequestRequirements {
991 model: "llama3:8b".to_string(),
992 estimated_tokens: 100,
993 needs_vision: false,
994 needs_tools: true,
995 needs_json_mode: false,
996 prefers_streaming: false,
997 };
998
999 let candidates = router.filter_candidates("llama3:8b", &requirements);
1000 assert_eq!(candidates.len(), 1);
1001 assert!(candidates[0].models[0].supports_tools);
1002 }
1003
1004 #[test]
1005 fn filters_by_json_mode_capability() {
1006 let model_no_json = Model {
1007 id: "llama3:8b".to_string(),
1008 name: "llama3:8b".to_string(),
1009 context_length: 4096,
1010 supports_vision: false,
1011 supports_tools: false,
1012 supports_json_mode: false,
1013 max_output_tokens: None,
1014 };
1015 let model_with_json = Model {
1016 id: "llama3:8b".to_string(),
1017 name: "llama3:8b".to_string(),
1018 context_length: 4096,
1019 supports_vision: false,
1020 supports_tools: false,
1021 supports_json_mode: true,
1022 max_output_tokens: None,
1023 };
1024
1025 let backends = vec![
1026 create_test_backend(
1027 "backend_a",
1028 "Backend A",
1029 BackendStatus::Healthy,
1030 vec![model_no_json],
1031 ),
1032 create_test_backend(
1033 "backend_b",
1034 "Backend B",
1035 BackendStatus::Healthy,
1036 vec![model_with_json],
1037 ),
1038 ];
1039
1040 let router = create_test_router(backends);
1041 let requirements = RequestRequirements {
1042 model: "llama3:8b".to_string(),
1043 estimated_tokens: 100,
1044 needs_vision: false,
1045 needs_tools: false,
1046 needs_json_mode: true,
1047 prefers_streaming: false,
1048 };
1049
1050 let candidates = router.filter_candidates("llama3:8b", &requirements);
1051 assert_eq!(candidates.len(), 1);
1052 assert_eq!(candidates[0].name, "Backend B");
1053 }
1054
1055 #[test]
1056 fn filters_by_multiple_capabilities() {
1057 let full_model = Model {
1058 id: "llama3:8b".to_string(),
1059 name: "llama3:8b".to_string(),
1060 context_length: 128000,
1061 supports_vision: true,
1062 supports_tools: true,
1063 supports_json_mode: true,
1064 max_output_tokens: None,
1065 };
1066
1067 let backends = vec![
1068 create_test_backend(
1069 "backend_a",
1070 "Backend A",
1071 BackendStatus::Healthy,
1072 vec![create_test_model("llama3:8b", 4096, false, false)],
1073 ),
1074 create_test_backend(
1075 "backend_b",
1076 "Backend B",
1077 BackendStatus::Healthy,
1078 vec![full_model],
1079 ),
1080 ];
1081
1082 let router = create_test_router(backends);
1083 let requirements = RequestRequirements {
1084 model: "llama3:8b".to_string(),
1085 estimated_tokens: 50000,
1086 needs_vision: true,
1087 needs_tools: true,
1088 needs_json_mode: true,
1089 prefers_streaming: false,
1090 };
1091
1092 let candidates = router.filter_candidates("llama3:8b", &requirements);
1093 assert_eq!(candidates.len(), 1);
1094 assert_eq!(candidates[0].name, "Backend B");
1095 }
1096}
1097
1098#[cfg(test)]
1099mod smart_strategy_tests {
1100 use super::*;
1101 use crate::registry::{Backend, BackendStatus, BackendType, DiscoverySource, Model};
1102 use chrono::Utc;
1103 use std::collections::HashMap;
1104 use std::sync::atomic::{AtomicU32, AtomicU64};
1105
1106 fn create_test_backend_with_state(
1107 id: &str,
1108 name: &str,
1109 priority: i32,
1110 pending_requests: u32,
1111 avg_latency_ms: u32,
1112 ) -> Backend {
1113 Backend {
1114 id: id.to_string(),
1115 name: name.to_string(),
1116 url: format!("http://{}", name),
1117 backend_type: BackendType::Ollama,
1118 status: BackendStatus::Healthy,
1119 last_health_check: Utc::now(),
1120 last_error: None,
1121 models: vec![Model {
1122 id: "llama3:8b".to_string(),
1123 name: "llama3:8b".to_string(),
1124 context_length: 4096,
1125 supports_vision: false,
1126 supports_tools: false,
1127 supports_json_mode: false,
1128 max_output_tokens: None,
1129 }],
1130 priority,
1131 pending_requests: AtomicU32::new(pending_requests),
1132 total_requests: AtomicU64::new(0),
1133 avg_latency_ms: AtomicU32::new(avg_latency_ms),
1134 discovery_source: DiscoverySource::Static,
1135 metadata: HashMap::new(),
1136 }
1137 }
1138
1139 fn create_test_router(backends: Vec<Backend>) -> Router {
1140 let registry = Arc::new(Registry::new());
1141 for backend in backends {
1142 registry.add_backend(backend).unwrap();
1143 }
1144
1145 Router::new(registry, RoutingStrategy::Smart, ScoringWeights::default())
1146 }
1147
1148 #[test]
1149 fn smart_selects_highest_score() {
1150 let backends = vec![
1151 create_test_backend_with_state("backend_a", "Backend A", 1, 0, 50),
1153 create_test_backend_with_state("backend_b", "Backend B", 10, 50, 500),
1155 ];
1156
1157 let router = create_test_router(backends);
1158 let requirements = RequestRequirements {
1159 model: "llama3:8b".to_string(),
1160 estimated_tokens: 100,
1161 needs_vision: false,
1162 needs_tools: false,
1163 needs_json_mode: false,
1164 prefers_streaming: false,
1165 };
1166
1167 let result = router.select_backend(&requirements, None).unwrap();
1168 assert_eq!(result.backend.name, "Backend A");
1169 }
1170
1171 #[test]
1172 fn smart_considers_load() {
1173 let backends = vec![
1174 create_test_backend_with_state("backend_a", "Backend A", 5, 0, 100),
1176 create_test_backend_with_state("backend_b", "Backend B", 5, 50, 100),
1177 ];
1178
1179 let router = create_test_router(backends);
1180 let requirements = RequestRequirements {
1181 model: "llama3:8b".to_string(),
1182 estimated_tokens: 100,
1183 needs_vision: false,
1184 needs_tools: false,
1185 needs_json_mode: false,
1186 prefers_streaming: false,
1187 };
1188
1189 let result = router.select_backend(&requirements, None).unwrap();
1190 assert_eq!(result.backend.name, "Backend A"); }
1192
1193 #[test]
1194 fn smart_considers_latency() {
1195 let backends = vec![
1196 create_test_backend_with_state("backend_a", "Backend A", 5, 10, 50),
1198 create_test_backend_with_state("backend_b", "Backend B", 5, 10, 500),
1199 ];
1200
1201 let router = create_test_router(backends);
1202 let requirements = RequestRequirements {
1203 model: "llama3:8b".to_string(),
1204 estimated_tokens: 100,
1205 needs_vision: false,
1206 needs_tools: false,
1207 needs_json_mode: false,
1208 prefers_streaming: false,
1209 };
1210
1211 let result = router.select_backend(&requirements, None).unwrap();
1212 assert_eq!(result.backend.name, "Backend A"); }
1214
1215 #[test]
1216 fn returns_error_when_no_candidates() {
1217 let backends = vec![create_test_backend_with_state(
1218 "backend_a",
1219 "Backend A",
1220 1,
1221 0,
1222 50,
1223 )];
1224
1225 let router = create_test_router(backends);
1226 let requirements = RequestRequirements {
1227 model: "nonexistent".to_string(),
1228 estimated_tokens: 100,
1229 needs_vision: false,
1230 needs_tools: false,
1231 needs_json_mode: false,
1232 prefers_streaming: false,
1233 };
1234
1235 let result = router.select_backend(&requirements, None);
1236 assert!(matches!(result, Err(RoutingError::ModelNotFound { .. })));
1237 }
1238}
1239
1240#[cfg(test)]
1241mod other_strategies_tests {
1242 use super::*;
1243 use crate::registry::{Backend, BackendStatus, BackendType, DiscoverySource, Model};
1244 use chrono::Utc;
1245 use std::collections::HashMap;
1246 use std::sync::atomic::{AtomicU32, AtomicU64};
1247
1248 fn create_test_backend_simple(id: &str, name: &str, priority: i32) -> Backend {
1249 Backend {
1250 id: id.to_string(),
1251 name: name.to_string(),
1252 url: format!("http://{}", name),
1253 backend_type: BackendType::Ollama,
1254 status: BackendStatus::Healthy,
1255 last_health_check: Utc::now(),
1256 last_error: None,
1257 models: vec![Model {
1258 id: "llama3:8b".to_string(),
1259 name: "llama3:8b".to_string(),
1260 context_length: 4096,
1261 supports_vision: false,
1262 supports_tools: false,
1263 supports_json_mode: false,
1264 max_output_tokens: None,
1265 }],
1266 priority,
1267 pending_requests: AtomicU32::new(0),
1268 total_requests: AtomicU64::new(0),
1269 avg_latency_ms: AtomicU32::new(50),
1270 discovery_source: DiscoverySource::Static,
1271 metadata: HashMap::new(),
1272 }
1273 }
1274
1275 fn create_test_router_with_strategy(
1276 backends: Vec<Backend>,
1277 strategy: RoutingStrategy,
1278 ) -> Router {
1279 let registry = Arc::new(Registry::new());
1280 for backend in backends {
1281 registry.add_backend(backend).unwrap();
1282 }
1283
1284 Router::new(registry, strategy, ScoringWeights::default())
1285 }
1286
1287 #[test]
1288 fn round_robin_cycles_through_backends() {
1289 let backends = vec![
1290 create_test_backend_simple("backend_a", "Backend A", 1),
1291 create_test_backend_simple("backend_b", "Backend B", 1),
1292 create_test_backend_simple("backend_c", "Backend C", 1),
1293 ];
1294
1295 let router = create_test_router_with_strategy(backends, RoutingStrategy::RoundRobin);
1296 let requirements = RequestRequirements {
1297 model: "llama3:8b".to_string(),
1298 estimated_tokens: 100,
1299 needs_vision: false,
1300 needs_tools: false,
1301 needs_json_mode: false,
1302 prefers_streaming: false,
1303 };
1304
1305 let names: Vec<String> = (0..6)
1307 .map(|_| {
1308 router
1309 .select_backend(&requirements, None)
1310 .unwrap()
1311 .backend
1312 .name
1313 .clone()
1314 })
1315 .collect();
1316
1317 assert_eq!(names[0], "Backend A");
1319 assert_eq!(names[1], "Backend B");
1320 assert_eq!(names[2], "Backend C");
1321 assert_eq!(names[3], "Backend A");
1322 assert_eq!(names[4], "Backend B");
1323 assert_eq!(names[5], "Backend C");
1324 }
1325
1326 #[test]
1327 fn priority_only_selects_lowest_priority() {
1328 let backends = vec![
1329 create_test_backend_simple("backend_a", "Backend A", 10),
1330 create_test_backend_simple("backend_b", "Backend B", 1),
1331 create_test_backend_simple("backend_c", "Backend C", 5),
1332 ];
1333
1334 let router = create_test_router_with_strategy(backends, RoutingStrategy::PriorityOnly);
1335 let requirements = RequestRequirements {
1336 model: "llama3:8b".to_string(),
1337 estimated_tokens: 100,
1338 needs_vision: false,
1339 needs_tools: false,
1340 needs_json_mode: false,
1341 prefers_streaming: false,
1342 };
1343
1344 for _ in 0..5 {
1346 let result = router.select_backend(&requirements, None).unwrap();
1347 assert_eq!(result.backend.name, "Backend B");
1348 }
1349 }
1350
1351 #[test]
1352 fn random_selects_from_candidates() {
1353 let backends = vec![
1354 create_test_backend_simple("backend_a", "Backend A", 1),
1355 create_test_backend_simple("backend_b", "Backend B", 1),
1356 create_test_backend_simple("backend_c", "Backend C", 1),
1357 ];
1358
1359 let router = create_test_router_with_strategy(backends, RoutingStrategy::Random);
1360 let requirements = RequestRequirements {
1361 model: "llama3:8b".to_string(),
1362 estimated_tokens: 100,
1363 needs_vision: false,
1364 needs_tools: false,
1365 needs_json_mode: false,
1366 prefers_streaming: false,
1367 };
1368
1369 let mut selected = HashMap::new();
1371 for _ in 0..30 {
1372 let result = router.select_backend(&requirements, None).unwrap();
1373 *selected.entry(result.backend.name.clone()).or_insert(0) += 1;
1374 }
1375
1376 assert!(selected.contains_key("Backend A"));
1378 assert!(selected.contains_key("Backend B"));
1379 assert!(selected.contains_key("Backend C"));
1380 }
1381}
1382
1383#[cfg(test)]
1384mod alias_and_fallback_tests {
1385 use super::*;
1386 use crate::registry::{Backend, BackendStatus, BackendType, DiscoverySource, Model};
1387 use chrono::Utc;
1388 use std::collections::HashMap;
1389 use std::sync::atomic::{AtomicU32, AtomicU64};
1390
1391 fn create_test_backend_with_model(id: &str, name: &str, model_id: &str) -> Backend {
1392 Backend {
1393 id: id.to_string(),
1394 name: name.to_string(),
1395 url: format!("http://{}", name),
1396 backend_type: BackendType::Ollama,
1397 status: BackendStatus::Healthy,
1398 last_health_check: Utc::now(),
1399 last_error: None,
1400 models: vec![Model {
1401 id: model_id.to_string(),
1402 name: model_id.to_string(),
1403 context_length: 4096,
1404 supports_vision: false,
1405 supports_tools: false,
1406 supports_json_mode: false,
1407 max_output_tokens: None,
1408 }],
1409 priority: 1,
1410 pending_requests: AtomicU32::new(0),
1411 total_requests: AtomicU64::new(0),
1412 avg_latency_ms: AtomicU32::new(50),
1413 discovery_source: DiscoverySource::Static,
1414 metadata: HashMap::new(),
1415 }
1416 }
1417
1418 #[test]
1419 fn resolves_alias_transparently() {
1420 let backends = vec![create_test_backend_with_model(
1421 "backend_a",
1422 "Backend A",
1423 "llama3:70b",
1424 )];
1425
1426 let registry = Arc::new(Registry::new());
1427 for backend in backends {
1428 registry.add_backend(backend).unwrap();
1429 }
1430
1431 let mut aliases = HashMap::new();
1432 aliases.insert("gpt-4".to_string(), "llama3:70b".to_string());
1433
1434 let router = Router::with_aliases_and_fallbacks(
1435 registry,
1436 RoutingStrategy::Smart,
1437 ScoringWeights::default(),
1438 aliases,
1439 HashMap::new(),
1440 );
1441
1442 let requirements = RequestRequirements {
1443 model: "gpt-4".to_string(),
1444 estimated_tokens: 100,
1445 needs_vision: false,
1446 needs_tools: false,
1447 needs_json_mode: false,
1448 prefers_streaming: false,
1449 };
1450
1451 let result = router.select_backend(&requirements, None).unwrap();
1452 assert_eq!(result.backend.name, "Backend A");
1453 }
1454
1455 #[test]
1456 fn uses_fallback_when_primary_unavailable() {
1457 let backends = vec![create_test_backend_with_model(
1458 "backend_a",
1459 "Backend A",
1460 "mistral:7b",
1461 )];
1462
1463 let registry = Arc::new(Registry::new());
1464 for backend in backends {
1465 registry.add_backend(backend).unwrap();
1466 }
1467
1468 let mut fallbacks = HashMap::new();
1469 fallbacks.insert(
1470 "llama3:70b".to_string(),
1471 vec!["llama3:8b".to_string(), "mistral:7b".to_string()],
1472 );
1473
1474 let router = Router::with_aliases_and_fallbacks(
1475 registry,
1476 RoutingStrategy::Smart,
1477 ScoringWeights::default(),
1478 HashMap::new(),
1479 fallbacks,
1480 );
1481
1482 let requirements = RequestRequirements {
1483 model: "llama3:70b".to_string(),
1484 estimated_tokens: 100,
1485 needs_vision: false,
1486 needs_tools: false,
1487 needs_json_mode: false,
1488 prefers_streaming: false,
1489 };
1490
1491 let result = router.select_backend(&requirements, None).unwrap();
1492 assert_eq!(result.backend.name, "Backend A");
1493 }
1494
1495 #[test]
1496 fn exhausts_fallback_chain() {
1497 let backends = vec![create_test_backend_with_model(
1498 "backend_a",
1499 "Backend A",
1500 "some-other-model",
1501 )];
1502
1503 let registry = Arc::new(Registry::new());
1504 for backend in backends {
1505 registry.add_backend(backend).unwrap();
1506 }
1507
1508 let mut fallbacks = HashMap::new();
1509 fallbacks.insert(
1510 "llama3:70b".to_string(),
1511 vec!["llama3:8b".to_string(), "mistral:7b".to_string()],
1512 );
1513
1514 let router = Router::with_aliases_and_fallbacks(
1515 registry,
1516 RoutingStrategy::Smart,
1517 ScoringWeights::default(),
1518 HashMap::new(),
1519 fallbacks,
1520 );
1521
1522 let requirements = RequestRequirements {
1523 model: "llama3:70b".to_string(),
1524 estimated_tokens: 100,
1525 needs_vision: false,
1526 needs_tools: false,
1527 needs_json_mode: false,
1528 prefers_streaming: false,
1529 };
1530
1531 let result = router.select_backend(&requirements, None);
1532 assert!(matches!(
1533 result,
1534 Err(RoutingError::FallbackChainExhausted { .. })
1535 ));
1536 }
1537
1538 #[test]
1539 fn alias_then_fallback() {
1540 let backends = vec![create_test_backend_with_model(
1541 "backend_a",
1542 "Backend A",
1543 "mistral:7b",
1544 )];
1545
1546 let registry = Arc::new(Registry::new());
1547 for backend in backends {
1548 registry.add_backend(backend).unwrap();
1549 }
1550
1551 let mut aliases = HashMap::new();
1552 aliases.insert("gpt-4".to_string(), "llama3:70b".to_string());
1553
1554 let mut fallbacks = HashMap::new();
1555 fallbacks.insert("llama3:70b".to_string(), vec!["mistral:7b".to_string()]);
1556
1557 let router = Router::with_aliases_and_fallbacks(
1558 registry,
1559 RoutingStrategy::Smart,
1560 ScoringWeights::default(),
1561 aliases,
1562 fallbacks,
1563 );
1564
1565 let requirements = RequestRequirements {
1566 model: "gpt-4".to_string(), estimated_tokens: 100,
1568 needs_vision: false,
1569 needs_tools: false,
1570 needs_json_mode: false,
1571 prefers_streaming: false,
1572 };
1573
1574 let result = router.select_backend(&requirements, None).unwrap();
1575 assert_eq!(result.backend.name, "Backend A");
1576 }
1577
1578 #[test]
1580 fn alias_chain_two_levels() {
1581 let backends = vec![create_test_backend_with_model(
1583 "backend_a",
1584 "Backend A",
1585 "llama3:70b",
1586 )];
1587
1588 let registry = Arc::new(Registry::new());
1589 for backend in backends {
1590 registry.add_backend(backend).unwrap();
1591 }
1592
1593 let mut aliases = HashMap::new();
1594 aliases.insert("gpt-4".to_string(), "llama-large".to_string());
1595 aliases.insert("llama-large".to_string(), "llama3:70b".to_string());
1596
1597 let router = Router::with_aliases_and_fallbacks(
1598 registry,
1599 RoutingStrategy::Smart,
1600 ScoringWeights::default(),
1601 aliases,
1602 HashMap::new(),
1603 );
1604
1605 let requirements = RequestRequirements {
1606 model: "gpt-4".to_string(),
1607 estimated_tokens: 100,
1608 needs_vision: false,
1609 needs_tools: false,
1610 needs_json_mode: false,
1611 prefers_streaming: false,
1612 };
1613
1614 let result = router.select_backend(&requirements, None).unwrap();
1617 assert_eq!(result.backend.name, "Backend A");
1618 }
1619
1620 #[test]
1621 fn alias_chain_three_levels() {
1622 let backends = vec![create_test_backend_with_model(
1624 "backend_a",
1625 "Backend A",
1626 "final-model",
1627 )];
1628
1629 let registry = Arc::new(Registry::new());
1630 for backend in backends {
1631 registry.add_backend(backend).unwrap();
1632 }
1633
1634 let mut aliases = HashMap::new();
1635 aliases.insert("a".to_string(), "b".to_string());
1636 aliases.insert("b".to_string(), "c".to_string());
1637 aliases.insert("c".to_string(), "final-model".to_string());
1638
1639 let router = Router::with_aliases_and_fallbacks(
1640 registry,
1641 RoutingStrategy::Smart,
1642 ScoringWeights::default(),
1643 aliases,
1644 HashMap::new(),
1645 );
1646
1647 let requirements = RequestRequirements {
1648 model: "a".to_string(),
1649 estimated_tokens: 100,
1650 needs_vision: false,
1651 needs_tools: false,
1652 needs_json_mode: false,
1653 prefers_streaming: false,
1654 };
1655
1656 let result = router.select_backend(&requirements, None).unwrap();
1659 assert_eq!(result.backend.name, "Backend A");
1660 }
1661
1662 #[test]
1663 fn alias_chain_stops_at_max_depth() {
1664 let backends = vec![
1667 create_test_backend_with_model("backend_d", "Backend D", "d"),
1668 create_test_backend_with_model("backend_e", "Backend E", "e"),
1669 ];
1670
1671 let registry = Arc::new(Registry::new());
1672 for backend in backends {
1673 registry.add_backend(backend).unwrap();
1674 }
1675
1676 let mut aliases = HashMap::new();
1677 aliases.insert("a".to_string(), "b".to_string());
1678 aliases.insert("b".to_string(), "c".to_string());
1679 aliases.insert("c".to_string(), "d".to_string());
1680 aliases.insert("d".to_string(), "e".to_string());
1681
1682 let router = Router::with_aliases_and_fallbacks(
1683 registry,
1684 RoutingStrategy::Smart,
1685 ScoringWeights::default(),
1686 aliases,
1687 HashMap::new(),
1688 );
1689
1690 let requirements = RequestRequirements {
1691 model: "a".to_string(),
1692 estimated_tokens: 100,
1693 needs_vision: false,
1694 needs_tools: false,
1695 needs_json_mode: false,
1696 prefers_streaming: false,
1697 };
1698
1699 let result = router.select_backend(&requirements, None).unwrap();
1702 assert_eq!(result.backend.name, "Backend D");
1703 }
1704
1705 #[test]
1706 fn alias_preserves_existing_single_level_behavior() {
1707 let backends = vec![create_test_backend_with_model(
1709 "backend_a",
1710 "Backend A",
1711 "llama3:70b",
1712 )];
1713
1714 let registry = Arc::new(Registry::new());
1715 for backend in backends {
1716 registry.add_backend(backend).unwrap();
1717 }
1718
1719 let mut aliases = HashMap::new();
1720 aliases.insert("gpt-4".to_string(), "llama3:70b".to_string());
1721
1722 let router = Router::with_aliases_and_fallbacks(
1723 registry,
1724 RoutingStrategy::Smart,
1725 ScoringWeights::default(),
1726 aliases,
1727 HashMap::new(),
1728 );
1729
1730 let requirements = RequestRequirements {
1731 model: "gpt-4".to_string(),
1732 estimated_tokens: 100,
1733 needs_vision: false,
1734 needs_tools: false,
1735 needs_json_mode: false,
1736 prefers_streaming: false,
1737 };
1738
1739 let result = router.select_backend(&requirements, None).unwrap();
1740 assert_eq!(result.backend.name, "Backend A");
1741 }
1742
1743 #[test]
1745 fn routing_result_contains_fallback_info() {
1746 let backends = vec![create_test_backend_with_model(
1748 "backend_fallback",
1749 "Backend Fallback",
1750 "fallback",
1751 )];
1752
1753 let registry = Arc::new(Registry::new());
1754 for backend in backends {
1755 registry.add_backend(backend).unwrap();
1756 }
1757
1758 let mut fallbacks = HashMap::new();
1759 fallbacks.insert("primary".to_string(), vec!["fallback".to_string()]);
1760
1761 let router = Router::with_aliases_and_fallbacks(
1762 registry,
1763 RoutingStrategy::Smart,
1764 ScoringWeights::default(),
1765 HashMap::new(),
1766 fallbacks,
1767 );
1768
1769 let requirements = RequestRequirements {
1771 model: "primary".to_string(),
1772 estimated_tokens: 100,
1773 needs_vision: false,
1774 needs_tools: false,
1775 needs_json_mode: false,
1776 prefers_streaming: false,
1777 };
1778
1779 let result = router.select_backend(&requirements, None).unwrap();
1781
1782 assert!(result.fallback_used, "Expected fallback_used to be true");
1784 assert_eq!(result.actual_model, "fallback");
1786 assert_eq!(result.backend.name, "Backend Fallback");
1788 }
1789
1790 #[test]
1791 fn routing_result_no_fallback_when_primary_used() {
1792 let backends = vec![
1794 create_test_backend_with_model("backend_primary", "Backend Primary", "primary"),
1795 create_test_backend_with_model("backend_fallback", "Backend Fallback", "fallback"),
1796 ];
1797
1798 let registry = Arc::new(Registry::new());
1799 for backend in backends {
1800 registry.add_backend(backend).unwrap();
1801 }
1802
1803 let mut fallbacks = HashMap::new();
1804 fallbacks.insert("primary".to_string(), vec!["fallback".to_string()]);
1805
1806 let router = Router::with_aliases_and_fallbacks(
1807 registry,
1808 RoutingStrategy::Smart,
1809 ScoringWeights::default(),
1810 HashMap::new(),
1811 fallbacks,
1812 );
1813
1814 let requirements = RequestRequirements {
1816 model: "primary".to_string(),
1817 estimated_tokens: 100,
1818 needs_vision: false,
1819 needs_tools: false,
1820 needs_json_mode: false,
1821 prefers_streaming: false,
1822 };
1823
1824 let result = router.select_backend(&requirements, None).unwrap();
1826
1827 assert!(!result.fallback_used, "Expected fallback_used to be false");
1829 assert_eq!(result.actual_model, "primary");
1831 assert_eq!(result.backend.name, "Backend Primary");
1833 }
1834
1835 #[test]
1836 fn test_circular_alias_detection() {
1837 let backends = vec![
1840 create_test_backend_with_model("backend_a", "Backend A", "a"),
1841 create_test_backend_with_model("backend_b", "Backend B", "b"),
1842 create_test_backend_with_model("backend_c", "Backend C", "c"),
1843 ];
1844
1845 let registry = Arc::new(Registry::new());
1846 for backend in backends {
1847 registry.add_backend(backend).unwrap();
1848 }
1849
1850 let mut aliases = HashMap::new();
1851 aliases.insert("a".to_string(), "b".to_string());
1852 aliases.insert("b".to_string(), "c".to_string());
1853 aliases.insert("c".to_string(), "a".to_string());
1854
1855 let router = Router::with_aliases_and_fallbacks(
1856 registry,
1857 RoutingStrategy::Smart,
1858 ScoringWeights::default(),
1859 aliases,
1860 HashMap::new(),
1861 );
1862
1863 let resolved = router.resolve_alias("a");
1865 assert_eq!(resolved, "a", "Circular alias should stop at MAX_DEPTH");
1866
1867 let requirements = RequestRequirements {
1869 model: "a".to_string(),
1870 estimated_tokens: 100,
1871 needs_vision: false,
1872 needs_tools: false,
1873 needs_json_mode: false,
1874 prefers_streaming: false,
1875 };
1876 let result = router.select_backend(&requirements, None);
1877 assert!(
1878 result.is_ok(),
1879 "Should route to model 'a' after circular alias resolution"
1880 );
1881 }
1882
1883 #[test]
1884 fn test_alias_max_depth() {
1885 let backends = vec![
1888 create_test_backend_with_model("backend_w", "Backend W", "w"),
1889 create_test_backend_with_model("backend_final", "Backend Final", "final"),
1890 ];
1891
1892 let registry = Arc::new(Registry::new());
1893 for backend in backends {
1894 registry.add_backend(backend).unwrap();
1895 }
1896
1897 let mut aliases = HashMap::new();
1898 aliases.insert("x".to_string(), "y".to_string());
1899 aliases.insert("y".to_string(), "z".to_string());
1900 aliases.insert("z".to_string(), "w".to_string());
1901 aliases.insert("w".to_string(), "final".to_string());
1902
1903 let router = Router::with_aliases_and_fallbacks(
1904 registry,
1905 RoutingStrategy::Smart,
1906 ScoringWeights::default(),
1907 aliases,
1908 HashMap::new(),
1909 );
1910
1911 let resolved = router.resolve_alias("x");
1913 assert_eq!(resolved, "w", "Should stop at 3 levels, resolving to 'w'");
1914
1915 let requirements = RequestRequirements {
1916 model: "x".to_string(),
1917 estimated_tokens: 100,
1918 needs_vision: false,
1919 needs_tools: false,
1920 needs_json_mode: false,
1921 prefers_streaming: false,
1922 };
1923 let result = router.select_backend(&requirements, None).unwrap();
1924 assert_eq!(result.backend.name, "Backend W");
1925 }
1926}
1927
1928#[cfg(test)]
1929mod constructor_tests {
1930 use super::*;
1931 use crate::config::{BudgetConfig, PolicyMatcher, QualityConfig};
1932 use crate::registry::{Backend, BackendStatus, BackendType, DiscoverySource, Model, Registry};
1933 use chrono::Utc;
1934 use dashmap::DashMap;
1935 use std::collections::HashMap;
1936 use std::sync::atomic::{AtomicU32, AtomicU64};
1937
1938 fn create_test_backend_with_model(id: &str, name: &str, model_id: &str) -> Backend {
1939 Backend {
1940 id: id.to_string(),
1941 name: name.to_string(),
1942 url: format!("http://{}", name),
1943 backend_type: BackendType::Ollama,
1944 status: BackendStatus::Healthy,
1945 last_health_check: Utc::now(),
1946 last_error: None,
1947 models: vec![Model {
1948 id: model_id.to_string(),
1949 name: model_id.to_string(),
1950 context_length: 4096,
1951 supports_vision: false,
1952 supports_tools: false,
1953 supports_json_mode: false,
1954 max_output_tokens: None,
1955 }],
1956 priority: 1,
1957 pending_requests: AtomicU32::new(0),
1958 total_requests: AtomicU64::new(0),
1959 avg_latency_ms: AtomicU32::new(50),
1960 discovery_source: DiscoverySource::Static,
1961 metadata: HashMap::new(),
1962 }
1963 }
1964
1965 fn simple_requirements(model: &str) -> RequestRequirements {
1966 RequestRequirements {
1967 model: model.to_string(),
1968 estimated_tokens: 100,
1969 needs_vision: false,
1970 needs_tools: false,
1971 needs_json_mode: false,
1972 prefers_streaming: false,
1973 }
1974 }
1975
1976 #[test]
1977 fn test_with_full_config() {
1978 let registry = Arc::new(Registry::new());
1979 registry
1980 .add_backend(create_test_backend_with_model(
1981 "b1",
1982 "Backend1",
1983 "llama3:8b",
1984 ))
1985 .unwrap();
1986
1987 let budget_config = BudgetConfig {
1988 monthly_limit_usd: Some(100.0),
1989 soft_limit_percent: 75.0,
1990 ..BudgetConfig::default()
1991 };
1992 let budget_state = Arc::new(DashMap::new());
1993
1994 let router = Router::with_full_config(
1995 registry,
1996 RoutingStrategy::Smart,
1997 ScoringWeights::default(),
1998 HashMap::new(),
1999 HashMap::new(),
2000 PolicyMatcher::default(),
2001 budget_config,
2002 budget_state,
2003 );
2004
2005 let result = router.select_backend(&simple_requirements("llama3:8b"), None);
2006 assert!(result.is_ok());
2007 assert_eq!(result.unwrap().backend.name, "Backend1");
2008 }
2009
2010 #[test]
2011 fn test_with_aliases_fallbacks_and_policies() {
2012 let registry = Arc::new(Registry::new());
2013 registry
2014 .add_backend(create_test_backend_with_model(
2015 "b1",
2016 "Backend1",
2017 "llama3:8b",
2018 ))
2019 .unwrap();
2020
2021 let mut aliases = HashMap::new();
2022 aliases.insert("gpt-4".to_string(), "llama3:8b".to_string());
2023
2024 let router = Router::with_aliases_fallbacks_and_policies(
2025 registry,
2026 RoutingStrategy::Smart,
2027 ScoringWeights::default(),
2028 aliases,
2029 HashMap::new(),
2030 PolicyMatcher::default(),
2031 QualityConfig::default(),
2032 );
2033
2034 let result = router.select_backend(&simple_requirements("gpt-4"), None);
2036 assert!(result.is_ok());
2037 assert_eq!(result.unwrap().backend.name, "Backend1");
2038 }
2039}
2040
2041#[cfg(test)]
2042mod select_backend_error_tests {
2043 use super::*;
2044 use crate::registry::{Backend, BackendStatus, BackendType, DiscoverySource, Model, Registry};
2045 use chrono::Utc;
2046 use std::collections::HashMap;
2047 use std::sync::atomic::{AtomicU32, AtomicU64};
2048
2049 fn create_backend(id: &str, name: &str, status: BackendStatus, models: Vec<Model>) -> Backend {
2050 Backend {
2051 id: id.to_string(),
2052 name: name.to_string(),
2053 url: format!("http://{}", name),
2054 backend_type: BackendType::Ollama,
2055 status,
2056 last_health_check: Utc::now(),
2057 last_error: None,
2058 models,
2059 priority: 1,
2060 pending_requests: AtomicU32::new(0),
2061 total_requests: AtomicU64::new(0),
2062 avg_latency_ms: AtomicU32::new(50),
2063 discovery_source: DiscoverySource::Static,
2064 metadata: HashMap::new(),
2065 }
2066 }
2067
2068 fn simple_requirements(model: &str) -> RequestRequirements {
2069 RequestRequirements {
2070 model: model.to_string(),
2071 estimated_tokens: 100,
2072 needs_vision: false,
2073 needs_tools: false,
2074 needs_json_mode: false,
2075 prefers_streaming: false,
2076 }
2077 }
2078
2079 #[test]
2080 fn test_select_backend_model_not_found() {
2081 let registry = Arc::new(Registry::new());
2082 registry
2083 .add_backend(create_backend(
2084 "b1",
2085 "Backend1",
2086 BackendStatus::Healthy,
2087 vec![Model {
2088 id: "llama3:8b".to_string(),
2089 name: "llama3:8b".to_string(),
2090 context_length: 4096,
2091 supports_vision: false,
2092 supports_tools: false,
2093 supports_json_mode: false,
2094 max_output_tokens: None,
2095 }],
2096 ))
2097 .unwrap();
2098
2099 let router = Router::new(registry, RoutingStrategy::Smart, ScoringWeights::default());
2100 let result = router.select_backend(&simple_requirements("nonexistent-model"), None);
2101
2102 assert!(result.is_err());
2103 match result.unwrap_err() {
2104 RoutingError::ModelNotFound { model } => {
2105 assert_eq!(model, "nonexistent-model");
2106 }
2107 other => panic!("Expected ModelNotFound, got: {:?}", other),
2108 }
2109 }
2110
2111 #[test]
2112 fn test_select_backend_no_healthy_backend() {
2113 let registry = Arc::new(Registry::new());
2114 registry
2115 .add_backend(create_backend(
2116 "b1",
2117 "Backend1",
2118 BackendStatus::Unhealthy,
2119 vec![Model {
2120 id: "llama3:8b".to_string(),
2121 name: "llama3:8b".to_string(),
2122 context_length: 4096,
2123 supports_vision: false,
2124 supports_tools: false,
2125 supports_json_mode: false,
2126 max_output_tokens: None,
2127 }],
2128 ))
2129 .unwrap();
2130
2131 let router = Router::new(registry, RoutingStrategy::Smart, ScoringWeights::default());
2132 let result = router.select_backend(&simple_requirements("llama3:8b"), None);
2133
2134 assert!(result.is_err());
2135 match result.unwrap_err() {
2136 RoutingError::NoHealthyBackend { model } => {
2137 assert_eq!(model, "llama3:8b");
2138 }
2139 other => panic!("Expected NoHealthyBackend, got: {:?}", other),
2140 }
2141 }
2142
2143 #[test]
2144 fn test_select_backend_capability_mismatch() {
2145 let registry = Arc::new(Registry::new());
2147 registry
2148 .add_backend(create_backend(
2149 "b1",
2150 "Backend1",
2151 BackendStatus::Healthy,
2152 vec![Model {
2153 id: "llama3:8b".to_string(),
2154 name: "llama3:8b".to_string(),
2155 context_length: 4096,
2156 supports_vision: false,
2157 supports_tools: false,
2158 supports_json_mode: false,
2159 max_output_tokens: None,
2160 }],
2161 ))
2162 .unwrap();
2163
2164 let router = Router::new(registry, RoutingStrategy::Smart, ScoringWeights::default());
2165
2166 let requirements = RequestRequirements {
2167 model: "llama3:8b".to_string(),
2168 estimated_tokens: 100,
2169 needs_vision: true,
2170 needs_tools: false,
2171 needs_json_mode: false,
2172 prefers_streaming: false,
2173 };
2174
2175 let result = router.select_backend(&requirements, None);
2176 assert!(
2180 result.is_err(),
2181 "Expected error for vision capability mismatch"
2182 );
2183 let err = result.unwrap_err();
2184 match &err {
2185 RoutingError::NoHealthyBackend { .. }
2186 | RoutingError::CapabilityMismatch { .. }
2187 | RoutingError::Reject { .. } => {
2188 }
2190 other => panic!(
2191 "Expected NoHealthyBackend, CapabilityMismatch, or Reject, got: {:?}",
2192 other
2193 ),
2194 }
2195 }
2196}
2197
2198#[cfg(test)]
2199mod budget_status_tests {
2200 use super::*;
2201 use crate::config::BudgetConfig;
2202 use crate::registry::{Backend, BackendStatus, BackendType, DiscoverySource, Model, Registry};
2203 use crate::routing::reconciler::budget::{BudgetMetrics, GLOBAL_BUDGET_KEY};
2204 use crate::routing::reconciler::intent::BudgetStatus;
2205 use chrono::Utc;
2206 use dashmap::DashMap;
2207 use std::collections::HashMap;
2208 use std::sync::atomic::{AtomicU32, AtomicU64};
2209
2210 fn create_backend_with_model(model_id: &str) -> Backend {
2211 Backend {
2212 id: "b1".to_string(),
2213 name: "Backend1".to_string(),
2214 url: "http://backend1".to_string(),
2215 backend_type: BackendType::Ollama,
2216 status: BackendStatus::Healthy,
2217 last_health_check: Utc::now(),
2218 last_error: None,
2219 models: vec![Model {
2220 id: model_id.to_string(),
2221 name: model_id.to_string(),
2222 context_length: 4096,
2223 supports_vision: false,
2224 supports_tools: false,
2225 supports_json_mode: false,
2226 max_output_tokens: None,
2227 }],
2228 priority: 1,
2229 pending_requests: AtomicU32::new(0),
2230 total_requests: AtomicU64::new(0),
2231 avg_latency_ms: AtomicU32::new(50),
2232 discovery_source: DiscoverySource::Static,
2233 metadata: HashMap::new(),
2234 }
2235 }
2236
2237 fn make_router_with_budget(
2238 monthly_limit: Option<f64>,
2239 soft_limit_percent: f64,
2240 spending: f64,
2241 ) -> Router {
2242 let registry = Arc::new(Registry::new());
2243 registry
2244 .add_backend(create_backend_with_model("llama3:8b"))
2245 .unwrap();
2246
2247 let budget_config = BudgetConfig {
2248 monthly_limit_usd: monthly_limit,
2249 soft_limit_percent,
2250 ..BudgetConfig::default()
2251 };
2252 let budget_state = Arc::new(DashMap::new());
2253 if spending > 0.0 {
2254 budget_state.insert(
2255 GLOBAL_BUDGET_KEY.to_string(),
2256 BudgetMetrics {
2257 current_month_spending: spending,
2258 last_reconciliation_time: Utc::now(),
2259 month_key: Utc::now().format("%Y-%m").to_string(),
2260 },
2261 );
2262 }
2263
2264 Router::with_full_config(
2265 registry,
2266 RoutingStrategy::Smart,
2267 ScoringWeights::default(),
2268 HashMap::new(),
2269 HashMap::new(),
2270 PolicyMatcher::default(),
2271 budget_config,
2272 budget_state,
2273 )
2274 }
2275
2276 #[test]
2277 fn budget_normal_when_no_limit() {
2278 let router = make_router_with_budget(None, 75.0, 0.0);
2279 let (status, utilization, remaining) = router.get_budget_status_and_utilization();
2280
2281 assert!(matches!(status, BudgetStatus::Normal));
2282 assert!(utilization.is_none());
2283 assert!(remaining.is_none());
2284 }
2285
2286 #[test]
2287 fn budget_normal_when_zero_limit() {
2288 let router = make_router_with_budget(Some(0.0), 75.0, 0.0);
2289 let (status, utilization, remaining) = router.get_budget_status_and_utilization();
2290
2291 assert!(matches!(status, BudgetStatus::Normal));
2292 assert!(utilization.is_none());
2293 assert!(remaining.is_none());
2294 }
2295
2296 #[test]
2297 fn budget_normal_when_below_soft_limit() {
2298 let router = make_router_with_budget(Some(100.0), 75.0, 50.0);
2300 let (status, utilization, remaining) = router.get_budget_status_and_utilization();
2301
2302 assert!(matches!(status, BudgetStatus::Normal));
2303 let util = utilization.unwrap();
2304 assert!((util - 50.0).abs() < 0.01);
2305 let rem = remaining.unwrap();
2306 assert!((rem - 50.0).abs() < 0.01);
2307 }
2308
2309 #[test]
2310 fn budget_soft_limit_when_at_threshold() {
2311 let router = make_router_with_budget(Some(100.0), 75.0, 75.0);
2313 let (status, utilization, remaining) = router.get_budget_status_and_utilization();
2314
2315 assert!(matches!(status, BudgetStatus::SoftLimit));
2316 let util = utilization.unwrap();
2317 assert!((util - 75.0).abs() < 0.01);
2318 let rem = remaining.unwrap();
2319 assert!((rem - 25.0).abs() < 0.01);
2320 }
2321
2322 #[test]
2323 fn budget_soft_limit_when_above_soft_below_hard() {
2324 let router = make_router_with_budget(Some(100.0), 75.0, 90.0);
2326 let (status, utilization, remaining) = router.get_budget_status_and_utilization();
2327
2328 assert!(matches!(status, BudgetStatus::SoftLimit));
2329 let util = utilization.unwrap();
2330 assert!((util - 90.0).abs() < 0.01);
2331 let rem = remaining.unwrap();
2332 assert!((rem - 10.0).abs() < 0.01);
2333 }
2334
2335 #[test]
2336 fn budget_hard_limit_when_at_100_percent() {
2337 let router = make_router_with_budget(Some(100.0), 75.0, 100.0);
2339 let (status, utilization, remaining) = router.get_budget_status_and_utilization();
2340
2341 assert!(matches!(status, BudgetStatus::HardLimit));
2342 let util = utilization.unwrap();
2343 assert!((util - 100.0).abs() < 0.01);
2344 let rem = remaining.unwrap();
2345 assert!((rem - 0.0).abs() < 0.01);
2346 }
2347
2348 #[test]
2349 fn budget_hard_limit_when_over_budget() {
2350 let router = make_router_with_budget(Some(100.0), 75.0, 150.0);
2352 let (status, utilization, remaining) = router.get_budget_status_and_utilization();
2353
2354 assert!(matches!(status, BudgetStatus::HardLimit));
2355 let util = utilization.unwrap();
2356 assert!((util - 150.0).abs() < 0.01);
2357 let rem = remaining.unwrap();
2359 assert!((rem - 0.0).abs() < 0.01);
2360 }
2361
2362 #[test]
2363 fn budget_normal_when_no_spending_recorded() {
2364 let router = make_router_with_budget(Some(100.0), 75.0, 0.0);
2366 let (status, utilization, remaining) = router.get_budget_status_and_utilization();
2367
2368 assert!(matches!(status, BudgetStatus::Normal));
2369 let util = utilization.unwrap();
2370 assert!((util - 0.0).abs() < 0.01);
2371 let rem = remaining.unwrap();
2372 assert!((rem - 100.0).abs() < 0.01);
2373 }
2374
2375 #[test]
2376 fn test_select_backend_queue_decision() {
2377 let registry = Arc::new(Registry::new());
2380 registry
2381 .add_backend(create_backend_with_model("llama3:8b"))
2382 .unwrap();
2383 registry
2385 .update_status(
2386 "b1",
2387 crate::registry::BackendStatus::Unhealthy,
2388 Some("test".to_string()),
2389 )
2390 .unwrap();
2391
2392 let mut router = Router::new(registry, RoutingStrategy::Smart, ScoringWeights::default());
2393 router.set_queue_enabled(true);
2394
2395 let requirements = RequestRequirements {
2396 model: "llama3:8b".to_string(),
2397 estimated_tokens: 100,
2398 needs_vision: false,
2399 needs_tools: false,
2400 needs_json_mode: false,
2401 prefers_streaming: false,
2402 };
2403
2404 let result = router.select_backend(&requirements, None);
2405 assert!(result.is_err());
2406 match result.unwrap_err() {
2407 RoutingError::Queue { reason, .. } => {
2408 assert!(
2409 reason.contains("capacity"),
2410 "Expected capacity reason, got: {}",
2411 reason
2412 );
2413 }
2414 other => panic!("Expected Queue, got: {:?}", other),
2415 }
2416 }
2417
2418 #[test]
2419 fn test_select_backend_with_fallback_chain_all_unhealthy() {
2420 let registry = Arc::new(Registry::new());
2421
2422 let mut backend = create_backend_with_model("primary-model");
2424 backend.models.push(Model {
2425 id: "fallback-model".to_string(),
2426 name: "fallback-model".to_string(),
2427 context_length: 4096,
2428 supports_vision: false,
2429 supports_tools: false,
2430 supports_json_mode: false,
2431 max_output_tokens: None,
2432 });
2433 registry.add_backend(backend).unwrap();
2434 registry
2435 .update_status(
2436 "b1",
2437 crate::registry::BackendStatus::Unhealthy,
2438 Some("down".to_string()),
2439 )
2440 .unwrap();
2441
2442 let mut fallbacks = HashMap::new();
2443 fallbacks.insert(
2444 "primary-model".to_string(),
2445 vec!["fallback-model".to_string()],
2446 );
2447
2448 let router = Router::with_aliases_and_fallbacks(
2449 registry,
2450 RoutingStrategy::Smart,
2451 ScoringWeights::default(),
2452 HashMap::new(),
2453 fallbacks,
2454 );
2455
2456 let requirements = RequestRequirements {
2457 model: "primary-model".to_string(),
2458 estimated_tokens: 100,
2459 needs_vision: false,
2460 needs_tools: false,
2461 needs_json_mode: false,
2462 prefers_streaming: false,
2463 };
2464
2465 let result = router.select_backend(&requirements, None);
2466 assert!(result.is_err());
2467 match result.unwrap_err() {
2468 RoutingError::FallbackChainExhausted { chain } => {
2469 assert!(chain.contains(&"primary-model".to_string()));
2470 assert!(chain.contains(&"fallback-model".to_string()));
2471 }
2472 other => panic!("Expected FallbackChainExhausted, got: {:?}", other),
2473 }
2474 }
2475
2476 #[test]
2477 fn test_select_backend_fallback_succeeds() {
2478 let registry = Arc::new(Registry::new());
2480
2481 let mut primary_backend = create_backend_with_model("primary-model");
2483 primary_backend.id = "b-primary".to_string();
2484 primary_backend.name = "PrimaryBackend".to_string();
2485 registry.add_backend(primary_backend).unwrap();
2486 registry
2487 .update_status(
2488 "b-primary",
2489 crate::registry::BackendStatus::Unhealthy,
2490 Some("down".to_string()),
2491 )
2492 .unwrap();
2493
2494 let mut fallback_backend = create_backend_with_model("fallback-model");
2496 fallback_backend.id = "b-fallback".to_string();
2497 fallback_backend.name = "FallbackBackend".to_string();
2498 registry.add_backend(fallback_backend).unwrap();
2499
2500 let mut fallbacks = HashMap::new();
2501 fallbacks.insert(
2502 "primary-model".to_string(),
2503 vec!["fallback-model".to_string()],
2504 );
2505
2506 let router = Router::with_aliases_and_fallbacks(
2507 registry,
2508 RoutingStrategy::Smart,
2509 ScoringWeights::default(),
2510 HashMap::new(),
2511 fallbacks,
2512 );
2513
2514 let requirements = RequestRequirements {
2515 model: "primary-model".to_string(),
2516 estimated_tokens: 100,
2517 needs_vision: false,
2518 needs_tools: false,
2519 needs_json_mode: false,
2520 prefers_streaming: false,
2521 };
2522
2523 let result = router.select_backend(&requirements, None);
2524 assert!(result.is_ok(), "Fallback should succeed");
2525 let routing_result = result.unwrap();
2526 assert!(routing_result.fallback_used);
2527 assert_eq!(routing_result.actual_model, "fallback-model");
2528 assert!(routing_result.route_reason.starts_with("fallback:"));
2529 }
2530
2531 #[test]
2532 fn test_select_backend_success_returns_budget_fields() {
2533 let registry = Arc::new(Registry::new());
2534 registry
2535 .add_backend(create_backend_with_model("llama3:8b"))
2536 .unwrap();
2537
2538 let budget_config = BudgetConfig {
2539 monthly_limit_usd: Some(100.0),
2540 soft_limit_percent: 75.0,
2541 ..BudgetConfig::default()
2542 };
2543 let budget_state = Arc::new(DashMap::new());
2544 budget_state.insert(
2545 crate::routing::reconciler::budget::GLOBAL_BUDGET_KEY.to_string(),
2546 BudgetMetrics {
2547 current_month_spending: 50.0,
2548 last_reconciliation_time: chrono::Utc::now(),
2549 month_key: chrono::Utc::now().format("%Y-%m").to_string(),
2550 },
2551 );
2552
2553 let router = Router::with_full_config(
2554 registry,
2555 RoutingStrategy::Smart,
2556 ScoringWeights::default(),
2557 HashMap::new(),
2558 HashMap::new(),
2559 PolicyMatcher::default(),
2560 budget_config,
2561 budget_state,
2562 );
2563
2564 let requirements = RequestRequirements {
2565 model: "llama3:8b".to_string(),
2566 estimated_tokens: 100,
2567 needs_vision: false,
2568 needs_tools: false,
2569 needs_json_mode: false,
2570 prefers_streaming: false,
2571 };
2572
2573 let result = router.select_backend(&requirements, None);
2574 assert!(result.is_ok());
2575 let r = result.unwrap();
2576 assert_eq!(r.budget_status, BudgetStatus::Normal);
2577 assert!(r.budget_utilization.is_some());
2578 assert!(r.budget_remaining.is_some());
2579 assert!(!r.fallback_used);
2580 assert!(r.cost_estimated.is_some());
2581 }
2582
2583 #[test]
2584 fn test_select_smart_selects_best_scored_backend() {
2585 let registry = Arc::new(Registry::new());
2586 let mut b1 = create_backend_with_model("llama3:8b");
2588 b1.id = "b1".to_string();
2589 b1.name = "B1".to_string();
2590 b1.priority = 1;
2591 registry.add_backend(b1).unwrap();
2592 let mut b2 = create_backend_with_model("llama3:8b");
2594 b2.id = "b2".to_string();
2595 b2.name = "B2".to_string();
2596 b2.priority = 10;
2597 registry.add_backend(b2).unwrap();
2598
2599 let router = Router::new(registry, RoutingStrategy::Smart, ScoringWeights::default());
2600 let result = router.select_backend(
2601 &RequestRequirements {
2602 model: "llama3:8b".to_string(),
2603 estimated_tokens: 100,
2604 needs_vision: false,
2605 needs_tools: false,
2606 needs_json_mode: false,
2607 prefers_streaming: false,
2608 },
2609 None,
2610 );
2611 assert!(result.is_ok());
2612 }
2613
2614 #[test]
2615 fn test_select_priority_only_via_router() {
2616 let registry = Arc::new(Registry::new());
2617 let mut b1 = create_backend_with_model("llama3:8b");
2619 b1.id = "b1".to_string();
2620 b1.name = "B1".to_string();
2621 b1.priority = 10;
2622 registry.add_backend(b1).unwrap();
2623 let mut b2 = create_backend_with_model("llama3:8b");
2625 b2.id = "b2".to_string();
2626 b2.name = "B2".to_string();
2627 b2.priority = 1;
2628 registry.add_backend(b2).unwrap();
2629
2630 let router = Router::new(
2631 registry,
2632 RoutingStrategy::PriorityOnly,
2633 ScoringWeights::default(),
2634 );
2635 let result = router.select_backend(
2636 &RequestRequirements {
2637 model: "llama3:8b".to_string(),
2638 estimated_tokens: 100,
2639 needs_vision: false,
2640 needs_tools: false,
2641 needs_json_mode: false,
2642 prefers_streaming: false,
2643 },
2644 None,
2645 );
2646 assert!(result.is_ok());
2647 let r = result.unwrap();
2648 assert_eq!(r.backend.name, "B2");
2649 }
2650
2651 #[test]
2652 fn test_select_round_robin_via_router() {
2653 let registry = Arc::new(Registry::new());
2654 let mut b1 = create_backend_with_model("llama3:8b");
2655 b1.id = "b1".to_string();
2656 b1.name = "B1".to_string();
2657 registry.add_backend(b1).unwrap();
2658 let mut b2 = create_backend_with_model("llama3:8b");
2659 b2.id = "b2".to_string();
2660 b2.name = "B2".to_string();
2661 registry.add_backend(b2).unwrap();
2662
2663 let router = Router::new(
2664 registry,
2665 RoutingStrategy::RoundRobin,
2666 ScoringWeights::default(),
2667 );
2668 let reqs = RequestRequirements {
2669 model: "llama3:8b".to_string(),
2670 estimated_tokens: 100,
2671 needs_vision: false,
2672 needs_tools: false,
2673 needs_json_mode: false,
2674 prefers_streaming: false,
2675 };
2676
2677 let r1 = router.select_backend(&reqs, None).unwrap();
2678 let r2 = router.select_backend(&reqs, None).unwrap();
2679 assert_ne!(r1.backend.name, r2.backend.name);
2681 }
2682
2683 #[test]
2684 fn test_select_random_via_router() {
2685 let registry = Arc::new(Registry::new());
2686 let mut b1 = create_backend_with_model("llama3:8b");
2687 b1.id = "b1".to_string();
2688 b1.name = "B1".to_string();
2689 registry.add_backend(b1).unwrap();
2690
2691 let router = Router::new(registry, RoutingStrategy::Random, ScoringWeights::default());
2692 let result = router.select_backend(
2693 &RequestRequirements {
2694 model: "llama3:8b".to_string(),
2695 estimated_tokens: 100,
2696 needs_vision: false,
2697 needs_tools: false,
2698 needs_json_mode: false,
2699 prefers_streaming: false,
2700 },
2701 None,
2702 );
2703 assert!(result.is_ok());
2704 }
2705
2706 #[test]
2707 fn test_select_backend_with_tier_enforcement() {
2708 let registry = Arc::new(Registry::new());
2709 registry
2710 .add_backend(create_backend_with_model("llama3:8b"))
2711 .unwrap();
2712
2713 let router = Router::new(registry, RoutingStrategy::Smart, ScoringWeights::default());
2714 let result = router.select_backend(
2715 &RequestRequirements {
2716 model: "llama3:8b".to_string(),
2717 estimated_tokens: 100,
2718 needs_vision: false,
2719 needs_tools: false,
2720 needs_json_mode: false,
2721 prefers_streaming: false,
2722 },
2723 Some(crate::routing::reconciler::intent::TierEnforcementMode::Strict),
2724 );
2725 assert!(result.is_ok());
2726 }
2727
2728 #[test]
2729 fn test_filter_candidates_legacy() {
2730 let registry = Arc::new(Registry::new());
2731 let mut backend = create_backend_with_model("llama3:8b");
2732 backend.models[0].supports_vision = true;
2733 backend.models[0].supports_tools = true;
2734 backend.models[0].supports_json_mode = true;
2735 registry.add_backend(backend).unwrap();
2736
2737 let router = Router::new(registry, RoutingStrategy::Smart, ScoringWeights::default());
2738
2739 let candidates = router.filter_candidates(
2741 "llama3:8b",
2742 &RequestRequirements {
2743 model: "llama3:8b".to_string(),
2744 estimated_tokens: 100,
2745 needs_vision: true,
2746 needs_tools: true,
2747 needs_json_mode: true,
2748 prefers_streaming: false,
2749 },
2750 );
2751 assert_eq!(candidates.len(), 1);
2752 }
2753
2754 #[test]
2755 fn test_filter_candidates_vision_mismatch() {
2756 let registry = Arc::new(Registry::new());
2757 registry
2758 .add_backend(create_backend_with_model("llama3:8b"))
2759 .unwrap();
2760
2761 let router = Router::new(registry, RoutingStrategy::Smart, ScoringWeights::default());
2762
2763 let candidates = router.filter_candidates(
2764 "llama3:8b",
2765 &RequestRequirements {
2766 model: "llama3:8b".to_string(),
2767 estimated_tokens: 100,
2768 needs_vision: true,
2769 needs_tools: false,
2770 needs_json_mode: false,
2771 prefers_streaming: false,
2772 },
2773 );
2774 assert!(candidates.is_empty());
2775 }
2776
2777 #[test]
2778 fn test_filter_candidates_tools_mismatch() {
2779 let registry = Arc::new(Registry::new());
2780 registry
2781 .add_backend(create_backend_with_model("llama3:8b"))
2782 .unwrap();
2783
2784 let router = Router::new(registry, RoutingStrategy::Smart, ScoringWeights::default());
2785
2786 let candidates = router.filter_candidates(
2787 "llama3:8b",
2788 &RequestRequirements {
2789 model: "llama3:8b".to_string(),
2790 estimated_tokens: 100,
2791 needs_vision: false,
2792 needs_tools: true,
2793 needs_json_mode: false,
2794 prefers_streaming: false,
2795 },
2796 );
2797 assert!(candidates.is_empty());
2798 }
2799
2800 #[test]
2801 fn test_filter_candidates_json_mode_mismatch() {
2802 let registry = Arc::new(Registry::new());
2803 registry
2804 .add_backend(create_backend_with_model("llama3:8b"))
2805 .unwrap();
2806
2807 let router = Router::new(registry, RoutingStrategy::Smart, ScoringWeights::default());
2808
2809 let candidates = router.filter_candidates(
2810 "llama3:8b",
2811 &RequestRequirements {
2812 model: "llama3:8b".to_string(),
2813 estimated_tokens: 100,
2814 needs_vision: false,
2815 needs_tools: false,
2816 needs_json_mode: true,
2817 prefers_streaming: false,
2818 },
2819 );
2820 assert!(candidates.is_empty());
2821 }
2822
2823 #[test]
2824 fn test_filter_candidates_context_length_exceeded() {
2825 let registry = Arc::new(Registry::new());
2826 registry
2827 .add_backend(create_backend_with_model("llama3:8b"))
2828 .unwrap();
2829
2830 let router = Router::new(registry, RoutingStrategy::Smart, ScoringWeights::default());
2831
2832 let candidates = router.filter_candidates(
2833 "llama3:8b",
2834 &RequestRequirements {
2835 model: "llama3:8b".to_string(),
2836 estimated_tokens: 999999,
2837 needs_vision: false,
2838 needs_tools: false,
2839 needs_json_mode: false,
2840 prefers_streaming: false,
2841 },
2842 );
2843 assert!(candidates.is_empty());
2844 }
2845
2846 #[test]
2847 fn test_filter_candidates_unhealthy_excluded() {
2848 let registry = Arc::new(Registry::new());
2849 registry
2850 .add_backend(create_backend_with_model("llama3:8b"))
2851 .unwrap();
2852 registry
2853 .update_status(
2854 "b1",
2855 crate::registry::BackendStatus::Unhealthy,
2856 Some("down".to_string()),
2857 )
2858 .unwrap();
2859
2860 let router = Router::new(registry, RoutingStrategy::Smart, ScoringWeights::default());
2861
2862 let candidates = router.filter_candidates(
2863 "llama3:8b",
2864 &RequestRequirements {
2865 model: "llama3:8b".to_string(),
2866 estimated_tokens: 100,
2867 needs_vision: false,
2868 needs_tools: false,
2869 needs_json_mode: false,
2870 prefers_streaming: false,
2871 },
2872 );
2873 assert!(candidates.is_empty());
2874 }
2875
2876 #[test]
2877 fn test_filter_candidates_model_not_found() {
2878 let registry = Arc::new(Registry::new());
2879 registry
2880 .add_backend(create_backend_with_model("llama3:8b"))
2881 .unwrap();
2882
2883 let router = Router::new(registry, RoutingStrategy::Smart, ScoringWeights::default());
2884
2885 let candidates = router.filter_candidates(
2886 "nonexistent",
2887 &RequestRequirements {
2888 model: "nonexistent".to_string(),
2889 estimated_tokens: 100,
2890 needs_vision: false,
2891 needs_tools: false,
2892 needs_json_mode: false,
2893 prefers_streaming: false,
2894 },
2895 );
2896 assert!(candidates.is_empty());
2897 }
2898
2899 #[test]
2900 fn test_select_smart_with_candidates() {
2901 let registry = Arc::new(Registry::new());
2902 let mut b1 = create_backend_with_model("llama3:8b");
2903 b1.id = "b1".to_string();
2904 b1.name = "B1".to_string();
2905 b1.priority = 1;
2906 registry.add_backend(b1).unwrap();
2907 let mut b2 = create_backend_with_model("llama3:8b");
2908 b2.id = "b2".to_string();
2909 b2.name = "B2".to_string();
2910 b2.priority = 10;
2911 registry.add_backend(b2).unwrap();
2912
2913 let router = Router::new(
2914 registry.clone(),
2915 RoutingStrategy::Smart,
2916 ScoringWeights::default(),
2917 );
2918 let candidates: Vec<Backend> = registry
2919 .get_backends_for_model("llama3:8b")
2920 .into_iter()
2921 .filter(|b| b.status == BackendStatus::Healthy)
2922 .collect();
2923 let selected = router.select_smart(&candidates);
2924 assert_eq!(selected.id, "b1");
2926 }
2927
2928 #[test]
2929 fn test_select_priority_only_with_candidates() {
2930 let registry = Arc::new(Registry::new());
2931 let mut b1 = create_backend_with_model("llama3:8b");
2932 b1.id = "b1".to_string();
2933 b1.name = "B1".to_string();
2934 b1.priority = 10;
2935 registry.add_backend(b1).unwrap();
2936 let mut b2 = create_backend_with_model("llama3:8b");
2937 b2.id = "b2".to_string();
2938 b2.name = "B2".to_string();
2939 b2.priority = 1;
2940 registry.add_backend(b2).unwrap();
2941
2942 let router = Router::new(
2943 registry.clone(),
2944 RoutingStrategy::PriorityOnly,
2945 ScoringWeights::default(),
2946 );
2947 let candidates: Vec<Backend> = registry
2948 .get_backends_for_model("llama3:8b")
2949 .into_iter()
2950 .filter(|b| b.status == BackendStatus::Healthy)
2951 .collect();
2952 let selected = router.select_priority_only(&candidates);
2953 assert_eq!(selected.id, "b2");
2954 }
2955
2956 #[test]
2957 fn test_select_random_with_candidates() {
2958 let registry = Arc::new(Registry::new());
2959 let mut b1 = create_backend_with_model("llama3:8b");
2960 b1.id = "b1".to_string();
2961 b1.name = "B1".to_string();
2962 registry.add_backend(b1).unwrap();
2963
2964 let router = Router::new(
2965 registry.clone(),
2966 RoutingStrategy::Random,
2967 ScoringWeights::default(),
2968 );
2969 let candidates: Vec<Backend> = registry
2970 .get_backends_for_model("llama3:8b")
2971 .into_iter()
2972 .filter(|b| b.status == BackendStatus::Healthy)
2973 .collect();
2974 let selected = router.select_random(&candidates);
2975 assert_eq!(selected.id, "b1");
2976 }
2977
2978 #[test]
2979 fn test_budget_config_and_state_accessors() {
2980 let registry = Arc::new(Registry::new());
2981 let budget_config = BudgetConfig {
2982 monthly_limit_usd: Some(200.0),
2983 soft_limit_percent: 80.0,
2984 ..BudgetConfig::default()
2985 };
2986 let budget_state = Arc::new(DashMap::new());
2987 let router = Router::with_full_config(
2988 registry,
2989 RoutingStrategy::Smart,
2990 ScoringWeights::default(),
2991 HashMap::new(),
2992 HashMap::new(),
2993 PolicyMatcher::default(),
2994 budget_config,
2995 budget_state,
2996 );
2997 assert_eq!(router.budget_config().monthly_limit_usd, Some(200.0));
2998 assert!(router.budget_state().is_empty());
2999 assert!(router.quality_store().get_all_metrics().is_empty());
3000 }
3001}