mod balanced;
pub use balanced::TierSelector;
use crate::config::{Config, ModelEndpoint};
use crate::models::endpoint_name::{EndpointName, ExclusionSet};
use crate::models::health::HealthChecker;
use crate::router::TargetModel;
use rand::Rng;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug)]
pub struct ModelSelector {
config: Arc<Config>,
health_checker: Arc<HealthChecker>,
fast_counter: AtomicUsize,
balanced_counter: AtomicUsize,
deep_counter: AtomicUsize,
}
impl ModelSelector {
pub fn new(config: Arc<Config>, metrics: Arc<crate::metrics::Metrics>) -> Self {
let health_checker = Arc::new(HealthChecker::new_with_metrics(config.clone(), metrics));
health_checker.clone().start_background_checks();
Self {
config,
health_checker,
fast_counter: AtomicUsize::new(0),
balanced_counter: AtomicUsize::new(0),
deep_counter: AtomicUsize::new(0),
}
}
pub fn health_checker(&self) -> &Arc<HealthChecker> {
&self.health_checker
}
pub async fn select(
&self,
target: TargetModel,
exclude: &ExclusionSet,
) -> Option<&ModelEndpoint> {
let (endpoints, counter) = match target {
TargetModel::Fast => (&self.config.models.fast, &self.fast_counter),
TargetModel::Balanced => (&self.config.models.balanced, &self.balanced_counter),
TargetModel::Deep => (&self.config.models.deep, &self.deep_counter),
};
if endpoints.is_empty() {
tracing::error!(
tier = ?target,
"No endpoints configured for tier - check config.toml"
);
return None;
}
let mut available_endpoints = Vec::new();
for endpoint in endpoints.iter() {
if !self.health_checker.is_healthy(endpoint.name()).await {
continue;
}
if exclude.contains(&EndpointName::from(endpoint)) {
tracing::debug!(
tier = ?target,
endpoint_name = %endpoint.name(),
"Skipping excluded endpoint"
);
continue;
}
available_endpoints.push(endpoint);
}
if available_endpoints.is_empty() {
tracing::error!(
tier = ?target,
total_endpoints = endpoints.len(),
excluded_count = exclude.len(),
"No available endpoints for tier - all endpoints either unhealthy or excluded"
);
return None;
}
tracing::debug!(
tier = ?target,
total_endpoints = endpoints.len(),
excluded_count = exclude.len(),
available_endpoints = available_endpoints.len(),
"Filtered to healthy and non-excluded endpoints"
);
let max_priority = available_endpoints
.iter()
.map(|e| e.priority())
.max()
.expect(
"Defensive check: available_endpoints cannot be empty due to early return above",
);
let highest_priority_endpoints: Vec<&ModelEndpoint> = available_endpoints
.iter()
.filter(|e| e.priority() == max_priority)
.copied()
.collect();
tracing::debug!(
tier = ?target,
max_priority = max_priority,
available_endpoints = available_endpoints.len(),
priority_tier_endpoints = highest_priority_endpoints.len(),
"Filtered to highest priority tier among available endpoints"
);
counter.fetch_add(1, Ordering::Relaxed);
let total_weight: f64 = highest_priority_endpoints.iter().map(|e| e.weight()).sum();
if total_weight <= 0.0 {
panic!(
"MEMORY CORRUPTION DETECTED: All endpoints in priority tier {} for {:?} have \
total weight {}. Config validation guarantees positive weights at startup. \
This indicates memory corruption (buffer overflow, use-after-free) or a critical \
bug in endpoint management. Cannot safely continue operation. \
Endpoints: {:?}",
max_priority,
target,
total_weight,
highest_priority_endpoints
.iter()
.map(|ep| (ep.name(), ep.weight()))
.collect::<Vec<_>>()
);
}
let mut rng = rand::rng();
let random_weight = rng.random_range(0.0..total_weight);
let mut cumulative_weight = 0.0;
for (index, endpoint) in highest_priority_endpoints.iter().enumerate() {
cumulative_weight += endpoint.weight();
if random_weight < cumulative_weight {
tracing::debug!(
tier = ?target,
priority = max_priority,
endpoint_name = %endpoint.name(),
endpoint_url = %endpoint.base_url(),
weight = endpoint.weight(),
index = index,
total_weight = total_weight,
"Selected endpoint via weighted random selection"
);
return Some(endpoint);
}
}
let last_endpoint = highest_priority_endpoints
.last()
.expect("Defensive check: highest_priority_endpoints cannot be empty");
tracing::warn!(
tier = ?target,
priority = max_priority,
endpoint_name = %last_endpoint.name(),
"Fallback to last endpoint (likely floating-point rounding)"
);
Some(last_endpoint)
}
pub fn endpoint_count(&self, target: TargetModel) -> usize {
match target {
TargetModel::Fast => self.config.models.fast.len(),
TargetModel::Balanced => self.config.models.balanced.len(),
TargetModel::Deep => self.config.models.deep.len(),
}
}
pub fn default_tier(&self) -> Option<TargetModel> {
let all_endpoints = self
.config
.models
.fast
.iter()
.chain(self.config.models.balanced.iter())
.chain(self.config.models.deep.iter());
let max_priority = all_endpoints.map(|e| e.priority()).max()?;
if self
.config
.models
.fast
.iter()
.any(|e| e.priority() == max_priority)
{
return Some(TargetModel::Fast);
}
if self
.config
.models
.balanced
.iter()
.any(|e| e.priority() == max_priority)
{
return Some(TargetModel::Balanced);
}
if self
.config
.models
.deep
.iter()
.any(|e| e.priority() == max_priority)
{
return Some(TargetModel::Deep);
}
None
}
}
#[cfg(test)]
mod tests_basic;
#[cfg(test)]
mod tests_exclusion;
#[cfg(test)]
mod tests_priority;
#[cfg(test)]
mod tests_weighted;
#[cfg(test)]
pub(crate) fn create_test_config() -> Config {
let toml = r#"
[server]
host = "127.0.0.1"
port = 3000
request_timeout_seconds = 30
[[models.fast]]
name = "fast-1"
base_url = "http://localhost:1234/v1"
max_tokens = 2048
temperature = 0.7
weight = 1.0
priority = 1
[[models.fast]]
name = "fast-2"
base_url = "http://localhost:1235/v1"
max_tokens = 2048
temperature = 0.7
weight = 1.0
priority = 1
[[models.balanced]]
name = "balanced-1"
base_url = "http://localhost:1236/v1"
max_tokens = 4096
temperature = 0.7
weight = 1.0
priority = 1
[[models.deep]]
name = "deep-1"
base_url = "http://localhost:1237/v1"
max_tokens = 8192
temperature = 0.7
weight = 1.0
priority = 1
[routing]
strategy = "rule"
default_importance = "normal"
router_tier = "balanced"
"#;
toml::from_str(toml).expect("should parse TOML config")
}