use core::cmp::min;
use core::sync::atomic::{AtomicU64, Ordering};
use crate::math::{elapsed_for_tokens, refill_tokens};
use crate::Clock;
#[cfg(feature = "std")]
use crate::StdClock;
pub struct RateLimiter<C: Clock> {
capacity: u64,
refill_per_sec: u64,
tokens: AtomicU64,
last_refill_ns: AtomicU64,
clock: C,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct Snapshot {
pub tokens: u64,
pub last_refill_ns: u64,
pub capacity: u64,
pub refill_per_sec: u64,
}
impl<C: Clock> RateLimiter<C> {
pub fn with_clock(capacity: u64, refill_per_sec: u64, clock: C) -> Self {
let now = clock.now_ns();
Self {
capacity,
refill_per_sec,
tokens: AtomicU64::new(capacity),
last_refill_ns: AtomicU64::new(now),
clock,
}
}
#[inline]
pub fn allow(&self) -> bool {
self.allow_n(1)
}
pub fn allow_n(&self, n: u64) -> bool {
if n == 0 {
return true;
}
if n > self.capacity {
return false;
}
if self.refill_per_sec == 0 {
return self.consume_tokens(n);
}
self.try_refill(self.clock.now_ns());
self.consume_tokens(n)
}
pub fn remaining(&self) -> u64 {
self.try_refill(self.clock.now_ns());
self.bounded_tokens()
}
#[inline]
pub const fn capacity(&self) -> u64 {
self.capacity
}
#[inline]
pub const fn refill_per_sec(&self) -> u64 {
self.refill_per_sec
}
pub fn snapshot(&self) -> Snapshot {
self.try_refill(self.clock.now_ns());
Snapshot {
tokens: self.bounded_tokens(),
last_refill_ns: self.last_refill_ns.load(Ordering::Relaxed),
capacity: self.capacity,
refill_per_sec: self.refill_per_sec,
}
}
#[inline]
fn consume_tokens(&self, n: u64) -> bool {
let mut current = self.tokens.load(Ordering::Relaxed);
loop {
if current < n {
return false;
}
let next = current - n;
match self.tokens.compare_exchange_weak(
current,
next,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => return true,
Err(observed) => {
current = observed;
core::hint::spin_loop();
}
}
}
}
#[inline]
fn add_tokens(&self, add: u64) {
let mut current = self.tokens.load(Ordering::Relaxed);
loop {
let next = min(current.saturating_add(add), self.capacity);
match self.tokens.compare_exchange_weak(
current,
next,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => return,
Err(observed) => {
current = observed;
core::hint::spin_loop();
}
}
}
}
fn bounded_tokens(&self) -> u64 {
min(self.tokens.load(Ordering::Relaxed), self.capacity)
}
fn try_refill(&self, now_ns: u64) {
if self.refill_per_sec == 0 || self.capacity == 0 {
return;
}
loop {
let last = self.last_refill_ns.load(Ordering::Relaxed);
if now_ns <= last {
return;
}
let elapsed = now_ns - last;
let add = refill_tokens(elapsed, self.refill_per_sec);
if add == 0 {
return;
}
let delta_ns = elapsed_for_tokens(add, self.refill_per_sec);
let new_last = min(last.saturating_add(delta_ns), now_ns);
match self.last_refill_ns.compare_exchange_weak(
last,
new_last,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => {
self.add_tokens(add);
return;
}
Err(_) => core::hint::spin_loop(),
}
}
}
}
#[cfg(feature = "std")]
impl RateLimiter<StdClock> {
#[inline]
pub fn new(capacity: u64, refill_per_sec: u64) -> Self {
Self::with_clock(capacity, refill_per_sec, StdClock)
}
}