use core::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use core::time::Duration;
use clock_lib::{Clock, Monotonic, SystemClock};
use tokio::sync::Notify;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Outcome {
Success {
rtt: Duration,
},
Failure,
}
pub trait AdaptiveStrategy: Send + Sync {
fn adjust(&self, current: u32, in_flight: u32, outcome: Outcome) -> u32;
}
#[derive(Debug, Clone, Copy)]
pub struct Aimd {
increase: u32,
decrease: f64,
}
impl Aimd {
#[must_use]
pub fn new(increase: u32, decrease: f64) -> Self {
Self {
increase: increase.max(1),
decrease: decrease.clamp(0.0, 1.0),
}
}
}
impl Default for Aimd {
fn default() -> Self {
Self::new(1, 0.5)
}
}
impl AdaptiveStrategy for Aimd {
fn adjust(&self, current: u32, in_flight: u32, outcome: Outcome) -> u32 {
match outcome {
Outcome::Success { .. } if in_flight >= current => {
current.saturating_add(self.increase)
}
Outcome::Success { .. } => current,
Outcome::Failure => {
let cut = (f64::from(current) * self.decrease) as u32;
cut.max(1)
}
}
}
}
#[derive(Debug)]
pub struct Vegas {
alpha: u32,
beta: u32,
min_rtt_ns: AtomicU64,
}
impl Vegas {
#[must_use]
pub fn new(alpha: u32, beta: u32) -> Self {
Self {
alpha,
beta: beta.max(alpha),
min_rtt_ns: AtomicU64::new(u64::MAX),
}
}
}
impl Default for Vegas {
fn default() -> Self {
Self::new(3, 6)
}
}
impl AdaptiveStrategy for Vegas {
fn adjust(&self, current: u32, _in_flight: u32, outcome: Outcome) -> u32 {
let rtt = match outcome {
Outcome::Failure => return (current / 2).max(1),
Outcome::Success { rtt } => rtt,
};
let rtt_ns = u64::try_from(rtt.as_nanos()).unwrap_or(u64::MAX).max(1);
let min_ns = self
.min_rtt_ns
.fetch_min(rtt_ns, Ordering::AcqRel)
.min(rtt_ns);
let queue = u64::from(current).saturating_mul(rtt_ns.saturating_sub(min_ns)) / rtt_ns;
if queue < u64::from(self.alpha) {
current.saturating_add(1)
} else if queue > u64::from(self.beta) {
current.saturating_sub(1)
} else {
current
}
}
}
pub struct AdaptiveLimiter<S, C = SystemClock>
where
C: Clock,
{
strategy: S,
limit: AtomicU32,
in_flight: AtomicU32,
floor: u32,
ceiling: u32,
notify: Notify,
clock: C,
}
impl AdaptiveLimiter<core::convert::Infallible> {
#[must_use]
pub fn builder() -> AdaptiveLimiterBuilder {
AdaptiveLimiterBuilder::new()
}
}
impl<S, C> AdaptiveLimiter<S, C>
where
S: AdaptiveStrategy,
C: Clock + Clone,
{
fn new(strategy: S, floor: u32, ceiling: u32, initial: u32, clock: C) -> Self {
let floor = floor.max(1);
let ceiling = ceiling.max(floor);
Self {
strategy,
limit: AtomicU32::new(initial.clamp(floor, ceiling)),
in_flight: AtomicU32::new(0),
floor,
ceiling,
notify: Notify::new(),
clock,
}
}
#[must_use]
pub fn with_clock<C2>(self, clock: C2) -> AdaptiveLimiter<S, C2>
where
C2: Clock + Clone,
{
AdaptiveLimiter::new(
self.strategy,
self.floor,
self.ceiling,
self.limit.load(Ordering::Acquire),
clock,
)
}
#[must_use]
pub fn current_limit(&self) -> u32 {
self.limit.load(Ordering::Acquire)
}
#[must_use]
pub fn in_flight(&self) -> u32 {
self.in_flight.load(Ordering::Acquire)
}
#[must_use]
pub fn ceiling(&self) -> u32 {
self.ceiling
}
fn try_reserve(&self) -> bool {
loop {
let in_flight = self.in_flight.load(Ordering::Acquire);
if in_flight >= self.limit.load(Ordering::Acquire) {
return false;
}
if self
.in_flight
.compare_exchange_weak(
in_flight,
in_flight + 1,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
{
return true;
}
}
}
#[must_use]
pub fn try_acquire(&self) -> Option<AdaptivePermit<'_, S, C>> {
self.try_reserve().then(|| AdaptivePermit::new(self))
}
fn settle(&self, outcome: Outcome) {
let in_flight = self.in_flight.load(Ordering::Acquire);
let current = self.limit.load(Ordering::Acquire);
let proposed = self.strategy.adjust(current, in_flight, outcome);
self.limit
.store(proposed.clamp(self.floor, self.ceiling), Ordering::Release);
let _ = self.in_flight.fetch_sub(1, Ordering::AcqRel);
self.notify.notify_waiters();
}
fn rtt_since(&self, started: Monotonic) -> Duration {
self.clock.now().saturating_duration_since(started)
}
#[inline]
fn now(&self) -> Monotonic {
self.clock.now()
}
}
#[cfg(feature = "tokio")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
impl<S, C> AdaptiveLimiter<S, C>
where
S: AdaptiveStrategy,
C: Clock + Clone,
{
pub async fn acquire(&self) -> AdaptivePermit<'_, S, C> {
loop {
let notified = self.notify.notified();
tokio::pin!(notified);
let _ = notified.as_mut().enable();
if self.try_reserve() {
return AdaptivePermit::new(self);
}
notified.await;
}
}
}
#[must_use = "settle the permit with `.success()` or `.failure()`; dropping it counts as a failure"]
pub struct AdaptivePermit<'a, S, C>
where
S: AdaptiveStrategy,
C: Clock + Clone,
{
limiter: &'a AdaptiveLimiter<S, C>,
started: Monotonic,
settled: bool,
}
impl<'a, S, C> AdaptivePermit<'a, S, C>
where
S: AdaptiveStrategy,
C: Clock + Clone,
{
fn new(limiter: &'a AdaptiveLimiter<S, C>) -> Self {
Self {
started: limiter.now(),
limiter,
settled: false,
}
}
pub fn success(mut self) {
let rtt = self.limiter.rtt_since(self.started);
self.limiter.settle(Outcome::Success { rtt });
self.settled = true;
}
pub fn failure(mut self) {
self.limiter.settle(Outcome::Failure);
self.settled = true;
}
}
impl<S, C> Drop for AdaptivePermit<'_, S, C>
where
S: AdaptiveStrategy,
C: Clock + Clone,
{
fn drop(&mut self) {
if !self.settled {
self.limiter.settle(Outcome::Failure);
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct AdaptiveLimiterBuilder {
floor: u32,
ceiling: u32,
initial: Option<u32>,
}
impl Default for AdaptiveLimiterBuilder {
fn default() -> Self {
Self::new()
}
}
impl AdaptiveLimiterBuilder {
#[must_use]
pub fn new() -> Self {
Self {
floor: 1,
ceiling: 100,
initial: None,
}
}
#[must_use]
pub fn floor(mut self, floor: u32) -> Self {
self.floor = floor.max(1);
self
}
#[must_use]
pub fn ceiling(mut self, ceiling: u32) -> Self {
self.ceiling = ceiling;
self
}
#[must_use]
pub fn initial(mut self, initial: u32) -> Self {
self.initial = Some(initial);
self
}
#[must_use]
pub fn build<S>(self, strategy: S) -> AdaptiveLimiter<S, SystemClock>
where
S: AdaptiveStrategy,
{
let initial = self.initial.unwrap_or(self.floor);
AdaptiveLimiter::new(
strategy,
self.floor,
self.ceiling,
initial,
SystemClock::new(),
)
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used, clippy::expect_used)]
use super::{AdaptiveLimiter, AdaptiveStrategy, Aimd, Outcome, Vegas};
use clock_lib::ManualClock;
use core::time::Duration;
use std::sync::Arc;
fn assert_send_sync<T: Send + Sync>() {}
#[test]
fn test_adaptive_is_send_sync() {
assert_send_sync::<AdaptiveLimiter<Aimd>>();
assert_send_sync::<AdaptiveLimiter<Vegas>>();
}
#[test]
fn test_aimd_adjust_rules() {
let aimd = Aimd::new(2, 0.5);
assert_eq!(
aimd.adjust(
10,
10,
Outcome::Success {
rtt: Duration::ZERO
}
),
12
);
assert_eq!(
aimd.adjust(
10,
3,
Outcome::Success {
rtt: Duration::ZERO
}
),
10
);
assert_eq!(aimd.adjust(10, 10, Outcome::Failure), 5);
}
#[test]
fn test_degradation_drives_limit_to_floor() {
let limiter = AdaptiveLimiter::builder()
.floor(4)
.ceiling(100)
.initial(64)
.build(Aimd::new(4, 0.5));
for _ in 0..10 {
let permit = limiter.try_acquire().expect("a slot under the limit");
permit.failure();
}
assert_eq!(limiter.current_limit(), 4);
}
#[test]
fn test_recovery_drives_limit_up_bounded_by_ceiling() {
let limiter = AdaptiveLimiter::builder()
.floor(1)
.ceiling(8)
.initial(1)
.build(Aimd::new(1, 0.5));
for _ in 0..50 {
let mut held = Vec::new();
while let Some(p) = limiter.try_acquire() {
held.push(p);
}
if let Some(p) = held.pop() {
p.success();
}
for p in held {
p.success();
}
}
assert_eq!(limiter.current_limit(), 8, "grows to the ceiling");
for _ in 0..20 {
let p = limiter.try_acquire().expect("slot");
p.success();
}
assert_eq!(limiter.current_limit(), 8, "never exceeds the ceiling");
}
#[test]
fn test_never_admits_more_than_the_limit() {
let limiter = AdaptiveLimiter::builder()
.floor(3)
.ceiling(3)
.initial(3)
.build(Aimd::default());
let p1 = limiter.try_acquire().expect("1");
let p2 = limiter.try_acquire().expect("2");
let p3 = limiter.try_acquire().expect("3");
assert_eq!(limiter.in_flight(), 3);
assert!(limiter.try_acquire().is_none());
drop((p1, p2, p3));
}
#[test]
fn test_dropping_permit_counts_as_failure() {
let limiter = AdaptiveLimiter::builder()
.floor(1)
.ceiling(100)
.initial(10)
.build(Aimd::new(1, 0.5));
drop(limiter.try_acquire().expect("slot")); assert_eq!(limiter.current_limit(), 5);
assert_eq!(limiter.in_flight(), 0, "the slot is released");
}
#[test]
fn test_vegas_grows_on_low_latency_shrinks_on_high() {
let clock = Arc::new(ManualClock::new());
let limiter = AdaptiveLimiter::builder()
.floor(1)
.ceiling(100)
.initial(20)
.build(Vegas::new(3, 6))
.with_clock(clock.clone());
let p = limiter.try_acquire().expect("slot");
clock.advance(Duration::from_millis(10));
p.success();
assert_eq!(limiter.current_limit(), 21);
let p = limiter.try_acquire().expect("slot");
clock.advance(Duration::from_millis(200));
p.success();
assert!(
limiter.current_limit() < 21,
"high latency shrinks the limit"
);
}
#[cfg(feature = "tokio")]
#[tokio::test]
async fn test_async_acquire_waits_for_a_freed_slot() {
let limiter = Arc::new(
AdaptiveLimiter::builder()
.floor(1)
.ceiling(1)
.initial(1)
.build(Aimd::default()),
);
let held = limiter.try_acquire().expect("the one slot");
assert!(limiter.try_acquire().is_none());
let l = Arc::clone(&limiter);
let waiter = tokio::spawn(async move { l.acquire().await.success() });
tokio::task::yield_now().await;
held.success();
waiter.await.unwrap();
}
}