use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use crate::provider::ProviderType;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum RoutingStrategy {
#[default]
Priority,
RoundRobin,
LowestLatency,
Direct,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderRoute {
pub provider: ProviderType,
pub priority: u32,
pub model_patterns: Vec<String>,
pub enabled: bool,
pub base_url: String,
#[serde(default)]
pub api_key: Option<String>,
#[serde(default)]
pub max_tokens_limit: Option<u32>,
#[serde(default)]
pub rate_limit_rpm: Option<u32>,
#[serde(skip)]
pub tls_config: Option<crate::provider::TlsConfig>,
}
pub struct Router {
routes: Vec<ProviderRoute>,
strategy: RoutingStrategy,
round_robin_index: std::sync::atomic::AtomicUsize,
latencies: Arc<DashMap<(ProviderType, String), AtomicU64>>,
health_status: Option<crate::health::HealthMap>,
}
impl Router {
pub fn new(mut routes: Vec<ProviderRoute>, strategy: RoutingStrategy) -> Self {
routes.sort_by_key(|r| r.priority);
Self {
routes,
strategy,
round_robin_index: std::sync::atomic::AtomicUsize::new(0),
latencies: Arc::new(DashMap::new()),
health_status: None,
}
}
pub fn set_health_map(&mut self, map: crate::health::HealthMap) {
self.health_status = Some(map);
}
fn is_provider_healthy(&self, provider: ProviderType, base_url: &str) -> bool {
match &self.health_status {
None => true,
Some(map) => {
let key = (provider, base_url.to_string());
match map.get(&key) {
None => true,
Some(state) => state.is_healthy,
}
}
}
}
pub fn select(&self, model: &str) -> Option<&ProviderRoute> {
let candidates: Vec<&ProviderRoute> = self
.routes
.iter()
.filter(|r| {
r.enabled
&& self.matches_model(r, model)
&& self.is_provider_healthy(r.provider, &r.base_url)
})
.collect();
if candidates.is_empty() {
return None;
}
match self.strategy {
RoutingStrategy::Priority | RoutingStrategy::Direct => candidates.first().copied(),
RoutingStrategy::LowestLatency => {
let mut best: Option<(&ProviderRoute, u64)> = None;
for c in &candidates {
let key = (c.provider, c.base_url.clone());
let latency = self
.latencies
.get(&key)
.map(|e| e.value().load(Ordering::Relaxed))
.unwrap_or(u64::MAX);
if best
.as_ref()
.is_none_or(|(_, best_lat)| latency < *best_lat)
{
best = Some((c, latency));
}
}
best.map(|(r, _)| r)
}
RoutingStrategy::RoundRobin => {
let idx = self
.round_robin_index
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Some(candidates[idx % candidates.len()])
}
}
}
#[cfg(feature = "dlp")]
pub fn select_with_classification(
&self,
model: &str,
classification: crate::dlp::ClassificationLevel,
) -> Option<&ProviderRoute> {
use crate::dlp::ClassificationLevel;
if classification == ClassificationLevel::Restricted {
return None;
}
let require_local = classification >= ClassificationLevel::Confidential;
let candidates: Vec<&ProviderRoute> = self
.routes
.iter()
.filter(|r| {
r.enabled
&& self.matches_model(r, model)
&& self.is_provider_healthy(r.provider, &r.base_url)
&& (!require_local || r.provider.is_local())
})
.collect();
if candidates.is_empty() {
return None;
}
match self.strategy {
RoutingStrategy::Priority | RoutingStrategy::Direct => candidates.first().copied(),
RoutingStrategy::LowestLatency => {
let mut best: Option<(&ProviderRoute, u64)> = None;
for c in &candidates {
let key = (c.provider, c.base_url.clone());
let latency = self
.latencies
.get(&key)
.map(|e| e.value().load(Ordering::Relaxed))
.unwrap_or(u64::MAX);
if best
.as_ref()
.is_none_or(|(_, best_lat)| latency < *best_lat)
{
best = Some((c, latency));
}
}
best.map(|(r, _)| r)
}
RoutingStrategy::RoundRobin => {
let idx = self
.round_robin_index
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Some(candidates[idx % candidates.len()])
}
}
}
#[cfg(feature = "hwaccel")]
#[must_use]
pub fn select_with_hardware(
&self,
model: &str,
hw: &crate::hardware::HardwareManager,
model_params: Option<u64>,
vram_reserve: u64,
) -> Option<&ProviderRoute> {
let model_params = match model_params {
Some(p) if p > 0 => p,
_ => return self.select(model),
};
let all_candidates: Vec<&ProviderRoute> = self
.routes
.iter()
.filter(|r| {
r.enabled
&& self.matches_model(r, model)
&& self.is_provider_healthy(r.provider, &r.base_url)
})
.collect();
if all_candidates.is_empty() {
return None;
}
let fits = hw.fits_model(model_params, vram_reserve);
let mut prioritized: Vec<&ProviderRoute> = Vec::with_capacity(all_candidates.len());
let mut deprioritized: Vec<&ProviderRoute> = Vec::new();
for c in &all_candidates {
if c.provider.is_local() && !fits {
tracing::debug!(
provider = %c.provider,
model,
"local provider deprioritized: model exceeds available VRAM"
);
deprioritized.push(c);
} else {
prioritized.push(c);
}
}
prioritized.extend(deprioritized);
if prioritized.is_empty() {
return None;
}
match self.strategy {
RoutingStrategy::Priority | RoutingStrategy::Direct => prioritized.first().copied(),
RoutingStrategy::LowestLatency => {
let mut best: Option<(&ProviderRoute, u64)> = None;
for c in &prioritized {
let key = (c.provider, c.base_url.clone());
let latency = self
.latencies
.get(&key)
.map(|e| e.value().load(Ordering::Relaxed))
.unwrap_or(u64::MAX);
if best
.as_ref()
.is_none_or(|(_, best_lat)| latency < *best_lat)
{
best = Some((c, latency));
}
}
best.map(|(r, _)| r)
}
RoutingStrategy::RoundRobin => {
let idx = self
.round_robin_index
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Some(prioritized[idx % prioritized.len()])
}
}
}
pub fn report_latency(&self, provider: ProviderType, base_url: &str, latency_ms: u64) {
let key = (provider, base_url.to_string());
self.latencies
.entry(key)
.and_modify(|existing| {
let old = existing.load(Ordering::Relaxed);
let new_avg = (old * 7 + latency_ms * 3) / 10;
existing.store(new_avg, Ordering::Relaxed);
})
.or_insert_with(|| AtomicU64::new(latency_ms));
}
pub fn reload(&mut self, mut routes: Vec<ProviderRoute>, strategy: RoutingStrategy) {
routes.sort_by_key(|r| r.priority);
self.routes = routes;
self.strategy = strategy;
self.round_robin_index
.store(0, std::sync::atomic::Ordering::Relaxed);
}
pub fn routes(&self) -> &[ProviderRoute] {
&self.routes
}
fn matches_model(&self, route: &ProviderRoute, model: &str) -> bool {
if route.model_patterns.is_empty() {
return true; }
route.model_patterns.iter().any(|pattern| {
if pattern.contains('*') {
let prefix = pattern.trim_end_matches('*');
model.starts_with(prefix)
} else {
model == pattern
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_route(provider: ProviderType, priority: u32, patterns: Vec<&str>) -> ProviderRoute {
ProviderRoute {
provider,
priority,
model_patterns: patterns.into_iter().map(String::from).collect(),
enabled: true,
base_url: "http://localhost".into(),
api_key: None,
max_tokens_limit: None,
rate_limit_rpm: None,
tls_config: None,
}
}
#[test]
fn priority_routing() {
let routes = vec![
make_route(ProviderType::Ollama, 1, vec!["llama*"]),
make_route(ProviderType::OpenAi, 2, vec!["gpt-*"]),
];
let router = Router::new(routes, RoutingStrategy::Priority);
let selected = router.select("llama3").unwrap();
assert_eq!(selected.provider, ProviderType::Ollama);
let selected = router.select("gpt-4o").unwrap();
assert_eq!(selected.provider, ProviderType::OpenAi);
}
#[test]
fn wildcard_route() {
let routes = vec![make_route(ProviderType::Ollama, 1, vec![])];
let router = Router::new(routes, RoutingStrategy::Priority);
assert!(router.select("anything").is_some());
}
#[test]
fn no_matching_provider() {
let routes = vec![make_route(ProviderType::Ollama, 1, vec!["llama*"])];
let router = Router::new(routes, RoutingStrategy::Priority);
assert!(router.select("gpt-4o").is_none());
}
#[test]
fn disabled_route_skipped() {
let mut route = make_route(ProviderType::Ollama, 1, vec![]);
route.enabled = false;
let router = Router::new(vec![route], RoutingStrategy::Priority);
assert!(router.select("llama3").is_none());
}
#[test]
fn round_robin() {
let routes = vec![
make_route(ProviderType::Ollama, 1, vec![]),
make_route(ProviderType::LlamaCpp, 1, vec![]),
];
let router = Router::new(routes, RoutingStrategy::RoundRobin);
let first = router.select("llama3").unwrap().provider;
let second = router.select("llama3").unwrap().provider;
assert_ne!(first, second);
}
#[test]
fn routing_strategy_default() {
assert_eq!(RoutingStrategy::default(), RoutingStrategy::Priority);
}
fn make_route_with_url(
provider: ProviderType,
priority: u32,
patterns: Vec<&str>,
base_url: &str,
) -> ProviderRoute {
ProviderRoute {
provider,
priority,
model_patterns: patterns.into_iter().map(String::from).collect(),
enabled: true,
base_url: base_url.to_string(),
api_key: None,
max_tokens_limit: None,
rate_limit_rpm: None,
tls_config: None,
}
}
#[test]
fn report_latency_records_values() {
let routes = vec![make_route(ProviderType::Ollama, 1, vec![])];
let router = Router::new(routes, RoutingStrategy::LowestLatency);
let key = (ProviderType::Ollama, "http://localhost".to_string());
router.report_latency(ProviderType::Ollama, "http://localhost", 100);
{
let latency = router.latencies.get(&key).unwrap();
assert_eq!(latency.value().load(Ordering::Relaxed), 100);
}
router.report_latency(ProviderType::Ollama, "http://localhost", 200);
{
let latency = router.latencies.get(&key).unwrap();
assert_eq!(latency.value().load(Ordering::Relaxed), 130);
}
}
#[test]
fn lowest_latency_picks_fastest() {
let routes = vec![
make_route_with_url(ProviderType::Ollama, 2, vec![], "http://ollama"),
make_route_with_url(ProviderType::LlamaCpp, 1, vec![], "http://llamacpp"),
];
let router = Router::new(routes, RoutingStrategy::LowestLatency);
router.report_latency(ProviderType::LlamaCpp, "http://llamacpp", 500);
router.report_latency(ProviderType::Ollama, "http://ollama", 50);
let selected = router.select("any-model").unwrap();
assert_eq!(selected.provider, ProviderType::Ollama);
}
#[test]
fn lowest_latency_deprioritizes_unknown() {
let routes = vec![
make_route_with_url(ProviderType::Ollama, 1, vec![], "http://ollama"),
make_route_with_url(ProviderType::LlamaCpp, 2, vec![], "http://llamacpp"),
];
let router = Router::new(routes, RoutingStrategy::LowestLatency);
router.report_latency(ProviderType::LlamaCpp, "http://llamacpp", 100);
let selected = router.select("any-model").unwrap();
assert_eq!(selected.provider, ProviderType::LlamaCpp);
}
#[test]
fn select_no_health_map_allows_all() {
let routes = vec![
make_route(ProviderType::Ollama, 1, vec![]),
make_route(ProviderType::LlamaCpp, 2, vec![]),
];
let router = Router::new(routes, RoutingStrategy::Priority);
assert!(router.select("any-model").is_some());
assert_eq!(
router.select("any-model").unwrap().provider,
ProviderType::Ollama
);
}
#[test]
fn select_filters_unhealthy_providers() {
let routes = vec![
make_route(ProviderType::Ollama, 1, vec![]),
make_route(ProviderType::LlamaCpp, 2, vec![]),
];
let mut router = Router::new(routes, RoutingStrategy::Priority);
let health_map = crate::health::new_health_map();
health_map.insert(
(ProviderType::Ollama, "http://localhost".to_string()),
crate::health::ProviderHealthState {
is_healthy: false,
last_check: std::time::Instant::now(),
consecutive_failures: 3,
last_error: Some("connection refused".into()),
},
);
router.set_health_map(health_map);
let selected = router.select("any-model").unwrap();
assert_eq!(selected.provider, ProviderType::LlamaCpp);
}
#[test]
fn select_unchecked_provider_assumed_healthy() {
let routes = vec![make_route(ProviderType::Ollama, 1, vec![])];
let mut router = Router::new(routes, RoutingStrategy::Priority);
let health_map = crate::health::new_health_map();
router.set_health_map(health_map);
assert!(router.select("any-model").is_some());
}
#[test]
fn select_all_unhealthy_returns_none() {
let routes = vec![make_route(ProviderType::Ollama, 1, vec![])];
let mut router = Router::new(routes, RoutingStrategy::Priority);
let health_map = crate::health::new_health_map();
health_map.insert(
(ProviderType::Ollama, "http://localhost".to_string()),
crate::health::ProviderHealthState {
is_healthy: false,
last_check: std::time::Instant::now(),
consecutive_failures: 5,
last_error: Some("down".into()),
},
);
router.set_health_map(health_map);
assert!(router.select("any-model").is_none());
}
#[test]
fn reload_changes_routes() {
let routes = vec![make_route(ProviderType::Ollama, 1, vec!["llama*"])];
let mut router = Router::new(routes, RoutingStrategy::Priority);
assert_eq!(router.routes().len(), 1);
assert!(router.select("llama3").is_some());
assert!(router.select("gpt-4o").is_none());
let new_routes = vec![
make_route(ProviderType::OpenAi, 1, vec!["gpt-*"]),
make_route(ProviderType::LlamaCpp, 2, vec!["gguf-*"]),
];
router.reload(new_routes, RoutingStrategy::RoundRobin);
assert_eq!(router.routes().len(), 2);
assert!(router.select("gpt-4o").is_some());
assert!(router.select("llama3").is_none());
}
#[test]
fn reload_resets_round_robin_index() {
let routes = vec![
make_route(ProviderType::Ollama, 1, vec![]),
make_route(ProviderType::LlamaCpp, 1, vec![]),
];
let mut router = Router::new(routes, RoutingStrategy::RoundRobin);
let _ = router.select("any");
let _ = router.select("any");
let _ = router.select("any");
let new_routes = vec![
make_route(ProviderType::Ollama, 1, vec![]),
make_route(ProviderType::LlamaCpp, 1, vec![]),
];
router.reload(new_routes, RoutingStrategy::RoundRobin);
let first = router.select("any").unwrap().provider;
let second = router.select("any").unwrap().provider;
assert_ne!(first, second, "round-robin should alternate after reload");
}
#[test]
fn rwlock_concurrent_reads() {
use std::sync::{Arc, RwLock};
let routes = vec![make_route(ProviderType::Ollama, 1, vec![])];
let router = Arc::new(RwLock::new(Router::new(routes, RoutingStrategy::Priority)));
let r1 = router.read().unwrap();
let r2 = router.read().unwrap();
assert_eq!(r1.routes().len(), 1);
assert_eq!(r2.routes().len(), 1);
}
#[cfg(feature = "dlp")]
#[test]
fn select_with_classification_restricted_returns_none() {
let routes = vec![make_route(ProviderType::Ollama, 1, vec![])];
let router = Router::new(routes, RoutingStrategy::Priority);
assert!(
router
.select_with_classification("any", crate::dlp::ClassificationLevel::Restricted)
.is_none()
);
}
#[cfg(feature = "dlp")]
#[test]
fn select_with_classification_public_allows_remote() {
let routes = vec![
make_route(ProviderType::OpenAi, 1, vec![]),
make_route(ProviderType::Ollama, 2, vec![]),
];
let router = Router::new(routes, RoutingStrategy::Priority);
let selected = router
.select_with_classification("any", crate::dlp::ClassificationLevel::Public)
.unwrap();
assert_eq!(selected.provider, ProviderType::OpenAi);
}
#[cfg(feature = "dlp")]
#[test]
fn select_with_classification_confidential_forces_local() {
let routes = vec![
make_route(ProviderType::OpenAi, 1, vec![]),
make_route(ProviderType::Ollama, 2, vec![]),
];
let router = Router::new(routes, RoutingStrategy::Priority);
let selected = router
.select_with_classification("any", crate::dlp::ClassificationLevel::Confidential)
.unwrap();
assert_eq!(selected.provider, ProviderType::Ollama);
}
#[cfg(feature = "dlp")]
#[test]
fn select_with_classification_confidential_no_local_returns_none() {
let routes = vec![make_route(ProviderType::OpenAi, 1, vec![])];
let router = Router::new(routes, RoutingStrategy::Priority);
let result =
router.select_with_classification("any", crate::dlp::ClassificationLevel::Confidential);
assert!(result.is_none());
}
#[cfg(feature = "dlp")]
#[test]
fn select_with_classification_round_robin() {
let routes = vec![
make_route(ProviderType::Ollama, 1, vec![]),
make_route(ProviderType::LlamaCpp, 1, vec![]),
];
let router = Router::new(routes, RoutingStrategy::RoundRobin);
let first = router
.select_with_classification("any", crate::dlp::ClassificationLevel::Public)
.unwrap()
.provider;
let second = router
.select_with_classification("any", crate::dlp::ClassificationLevel::Public)
.unwrap()
.provider;
assert_ne!(first, second);
}
#[cfg(feature = "dlp")]
#[test]
fn select_with_classification_lowest_latency() {
let routes = vec![
make_route_with_url(ProviderType::Ollama, 1, vec![], "http://ollama"),
make_route_with_url(ProviderType::LlamaCpp, 1, vec![], "http://llamacpp"),
];
let router = Router::new(routes, RoutingStrategy::LowestLatency);
router.report_latency(ProviderType::Ollama, "http://ollama", 50);
router.report_latency(ProviderType::LlamaCpp, "http://llamacpp", 500);
let selected = router
.select_with_classification("any", crate::dlp::ClassificationLevel::Public)
.unwrap();
assert_eq!(selected.provider, ProviderType::Ollama);
}
#[test]
fn direct_routing_picks_first() {
let routes = vec![
make_route(ProviderType::Ollama, 1, vec![]),
make_route(ProviderType::LlamaCpp, 2, vec![]),
];
let router = Router::new(routes, RoutingStrategy::Direct);
let selected = router.select("any-model").unwrap();
assert_eq!(selected.provider, ProviderType::Ollama);
}
#[test]
fn exact_model_match() {
let routes = vec![make_route(ProviderType::Ollama, 1, vec!["llama3.2:latest"])];
let router = Router::new(routes, RoutingStrategy::Priority);
assert!(router.select("llama3.2:latest").is_some());
assert!(router.select("llama3.2").is_none());
}
#[test]
fn multiple_model_patterns() {
let routes = vec![make_route(
ProviderType::Ollama,
1,
vec!["llama*", "mistral*"],
)];
let router = Router::new(routes, RoutingStrategy::Priority);
assert!(router.select("llama3").is_some());
assert!(router.select("mistral-7b").is_some());
assert!(router.select("gpt-4o").is_none());
}
#[test]
fn report_latency_ema_multiple_updates() {
let routes = vec![make_route(ProviderType::Ollama, 1, vec![])];
let router = Router::new(routes, RoutingStrategy::LowestLatency);
let key = (ProviderType::Ollama, "http://localhost".to_string());
router.report_latency(ProviderType::Ollama, "http://localhost", 100);
router.report_latency(ProviderType::Ollama, "http://localhost", 200);
router.report_latency(ProviderType::Ollama, "http://localhost", 50);
let latency = router
.latencies
.get(&key)
.unwrap()
.value()
.load(Ordering::Relaxed);
assert_eq!(latency, 106);
}
#[cfg(feature = "hwaccel")]
#[test]
fn select_with_hardware_no_params_falls_through() {
let routes = vec![make_route(ProviderType::Ollama, 1, vec![])];
let router = Router::new(routes, RoutingStrategy::Priority);
let hw = crate::hardware::HardwareManager::detect();
let selected = router.select_with_hardware("any", &hw, None, 0).unwrap();
assert_eq!(selected.provider, ProviderType::Ollama);
}
#[cfg(feature = "hwaccel")]
#[test]
fn select_with_hardware_deprioritizes_local_when_no_vram() {
let routes = vec![
make_route(ProviderType::Ollama, 1, vec![]),
make_route(ProviderType::OpenAi, 2, vec![]),
];
let router = Router::new(routes, RoutingStrategy::Priority);
let hw = crate::hardware::HardwareManager::detect();
if !hw.has_accelerator() {
let selected = router
.select_with_hardware("any", &hw, Some(200_000_000_000), 0)
.unwrap();
assert_eq!(selected.provider, ProviderType::OpenAi);
}
}
#[cfg(feature = "hwaccel")]
#[test]
fn select_with_hardware_local_still_available_as_fallback() {
let routes = vec![make_route(ProviderType::Ollama, 1, vec![])];
let router = Router::new(routes, RoutingStrategy::Priority);
let hw = crate::hardware::HardwareManager::detect();
if !hw.has_accelerator() {
let selected = router
.select_with_hardware("any", &hw, Some(200_000_000_000), 0)
.unwrap();
assert_eq!(selected.provider, ProviderType::Ollama);
}
}
#[cfg(feature = "hwaccel")]
#[test]
fn select_with_hardware_zero_params_falls_through() {
let routes = vec![make_route(ProviderType::Ollama, 1, vec![])];
let router = Router::new(routes, RoutingStrategy::Priority);
let hw = crate::hardware::HardwareManager::detect();
let selected = router.select_with_hardware("any", &hw, Some(0), 0).unwrap();
assert_eq!(selected.provider, ProviderType::Ollama);
}
}