use crate::config::ModelEndpoint;
use crate::error::{AppError, AppResult};
use crate::models::endpoint_name::ExclusionSet;
use crate::models::selector::ModelSelector;
use crate::router::TargetModel;
use std::sync::Arc;
#[derive(Debug)]
pub struct TierSelector {
inner: Arc<ModelSelector>,
tier: TargetModel,
}
impl TierSelector {
pub fn new(selector: Arc<ModelSelector>, tier: TargetModel) -> AppResult<Self> {
if selector.endpoint_count(tier) == 0 {
return Err(AppError::Config(format!(
"TierSelector requires at least one {:?} tier endpoint",
tier
)));
}
Ok(Self {
inner: selector,
tier,
})
}
pub async fn select(&self, exclude: &ExclusionSet) -> Option<&ModelEndpoint> {
self.inner.select(self.tier, exclude).await
}
pub fn tier(&self) -> TargetModel {
self.tier
}
pub fn endpoint_count(&self) -> usize {
self.inner.endpoint_count(self.tier)
}
pub fn health_checker(&self) -> &Arc<crate::models::health::HealthChecker> {
self.inner.health_checker()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_metrics() -> Arc<crate::metrics::Metrics> {
Arc::new(crate::metrics::Metrics::new().expect("should create metrics"))
}
use crate::config::Config;
use crate::models::EndpointName;
fn create_test_config_with_balanced() -> Arc<Config> {
let toml = r#"
[server]
host = "127.0.0.1"
port = 3000
request_timeout_seconds = 30
[[models.fast]]
name = "fast-1"
base_url = "http://localhost:11434/v1"
max_tokens = 2048
weight = 1.0
priority = 1
[[models.balanced]]
name = "balanced-1"
base_url = "http://localhost:1234/v1"
max_tokens = 4096
weight = 1.0
priority = 1
[[models.balanced]]
name = "balanced-2"
base_url = "http://localhost:1235/v1"
max_tokens = 4096
weight = 1.0
priority = 1
[[models.deep]]
name = "deep-1"
base_url = "http://localhost:8080/v1"
max_tokens = 8192
weight = 1.0
priority = 1
[routing]
strategy = "hybrid"
default_importance = "normal"
router_tier = "balanced"
"#;
let config: Config = toml::from_str(toml).expect("should parse config");
Arc::new(config)
}
fn create_test_config_without_balanced() -> Arc<Config> {
let toml = r#"
[server]
host = "127.0.0.1"
port = 3000
request_timeout_seconds = 30
[[models.fast]]
name = "fast-1"
base_url = "http://localhost:11434/v1"
max_tokens = 2048
weight = 1.0
priority = 1
[models]
balanced = []
[[models.deep]]
name = "deep-1"
base_url = "http://localhost:8080/v1"
max_tokens = 8192
weight = 1.0
priority = 1
[routing]
strategy = "rule"
default_importance = "normal"
router_tier = "balanced"
"#;
let config: Config = toml::from_str(toml).expect("should parse config");
Arc::new(config)
}
#[tokio::test]
async fn test_tier_selector_new_with_balanced_endpoints() {
let config = create_test_config_with_balanced();
let selector = Arc::new(ModelSelector::new(config, test_metrics()));
let result = TierSelector::new(selector, TargetModel::Balanced);
assert!(
result.is_ok(),
"should create TierSelector with balanced endpoints"
);
}
#[tokio::test]
async fn test_tier_selector_new_without_balanced_endpoints() {
let config = create_test_config_without_balanced();
let selector = Arc::new(ModelSelector::new(config, test_metrics()));
let result = TierSelector::new(selector, TargetModel::Balanced);
assert!(result.is_err(), "should fail without balanced endpoints");
let err = result.unwrap_err();
match err {
AppError::Config(msg) => {
assert!(
msg.contains("Balanced") && msg.contains("tier endpoint"),
"error should mention Balanced tier requirement, got: {}",
msg
);
}
_ => panic!("expected Config error, got: {:?}", err),
}
}
#[tokio::test]
async fn test_tier_selector_selects_balanced_endpoint() {
let config = create_test_config_with_balanced();
let selector = Arc::new(ModelSelector::new(config, test_metrics()));
let tier_selector =
TierSelector::new(selector, TargetModel::Balanced).expect("should create TierSelector");
let exclude = ExclusionSet::new();
let endpoint = tier_selector.select(&exclude).await;
assert!(endpoint.is_some(), "should select a balanced endpoint");
let endpoint = endpoint.unwrap();
assert!(
endpoint.name().starts_with("balanced-"),
"selected endpoint should be from balanced tier, got: {}",
endpoint.name()
);
}
#[tokio::test]
async fn test_tier_selector_respects_exclusion() {
let config = create_test_config_with_balanced();
let selector = Arc::new(ModelSelector::new(config, test_metrics()));
let tier_selector =
TierSelector::new(selector, TargetModel::Balanced).expect("should create TierSelector");
let exclude = ExclusionSet::new();
let first = tier_selector.select(&exclude).await;
assert!(first.is_some());
let mut exclude = ExclusionSet::new();
exclude.insert(EndpointName::from("balanced-1"));
exclude.insert(EndpointName::from("balanced-2"));
let excluded = tier_selector.select(&exclude).await;
assert!(
excluded.is_none(),
"should return None when all balanced endpoints excluded"
);
}
#[tokio::test]
async fn test_tier_selector_endpoint_count() {
let config = create_test_config_with_balanced();
let selector = Arc::new(ModelSelector::new(config, test_metrics()));
let tier_selector =
TierSelector::new(selector, TargetModel::Balanced).expect("should create TierSelector");
assert_eq!(
tier_selector.endpoint_count(),
2,
"should have 2 balanced endpoints"
);
}
#[tokio::test]
async fn test_tier_selector_with_fast_tier() {
let config = create_test_config_with_balanced();
let selector = Arc::new(ModelSelector::new(config, test_metrics()));
let tier_selector = TierSelector::new(selector, TargetModel::Fast)
.expect("should create TierSelector for Fast tier");
assert_eq!(tier_selector.tier(), TargetModel::Fast);
assert_eq!(
tier_selector.endpoint_count(),
1,
"should have 1 fast endpoint"
);
let exclude = ExclusionSet::new();
let endpoint = tier_selector.select(&exclude).await;
assert!(endpoint.is_some(), "should select a fast endpoint");
assert_eq!(endpoint.unwrap().name(), "fast-1");
}
#[tokio::test]
async fn test_tier_selector_with_deep_tier() {
let config = create_test_config_with_balanced();
let selector = Arc::new(ModelSelector::new(config, test_metrics()));
let tier_selector = TierSelector::new(selector, TargetModel::Deep)
.expect("should create TierSelector for Deep tier");
assert_eq!(tier_selector.tier(), TargetModel::Deep);
assert_eq!(
tier_selector.endpoint_count(),
1,
"should have 1 deep endpoint"
);
let exclude = ExclusionSet::new();
let endpoint = tier_selector.select(&exclude).await;
assert!(endpoint.is_some(), "should select a deep endpoint");
assert_eq!(endpoint.unwrap().name(), "deep-1");
}
#[tokio::test]
async fn test_tier_selector_fails_with_no_endpoints_for_tier() {
let config = create_test_config_without_balanced();
let selector = Arc::new(ModelSelector::new(config, test_metrics()));
let result = TierSelector::new(selector.clone(), TargetModel::Balanced);
assert!(result.is_err(), "should fail when tier has no endpoints");
let err = result.unwrap_err();
match err {
AppError::Config(msg) => {
assert!(
msg.contains("Balanced") && msg.contains("tier endpoint"),
"error should mention missing tier, got: {}",
msg
);
}
_ => panic!("expected Config error, got: {:?}", err),
}
}
}