use clocksource::precise::{AtomicInstant, Duration, Instant};
use core::sync::atomic::{AtomicU64, Ordering};
use parking_lot::RwLock;
use thiserror::Error;
#[derive(Error, Debug, PartialEq, Eq)]
pub enum Error {
#[error("available tokens cannot be set higher than max tokens")]
AvailableTokensTooHigh,
#[error("max tokens cannot be less than the refill amount")]
MaxTokensTooLow,
#[error("refill amount cannot exceed the max tokens")]
RefillAmountTooHigh,
#[error("refill interval in nanoseconds exceeds maximum u64")]
RefillIntervalTooLong,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
struct Parameters {
capacity: u64,
refill_amount: u64,
refill_interval: Duration,
}
pub struct Ratelimiter {
available: AtomicU64,
dropped: AtomicU64,
parameters: RwLock<Parameters>,
refill_at: AtomicInstant,
}
impl Ratelimiter {
pub fn builder(amount: u64, interval: core::time::Duration) -> Builder {
Builder::new(amount, interval)
}
pub fn rate(&self) -> f64 {
let parameters = self.parameters.read();
parameters.refill_amount as f64 * 1_000_000_000.0
/ parameters.refill_interval.as_nanos() as f64
}
pub fn refill_interval(&self) -> core::time::Duration {
let parameters = self.parameters.read();
core::time::Duration::from_nanos(parameters.refill_interval.as_nanos())
}
pub fn set_refill_interval(&self, duration: core::time::Duration) -> Result<(), Error> {
if duration.as_nanos() > u64::MAX as u128 {
return Err(Error::RefillIntervalTooLong);
}
let mut parameters = self.parameters.write();
parameters.refill_interval = Duration::from_nanos(duration.as_nanos() as u64);
Ok(())
}
pub fn refill_amount(&self) -> u64 {
let parameters = self.parameters.read();
parameters.refill_amount
}
pub fn set_refill_amount(&self, amount: u64) -> Result<(), Error> {
let mut parameters = self.parameters.write();
if amount > parameters.capacity {
Err(Error::RefillAmountTooHigh)
} else {
parameters.refill_amount = amount;
Ok(())
}
}
pub fn max_tokens(&self) -> u64 {
let parameters = self.parameters.read();
parameters.capacity
}
pub fn set_max_tokens(&self, amount: u64) -> Result<(), Error> {
let mut parameters = self.parameters.write();
if amount < parameters.refill_amount {
Err(Error::MaxTokensTooLow)
} else {
parameters.capacity = amount;
loop {
let available = self.available();
if amount > available {
if self
.available
.compare_exchange(available, amount, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
break;
}
} else {
break;
}
}
Ok(())
}
}
pub fn available(&self) -> u64 {
self.available.load(Ordering::Relaxed)
}
pub fn next_refill(&self) -> Instant {
self.refill_at.load(Ordering::Relaxed)
}
pub fn set_available(&self, amount: u64) -> Result<(), Error> {
let parameters = self.parameters.read();
if amount > parameters.capacity {
Err(Error::AvailableTokensTooHigh)
} else {
self.available.store(amount, Ordering::Release);
Ok(())
}
}
pub fn dropped(&self) -> u64 {
self.dropped.load(Ordering::Relaxed)
}
fn refill(&self, time: Instant) -> Result<(), core::time::Duration> {
let mut intervals;
let mut parameters;
loop {
let refill_at = self.refill_at.load(Ordering::Relaxed);
if time < refill_at {
return Err(core::time::Duration::from_nanos(
(refill_at - time).as_nanos(),
));
}
parameters = self.parameters.read();
intervals = (time - refill_at).as_nanos() / parameters.refill_interval.as_nanos() + 1;
let next_refill =
refill_at + Duration::from_nanos(intervals * parameters.refill_interval.as_nanos());
if self
.refill_at
.compare_exchange(refill_at, next_refill, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
break;
}
}
let amount = intervals * parameters.refill_amount;
let available = self.available.load(Ordering::Acquire);
if available + amount >= parameters.capacity {
let to_add = parameters.capacity - available;
self.available.fetch_add(to_add, Ordering::Release);
self.dropped.fetch_add(amount - to_add, Ordering::Relaxed);
} else {
self.available.fetch_add(amount, Ordering::Release);
}
Ok(())
}
pub fn try_wait(&self) -> Result<(), core::time::Duration> {
loop {
let refill_result = self.refill(Instant::now());
loop {
let available = self.available.load(Ordering::Acquire);
if available == 0 {
match refill_result {
Ok(_) => {
break;
}
Err(e) => {
return Err(e);
}
}
}
let new = available - 1;
if self
.available
.compare_exchange(available, new, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
return Ok(());
}
}
}
}
}
pub struct Builder {
initial_available: u64,
max_tokens: u64,
refill_amount: u64,
refill_interval: core::time::Duration,
}
impl Builder {
fn new(amount: u64, interval: core::time::Duration) -> Self {
Self {
initial_available: 0,
max_tokens: 1,
refill_amount: amount,
refill_interval: interval,
}
}
pub fn max_tokens(mut self, tokens: u64) -> Self {
self.max_tokens = tokens;
self
}
pub fn initial_available(mut self, tokens: u64) -> Self {
self.initial_available = tokens;
self
}
pub fn build(self) -> Result<Ratelimiter, Error> {
if self.max_tokens < self.refill_amount {
return Err(Error::MaxTokensTooLow);
}
if self.refill_interval.as_nanos() > u64::MAX as u128 {
return Err(Error::RefillIntervalTooLong);
}
let available = AtomicU64::new(self.initial_available);
let parameters = Parameters {
capacity: self.max_tokens,
refill_amount: self.refill_amount,
refill_interval: Duration::from_nanos(self.refill_interval.as_nanos() as u64),
};
let refill_at = AtomicInstant::new(Instant::now() + self.refill_interval);
Ok(Ratelimiter {
available,
dropped: AtomicU64::new(0),
parameters: parameters.into(),
refill_at,
})
}
}
#[cfg(test)]
mod tests {
use crate::*;
use std::time::{Duration, Instant};
macro_rules! approx_eq {
($value:expr, $target:expr) => {
let value: f64 = $value;
let target: f64 = $target;
assert!(value >= target * 0.999, "{value} >= {}", target * 0.999);
assert!(value <= target * 1.001, "{value} <= {}", target * 1.001);
};
}
#[test]
pub fn rate() {
let rl = Ratelimiter::builder(4, Duration::from_nanos(333))
.max_tokens(4)
.build()
.unwrap();
approx_eq!(rl.rate(), 12012012.0);
}
#[test]
pub fn wait() {
let rl = Ratelimiter::builder(1, Duration::from_micros(10))
.build()
.unwrap();
let mut count = 0;
let now = Instant::now();
let end = now + Duration::from_millis(10);
while Instant::now() < end {
if rl.try_wait().is_ok() {
count += 1;
}
}
assert!(count >= 600);
assert!(count <= 1400);
}
#[test]
pub fn idle() {
let rl = Ratelimiter::builder(1, Duration::from_millis(1))
.initial_available(1)
.build()
.unwrap();
std::thread::sleep(Duration::from_millis(10));
assert!(rl.next_refill() < clocksource::precise::Instant::now());
assert!(rl.try_wait().is_ok());
assert!(rl.try_wait().is_err());
assert!(rl.dropped() >= 8);
assert!(rl.next_refill() >= clocksource::precise::Instant::now());
std::thread::sleep(Duration::from_millis(5));
assert!(rl.next_refill() < clocksource::precise::Instant::now());
}
#[test]
pub fn capacity() {
let rl = Ratelimiter::builder(1, Duration::from_millis(10))
.max_tokens(10)
.initial_available(0)
.build()
.unwrap();
std::thread::sleep(Duration::from_millis(100));
assert!(rl.try_wait().is_ok());
assert!(rl.try_wait().is_ok());
assert!(rl.try_wait().is_ok());
assert!(rl.try_wait().is_ok());
assert!(rl.try_wait().is_ok());
assert!(rl.try_wait().is_ok());
assert!(rl.try_wait().is_ok());
assert!(rl.try_wait().is_ok());
assert!(rl.try_wait().is_ok());
assert!(rl.try_wait().is_ok());
assert!(rl.try_wait().is_err());
}
}