use super::*;
fn test_metrics() -> Arc<crate::metrics::Metrics> {
Arc::new(crate::metrics::Metrics::new().expect("should create metrics"))
}
use crate::models::endpoint_name::ExclusionSet;
use crate::models::selector::ModelSelector;
use crate::router::TargetModel;
use std::sync::Arc;
#[tokio::test]
async fn test_selector_weighted_fast_tier_both_endpoints_selectable() {
let config = Arc::new(create_test_config());
let selector = ModelSelector::new(config, test_metrics());
let mut fast1_seen = false;
let mut fast2_seen = false;
let no_exclude = ExclusionSet::new();
for _ in 0..100 {
let selected = selector
.select(TargetModel::Fast, &no_exclude)
.await
.unwrap();
if selected.name() == "fast-1" {
fast1_seen = true;
}
if selected.name() == "fast-2" {
fast2_seen = true;
}
if fast1_seen && fast2_seen {
break; }
}
assert!(
fast1_seen,
"fast-1 should be selected at least once in 100 attempts"
);
assert!(
fast2_seen,
"fast-2 should be selected at least once in 100 attempts"
);
}
#[tokio::test]
#[should_panic(expected = "MEMORY CORRUPTION DETECTED")]
async fn test_selector_zero_weight_fallback() {
let toml_config = r#"
[server]
host = "127.0.0.1"
port = 3000
request_timeout_seconds = 30
[[models.fast]]
name = "fast-1"
base_url = "http://localhost:1234/v1"
max_tokens = 2048
weight = 0.0
[[models.fast]]
name = "fast-2"
base_url = "http://localhost:1235/v1"
max_tokens = 2048
weight = 0.0
[[models.balanced]]
name = "balanced-1"
base_url = "http://localhost:1236/v1"
max_tokens = 4096
[[models.deep]]
name = "deep-1"
base_url = "http://localhost:1237/v1"
max_tokens = 8192
[routing]
strategy = "rule"
router_tier = "balanced"
"#;
let config: Config = toml::from_str(toml_config).expect("should parse TOML");
let selector = ModelSelector::new(Arc::new(config), test_metrics());
let no_exclude = ExclusionSet::new();
let _result = selector.select(TargetModel::Fast, &no_exclude).await;
}
#[tokio::test]
#[should_panic(expected = "MEMORY CORRUPTION DETECTED")]
async fn test_selector_negative_weight_fallback() {
let toml_config = r#"
[server]
host = "127.0.0.1"
port = 3000
request_timeout_seconds = 30
[[models.fast]]
name = "fast-1"
base_url = "http://localhost:1234/v1"
max_tokens = 2048
weight = -1.0
[[models.fast]]
name = "fast-2"
base_url = "http://localhost:1235/v1"
max_tokens = 2048
weight = -2.0
[[models.balanced]]
name = "balanced-1"
base_url = "http://localhost:1236/v1"
max_tokens = 4096
[[models.deep]]
name = "deep-1"
base_url = "http://localhost:1237/v1"
max_tokens = 8192
[routing]
strategy = "rule"
router_tier = "balanced"
"#;
let config: Config = toml::from_str(toml_config).expect("should parse TOML");
let selector = ModelSelector::new(Arc::new(config), test_metrics());
let no_exclude = ExclusionSet::new();
let _result = selector.select(TargetModel::Fast, &no_exclude).await;
}
#[tokio::test]
async fn test_weighted_selection_distribution() {
let toml_config = r#"
[server]
host = "127.0.0.1"
port = 3000
request_timeout_seconds = 30
[[models.fast]]
name = "fast-1"
base_url = "http://localhost:1234/v1"
max_tokens = 2048
weight = 2.0
[[models.fast]]
name = "fast-2"
base_url = "http://localhost:1235/v1"
max_tokens = 2048
weight = 1.0
[[models.balanced]]
name = "balanced-1"
base_url = "http://localhost:1236/v1"
max_tokens = 4096
[[models.deep]]
name = "deep-1"
base_url = "http://localhost:1237/v1"
max_tokens = 8192
[routing]
strategy = "rule"
router_tier = "balanced"
"#;
let config: Config = toml::from_str(toml_config).expect("should parse TOML");
let selector = ModelSelector::new(Arc::new(config), test_metrics());
let no_exclude = ExclusionSet::new();
let mut counts = std::collections::HashMap::new();
for _ in 0..3000 {
let endpoint = selector
.select(TargetModel::Fast, &no_exclude)
.await
.unwrap();
*counts.entry(endpoint.name()).or_insert(0) += 1;
}
let fast1_count = counts.get("fast-1").unwrap_or(&0);
let fast2_count = counts.get("fast-2").unwrap_or(&0);
assert!(
*fast1_count >= 1800 && *fast1_count <= 2200,
"fast-1 (weight 2.0) should get ~2000/3000 selections, got {}",
fast1_count
);
assert!(
*fast2_count >= 800 && *fast2_count <= 1200,
"fast-2 (weight 1.0) should get ~1000/3000 selections, got {}",
fast2_count
);
}
#[tokio::test]
async fn test_weighted_selection_heavily_skewed() {
let toml_config = r#"
[server]
host = "127.0.0.1"
port = 3000
request_timeout_seconds = 30
[[models.fast]]
name = "fast-1"
base_url = "http://localhost:1234/v1"
max_tokens = 2048
weight = 9.0
[[models.fast]]
name = "fast-2"
base_url = "http://localhost:1235/v1"
max_tokens = 2048
weight = 1.0
[[models.balanced]]
name = "balanced-1"
base_url = "http://localhost:1236/v1"
max_tokens = 4096
[[models.deep]]
name = "deep-1"
base_url = "http://localhost:1237/v1"
max_tokens = 8192
[routing]
strategy = "rule"
router_tier = "balanced"
"#;
let config: Config = toml::from_str(toml_config).expect("should parse TOML");
let selector = ModelSelector::new(Arc::new(config), test_metrics());
let no_exclude = ExclusionSet::new();
let mut counts = std::collections::HashMap::new();
for _ in 0..1000 {
let endpoint = selector
.select(TargetModel::Fast, &no_exclude)
.await
.unwrap();
*counts.entry(endpoint.name()).or_insert(0) += 1;
}
let fast1_count = counts.get("fast-1").unwrap_or(&0);
let fast2_count = counts.get("fast-2").unwrap_or(&0);
assert!(
*fast1_count >= 765 && *fast1_count <= 1035,
"fast-1 (weight 9.0) should get ~900/1000 selections, got {}",
fast1_count
);
assert!(
*fast2_count >= 35 && *fast2_count <= 165,
"fast-2 (weight 1.0) should get ~100/1000 selections, got {}",
fast2_count
);
}
#[tokio::test]
async fn test_weighted_selection_all_equal_weights() {
let config = create_test_config();
let selector = ModelSelector::new(Arc::new(config), test_metrics());
let no_exclude = ExclusionSet::new();
let mut counts = std::collections::HashMap::new();
for _ in 0..2000 {
let endpoint = selector
.select(TargetModel::Fast, &no_exclude)
.await
.unwrap();
*counts.entry(endpoint.name()).or_insert(0) += 1;
}
let fast1_count = counts.get("fast-1").unwrap_or(&0);
let fast2_count = counts.get("fast-2").unwrap_or(&0);
assert!(
*fast1_count >= 850 && *fast1_count <= 1150,
"fast-1 (weight 1.0) should get ~1000/2000 selections, got {}",
fast1_count
);
assert!(
*fast2_count >= 850 && *fast2_count <= 1150,
"fast-2 (weight 1.0) should get ~1000/2000 selections, got {}",
fast2_count
);
}
#[tokio::test]
async fn test_weighted_selection_three_endpoints() {
let toml_config = r#"
[server]
host = "127.0.0.1"
port = 3000
request_timeout_seconds = 30
[[models.fast]]
name = "fast-1"
base_url = "http://localhost:1234/v1"
max_tokens = 2048
weight = 3.0
[[models.fast]]
name = "fast-2"
base_url = "http://localhost:1235/v1"
max_tokens = 2048
weight = 2.0
[[models.fast]]
name = "fast-3"
base_url = "http://localhost:1236/v1"
max_tokens = 2048
weight = 1.0
[[models.balanced]]
name = "balanced-1"
base_url = "http://localhost:1236/v1"
max_tokens = 4096
[[models.deep]]
name = "deep-1"
base_url = "http://localhost:1237/v1"
max_tokens = 8192
[routing]
strategy = "rule"
router_tier = "balanced"
"#;
let config: Config = toml::from_str(toml_config).expect("should parse TOML");
let selector = ModelSelector::new(Arc::new(config), test_metrics());
let no_exclude = ExclusionSet::new();
let mut counts = std::collections::HashMap::new();
for _ in 0..6000 {
let endpoint = selector
.select(TargetModel::Fast, &no_exclude)
.await
.unwrap();
*counts.entry(endpoint.name()).or_insert(0) += 1;
}
let fast1_count = counts.get("fast-1").unwrap_or(&0);
let fast2_count = counts.get("fast-2").unwrap_or(&0);
let fast3_count = counts.get("fast-3").unwrap_or(&0);
assert!(
*fast1_count >= 2700 && *fast1_count <= 3300,
"fast-1 (weight 3.0) should get ~3000/6000 selections, got {}",
fast1_count
);
assert!(
*fast2_count >= 1800 && *fast2_count <= 2200,
"fast-2 (weight 2.0) should get ~2000/6000 selections, got {}",
fast2_count
);
assert!(
*fast3_count >= 900 && *fast3_count <= 1100,
"fast-3 (weight 1.0) should get ~1000/6000 selections, got {}",
fast3_count
);
}
#[tokio::test]
async fn test_weighted_selection_statistical_validation() {
let toml_config = r#"
[server]
host = "127.0.0.1"
port = 3000
request_timeout_seconds = 30
[[models.fast]]
name = "fast-light"
base_url = "http://localhost:1234/v1"
max_tokens = 2048
weight = 1.0
priority = 1
[[models.fast]]
name = "fast-heavy"
base_url = "http://localhost:1235/v1"
max_tokens = 2048
weight = 3.0
priority = 1
[[models.balanced]]
name = "balanced-1"
base_url = "http://localhost:1236/v1"
max_tokens = 4096
[[models.deep]]
name = "deep-1"
base_url = "http://localhost:1237/v1"
max_tokens = 8192
[routing]
strategy = "rule"
router_tier = "balanced"
"#;
let config: Config = toml::from_str(toml_config).expect("should parse TOML");
let selector = ModelSelector::new(Arc::new(config), test_metrics());
const SAMPLE_SIZE: usize = 10_000;
let mut light_count = 0;
let mut heavy_count = 0;
let no_exclude = ExclusionSet::new();
for _ in 0..SAMPLE_SIZE {
let endpoint = selector
.select(TargetModel::Fast, &no_exclude)
.await
.unwrap();
match endpoint.name() {
"fast-light" => light_count += 1,
"fast-heavy" => heavy_count += 1,
other => panic!("Unexpected endpoint selected: {}", other),
}
}
let expected_light = SAMPLE_SIZE as f64 * 0.25;
let expected_heavy = SAMPLE_SIZE as f64 * 0.75;
let chi_squared = ((light_count as f64 - expected_light).powi(2) / expected_light)
+ ((heavy_count as f64 - expected_heavy).powi(2) / expected_heavy);
const CHI_SQUARED_THRESHOLD: f64 = 10.0;
assert!(
chi_squared < CHI_SQUARED_THRESHOLD,
"Chi-squared test failed: χ² = {:.2} (threshold = {}). \
Distribution does not match configured weights. \
Observed: light={} ({:.1}%), heavy={} ({:.1}%). \
Expected: light={:.0} (25.0%), heavy={:.0} (75.0%)",
chi_squared,
CHI_SQUARED_THRESHOLD,
light_count,
(light_count as f64 / SAMPLE_SIZE as f64) * 100.0,
heavy_count,
(heavy_count as f64 / SAMPLE_SIZE as f64) * 100.0,
expected_light,
expected_heavy
);
assert!(
(2_000..=3_000).contains(&light_count),
"Light endpoint selected {} times, expected ~2,500 (20-30%)",
light_count
);
assert!(
(7_000..=8_000).contains(&heavy_count),
"Heavy endpoint selected {} times, expected ~7,500 (70-80%)",
heavy_count
);
println!(
"✓ Statistical validation passed: χ² = {:.2}, light={} ({:.1}%), heavy={} ({:.1}%)",
chi_squared,
light_count,
(light_count as f64 / SAMPLE_SIZE as f64) * 100.0,
heavy_count,
(heavy_count as f64 / SAMPLE_SIZE as f64) * 100.0
);
}