#![forbid(unsafe_code)]
use core::time::Duration;
use oorandom::Rand32;
use std::time::Instant;
trait SaturatingAddAssign<T> {
fn saturating_add_assign(&mut self, rhs: T);
}
impl SaturatingAddAssign<u32> for u32 {
fn saturating_add_assign(&mut self, rhs: u32) {
*self = self.saturating_add(rhs);
}
}
fn decide(recent_cost: u32, max_cost: u32, mut rand_float: impl FnMut() -> f32) -> bool {
let load = if max_cost == 0 || recent_cost >= max_cost {
return false;
} else {
f64::from(recent_cost) / f64::from(max_cost)
};
let linear_reject_prob = (load - 0.75) * 4.0;
if linear_reject_prob <= 0.0 {
return true;
}
let reject_prob = linear_reject_prob.powi(2);
reject_prob < rand_float().into()
}
#[cfg(test)]
#[test]
#[allow(clippy::unreadable_literal)]
fn test_decide() {
assert!(!decide(0, 0, || unreachable!()));
assert!(decide(0, 100, || unreachable!()));
assert!(decide(50, 100, || unreachable!()));
assert!(decide(75, 100, || unreachable!()));
assert!(decide(76, 100, || 0.999999));
assert!(!decide(76, 100, || 0.0));
assert!(!decide(85, 100, || 0.15));
assert!(decide(85, 100, || 0.17));
assert!(!decide(90, 100, || 0.35));
assert!(decide(90, 100, || 0.37));
assert!(!decide(95, 100, || 0.63));
assert!(decide(95, 100, || 0.65));
assert!(!decide(99, 100, || 0.92));
assert!(decide(99, 100, || 0.93));
assert!(!decide(100, 100, || unreachable!()));
assert!(!decide(101, 100, || unreachable!()));
}
#[derive(Clone, Debug)]
pub struct ProbRateLimiter {
tick_duration: Duration,
max_cost: u32,
cost: u32,
last: Instant,
prng: Rand32,
}
impl ProbRateLimiter {
pub fn new_custom(
tick_duration: Duration,
max_cost_per_tick: u32,
now: Instant,
prng: Rand32,
) -> Result<Self, String> {
if tick_duration.as_micros() == 0 {
return Err(format!("tick_duration too small: {:?}", tick_duration));
}
Ok(Self {
tick_duration,
max_cost: max_cost_per_tick * 2,
cost: 0_u32,
last: now,
prng,
})
}
#[allow(clippy::missing_panics_doc)]
#[must_use]
pub fn new(max_cost_per_sec: u32) -> Self {
Self::new_custom(
Duration::from_secs(1),
max_cost_per_sec,
Instant::now(),
Rand32::new(0),
)
.unwrap()
}
pub fn attempt(&mut self, now: Instant) -> bool {
if self.max_cost == 0 {
return false;
}
let elapsed = now.saturating_duration_since(self.last);
#[allow(clippy::cast_possible_truncation)]
let elapsed_ticks = (elapsed.as_micros() / self.tick_duration.as_micros()) as u32;
self.last += self.tick_duration * elapsed_ticks;
self.cost = self.cost.wrapping_shr(elapsed_ticks);
decide(self.cost, self.max_cost, || self.prng.rand_float())
}
pub fn record(&mut self, cost: u32) {
self.cost.saturating_add_assign(cost);
}
pub fn check(&mut self, cost: u32, now: Instant) -> bool {
if self.attempt(now) {
self.record(cost);
true
} else {
false
}
}
}