use crate::core::providers::Provider;
use crate::core::providers::openai::OpenAIProvider;
use crate::core::router::deployment::{Deployment, DeploymentConfig, DeploymentState};
use crate::core::router::strategy_impl::*;
use dashmap::DashMap;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::Relaxed;
async fn create_test_provider() -> Provider {
let openai = OpenAIProvider::with_api_key("sk-test-key-for-unit-testing-only")
.await
.expect("Failed to create OpenAI provider");
Provider::OpenAI(openai)
}
async fn create_test_deployment(id: &str, config: DeploymentConfig) -> Deployment {
Deployment {
id: id.to_string(),
provider: create_test_provider().await,
model: "gpt-4".to_string(),
model_name: "gpt-4".to_string(),
config,
state: DeploymentState::new(),
tags: vec![],
}
}
#[tokio::test]
async fn test_build_routing_contexts_skips_missing_deployments() {
let deployments = DashMap::new();
let config = DeploymentConfig {
weight: 3,
priority: 7,
..Default::default()
};
let deployment = create_test_deployment("d1", config).await;
deployment.state.active_requests.store(2, Relaxed);
deployment.state.tpm_current.store(120, Relaxed);
deployment.state.rpm_current.store(12, Relaxed);
deployment.state.avg_latency_us.store(55, Relaxed);
deployments.insert("d1".to_string(), deployment);
let candidates = vec!["d1".to_string(), "missing".to_string()];
let contexts = build_routing_contexts(&candidates, &deployments);
assert_eq!(contexts.len(), 1);
assert_eq!(contexts[0].deployment_id, "d1");
assert_eq!(contexts[0].weight, 3);
assert_eq!(contexts[0].priority, 7);
assert_eq!(contexts[0].active_requests, 2);
assert_eq!(contexts[0].tpm_current, 120);
assert_eq!(contexts[0].rpm_current, 12);
assert_eq!(contexts[0].avg_latency_us, 55);
}
#[tokio::test]
async fn test_weighted_random_single_candidate() {
let deployments = DashMap::new();
let config = DeploymentConfig {
weight: 1,
..Default::default()
};
deployments.insert("d1".to_string(), create_test_deployment("d1", config).await);
let candidates = vec!["d1".to_string()];
let contexts = build_routing_contexts(&candidates, &deployments);
let selected = weighted_random_from_context(&contexts).unwrap();
assert_eq!(selected, "d1");
}
#[tokio::test]
async fn test_weighted_random_returns_valid_candidate() {
let deployments = DashMap::new();
for i in 1..=3 {
let config = DeploymentConfig {
weight: 1,
..Default::default()
};
deployments.insert(
format!("d{}", i),
create_test_deployment(&format!("d{}", i), config).await,
);
}
let candidates: Vec<String> = (1..=3).map(|i| format!("d{}", i)).collect();
let contexts = build_routing_contexts(&candidates, &deployments);
for _ in 0..100 {
let selected = weighted_random_from_context(&contexts).unwrap();
assert!(candidates.contains(selected));
}
}
#[tokio::test]
async fn test_weighted_random_respects_weights() {
let deployments = DashMap::new();
let config1 = DeploymentConfig {
weight: 10,
..Default::default()
};
let config2 = DeploymentConfig {
weight: 1,
..Default::default()
};
deployments.insert(
"d1".to_string(),
create_test_deployment("d1", config1).await,
);
deployments.insert(
"d2".to_string(),
create_test_deployment("d2", config2).await,
);
let candidates = vec!["d1".to_string(), "d2".to_string()];
let contexts = build_routing_contexts(&candidates, &deployments);
let mut d1_count = 0;
let mut d2_count = 0;
for _ in 0..1000 {
let selected = weighted_random_from_context(&contexts).unwrap();
if selected == "d1" {
d1_count += 1;
} else {
d2_count += 1;
}
}
assert!(
d1_count > d2_count * 5,
"d1 should be selected much more often due to higher weight"
);
}
#[tokio::test]
async fn test_weighted_random_all_zero_weights() {
let deployments = DashMap::new();
for i in 1..=3 {
let config = DeploymentConfig {
weight: 0,
..Default::default()
};
deployments.insert(
format!("d{}", i),
create_test_deployment(&format!("d{}", i), config).await,
);
}
let candidates: Vec<String> = (1..=3).map(|i| format!("d{}", i)).collect();
let contexts = build_routing_contexts(&candidates, &deployments);
for _ in 0..10 {
let selected = weighted_random_from_context(&contexts).unwrap();
assert!(candidates.contains(selected));
}
}
#[test]
fn test_weighted_random_empty_candidates() {
assert!(weighted_random_from_context(&[]).is_none());
}
#[tokio::test]
async fn test_least_busy_single_candidate() {
let deployments = DashMap::new();
let config = DeploymentConfig::default();
deployments.insert("d1".to_string(), create_test_deployment("d1", config).await);
let candidates = vec!["d1".to_string()];
let contexts = build_routing_contexts(&candidates, &deployments);
let selected = least_busy_from_context(&contexts).unwrap();
assert_eq!(selected, "d1");
}
#[tokio::test]
async fn test_least_busy_selects_lowest_active() {
let deployments = DashMap::new();
let d1 = create_test_deployment("d1", DeploymentConfig::default()).await;
d1.state.active_requests.store(10, Relaxed);
deployments.insert("d1".to_string(), d1);
let d2 = create_test_deployment("d2", DeploymentConfig::default()).await;
d2.state.active_requests.store(5, Relaxed);
deployments.insert("d2".to_string(), d2);
let d3 = create_test_deployment("d3", DeploymentConfig::default()).await;
d3.state.active_requests.store(15, Relaxed);
deployments.insert("d3".to_string(), d3);
let candidates = vec!["d1".to_string(), "d2".to_string(), "d3".to_string()];
let contexts = build_routing_contexts(&candidates, &deployments);
let selected = least_busy_from_context(&contexts).unwrap();
assert_eq!(selected, "d2");
}
#[tokio::test]
async fn test_least_busy_with_ties() {
let deployments = DashMap::new();
let d1 = create_test_deployment("d1", DeploymentConfig::default()).await;
d1.state.active_requests.store(5, Relaxed);
deployments.insert("d1".to_string(), d1);
let d2 = create_test_deployment("d2", DeploymentConfig::default()).await;
d2.state.active_requests.store(5, Relaxed);
deployments.insert("d2".to_string(), d2);
let candidates = vec!["d1".to_string(), "d2".to_string()];
let contexts = build_routing_contexts(&candidates, &deployments);
for _ in 0..10 {
let selected = least_busy_from_context(&contexts).unwrap();
assert!(selected == "d1" || selected == "d2");
}
}
#[tokio::test]
async fn test_least_busy_all_zero() {
let deployments = DashMap::new();
for i in 1..=3 {
let d = create_test_deployment(&format!("d{}", i), DeploymentConfig::default()).await;
d.state.active_requests.store(0, Relaxed);
deployments.insert(format!("d{}", i), d);
}
let candidates: Vec<String> = (1..=3).map(|i| format!("d{}", i)).collect();
let contexts = build_routing_contexts(&candidates, &deployments);
let selected = least_busy_from_context(&contexts).unwrap();
assert!(candidates.contains(selected));
}
#[test]
fn test_least_busy_empty_candidates() {
assert!(least_busy_from_context(&[]).is_none());
}
#[tokio::test]
async fn test_lowest_usage_single_candidate() {
let deployments = DashMap::new();
let config = DeploymentConfig {
tpm_limit: Some(1000),
..Default::default()
};
deployments.insert("d1".to_string(), create_test_deployment("d1", config).await);
let candidates = vec!["d1".to_string()];
let contexts = build_routing_contexts(&candidates, &deployments);
let selected = lowest_usage_from_context(&contexts).unwrap();
assert_eq!(selected, "d1");
}
#[tokio::test]
async fn test_lowest_usage_selects_lowest_percentage() {
let deployments = DashMap::new();
let config1 = DeploymentConfig {
tpm_limit: Some(1000),
..Default::default()
};
let d1 = create_test_deployment("d1", config1).await;
d1.state.tpm_current.store(500, Relaxed);
deployments.insert("d1".to_string(), d1);
let config2 = DeploymentConfig {
tpm_limit: Some(1000),
..Default::default()
};
let d2 = create_test_deployment("d2", config2).await;
d2.state.tpm_current.store(200, Relaxed);
deployments.insert("d2".to_string(), d2);
let config3 = DeploymentConfig {
tpm_limit: Some(1000),
..Default::default()
};
let d3 = create_test_deployment("d3", config3).await;
d3.state.tpm_current.store(800, Relaxed);
deployments.insert("d3".to_string(), d3);
let candidates = vec!["d1".to_string(), "d2".to_string(), "d3".to_string()];
let contexts = build_routing_contexts(&candidates, &deployments);
let selected = lowest_usage_from_context(&contexts).unwrap();
assert_eq!(selected, "d2");
}
#[tokio::test]
async fn test_lowest_usage_no_limit_treated_as_zero() {
let deployments = DashMap::new();
let config1 = DeploymentConfig {
tpm_limit: None,
..Default::default()
};
let d1 = create_test_deployment("d1", config1).await;
deployments.insert("d1".to_string(), d1);
let config2 = DeploymentConfig {
tpm_limit: Some(1000),
..Default::default()
};
let d2 = create_test_deployment("d2", config2).await;
d2.state.tpm_current.store(500, Relaxed);
deployments.insert("d2".to_string(), d2);
let candidates = vec!["d1".to_string(), "d2".to_string()];
let contexts = build_routing_contexts(&candidates, &deployments);
let selected = lowest_usage_from_context(&contexts).unwrap();
assert_eq!(selected, "d1");
}
#[test]
fn test_lowest_usage_empty_candidates() {
assert!(lowest_usage_from_context(&[]).is_none());
}
#[tokio::test]
async fn test_lowest_latency_single_candidate() {
let deployments = DashMap::new();
let d1 = create_test_deployment("d1", DeploymentConfig::default()).await;
d1.state.avg_latency_us.store(100, Relaxed);
deployments.insert("d1".to_string(), d1);
let candidates = vec!["d1".to_string()];
let contexts = build_routing_contexts(&candidates, &deployments);
let selected = lowest_latency_from_context(&contexts).unwrap();
assert_eq!(selected, "d1");
}
#[tokio::test]
async fn test_lowest_latency_selects_fastest() {
let deployments = DashMap::new();
let d1 = create_test_deployment("d1", DeploymentConfig::default()).await;
d1.state.avg_latency_us.store(500, Relaxed);
deployments.insert("d1".to_string(), d1);
let d2 = create_test_deployment("d2", DeploymentConfig::default()).await;
d2.state.avg_latency_us.store(100, Relaxed);
deployments.insert("d2".to_string(), d2);
let d3 = create_test_deployment("d3", DeploymentConfig::default()).await;
d3.state.avg_latency_us.store(300, Relaxed);
deployments.insert("d3".to_string(), d3);
let candidates = vec!["d1".to_string(), "d2".to_string(), "d3".to_string()];
let contexts = build_routing_contexts(&candidates, &deployments);
let selected = lowest_latency_from_context(&contexts).unwrap();
assert_eq!(selected, "d2");
}
#[tokio::test]
async fn test_lowest_latency_new_deployment_uses_average() {
let deployments = DashMap::new();
let d1 = create_test_deployment("d1", DeploymentConfig::default()).await;
d1.state.avg_latency_us.store(1000, Relaxed);
deployments.insert("d1".to_string(), d1);
let d2 = create_test_deployment("d2", DeploymentConfig::default()).await;
d2.state.avg_latency_us.store(0, Relaxed);
deployments.insert("d2".to_string(), d2);
let d3 = create_test_deployment("d3", DeploymentConfig::default()).await;
d3.state.avg_latency_us.store(2000, Relaxed);
deployments.insert("d3".to_string(), d3);
let candidates = vec!["d1".to_string(), "d2".to_string(), "d3".to_string()];
let contexts = build_routing_contexts(&candidates, &deployments);
let selected = lowest_latency_from_context(&contexts).unwrap();
assert_eq!(selected, "d1");
}
#[tokio::test]
async fn test_lowest_latency_all_zero() {
let deployments = DashMap::new();
for i in 1..=3 {
let d = create_test_deployment(&format!("d{}", i), DeploymentConfig::default()).await;
d.state.avg_latency_us.store(0, Relaxed);
deployments.insert(format!("d{}", i), d);
}
let candidates: Vec<String> = (1..=3).map(|i| format!("d{}", i)).collect();
let contexts = build_routing_contexts(&candidates, &deployments);
let selected = lowest_latency_from_context(&contexts).unwrap();
assert!(candidates.contains(selected));
}
#[test]
fn test_lowest_latency_empty_candidates() {
assert!(lowest_latency_from_context(&[]).is_none());
}
#[tokio::test]
async fn test_lowest_priority_single_candidate() {
let deployments = DashMap::new();
let config = DeploymentConfig {
priority: 5,
..Default::default()
};
deployments.insert("d1".to_string(), create_test_deployment("d1", config).await);
let candidates = vec!["d1".to_string()];
let contexts = build_routing_contexts(&candidates, &deployments);
let selected = lowest_priority_from_context(&contexts).unwrap();
assert_eq!(selected, "d1");
}
#[tokio::test]
async fn test_lowest_priority_selects_lowest_priority() {
let deployments = DashMap::new();
let config1 = DeploymentConfig {
priority: 10,
..Default::default()
};
deployments.insert(
"d1".to_string(),
create_test_deployment("d1", config1).await,
);
let config2 = DeploymentConfig {
priority: 1,
..Default::default()
};
deployments.insert(
"d2".to_string(),
create_test_deployment("d2", config2).await,
);
let config3 = DeploymentConfig {
priority: 5,
..Default::default()
};
deployments.insert(
"d3".to_string(),
create_test_deployment("d3", config3).await,
);
let candidates = vec!["d1".to_string(), "d2".to_string(), "d3".to_string()];
let contexts = build_routing_contexts(&candidates, &deployments);
let selected = lowest_priority_from_context(&contexts).unwrap();
assert_eq!(selected, "d2");
}
#[tokio::test]
async fn test_lowest_priority_all_same_priority() {
let deployments = DashMap::new();
for i in 1..=3 {
let config = DeploymentConfig {
priority: 5,
..Default::default()
};
deployments.insert(
format!("d{}", i),
create_test_deployment(&format!("d{}", i), config).await,
);
}
let candidates: Vec<String> = (1..=3).map(|i| format!("d{}", i)).collect();
let contexts = build_routing_contexts(&candidates, &deployments);
let selected = lowest_priority_from_context(&contexts).unwrap();
assert_eq!(selected, "d1");
}
#[test]
fn test_lowest_priority_empty_candidates() {
assert!(lowest_priority_from_context(&[]).is_none());
}
#[tokio::test]
async fn test_rate_limit_aware_single_candidate() {
let deployments = DashMap::new();
let config = DeploymentConfig {
tpm_limit: Some(1000),
rpm_limit: Some(100),
..Default::default()
};
deployments.insert("d1".to_string(), create_test_deployment("d1", config).await);
let candidates = vec!["d1".to_string()];
let contexts = build_routing_contexts(&candidates, &deployments);
let selected = rate_limit_aware_from_context(&contexts).unwrap();
assert_eq!(selected, "d1");
}
#[tokio::test]
async fn test_rate_limit_aware_selects_most_headroom() {
let deployments = DashMap::new();
let config1 = DeploymentConfig {
tpm_limit: Some(1000),
rpm_limit: Some(100),
..Default::default()
};
let d1 = create_test_deployment("d1", config1).await;
d1.state.tpm_current.store(800, Relaxed);
d1.state.rpm_current.store(20, Relaxed);
deployments.insert("d1".to_string(), d1);
let config2 = DeploymentConfig {
tpm_limit: Some(1000),
rpm_limit: Some(100),
..Default::default()
};
let d2 = create_test_deployment("d2", config2).await;
d2.state.tpm_current.store(200, Relaxed);
d2.state.rpm_current.store(20, Relaxed);
deployments.insert("d2".to_string(), d2);
let candidates = vec!["d1".to_string(), "d2".to_string()];
let contexts = build_routing_contexts(&candidates, &deployments);
let selected = rate_limit_aware_from_context(&contexts).unwrap();
assert_eq!(selected, "d2");
}
#[tokio::test]
async fn test_rate_limit_aware_considers_rpm() {
let deployments = DashMap::new();
let config1 = DeploymentConfig {
tpm_limit: Some(1000),
rpm_limit: Some(100),
..Default::default()
};
let d1 = create_test_deployment("d1", config1).await;
d1.state.tpm_current.store(100, Relaxed);
d1.state.rpm_current.store(90, Relaxed); deployments.insert("d1".to_string(), d1);
let config2 = DeploymentConfig {
tpm_limit: Some(1000),
rpm_limit: Some(100),
..Default::default()
};
let d2 = create_test_deployment("d2", config2).await;
d2.state.tpm_current.store(400, Relaxed); d2.state.rpm_current.store(40, Relaxed); deployments.insert("d2".to_string(), d2);
let candidates = vec!["d1".to_string(), "d2".to_string()];
let contexts = build_routing_contexts(&candidates, &deployments);
let selected = rate_limit_aware_from_context(&contexts).unwrap();
assert_eq!(selected, "d2");
}
#[tokio::test]
async fn test_rate_limit_aware_no_limits() {
let deployments = DashMap::new();
let config = DeploymentConfig {
tpm_limit: None,
rpm_limit: None,
..Default::default()
};
deployments.insert(
"d1".to_string(),
create_test_deployment("d1", config.clone()).await,
);
deployments.insert("d2".to_string(), create_test_deployment("d2", config).await);
let candidates = vec!["d1".to_string(), "d2".to_string()];
let contexts = build_routing_contexts(&candidates, &deployments);
let selected = rate_limit_aware_from_context(&contexts).unwrap();
assert_eq!(selected, "d1");
}
#[test]
fn test_rate_limit_aware_empty_candidates() {
assert!(rate_limit_aware_from_context(&[]).is_none());
}
#[test]
fn test_round_robin_single_candidate() {
let counters: DashMap<String, AtomicUsize> = DashMap::new();
let candidate_ids = ["d1".to_string()];
let contexts: Vec<RoutingContext<'_>> = candidate_ids
.iter()
.map(|id| RoutingContext {
deployment_id: id,
weight: 1,
priority: 1,
active_requests: 0,
tpm_current: 0,
tpm_limit: None,
rpm_current: 0,
rpm_limit: None,
avg_latency_us: 0,
})
.collect();
let selected = round_robin_from_context("gpt-4", &contexts, &counters).unwrap();
assert_eq!(selected, "d1");
}
#[test]
fn test_round_robin_cycles_through_candidates() {
let counters: DashMap<String, AtomicUsize> = DashMap::new();
let candidate_ids = ["d1".to_string(), "d2".to_string(), "d3".to_string()];
let contexts: Vec<RoutingContext<'_>> = candidate_ids
.iter()
.map(|id| RoutingContext {
deployment_id: id,
weight: 1,
priority: 1,
active_requests: 0,
tpm_current: 0,
tpm_limit: None,
rpm_current: 0,
rpm_limit: None,
avg_latency_us: 0,
})
.collect();
assert_eq!(
round_robin_from_context("gpt-4", &contexts, &counters).unwrap(),
"d1"
);
assert_eq!(
round_robin_from_context("gpt-4", &contexts, &counters).unwrap(),
"d2"
);
assert_eq!(
round_robin_from_context("gpt-4", &contexts, &counters).unwrap(),
"d3"
);
assert_eq!(
round_robin_from_context("gpt-4", &contexts, &counters).unwrap(),
"d1"
);
assert_eq!(
round_robin_from_context("gpt-4", &contexts, &counters).unwrap(),
"d2"
);
}
#[test]
fn test_round_robin_separate_counters_per_model() {
let counters: DashMap<String, AtomicUsize> = DashMap::new();
let candidate_ids = ["d1".to_string(), "d2".to_string()];
let contexts: Vec<RoutingContext<'_>> = candidate_ids
.iter()
.map(|id| RoutingContext {
deployment_id: id,
weight: 1,
priority: 1,
active_requests: 0,
tpm_current: 0,
tpm_limit: None,
rpm_current: 0,
rpm_limit: None,
avg_latency_us: 0,
})
.collect();
assert_eq!(
round_robin_from_context("gpt-4", &contexts, &counters).unwrap(),
"d1"
);
assert_eq!(
round_robin_from_context("gpt-4", &contexts, &counters).unwrap(),
"d2"
);
assert_eq!(
round_robin_from_context("claude-3", &contexts, &counters).unwrap(),
"d1"
);
assert_eq!(
round_robin_from_context("claude-3", &contexts, &counters).unwrap(),
"d2"
);
assert_eq!(
round_robin_from_context("gpt-4", &contexts, &counters).unwrap(),
"d1"
);
}
#[test]
fn test_round_robin_wraps_around() {
let counters: DashMap<String, AtomicUsize> = DashMap::new();
let candidate_ids = ["d1".to_string(), "d2".to_string()];
let contexts: Vec<RoutingContext<'_>> = candidate_ids
.iter()
.map(|id| RoutingContext {
deployment_id: id,
weight: 1,
priority: 1,
active_requests: 0,
tpm_current: 0,
tpm_limit: None,
rpm_current: 0,
rpm_limit: None,
avg_latency_us: 0,
})
.collect();
for i in 0..100 {
let selected = round_robin_from_context("gpt-4", &contexts, &counters).unwrap();
if i % 2 == 0 {
assert_eq!(selected, "d1");
} else {
assert_eq!(selected, "d2");
}
}
}
#[test]
fn test_round_robin_empty_candidates() {
let counters: DashMap<String, AtomicUsize> = DashMap::new();
let contexts: Vec<RoutingContext<'_>> = vec![];
assert!(round_robin_from_context("gpt-4", &contexts, &counters).is_none());
}
#[test]
fn test_round_robin_from_context_cycles_through_candidates() {
let counters: DashMap<String, AtomicUsize> = DashMap::new();
let candidate_ids = ["d1".to_string(), "d2".to_string(), "d3".to_string()];
let contexts: Vec<RoutingContext<'_>> = candidate_ids
.iter()
.map(|id| RoutingContext {
deployment_id: id,
weight: 1,
priority: 1,
active_requests: 0,
tpm_current: 0,
tpm_limit: None,
rpm_current: 0,
rpm_limit: None,
avg_latency_us: 0,
})
.collect();
assert_eq!(
round_robin_from_context("gpt-4", &contexts, &counters).unwrap(),
"d1"
);
assert_eq!(
round_robin_from_context("gpt-4", &contexts, &counters).unwrap(),
"d2"
);
assert_eq!(
round_robin_from_context("gpt-4", &contexts, &counters).unwrap(),
"d3"
);
}
#[tokio::test]
async fn test_strategy_consistency() {
let deployments = DashMap::new();
let config1 = DeploymentConfig {
weight: 1,
priority: 10,
tpm_limit: Some(1000),
..Default::default()
};
let d1 = create_test_deployment("d1", config1).await;
d1.state.tpm_current.store(500, Relaxed);
d1.state.active_requests.store(5, Relaxed);
d1.state.avg_latency_us.store(100, Relaxed);
deployments.insert("d1".to_string(), d1);
let config2 = DeploymentConfig {
weight: 1,
priority: 1,
tpm_limit: Some(1000),
..Default::default()
};
let d2 = create_test_deployment("d2", config2).await;
d2.state.tpm_current.store(100, Relaxed);
d2.state.active_requests.store(2, Relaxed);
d2.state.avg_latency_us.store(200, Relaxed);
deployments.insert("d2".to_string(), d2);
let candidates = vec!["d1".to_string(), "d2".to_string()];
let contexts = build_routing_contexts(&candidates, &deployments);
for _ in 0..10 {
assert_eq!(least_busy_from_context(&contexts).unwrap(), "d2");
assert_eq!(lowest_usage_from_context(&contexts).unwrap(), "d2");
assert_eq!(lowest_latency_from_context(&contexts).unwrap(), "d1");
assert_eq!(lowest_priority_from_context(&contexts).unwrap(), "d2");
}
}