use crate::auth::KeySet;
use crate::target::{ConcurrencyLimiter, FallbackConfig, LoadBalanceStrategy, RateLimiter, Target};
use rand::Rng;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct ProviderPool {
providers: Vec<Provider>,
total_weight: u32,
keys: Option<KeySet>,
pool_limiter: Option<Arc<dyn RateLimiter>>,
pool_concurrency_limiter: Option<Arc<dyn ConcurrencyLimiter>>,
fallback: Option<FallbackConfig>,
strategy: LoadBalanceStrategy,
trusted: bool,
}
#[derive(Debug, Clone)]
pub struct Provider {
pub target: Target,
pub weight: u32,
}
impl ProviderPool {
pub fn new(providers: Vec<Provider>) -> Self {
let total_weight = providers.iter().map(|p| p.weight).sum();
Self {
providers,
total_weight,
keys: None,
pool_limiter: None,
pool_concurrency_limiter: None,
fallback: None,
strategy: LoadBalanceStrategy::default(),
trusted: false,
}
}
pub fn with_config(
providers: Vec<Provider>,
keys: Option<KeySet>,
pool_limiter: Option<Arc<dyn RateLimiter>>,
pool_concurrency_limiter: Option<Arc<dyn ConcurrencyLimiter>>,
fallback: Option<FallbackConfig>,
strategy: LoadBalanceStrategy,
trusted: bool,
) -> Self {
let total_weight = providers.iter().map(|p| p.weight).sum();
Self {
providers,
total_weight,
keys,
pool_limiter,
pool_concurrency_limiter,
fallback,
strategy,
trusted,
}
}
pub fn single(target: Target, weight: u32) -> Self {
Self::new(vec![Provider { target, weight }])
}
pub fn select(&self) -> Option<&Target> {
if self.providers.is_empty() {
return None;
}
if self.providers.len() == 1 {
return Some(&self.providers[0].target);
}
let mut rng = rand::rng();
let random_weight: u32 = rng.random_range(0..self.total_weight);
let mut cumulative_weight = 0;
let mut selected_idx = 0;
for (idx, provider) in self.providers.iter().enumerate() {
cumulative_weight += provider.weight;
if random_weight < cumulative_weight {
selected_idx = idx;
break;
}
}
let selected = &self.providers[selected_idx];
if !is_rate_limited(&selected.target) {
return Some(&selected.target);
}
let mut indices: Vec<usize> = (0..self.providers.len())
.filter(|&i| i != selected_idx)
.collect();
indices.sort_by(|&a, &b| self.providers[b].weight.cmp(&self.providers[a].weight));
for idx in indices {
let provider = &self.providers[idx];
if !is_rate_limited(&provider.target) {
return Some(&provider.target);
}
}
Some(&self.providers[selected_idx].target)
}
pub fn providers(&self) -> &[Provider] {
&self.providers
}
pub fn len(&self) -> usize {
self.providers.len()
}
pub fn is_empty(&self) -> bool {
self.providers.is_empty()
}
pub fn first_target(&self) -> Option<&Target> {
self.providers.first().map(|p| &p.target)
}
pub fn keys(&self) -> Option<&KeySet> {
self.keys.as_ref()
}
pub fn pool_limiter(&self) -> Option<&Arc<dyn RateLimiter>> {
self.pool_limiter.as_ref()
}
pub fn pool_concurrency_limiter(&self) -> Option<&Arc<dyn ConcurrencyLimiter>> {
self.pool_concurrency_limiter.as_ref()
}
pub fn fallback(&self) -> Option<&FallbackConfig> {
self.fallback.as_ref()
}
pub fn fallback_enabled(&self) -> bool {
self.fallback.as_ref().is_some_and(|f| f.enabled)
}
pub fn should_fallback_on_status(&self, status_code: u16) -> bool {
self.fallback
.as_ref()
.is_some_and(|f| f.should_fallback_on_status(status_code))
}
pub fn should_fallback_on_rate_limit(&self) -> bool {
self.fallback
.as_ref()
.is_some_and(|f| f.enabled && f.on_rate_limit)
}
pub fn strategy(&self) -> LoadBalanceStrategy {
self.strategy
}
pub fn is_trusted(&self) -> bool {
self.trusted
}
pub fn select_ordered(&self) -> impl Iterator<Item = (usize, &Target)> {
let with_replacement = self.fallback.as_ref().is_some_and(|f| f.with_replacement);
let max_attempts = self.fallback.as_ref().and_then(|f| f.max_attempts);
let attempt_count = max_attempts.unwrap_or(self.providers.len());
let mut order = Vec::with_capacity(attempt_count);
if self.providers.is_empty() {
return order.into_iter();
}
if self.providers.len() == 1 {
let count = if with_replacement { attempt_count } else { 1 };
for _ in 0..count {
order.push((0, &self.providers[0].target));
}
return order.into_iter();
}
match self.strategy {
LoadBalanceStrategy::Priority => {
for (idx, provider) in self.providers.iter().enumerate().take(attempt_count) {
order.push((idx, &provider.target));
}
}
LoadBalanceStrategy::WeightedRandom if with_replacement => {
let weights: Vec<(usize, u32)> = self
.providers
.iter()
.enumerate()
.map(|(i, p)| (i, p.weight))
.collect();
let mut rng = rand::rng();
for _ in 0..attempt_count {
let total: u32 = weights.iter().map(|(_, w)| w).sum();
let random_weight: u32 = if total > 0 {
rng.random_range(0..total)
} else {
0
};
let mut cumulative = 0;
let mut selected_idx = 0;
for &(idx, weight) in &weights {
cumulative += weight;
if random_weight < cumulative {
selected_idx = idx;
break;
}
}
order.push((selected_idx, &self.providers[selected_idx].target));
}
}
LoadBalanceStrategy::WeightedRandom => {
let mut remaining: Vec<(usize, u32)> = self
.providers
.iter()
.enumerate()
.map(|(i, p)| (i, p.weight))
.collect();
let mut rng = rand::rng();
let cap = attempt_count.min(self.providers.len());
while order.len() < cap && !remaining.is_empty() {
let total: u32 = remaining.iter().map(|(_, w)| w).sum();
let random_weight: u32 = if total > 0 {
rng.random_range(0..total)
} else {
0
};
let mut cumulative = 0;
let mut selected_pos = 0;
for (pos, (_, weight)) in remaining.iter().enumerate() {
cumulative += weight;
if random_weight < cumulative {
selected_pos = pos;
break;
}
}
let (idx, _) = remaining.remove(selected_pos);
order.push((idx, &self.providers[idx].target));
}
}
}
order.into_iter()
}
}
fn is_rate_limited(target: &Target) -> bool {
if let Some(ref limiter) = target.limiter {
limiter.check().is_err()
} else {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::target::Target;
use std::collections::HashMap;
fn create_test_target(url: &str) -> Target {
Target::builder().url(url.parse().unwrap()).build()
}
#[test]
fn test_single_provider_pool() {
let target = create_test_target("https://api.example.com");
let pool = ProviderPool::single(target.clone(), 1);
assert_eq!(pool.len(), 1);
assert!(!pool.is_empty());
let selected = pool.select();
assert!(selected.is_some());
assert_eq!(selected.unwrap().url.as_str(), "https://api.example.com/");
}
#[test]
fn test_empty_pool_returns_none() {
let pool = ProviderPool::new(vec![]);
assert!(pool.is_empty());
assert!(pool.select().is_none());
}
#[test]
fn test_weighted_selection_distribution() {
let providers = vec![
Provider {
target: create_test_target("https://api1.example.com"),
weight: 3,
},
Provider {
target: create_test_target("https://api2.example.com"),
weight: 1,
},
];
let pool = ProviderPool::new(providers);
let mut counts: HashMap<String, usize> = HashMap::new();
for _ in 0..1000 {
if let Some(target) = pool.select() {
*counts.entry(target.url.to_string()).or_insert(0) += 1;
}
}
let count1 = *counts.get("https://api1.example.com/").unwrap_or(&0);
let count2 = *counts.get("https://api2.example.com/").unwrap_or(&0);
let ratio = count1 as f64 / count2 as f64;
assert!(
ratio > 1.5 && ratio < 6.0,
"Expected ratio around 3.0, got {}",
ratio
);
}
#[test]
fn test_first_target() {
let providers = vec![
Provider {
target: create_test_target("https://api1.example.com"),
weight: 1,
},
Provider {
target: create_test_target("https://api2.example.com"),
weight: 2,
},
];
let pool = ProviderPool::new(providers);
let first = pool.first_target();
assert!(first.is_some());
assert_eq!(first.unwrap().url.as_str(), "https://api1.example.com/");
}
#[test]
fn test_providers_accessor() {
let providers = vec![
Provider {
target: create_test_target("https://api1.example.com"),
weight: 1,
},
Provider {
target: create_test_target("https://api2.example.com"),
weight: 2,
},
];
let pool = ProviderPool::new(providers);
assert_eq!(pool.providers().len(), 2);
assert_eq!(pool.providers()[0].weight, 1);
assert_eq!(pool.providers()[1].weight, 2);
}
#[test]
fn test_select_ordered_priority_strategy() {
use crate::target::LoadBalanceStrategy;
let providers = vec![
Provider {
target: create_test_target("https://primary.example.com"),
weight: 1,
},
Provider {
target: create_test_target("https://secondary.example.com"),
weight: 10,
},
Provider {
target: create_test_target("https://tertiary.example.com"),
weight: 5,
},
];
let pool = ProviderPool::with_config(
providers,
None,
None,
None,
None,
LoadBalanceStrategy::Priority,
false,
);
let order: Vec<_> = pool.select_ordered().collect();
assert_eq!(order.len(), 3);
assert_eq!(order[0].0, 0); assert_eq!(order[1].0, 1); assert_eq!(order[2].0, 2); assert_eq!(order[0].1.url.as_str(), "https://primary.example.com/");
assert_eq!(order[1].1.url.as_str(), "https://secondary.example.com/");
assert_eq!(order[2].1.url.as_str(), "https://tertiary.example.com/");
}
#[test]
fn test_select_ordered_weighted_random_includes_all() {
use crate::target::LoadBalanceStrategy;
let providers = vec![
Provider {
target: create_test_target("https://api1.example.com"),
weight: 3,
},
Provider {
target: create_test_target("https://api2.example.com"),
weight: 1,
},
];
let pool = ProviderPool::with_config(
providers,
None,
None,
None,
None,
LoadBalanceStrategy::WeightedRandom,
false,
);
let order: Vec<_> = pool.select_ordered().collect();
assert_eq!(order.len(), 2);
let urls: std::collections::HashSet<_> =
order.iter().map(|(_, t)| t.url.as_str()).collect();
assert!(urls.contains("https://api1.example.com/"));
assert!(urls.contains("https://api2.example.com/"));
}
#[test]
fn test_select_ordered_weighted_random_distribution() {
use crate::target::LoadBalanceStrategy;
let providers = vec![
Provider {
target: create_test_target("https://heavy.example.com"),
weight: 9,
},
Provider {
target: create_test_target("https://light.example.com"),
weight: 1,
},
];
let pool = ProviderPool::with_config(
providers,
None,
None,
None,
None,
LoadBalanceStrategy::WeightedRandom,
false,
);
let mut heavy_first = 0;
let iterations = 1000;
for _ in 0..iterations {
let order: Vec<_> = pool.select_ordered().collect();
if order[0].1.url.as_str() == "https://heavy.example.com/" {
heavy_first += 1;
}
}
let percentage = (heavy_first * 100) / iterations;
assert!(
(80..=98).contains(&percentage),
"Expected heavy to be first ~90% of the time, got {}% ({}/{})",
percentage,
heavy_first,
iterations
);
}
#[test]
fn test_select_ordered_empty_pool() {
let pool = ProviderPool::new(vec![]);
let order: Vec<_> = pool.select_ordered().collect();
assert!(order.is_empty());
}
#[test]
fn test_select_ordered_single_provider() {
use crate::target::LoadBalanceStrategy;
let providers = vec![Provider {
target: create_test_target("https://only.example.com"),
weight: 1,
}];
for strategy in [
LoadBalanceStrategy::Priority,
LoadBalanceStrategy::WeightedRandom,
] {
let pool = ProviderPool::with_config(
providers.clone(),
None,
None,
None,
None,
strategy,
false,
);
let order: Vec<_> = pool.select_ordered().collect();
assert_eq!(order.len(), 1);
assert_eq!(order[0].1.url.as_str(), "https://only.example.com/");
}
}
#[test]
fn test_select_ordered_with_replacement_allows_duplicates() {
use crate::target::{FallbackConfig, LoadBalanceStrategy};
let providers = vec![
Provider {
target: create_test_target("https://api1.example.com"),
weight: 9,
},
Provider {
target: create_test_target("https://api2.example.com"),
weight: 1,
},
];
let fallback = Some(FallbackConfig {
enabled: true,
with_replacement: true,
max_attempts: Some(5),
..Default::default()
});
let pool = ProviderPool::with_config(
providers,
None,
None,
None,
fallback,
LoadBalanceStrategy::WeightedRandom,
false,
);
let order: Vec<_> = pool.select_ordered().collect();
assert_eq!(order.len(), 5);
let mut found_duplicate = false;
for _ in 0..100 {
let order: Vec<_> = pool.select_ordered().collect();
let indices: Vec<usize> = order.iter().map(|(idx, _)| *idx).collect();
let unique: std::collections::HashSet<_> = indices.iter().collect();
if unique.len() < indices.len() {
found_duplicate = true;
break;
}
}
assert!(
found_duplicate,
"With replacement should allow the same provider to appear multiple times"
);
}
#[test]
fn test_select_ordered_max_attempts_controls_length() {
use crate::target::{FallbackConfig, LoadBalanceStrategy};
let providers = vec![
Provider {
target: create_test_target("https://api1.example.com"),
weight: 1,
},
Provider {
target: create_test_target("https://api2.example.com"),
weight: 1,
},
Provider {
target: create_test_target("https://api3.example.com"),
weight: 1,
},
];
let fallback = Some(FallbackConfig {
enabled: true,
max_attempts: Some(2),
..Default::default()
});
let pool = ProviderPool::with_config(
providers,
None,
None,
None,
fallback,
LoadBalanceStrategy::WeightedRandom,
false,
);
let order: Vec<_> = pool.select_ordered().collect();
assert_eq!(
order.len(),
2,
"max_attempts should cap the ordering length"
);
let indices: std::collections::HashSet<_> = order.iter().map(|(idx, _)| *idx).collect();
assert_eq!(
indices.len(),
2,
"Without replacement, all entries should be unique"
);
}
#[test]
fn test_select_ordered_max_attempts_with_priority() {
use crate::target::{FallbackConfig, LoadBalanceStrategy};
let providers = vec![
Provider {
target: create_test_target("https://primary.example.com"),
weight: 1,
},
Provider {
target: create_test_target("https://secondary.example.com"),
weight: 1,
},
Provider {
target: create_test_target("https://tertiary.example.com"),
weight: 1,
},
];
let fallback = Some(FallbackConfig {
enabled: true,
max_attempts: Some(2),
..Default::default()
});
let pool = ProviderPool::with_config(
providers,
None,
None,
None,
fallback,
LoadBalanceStrategy::Priority,
false,
);
let order: Vec<_> = pool.select_ordered().collect();
assert_eq!(order.len(), 2);
assert_eq!(order[0].1.url.as_str(), "https://primary.example.com/");
assert_eq!(order[1].1.url.as_str(), "https://secondary.example.com/");
}
#[test]
fn test_select_ordered_defaults_preserve_current_behavior() {
use crate::target::LoadBalanceStrategy;
let providers = vec![
Provider {
target: create_test_target("https://api1.example.com"),
weight: 3,
},
Provider {
target: create_test_target("https://api2.example.com"),
weight: 1,
},
];
let pool = ProviderPool::with_config(
providers,
None,
None,
None,
None,
LoadBalanceStrategy::WeightedRandom,
false,
);
let order: Vec<_> = pool.select_ordered().collect();
assert_eq!(order.len(), 2);
let urls: std::collections::HashSet<_> =
order.iter().map(|(_, t)| t.url.as_str()).collect();
assert!(urls.contains("https://api1.example.com/"));
assert!(urls.contains("https://api2.example.com/"));
}
#[test]
fn test_select_ordered_with_replacement_single_provider() {
use crate::target::{FallbackConfig, LoadBalanceStrategy};
let providers = vec![Provider {
target: create_test_target("https://only.example.com"),
weight: 1,
}];
let fallback = Some(FallbackConfig {
enabled: true,
with_replacement: true,
max_attempts: Some(3),
..Default::default()
});
let pool = ProviderPool::with_config(
providers,
None,
None,
None,
fallback,
LoadBalanceStrategy::WeightedRandom,
false,
);
let order: Vec<_> = pool.select_ordered().collect();
assert_eq!(
order.len(),
3,
"Single provider with replacement should repeat"
);
for (idx, target) in &order {
assert_eq!(*idx, 0);
assert_eq!(target.url.as_str(), "https://only.example.com/");
}
}
#[test]
fn test_select_ordered_with_replacement_respects_weights() {
use crate::target::{FallbackConfig, LoadBalanceStrategy};
let providers = vec![
Provider {
target: create_test_target("https://heavy.example.com"),
weight: 99,
},
Provider {
target: create_test_target("https://light.example.com"),
weight: 1,
},
];
let fallback = Some(FallbackConfig {
enabled: true,
with_replacement: true,
max_attempts: Some(10),
..Default::default()
});
let pool = ProviderPool::with_config(
providers,
None,
None,
None,
fallback,
LoadBalanceStrategy::WeightedRandom,
false,
);
let mut heavy_count = 0;
let iterations = 100;
for _ in 0..iterations {
let order: Vec<_> = pool.select_ordered().collect();
heavy_count += order
.iter()
.filter(|(_, t)| t.url.as_str() == "https://heavy.example.com/")
.count();
}
let total = iterations * 10;
let percentage = (heavy_count * 100) / total;
assert!(
percentage > 90,
"Heavy provider (99:1 weight) should appear >90% of the time, got {}%",
percentage
);
}
}