use super::*;
fn test_metrics() -> Arc<crate::metrics::Metrics> {
Arc::new(crate::metrics::Metrics::new().expect("should create metrics"))
}
use crate::models::endpoint_name::ExclusionSet;
use crate::models::selector::ModelSelector;
use crate::router::TargetModel;
use std::sync::Arc;
#[tokio::test]
async fn test_selector_new_creates_selector() {
let config = Arc::new(create_test_config());
let selector = ModelSelector::new(config, test_metrics());
assert_eq!(selector.endpoint_count(TargetModel::Fast), 2);
assert_eq!(selector.endpoint_count(TargetModel::Balanced), 1);
assert_eq!(selector.endpoint_count(TargetModel::Deep), 1);
}
#[tokio::test]
async fn test_selector_select_returns_endpoint() {
let config = Arc::new(create_test_config());
let selector = ModelSelector::new(config, test_metrics());
let no_exclude = ExclusionSet::new();
assert!(
selector
.select(TargetModel::Fast, &no_exclude)
.await
.is_some()
);
assert!(
selector
.select(TargetModel::Balanced, &no_exclude)
.await
.is_some()
);
assert!(
selector
.select(TargetModel::Deep, &no_exclude)
.await
.is_some()
);
}
#[tokio::test]
async fn test_selector_single_endpoint_tier() {
let config = Arc::new(create_test_config());
let selector = ModelSelector::new(config, test_metrics());
let no_exclude = ExclusionSet::new();
let first = selector
.select(TargetModel::Balanced, &no_exclude)
.await
.unwrap();
let second = selector
.select(TargetModel::Balanced, &no_exclude)
.await
.unwrap();
assert_eq!(first.name(), "balanced-1");
assert_eq!(second.name(), "balanced-1");
}
#[tokio::test]
async fn test_selector_endpoint_count() {
let config = Arc::new(create_test_config());
let selector = ModelSelector::new(config, test_metrics());
assert_eq!(selector.endpoint_count(TargetModel::Fast), 2);
assert_eq!(selector.endpoint_count(TargetModel::Balanced), 1);
assert_eq!(selector.endpoint_count(TargetModel::Deep), 1);
}
#[tokio::test]
async fn test_selector_returns_none_for_empty_tier() {
let toml_config = r#"
[server]
host = "127.0.0.1"
port = 3000
request_timeout_seconds = 30
[models]
fast = []
[[models.balanced]]
name = "balanced-1"
base_url = "http://localhost:1236/v1"
max_tokens = 4096
[[models.deep]]
name = "deep-1"
base_url = "http://localhost:1237/v1"
max_tokens = 8192
[routing]
strategy = "rule"
router_tier = "balanced"
"#;
let config: Config = toml::from_str(toml_config).expect("should parse TOML");
let selector = ModelSelector::new(Arc::new(config), test_metrics());
let no_exclude = ExclusionSet::new();
let result = selector.select(TargetModel::Fast, &no_exclude).await;
assert!(result.is_none(), "should return None for empty tier");
}
#[tokio::test]
async fn test_selector_concurrent_weighted_selection() {
let config = Arc::new(create_test_config());
let selector = Arc::new(ModelSelector::new(config, test_metrics()));
let mut handles = vec![];
for _ in 0..10 {
let sel = selector.clone();
handles.push(tokio::spawn(async move {
let no_exclude = ExclusionSet::new();
sel.select(TargetModel::Fast, &no_exclude)
.await
.map(|e| e.name().to_string())
}));
}
let results: Vec<_> = futures::future::join_all(handles)
.await
.into_iter()
.map(|r| r.unwrap())
.collect();
assert_eq!(
results.len(),
10,
"all concurrent selections should complete"
);
for result in &results {
assert!(result.is_some(), "all selections should succeed");
}
let endpoint_names: Vec<String> = results.into_iter().flatten().collect();
for name in &endpoint_names {
assert!(
name == "fast-1" || name == "fast-2",
"selected endpoint should be valid, got: {}",
name
);
}
}