use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::Semaphore;
pub struct AIMDController {
current_limit: AtomicUsize,
min_limit: usize,
max_limit: usize,
success_count: AtomicUsize,
increase_threshold: usize,
decrease_shift: u32,
}
impl AIMDController {
pub fn new(
initial_limit: usize,
min_limit: usize,
max_limit: usize,
increase_threshold: usize,
_decrease_factor: f64,
) -> Self {
let min_limit = min_limit.max(1);
let max_limit = max_limit.max(min_limit);
let initial = initial_limit.clamp(min_limit, max_limit);
let threshold = if increase_threshold == 0 {
10
} else {
increase_threshold
};
Self {
current_limit: AtomicUsize::new(initial),
min_limit,
max_limit,
success_count: AtomicUsize::new(0),
increase_threshold: threshold,
decrease_shift: 1, }
}
pub fn with_defaults(initial: usize, max_limit: usize) -> Self {
Self::new(initial, 1, max_limit, 10, 0.5)
}
pub fn record_success(&self) {
let prev = self.success_count.fetch_add(1, Ordering::Relaxed);
if prev.saturating_add(1) >= self.increase_threshold {
self.success_count.store(0, Ordering::Relaxed);
let _ = self
.current_limit
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |cur| {
if cur < self.max_limit {
Some(cur.saturating_add(1))
} else {
None
}
});
}
}
pub fn record_failure(&self) {
self.success_count.store(0, Ordering::Relaxed);
let _ = self
.current_limit
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |cur| {
let halved = cur >> self.decrease_shift;
let next = halved.max(self.min_limit);
if next != cur {
Some(next)
} else {
None
}
});
}
#[inline]
pub fn current_limit(&self) -> usize {
self.current_limit.load(Ordering::Relaxed)
}
#[inline]
pub fn min_limit(&self) -> usize {
self.min_limit
}
#[inline]
pub fn max_limit(&self) -> usize {
self.max_limit
}
}
#[derive(Debug, Clone)]
pub struct AdaptiveSemaphore {
sem: Arc<Semaphore>,
target: Arc<AtomicUsize>,
}
impl AdaptiveSemaphore {
pub fn new(initial: usize) -> Self {
let initial = Self::clamp_permits(initial);
Self {
sem: Arc::new(Semaphore::new(initial)),
target: Arc::new(AtomicUsize::new(initial)),
}
}
pub fn semaphore(&self) -> Arc<Semaphore> {
Arc::clone(&self.sem)
}
#[inline]
pub fn target(&self) -> usize {
self.target.load(Ordering::Relaxed)
}
#[inline]
pub fn available(&self) -> usize {
self.sem.available_permits()
}
pub fn set_target(&self, new_target: usize) {
let clamped = Self::clamp_permits(new_target);
let prev = self.target.swap(clamped, Ordering::Relaxed);
if clamped > prev {
self.sem.add_permits(clamped - prev);
} else if clamped < prev {
self.sem.forget_permits(prev - clamped);
}
}
pub fn sync_from(&self, controller: &AIMDController) {
self.set_target(controller.current_limit());
}
#[inline]
fn clamp_permits(n: usize) -> usize {
n.clamp(1, Semaphore::MAX_PERMITS)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn initial_clamp() {
let c = AIMDController::new(100, 2, 10, 10, 0.5);
assert_eq!(c.current_limit(), 10);
let c = AIMDController::new(0, 3, 10, 10, 0.5);
assert_eq!(c.current_limit(), 3);
}
#[test]
fn additive_increase_after_threshold() {
let c = AIMDController::new(5, 1, 100, 5, 0.5);
assert_eq!(c.current_limit(), 5);
for _ in 0..5 {
c.record_success();
}
assert_eq!(c.current_limit(), 6);
for _ in 0..5 {
c.record_success();
}
assert_eq!(c.current_limit(), 7);
}
#[test]
fn increase_capped_at_max() {
let c = AIMDController::new(9, 1, 10, 1, 0.5);
c.record_success(); assert_eq!(c.current_limit(), 10);
c.record_success(); assert_eq!(c.current_limit(), 10);
}
#[test]
fn multiplicative_decrease() {
let c = AIMDController::new(20, 1, 100, 10, 0.5);
c.record_failure();
assert_eq!(c.current_limit(), 10);
c.record_failure();
assert_eq!(c.current_limit(), 5);
}
#[test]
fn decrease_clamped_to_min() {
let c = AIMDController::new(4, 3, 100, 10, 0.5);
c.record_failure(); assert_eq!(c.current_limit(), 3);
c.record_failure(); assert_eq!(c.current_limit(), 3);
}
#[test]
fn failure_resets_success_counter() {
let c = AIMDController::new(10, 1, 100, 5, 0.5);
for _ in 0..4 {
c.record_success();
}
c.record_failure(); assert_eq!(c.current_limit(), 5);
c.record_success();
assert_eq!(c.current_limit(), 5);
}
#[test]
fn with_defaults_constructor() {
let c = AIMDController::with_defaults(8, 50);
assert_eq!(c.current_limit(), 8);
assert_eq!(c.min_limit(), 1);
assert_eq!(c.max_limit(), 50);
}
#[test]
fn min_greater_than_max_corrected() {
let c = AIMDController::new(5, 20, 10, 10, 0.5);
assert_eq!(c.max_limit(), 20);
assert_eq!(c.min_limit(), 20);
assert_eq!(c.current_limit(), 20);
}
#[test]
fn adaptive_initial_target_and_available() {
let s = AdaptiveSemaphore::new(4);
assert_eq!(s.target(), 4);
assert_eq!(s.available(), 4);
}
#[test]
fn adaptive_initial_zero_is_clamped_to_one() {
let s = AdaptiveSemaphore::new(0);
assert_eq!(s.target(), 1);
assert_eq!(s.available(), 1);
}
#[test]
fn adaptive_set_target_expand() {
let s = AdaptiveSemaphore::new(2);
s.set_target(5);
assert_eq!(s.target(), 5);
assert_eq!(s.available(), 5);
}
#[test]
fn adaptive_set_target_shrink_without_inflight() {
let s = AdaptiveSemaphore::new(5);
s.set_target(2);
assert_eq!(s.target(), 2);
assert_eq!(s.available(), 2);
}
#[tokio::test]
async fn adaptive_shrink_with_inflight_does_not_cancel_existing_permits() {
let s = AdaptiveSemaphore::new(3);
let sem = s.semaphore();
let permit = sem.clone().acquire_owned().await.unwrap();
assert_eq!(s.available(), 2);
s.set_target(1);
assert_eq!(s.target(), 1);
assert_eq!(s.available(), 0);
drop(permit);
assert_eq!(s.available(), 1);
}
#[test]
fn adaptive_set_target_same_is_noop() {
let s = AdaptiveSemaphore::new(3);
s.set_target(3);
assert_eq!(s.target(), 3);
assert_eq!(s.available(), 3);
}
#[test]
fn adaptive_set_target_above_max_clamps() {
let s = AdaptiveSemaphore::new(4);
s.set_target(usize::MAX);
assert_eq!(s.target(), Semaphore::MAX_PERMITS);
}
#[test]
fn adaptive_clones_share_state() {
let a = AdaptiveSemaphore::new(2);
let b = a.clone();
b.set_target(7);
assert_eq!(a.target(), 7);
assert_eq!(b.target(), 7);
assert!(Arc::ptr_eq(&a.semaphore(), &b.semaphore()));
}
#[test]
fn adaptive_sync_from_aimd_controller() {
let c = AIMDController::new(3, 1, 100, 5, 0.5);
let s = AdaptiveSemaphore::new(1);
s.sync_from(&c);
assert_eq!(s.target(), 3);
for _ in 0..5 {
c.record_success();
}
s.sync_from(&c);
assert_eq!(s.target(), 4);
}
#[tokio::test]
async fn adaptive_concurrent_set_target_calls_converge() {
use std::sync::Arc as StdArc;
let s = StdArc::new(AdaptiveSemaphore::new(10));
let mut handles = Vec::new();
for i in 1..=8 {
let s = StdArc::clone(&s);
handles.push(tokio::spawn(async move { s.set_target(i * 2) }));
}
for h in handles {
h.await.unwrap();
}
let final_target = s.target();
assert!((2..=16).contains(&final_target));
assert_eq!(s.available(), final_target);
}
}