use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::time::Duration;
use tower_resilience_core::aimd::{AimdConfig, AimdController};
pub trait ConcurrencyAlgorithm: Send + Sync {
fn record_success(&self, latency: Duration);
fn record_failure(&self);
fn record_dropped(&self);
fn limit(&self) -> usize;
fn min_limit(&self) -> usize;
fn max_limit(&self) -> usize;
}
pub struct Aimd {
controller: AimdController,
latency_threshold: Duration,
}
impl Aimd {
pub fn new(config: AimdConfig, latency_threshold: Duration) -> Self {
Self {
controller: AimdController::new(config),
latency_threshold,
}
}
pub fn builder() -> AimdBuilder {
AimdBuilder::default()
}
}
impl ConcurrencyAlgorithm for Aimd {
fn record_success(&self, latency: Duration) {
if latency > self.latency_threshold {
self.controller.record_failure();
} else {
self.controller.record_success();
}
}
fn record_failure(&self) {
self.controller.record_failure();
}
fn record_dropped(&self) {
}
fn limit(&self) -> usize {
self.controller.limit()
}
fn min_limit(&self) -> usize {
self.controller.min_limit()
}
fn max_limit(&self) -> usize {
self.controller.max_limit()
}
}
#[derive(Debug, Clone)]
pub struct AimdBuilder {
initial_limit: usize,
min_limit: usize,
max_limit: usize,
increase_by: usize,
decrease_factor: f64,
latency_threshold: Duration,
}
impl Default for AimdBuilder {
fn default() -> Self {
Self {
initial_limit: 10,
min_limit: 1,
max_limit: 100,
increase_by: 1,
decrease_factor: 0.5,
latency_threshold: Duration::from_millis(100),
}
}
}
impl AimdBuilder {
pub fn initial_limit(mut self, limit: usize) -> Self {
self.initial_limit = limit;
self
}
pub fn min_limit(mut self, limit: usize) -> Self {
self.min_limit = limit;
self
}
pub fn max_limit(mut self, limit: usize) -> Self {
self.max_limit = limit;
self
}
pub fn increase_by(mut self, amount: usize) -> Self {
self.increase_by = amount;
self
}
pub fn decrease_factor(mut self, factor: f64) -> Self {
self.decrease_factor = factor;
self
}
pub fn latency_threshold(mut self, threshold: Duration) -> Self {
self.latency_threshold = threshold;
self
}
pub fn build(self) -> Aimd {
let config = AimdConfig::new()
.with_initial_limit(self.initial_limit)
.with_min_limit(self.min_limit)
.with_max_limit(self.max_limit)
.with_increase_by(self.increase_by)
.with_decrease_factor(self.decrease_factor);
Aimd::new(config, self.latency_threshold)
}
}
pub struct Vegas {
limit: AtomicUsize,
min_limit: usize,
max_limit: usize,
min_rtt_nanos: AtomicU64,
alpha: usize,
beta: usize,
smoothing: f64,
smoothed_rtt_nanos: AtomicU64,
sample_count: AtomicUsize,
min_samples: usize,
}
impl Vegas {
pub fn new(
initial_limit: usize,
min_limit: usize,
max_limit: usize,
alpha: usize,
beta: usize,
) -> Self {
Self {
limit: AtomicUsize::new(initial_limit.clamp(min_limit, max_limit)),
min_limit,
max_limit,
min_rtt_nanos: AtomicU64::new(u64::MAX),
alpha,
beta,
smoothing: 0.5,
smoothed_rtt_nanos: AtomicU64::new(0),
sample_count: AtomicUsize::new(0),
min_samples: 10,
}
}
pub fn builder() -> VegasBuilder {
VegasBuilder::default()
}
fn update_rtt(&self, rtt: Duration) {
let rtt_nanos = rtt.as_nanos() as u64;
let mut current_min = self.min_rtt_nanos.load(Ordering::Relaxed);
while rtt_nanos < current_min {
match self.min_rtt_nanos.compare_exchange_weak(
current_min,
rtt_nanos,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(c) => current_min = c,
}
}
let current_smoothed = self.smoothed_rtt_nanos.load(Ordering::Relaxed);
let new_smoothed = if current_smoothed == 0 {
rtt_nanos
} else {
(self.smoothing * rtt_nanos as f64 + (1.0 - self.smoothing) * current_smoothed as f64)
as u64
};
self.smoothed_rtt_nanos
.store(new_smoothed, Ordering::Relaxed);
self.sample_count.fetch_add(1, Ordering::Relaxed);
}
fn adjust_limit(&self) {
if self.sample_count.load(Ordering::Relaxed) < self.min_samples {
return;
}
let min_rtt = self.min_rtt_nanos.load(Ordering::Relaxed);
let smoothed_rtt = self.smoothed_rtt_nanos.load(Ordering::Relaxed);
if min_rtt == u64::MAX || min_rtt == 0 || smoothed_rtt == 0 {
return;
}
let current_limit = self.limit.load(Ordering::Relaxed);
let queue_estimate = if smoothed_rtt > min_rtt {
((smoothed_rtt - min_rtt) as f64 / min_rtt as f64 * current_limit as f64) as usize
} else {
0
};
let new_limit = if queue_estimate < self.alpha {
(current_limit + 1).min(self.max_limit)
} else if queue_estimate > self.beta {
(current_limit.saturating_sub(1)).max(self.min_limit)
} else {
current_limit
};
self.limit.store(new_limit, Ordering::Relaxed);
}
}
impl ConcurrencyAlgorithm for Vegas {
fn record_success(&self, latency: Duration) {
self.update_rtt(latency);
self.adjust_limit();
}
fn record_failure(&self) {
let current = self.limit.load(Ordering::Relaxed);
let new_limit = (current / 2).max(self.min_limit);
self.limit.store(new_limit, Ordering::Relaxed);
}
fn record_dropped(&self) {
}
fn limit(&self) -> usize {
self.limit.load(Ordering::Relaxed)
}
fn min_limit(&self) -> usize {
self.min_limit
}
fn max_limit(&self) -> usize {
self.max_limit
}
}
#[derive(Debug, Clone)]
pub struct VegasBuilder {
initial_limit: usize,
min_limit: usize,
max_limit: usize,
alpha: usize,
beta: usize,
}
impl Default for VegasBuilder {
fn default() -> Self {
Self {
initial_limit: 10,
min_limit: 1,
max_limit: 100,
alpha: 3,
beta: 6,
}
}
}
impl VegasBuilder {
pub fn initial_limit(mut self, limit: usize) -> Self {
self.initial_limit = limit;
self
}
pub fn min_limit(mut self, limit: usize) -> Self {
self.min_limit = limit;
self
}
pub fn max_limit(mut self, limit: usize) -> Self {
self.max_limit = limit;
self
}
pub fn alpha(mut self, alpha: usize) -> Self {
self.alpha = alpha;
self
}
pub fn beta(mut self, beta: usize) -> Self {
self.beta = beta;
self
}
pub fn build(self) -> Vegas {
Vegas::new(
self.initial_limit,
self.min_limit,
self.max_limit,
self.alpha,
self.beta,
)
}
}
pub enum Algorithm {
Aimd(Aimd),
Vegas(Vegas),
}
impl ConcurrencyAlgorithm for Algorithm {
fn record_success(&self, latency: Duration) {
match self {
Algorithm::Aimd(a) => a.record_success(latency),
Algorithm::Vegas(v) => v.record_success(latency),
}
}
fn record_failure(&self) {
match self {
Algorithm::Aimd(a) => a.record_failure(),
Algorithm::Vegas(v) => v.record_failure(),
}
}
fn record_dropped(&self) {
match self {
Algorithm::Aimd(a) => a.record_dropped(),
Algorithm::Vegas(v) => v.record_dropped(),
}
}
fn limit(&self) -> usize {
match self {
Algorithm::Aimd(a) => a.limit(),
Algorithm::Vegas(v) => v.limit(),
}
}
fn min_limit(&self) -> usize {
match self {
Algorithm::Aimd(a) => a.min_limit(),
Algorithm::Vegas(v) => v.min_limit(),
}
}
fn max_limit(&self) -> usize {
match self {
Algorithm::Aimd(a) => a.max_limit(),
Algorithm::Vegas(v) => v.max_limit(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_aimd_builder() {
let aimd = Aimd::builder()
.initial_limit(20)
.min_limit(5)
.max_limit(200)
.increase_by(2)
.decrease_factor(0.75)
.latency_threshold(Duration::from_millis(50))
.build();
assert_eq!(aimd.limit(), 20);
assert_eq!(aimd.min_limit(), 5);
assert_eq!(aimd.max_limit(), 200);
}
#[test]
fn test_aimd_success_increases() {
let aimd = Aimd::builder()
.initial_limit(10)
.increase_by(1)
.latency_threshold(Duration::from_millis(100))
.build();
aimd.record_success(Duration::from_millis(50));
assert_eq!(aimd.limit(), 11);
}
#[test]
fn test_aimd_high_latency_decreases() {
let aimd = Aimd::builder()
.initial_limit(10)
.decrease_factor(0.5)
.latency_threshold(Duration::from_millis(100))
.build();
aimd.record_success(Duration::from_millis(150));
assert_eq!(aimd.limit(), 5);
}
#[test]
fn test_aimd_failure_decreases() {
let aimd = Aimd::builder()
.initial_limit(10)
.decrease_factor(0.5)
.build();
aimd.record_failure();
assert_eq!(aimd.limit(), 5);
}
#[test]
fn test_vegas_builder() {
let vegas = Vegas::builder()
.initial_limit(20)
.min_limit(5)
.max_limit(200)
.alpha(2)
.beta(8)
.build();
assert_eq!(vegas.limit(), 20);
assert_eq!(vegas.min_limit(), 5);
assert_eq!(vegas.max_limit(), 200);
}
#[test]
fn test_vegas_failure_decreases() {
let vegas = Vegas::builder().initial_limit(20).min_limit(1).build();
vegas.record_failure();
assert_eq!(vegas.limit(), 10);
}
#[test]
fn test_vegas_min_rtt_tracking() {
let vegas = Vegas::builder().initial_limit(10).build();
vegas.record_success(Duration::from_millis(100));
vegas.record_success(Duration::from_millis(50));
vegas.record_success(Duration::from_millis(75));
let min_rtt = vegas.min_rtt_nanos.load(Ordering::Relaxed);
assert_eq!(min_rtt, Duration::from_millis(50).as_nanos() as u64);
}
#[test]
fn test_algorithm_enum() {
let aimd = Algorithm::Aimd(Aimd::builder().initial_limit(10).build());
assert_eq!(aimd.limit(), 10);
let vegas = Algorithm::Vegas(Vegas::builder().initial_limit(20).build());
assert_eq!(vegas.limit(), 20);
}
}