use crate::auth::KeySet;
use crate::target::{
ConcurrencyGuard, ConcurrencyLimiter, FallbackConfig, LoadBalanceStrategy, RateLimiter,
RoutingAction, RoutingRule, Target,
};
use rand::Rng;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct ProviderPool {
providers: Vec<Provider>,
keys: Option<KeySet>,
pool_limiter: Option<Arc<dyn RateLimiter>>,
pool_concurrency_limiter: Option<ConcurrencyLimiter>,
fallback: Option<FallbackConfig>,
strategy: LoadBalanceStrategy,
trusted: bool,
routing_rules: Vec<RoutingRule>,
}
#[derive(Debug, Clone)]
pub struct Provider {
pub target: Target,
pub weight: u32,
limiter: ConcurrencyLimiter,
}
impl Provider {
pub fn new(target: Target, weight: u32) -> Self {
Self {
target,
weight,
limiter: ConcurrencyLimiter::new(),
}
}
pub fn with_concurrency_limit(target: Target, weight: u32, limit: usize) -> Self {
Self {
target,
weight,
limiter: ConcurrencyLimiter::with_limit(limit),
}
}
pub fn active_connections(&self) -> usize {
self.limiter.active()
}
}
impl ProviderPool {
pub fn new(providers: Vec<Provider>) -> Self {
Self {
providers,
keys: None,
pool_limiter: None,
pool_concurrency_limiter: None,
fallback: None,
strategy: LoadBalanceStrategy::default(),
trusted: false,
routing_rules: Vec::new(),
}
}
pub fn with_config(
providers: Vec<Provider>,
keys: Option<KeySet>,
pool_limiter: Option<Arc<dyn RateLimiter>>,
pool_concurrency_limiter: Option<ConcurrencyLimiter>,
fallback: Option<FallbackConfig>,
strategy: LoadBalanceStrategy,
trusted: bool,
routing_rules: Vec<RoutingRule>,
) -> Self {
Self {
providers,
keys,
pool_limiter,
pool_concurrency_limiter,
fallback,
strategy,
trusted,
routing_rules,
}
}
pub fn single(target: Target, weight: u32) -> Self {
Self::new(vec![Provider::new(target, weight)])
}
pub fn select(&self) -> Option<(usize, &Target, ConcurrencyGuard)> {
self.select_excluding(&HashSet::new())
}
pub fn select_iter(&self) -> SelectIter<'_> {
let with_replacement = self.fallback.as_ref().is_some_and(|f| f.with_replacement);
let max_attempts = self
.fallback
.as_ref()
.and_then(|f| f.max_attempts)
.unwrap_or(self.providers.len());
SelectIter {
pool: self,
excluded: HashSet::new(),
max_attempts,
attempts: 0,
with_replacement,
}
}
fn select_excluding(
&self,
exclude: &HashSet<usize>,
) -> Option<(usize, &Target, ConcurrencyGuard)> {
if self.providers.is_empty() {
return None;
}
match self.strategy {
LoadBalanceStrategy::Priority => self.select_priority(exclude),
LoadBalanceStrategy::WeightedRandom => self.select_least_connections(exclude),
}
}
fn select_priority(
&self,
exclude: &HashSet<usize>,
) -> Option<(usize, &Target, ConcurrencyGuard)> {
for (idx, provider) in self.providers.iter().enumerate() {
if exclude.contains(&idx) {
continue;
}
if let Some(guard) = provider.limiter.try_acquire() {
return Some((idx, &provider.target, guard));
}
}
None
}
fn select_least_connections(
&self,
exclude: &HashSet<usize>,
) -> Option<(usize, &Target, ConcurrencyGuard)> {
let mut best_score = f64::INFINITY;
let mut candidates: Vec<usize> = Vec::new();
for (idx, provider) in self.providers.iter().enumerate() {
if exclude.contains(&idx) {
continue;
}
if provider.limiter.at_capacity() {
continue;
}
let score = provider.limiter.active() as f64 / provider.weight as f64;
if score < best_score - f64::EPSILON {
best_score = score;
candidates.clear();
candidates.push(idx);
} else if (score - best_score).abs() < f64::EPSILON {
candidates.push(idx);
}
}
if candidates.is_empty() {
return None;
}
let selected = if candidates.len() == 1 {
candidates[0]
} else {
let mut rng = rand::rng();
let total_weight: u32 = candidates
.iter()
.map(|&idx| self.providers[idx].weight)
.sum();
let r: u32 = rng.random_range(0..total_weight);
let mut cumulative = 0;
let mut picked = candidates[0];
for &idx in &candidates {
cumulative += self.providers[idx].weight;
if r < cumulative {
picked = idx;
break;
}
}
picked
};
let provider = &self.providers[selected];
match provider.limiter.try_acquire() {
Some(guard) => Some((selected, &provider.target, guard)),
None => {
let mut new_exclude = exclude.clone();
new_exclude.insert(selected);
self.select_least_connections(&new_exclude)
}
}
}
pub fn providers(&self) -> &[Provider] {
&self.providers
}
pub fn len(&self) -> usize {
self.providers.len()
}
pub fn is_empty(&self) -> bool {
self.providers.is_empty()
}
pub fn first_target(&self) -> Option<&Target> {
self.providers.first().map(|p| &p.target)
}
pub fn keys(&self) -> Option<&KeySet> {
self.keys.as_ref()
}
pub fn pool_limiter(&self) -> Option<&Arc<dyn RateLimiter>> {
self.pool_limiter.as_ref()
}
pub fn pool_concurrency_limiter(&self) -> Option<&ConcurrencyLimiter> {
self.pool_concurrency_limiter.as_ref()
}
pub fn fallback(&self) -> Option<&FallbackConfig> {
self.fallback.as_ref()
}
pub fn fallback_enabled(&self) -> bool {
self.fallback.as_ref().is_some_and(|f| f.enabled)
}
pub fn should_fallback_on_status(&self, status_code: u16) -> bool {
self.fallback
.as_ref()
.is_some_and(|f| f.should_fallback_on_status(status_code))
}
pub fn should_fallback_on_rate_limit(&self) -> bool {
self.fallback
.as_ref()
.is_some_and(|f| f.enabled && f.on_rate_limit)
}
pub fn strategy(&self) -> LoadBalanceStrategy {
self.strategy
}
pub fn is_trusted(&self) -> bool {
self.trusted
}
pub fn routing_rules(&self) -> &[RoutingRule] {
&self.routing_rules
}
pub fn evaluate_routing_rules(
&self,
key_labels: &HashMap<String, String>,
) -> Option<&RoutingAction> {
self.routing_rules.iter().find_map(|rule| {
let matches = rule
.match_labels
.iter()
.all(|(k, v)| key_labels.get(k).is_some_and(|kv| kv == v));
matches.then_some(&rule.action)
})
}
pub fn adopt_provider_state(&mut self, old: &ProviderPool) {
for new_provider in &mut self.providers {
if let Some(old_provider) = old.providers.iter().find(|old_p| {
old_p.target.url == new_provider.target.url
&& old_p.target.onwards_key == new_provider.target.onwards_key
&& old_p.target.onwards_model == new_provider.target.onwards_model
}) {
new_provider.limiter.adopt_active_counter(&old_provider.limiter);
}
}
if let (Some(new_limiter), Some(old_limiter)) =
(&mut self.pool_concurrency_limiter, &old.pool_concurrency_limiter)
{
new_limiter.adopt_active_counter(old_limiter);
}
}
}
pub struct SelectIter<'a> {
pool: &'a ProviderPool,
excluded: HashSet<usize>,
max_attempts: usize,
attempts: usize,
with_replacement: bool,
}
impl<'a> Iterator for SelectIter<'a> {
type Item = (usize, &'a Target, ConcurrencyGuard);
fn next(&mut self) -> Option<Self::Item> {
if self.attempts >= self.max_attempts {
return None;
}
self.attempts += 1;
let result = self.pool.select_excluding(&self.excluded)?;
let should_exclude = match self.pool.strategy {
LoadBalanceStrategy::Priority => true,
LoadBalanceStrategy::WeightedRandom => !self.with_replacement,
};
if should_exclude {
self.excluded.insert(result.0);
}
Some(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::target::Target;
use std::collections::HashMap;
fn create_test_target(url: &str) -> Target {
Target::builder().url(url.parse().unwrap()).build()
}
#[test]
fn test_single_provider_pool() {
let target = create_test_target("https://api.example.com");
let pool = ProviderPool::single(target.clone(), 1);
assert_eq!(pool.len(), 1);
assert!(!pool.is_empty());
let selected = pool.select();
assert!(selected.is_some());
let (_, target, _guard) = selected.unwrap();
assert_eq!(target.url.as_str(), "https://api.example.com/");
}
#[test]
fn test_empty_pool_returns_none() {
let pool = ProviderPool::new(vec![]);
assert!(pool.is_empty());
assert!(pool.select().is_none());
}
#[test]
fn test_weighted_selection_distribution() {
let providers = vec![
Provider::new(create_test_target("https://api1.example.com"), 3),
Provider::new(create_test_target("https://api2.example.com"), 1),
];
let pool = ProviderPool::new(providers);
let mut counts: HashMap<String, usize> = HashMap::new();
for _ in 0..1000 {
if let Some((_, target, _guard)) = pool.select() {
*counts.entry(target.url.to_string()).or_insert(0) += 1;
}
}
let count1 = *counts.get("https://api1.example.com/").unwrap_or(&0);
let count2 = *counts.get("https://api2.example.com/").unwrap_or(&0);
let ratio = count1 as f64 / count2 as f64;
assert!(
ratio > 1.5 && ratio < 6.0,
"Expected ratio around 3.0, got {}",
ratio
);
}
#[test]
fn test_least_connections_prefers_less_loaded() {
let providers = vec![
Provider::new(create_test_target("https://api1.example.com"), 1),
Provider::new(create_test_target("https://api2.example.com"), 1),
];
let pool = ProviderPool::new(providers);
let (idx1, _, guard1) = pool.select().unwrap();
let (idx2, _, _guard2) = pool.select().unwrap();
assert_ne!(idx1, idx2, "Should pick the less loaded provider");
drop(guard1);
let (idx3, _, _guard3) = pool.select().unwrap();
assert_eq!(
idx3, idx1,
"Should pick the provider whose guard was dropped"
);
}
#[test]
fn test_weighted_least_connections_respects_weights() {
let providers = vec![
Provider::new(create_test_target("https://heavy.example.com"), 3),
Provider::new(create_test_target("https://light.example.com"), 1),
];
let pool = ProviderPool::new(providers);
let mut guards = Vec::new();
for _ in 0..40 {
if let Some((_, _, guard)) = pool.select() {
guards.push(guard);
}
}
let heavy_active = pool.providers()[0].active_connections();
let light_active = pool.providers()[1].active_connections();
let ratio = heavy_active as f64 / light_active as f64;
assert!(
ratio > 2.0 && ratio < 5.0,
"Expected ratio around 3.0, got {} (heavy={}, light={})",
ratio,
heavy_active,
light_active
);
}
#[test]
fn test_concurrency_limit_skips_full_provider() {
let providers = vec![
Provider::with_concurrency_limit(
create_test_target("https://limited.example.com"),
1,
1,
),
Provider::new(create_test_target("https://unlimited.example.com"), 1),
];
let pool = ProviderPool::new(providers);
let mut guard_on_limited = None;
for _ in 0..100 {
let (idx, _, guard) = pool.select().unwrap();
if idx == 0 {
guard_on_limited = Some(guard);
break;
}
}
assert!(
guard_on_limited.is_some(),
"Should eventually select the limited provider"
);
let (idx, _, _guard) = pool.select().unwrap();
assert_eq!(idx, 1, "Should skip the full provider");
}
#[test]
fn test_all_at_capacity_returns_none() {
let providers = vec![
Provider::with_concurrency_limit(create_test_target("https://a.example.com"), 1, 1),
Provider::with_concurrency_limit(create_test_target("https://b.example.com"), 1, 1),
];
let pool = ProviderPool::new(providers);
let (_, _, _g1) = pool.select().unwrap();
let (_, _, _g2) = pool.select().unwrap();
assert!(pool.select().is_none());
}
#[test]
fn test_first_target() {
let providers = vec![
Provider::new(create_test_target("https://api1.example.com"), 1),
Provider::new(create_test_target("https://api2.example.com"), 2),
];
let pool = ProviderPool::new(providers);
let first = pool.first_target();
assert!(first.is_some());
assert_eq!(first.unwrap().url.as_str(), "https://api1.example.com/");
}
#[test]
fn test_providers_accessor() {
let providers = vec![
Provider::new(create_test_target("https://api1.example.com"), 1),
Provider::new(create_test_target("https://api2.example.com"), 2),
];
let pool = ProviderPool::new(providers);
assert_eq!(pool.providers().len(), 2);
assert_eq!(pool.providers()[0].weight, 1);
assert_eq!(pool.providers()[1].weight, 2);
}
#[test]
fn test_select_iter_priority_strategy() {
use crate::target::LoadBalanceStrategy;
let providers = vec![
Provider::new(create_test_target("https://primary.example.com"), 1),
Provider::new(create_test_target("https://secondary.example.com"), 10),
Provider::new(create_test_target("https://tertiary.example.com"), 5),
];
let pool = ProviderPool::with_config(
providers,
None,
None,
None,
None,
LoadBalanceStrategy::Priority,
false,
Vec::new(),
);
let order: Vec<_> = pool.select_iter().collect();
assert_eq!(order.len(), 3);
assert_eq!(order[0].0, 0);
assert_eq!(order[1].0, 1);
assert_eq!(order[2].0, 2);
assert_eq!(order[0].1.url.as_str(), "https://primary.example.com/");
assert_eq!(order[1].1.url.as_str(), "https://secondary.example.com/");
assert_eq!(order[2].1.url.as_str(), "https://tertiary.example.com/");
}
#[test]
fn test_select_iter_priority_with_replacement_still_advances() {
use crate::target::{FallbackConfig, LoadBalanceStrategy};
let providers = vec![
Provider::new(create_test_target("https://primary.example.com"), 1),
Provider::new(create_test_target("https://secondary.example.com"), 1),
Provider::new(create_test_target("https://tertiary.example.com"), 1),
];
let fallback = Some(FallbackConfig {
enabled: true,
with_replacement: true,
max_attempts: Some(3),
..Default::default()
});
let pool = ProviderPool::with_config(
providers,
None,
None,
None,
fallback,
LoadBalanceStrategy::Priority,
false,
Vec::new(),
);
let order: Vec<_> = pool.select_iter().collect();
assert_eq!(order.len(), 3);
assert_eq!(order[0].1.url.as_str(), "https://primary.example.com/");
assert_eq!(order[1].1.url.as_str(), "https://secondary.example.com/");
assert_eq!(order[2].1.url.as_str(), "https://tertiary.example.com/");
}
#[test]
fn test_select_iter_weighted_random_includes_all() {
use crate::target::LoadBalanceStrategy;
let providers = vec![
Provider::new(create_test_target("https://api1.example.com"), 3),
Provider::new(create_test_target("https://api2.example.com"), 1),
];
let pool = ProviderPool::with_config(
providers,
None,
None,
None,
None,
LoadBalanceStrategy::WeightedRandom,
false,
Vec::new(),
);
let order: Vec<_> = pool.select_iter().collect();
assert_eq!(order.len(), 2);
let urls: std::collections::HashSet<_> =
order.iter().map(|(_, t, _)| t.url.as_str()).collect();
assert!(urls.contains("https://api1.example.com/"));
assert!(urls.contains("https://api2.example.com/"));
}
#[test]
fn test_select_iter_weighted_random_distribution() {
use crate::target::LoadBalanceStrategy;
let providers = vec![
Provider::new(create_test_target("https://heavy.example.com"), 9),
Provider::new(create_test_target("https://light.example.com"), 1),
];
let pool = ProviderPool::with_config(
providers,
None,
None,
None,
None,
LoadBalanceStrategy::WeightedRandom,
false,
Vec::new(),
);
let mut heavy_first = 0;
let iterations = 1000;
for _ in 0..iterations {
let order: Vec<_> = pool.select_iter().collect();
if order[0].1.url.as_str() == "https://heavy.example.com/" {
heavy_first += 1;
}
}
let percentage = (heavy_first * 100) / iterations;
assert!(
(80..=98).contains(&percentage),
"Expected heavy to be first ~90% of the time, got {}% ({}/{})",
percentage,
heavy_first,
iterations
);
}
#[test]
fn test_select_iter_empty_pool() {
let pool = ProviderPool::new(vec![]);
let order: Vec<_> = pool.select_iter().collect();
assert!(order.is_empty());
}
#[test]
fn test_select_iter_single_provider() {
use crate::target::LoadBalanceStrategy;
let providers = vec![Provider::new(
create_test_target("https://only.example.com"),
1,
)];
for strategy in [
LoadBalanceStrategy::Priority,
LoadBalanceStrategy::WeightedRandom,
] {
let pool = ProviderPool::with_config(
providers.clone(),
None,
None,
None,
None,
strategy,
false,
Vec::new(),
);
let order: Vec<_> = pool.select_iter().collect();
assert_eq!(order.len(), 1);
assert_eq!(order[0].1.url.as_str(), "https://only.example.com/");
}
}
#[test]
fn test_select_iter_with_replacement_allows_duplicates() {
use crate::target::{FallbackConfig, LoadBalanceStrategy};
let providers = vec![
Provider::new(create_test_target("https://api1.example.com"), 9),
Provider::new(create_test_target("https://api2.example.com"), 1),
];
let fallback = Some(FallbackConfig {
enabled: true,
with_replacement: true,
max_attempts: Some(5),
..Default::default()
});
let pool = ProviderPool::with_config(
providers,
None,
None,
None,
fallback,
LoadBalanceStrategy::WeightedRandom,
false,
Vec::new(),
);
let order: Vec<_> = pool.select_iter().collect();
assert_eq!(order.len(), 5);
let mut found_duplicate = false;
for _ in 0..100 {
let order: Vec<_> = pool.select_iter().collect();
let indices: Vec<usize> = order.iter().map(|(idx, _, _)| *idx).collect();
let unique: std::collections::HashSet<_> = indices.iter().collect();
if unique.len() < indices.len() {
found_duplicate = true;
break;
}
}
assert!(
found_duplicate,
"With replacement should allow the same provider to appear multiple times"
);
}
#[test]
fn test_select_iter_max_attempts_controls_length() {
use crate::target::{FallbackConfig, LoadBalanceStrategy};
let providers = vec![
Provider::new(create_test_target("https://api1.example.com"), 1),
Provider::new(create_test_target("https://api2.example.com"), 1),
Provider::new(create_test_target("https://api3.example.com"), 1),
];
let fallback = Some(FallbackConfig {
enabled: true,
max_attempts: Some(2),
..Default::default()
});
let pool = ProviderPool::with_config(
providers,
None,
None,
None,
fallback,
LoadBalanceStrategy::WeightedRandom,
false,
Vec::new(),
);
let order: Vec<_> = pool.select_iter().collect();
assert_eq!(
order.len(),
2,
"max_attempts should cap the ordering length"
);
let indices: std::collections::HashSet<_> = order.iter().map(|(idx, _, _)| *idx).collect();
assert_eq!(
indices.len(),
2,
"Without replacement, all entries should be unique"
);
}
#[test]
fn test_select_iter_max_attempts_with_priority() {
use crate::target::{FallbackConfig, LoadBalanceStrategy};
let providers = vec![
Provider::new(create_test_target("https://primary.example.com"), 1),
Provider::new(create_test_target("https://secondary.example.com"), 1),
Provider::new(create_test_target("https://tertiary.example.com"), 1),
];
let fallback = Some(FallbackConfig {
enabled: true,
max_attempts: Some(2),
..Default::default()
});
let pool = ProviderPool::with_config(
providers,
None,
None,
None,
fallback,
LoadBalanceStrategy::Priority,
false,
Vec::new(),
);
let order: Vec<_> = pool.select_iter().collect();
assert_eq!(order.len(), 2);
assert_eq!(order[0].1.url.as_str(), "https://primary.example.com/");
assert_eq!(order[1].1.url.as_str(), "https://secondary.example.com/");
}
#[test]
fn test_select_iter_defaults_preserve_behavior() {
use crate::target::LoadBalanceStrategy;
let providers = vec![
Provider::new(create_test_target("https://api1.example.com"), 3),
Provider::new(create_test_target("https://api2.example.com"), 1),
];
let pool = ProviderPool::with_config(
providers,
None,
None,
None,
None,
LoadBalanceStrategy::WeightedRandom,
false,
Vec::new(),
);
let order: Vec<_> = pool.select_iter().collect();
assert_eq!(order.len(), 2);
let urls: std::collections::HashSet<_> =
order.iter().map(|(_, t, _)| t.url.as_str()).collect();
assert!(urls.contains("https://api1.example.com/"));
assert!(urls.contains("https://api2.example.com/"));
}
#[test]
fn test_select_iter_with_replacement_single_provider() {
use crate::target::{FallbackConfig, LoadBalanceStrategy};
let providers = vec![Provider::new(
create_test_target("https://only.example.com"),
1,
)];
let fallback = Some(FallbackConfig {
enabled: true,
with_replacement: true,
max_attempts: Some(3),
..Default::default()
});
let pool = ProviderPool::with_config(
providers,
None,
None,
None,
fallback,
LoadBalanceStrategy::WeightedRandom,
false,
Vec::new(),
);
let order: Vec<_> = pool.select_iter().collect();
assert_eq!(
order.len(),
3,
"Single provider with replacement should repeat"
);
for (idx, target, _) in &order {
assert_eq!(*idx, 0);
assert_eq!(target.url.as_str(), "https://only.example.com/");
}
}
#[test]
fn test_select_iter_with_replacement_respects_weights() {
use crate::target::{FallbackConfig, LoadBalanceStrategy};
let providers = vec![
Provider::new(create_test_target("https://heavy.example.com"), 99),
Provider::new(create_test_target("https://light.example.com"), 1),
];
let fallback = Some(FallbackConfig {
enabled: true,
with_replacement: true,
max_attempts: Some(10),
..Default::default()
});
let pool = ProviderPool::with_config(
providers,
None,
None,
None,
fallback,
LoadBalanceStrategy::WeightedRandom,
false,
Vec::new(),
);
let mut heavy_count = 0;
let iterations = 100;
for _ in 0..iterations {
let order: Vec<_> = pool.select_iter().collect();
heavy_count += order
.iter()
.filter(|(_, t, _)| t.url.as_str() == "https://heavy.example.com/")
.count();
}
let total = iterations * 10;
let percentage = (heavy_count * 100) / total;
assert!(
percentage > 85,
"Heavy provider (99:1 weight) should appear >85% of the time, got {}%",
percentage
);
}
#[test]
fn test_evaluate_routing_rules_no_rules() {
let pool = ProviderPool::new(vec![Provider::new(
create_test_target("https://api.example.com"),
1,
)]);
let labels = HashMap::from([("purpose".to_string(), "batch".to_string())]);
assert!(pool.evaluate_routing_rules(&labels).is_none());
}
#[test]
fn test_evaluate_routing_rules_deny() {
use crate::target::{RoutingAction, RoutingRule};
let rules = vec![RoutingRule {
match_labels: HashMap::from([("purpose".to_string(), "playground".to_string())]),
action: RoutingAction::Deny,
}];
let pool = ProviderPool::with_config(
vec![Provider::new(
create_test_target("https://api.example.com"),
1,
)],
None,
None,
None,
None,
LoadBalanceStrategy::default(),
false,
rules,
);
let labels = HashMap::from([("purpose".to_string(), "playground".to_string())]);
assert!(matches!(
pool.evaluate_routing_rules(&labels),
Some(RoutingAction::Deny)
));
let labels = HashMap::from([("purpose".to_string(), "batch".to_string())]);
assert!(pool.evaluate_routing_rules(&labels).is_none());
assert!(pool.evaluate_routing_rules(&HashMap::new()).is_none());
}
#[test]
fn test_evaluate_routing_rules_redirect() {
use crate::target::{RoutingAction, RoutingRule};
let rules = vec![RoutingRule {
match_labels: HashMap::from([("purpose".to_string(), "batch".to_string())]),
action: RoutingAction::Redirect {
target: "gpt-4o-mini".to_string(),
},
}];
let pool = ProviderPool::with_config(
vec![Provider::new(
create_test_target("https://api.example.com"),
1,
)],
None,
None,
None,
None,
LoadBalanceStrategy::default(),
false,
rules,
);
let labels = HashMap::from([("purpose".to_string(), "batch".to_string())]);
match pool.evaluate_routing_rules(&labels) {
Some(RoutingAction::Redirect { target }) => {
assert_eq!(target, "gpt-4o-mini");
}
other => panic!("Expected Redirect, got {:?}", other),
}
}
#[test]
fn test_evaluate_routing_rules_first_match_wins() {
use crate::target::{RoutingAction, RoutingRule};
let rules = vec![
RoutingRule {
match_labels: HashMap::from([("purpose".to_string(), "batch".to_string())]),
action: RoutingAction::Deny,
},
RoutingRule {
match_labels: HashMap::from([("purpose".to_string(), "batch".to_string())]),
action: RoutingAction::Redirect {
target: "other".to_string(),
},
},
];
let pool = ProviderPool::with_config(
vec![Provider::new(
create_test_target("https://api.example.com"),
1,
)],
None,
None,
None,
None,
LoadBalanceStrategy::default(),
false,
rules,
);
let labels = HashMap::from([("purpose".to_string(), "batch".to_string())]);
assert!(matches!(
pool.evaluate_routing_rules(&labels),
Some(RoutingAction::Deny)
));
}
#[test]
fn test_evaluate_routing_rules_multiple_label_conditions() {
use crate::target::{RoutingAction, RoutingRule};
let rules = vec![RoutingRule {
match_labels: HashMap::from([
("purpose".to_string(), "batch".to_string()),
("tier".to_string(), "free".to_string()),
]),
action: RoutingAction::Deny,
}];
let pool = ProviderPool::with_config(
vec![Provider::new(
create_test_target("https://api.example.com"),
1,
)],
None,
None,
None,
None,
LoadBalanceStrategy::default(),
false,
rules,
);
let labels = HashMap::from([
("purpose".to_string(), "batch".to_string()),
("tier".to_string(), "free".to_string()),
]);
assert!(matches!(
pool.evaluate_routing_rules(&labels),
Some(RoutingAction::Deny)
));
let labels = HashMap::from([("purpose".to_string(), "batch".to_string())]);
assert!(pool.evaluate_routing_rules(&labels).is_none());
let labels = HashMap::from([
("purpose".to_string(), "batch".to_string()),
("tier".to_string(), "free".to_string()),
("org".to_string(), "acme".to_string()),
]);
assert!(matches!(
pool.evaluate_routing_rules(&labels),
Some(RoutingAction::Deny)
));
}
#[test]
fn test_adopt_provider_state_preserves_active_counts() {
let providers = vec![
Provider::new(create_test_target("https://api1.example.com"), 3),
Provider::new(create_test_target("https://api2.example.com"), 1),
];
let old_pool = ProviderPool::new(providers);
let mut guards = Vec::new();
for _ in 0..8 {
if let Some((_, _, guard)) = old_pool.select() {
guards.push(guard);
}
}
let old_active_0 = old_pool.providers()[0].active_connections();
let old_active_1 = old_pool.providers()[1].active_connections();
assert!(old_active_0 > 0, "Should have active connections on provider 0");
assert!(old_active_1 > 0, "Should have active connections on provider 1");
let new_providers = vec![
Provider::new(create_test_target("https://api1.example.com"), 3),
Provider::new(create_test_target("https://api2.example.com"), 1),
];
let mut new_pool = ProviderPool::new(new_providers);
assert_eq!(new_pool.providers()[0].active_connections(), 0);
assert_eq!(new_pool.providers()[1].active_connections(), 0);
new_pool.adopt_provider_state(&old_pool);
assert_eq!(new_pool.providers()[0].active_connections(), old_active_0);
assert_eq!(new_pool.providers()[1].active_connections(), old_active_1);
guards.pop();
let total_after = new_pool.providers()[0].active_connections()
+ new_pool.providers()[1].active_connections();
assert_eq!(total_after, old_active_0 + old_active_1 - 1);
}
#[test]
fn test_adopt_provider_state_new_provider_starts_at_zero() {
let old_providers = vec![
Provider::new(create_test_target("https://api1.example.com"), 1),
];
let old_pool = ProviderPool::new(old_providers);
let _guards: Vec<_> = (0..5)
.filter_map(|_| old_pool.select().map(|(_, _, g)| g))
.collect();
let new_providers = vec![
Provider::new(create_test_target("https://api1.example.com"), 1),
Provider::new(create_test_target("https://api2.example.com"), 1),
];
let mut new_pool = ProviderPool::new(new_providers);
new_pool.adopt_provider_state(&old_pool);
assert_eq!(new_pool.providers()[0].active_connections(), 5);
assert_eq!(new_pool.providers()[1].active_connections(), 0);
}
#[test]
fn test_adopt_provider_state_removed_provider_ignored() {
let old_providers = vec![
Provider::new(create_test_target("https://api1.example.com"), 1),
Provider::new(create_test_target("https://api2.example.com"), 1),
];
let old_pool = ProviderPool::new(old_providers);
let _guards: Vec<_> = (0..4)
.filter_map(|_| old_pool.select().map(|(_, _, g)| g))
.collect();
let new_providers = vec![
Provider::new(create_test_target("https://api1.example.com"), 1),
];
let mut new_pool = ProviderPool::new(new_providers);
new_pool.adopt_provider_state(&old_pool);
assert!(new_pool.providers()[0].active_connections() > 0);
assert_eq!(new_pool.providers().len(), 1);
}
#[test]
fn test_adopt_provider_state_preserves_pool_concurrency_limiter() {
use crate::target::ConcurrencyLimiter;
let old_pool = ProviderPool::with_config(
vec![Provider::new(create_test_target("https://api1.example.com"), 1)],
None,
None,
Some(ConcurrencyLimiter::with_limit(100)),
None,
LoadBalanceStrategy::default(),
false,
Vec::new(),
);
let _guard1 = old_pool.pool_concurrency_limiter().unwrap().try_acquire();
let _guard2 = old_pool.pool_concurrency_limiter().unwrap().try_acquire();
assert_eq!(old_pool.pool_concurrency_limiter().unwrap().active(), 2);
let mut new_pool = ProviderPool::with_config(
vec![Provider::new(create_test_target("https://api1.example.com"), 1)],
None,
None,
Some(ConcurrencyLimiter::with_limit(200)), None,
LoadBalanceStrategy::default(),
false,
Vec::new(),
);
new_pool.adopt_provider_state(&old_pool);
assert_eq!(new_pool.pool_concurrency_limiter().unwrap().active(), 2);
assert_eq!(new_pool.pool_concurrency_limiter().unwrap().limit(), Some(200));
}
}