use std::sync::atomic::{AtomicU64, Ordering};
#[cfg(test)]
use std::time::Duration;
use std::time::Instant;
pub struct RateLimiter {
capacity: u64,
refill_per_sec: u64,
millitokens: AtomicU64,
base: Instant,
last_refill_micros: AtomicU64,
}
impl RateLimiter {
pub fn new(capacity: u64, refill_per_sec: u64) -> Self {
Self {
capacity,
refill_per_sec,
millitokens: AtomicU64::new(capacity.saturating_mul(1000)),
base: Instant::now(),
last_refill_micros: AtomicU64::new(0),
}
}
pub fn enabled(&self) -> bool {
self.capacity > 0
}
#[allow(clippy::result_unit_err)]
pub fn try_acquire(&self) -> Result<(), ()> {
if !self.enabled() {
return Ok(());
}
self.refill();
loop {
let cur = self.millitokens.load(Ordering::Acquire);
if cur < 1000 {
return Err(());
}
let next = cur - 1000;
if self
.millitokens
.compare_exchange_weak(cur, next, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
return Ok(());
}
}
}
fn refill(&self) {
let now_micros = self.base.elapsed().as_micros() as u64;
let last = self.last_refill_micros.load(Ordering::Acquire);
if now_micros <= last {
return;
}
let delta_micros = now_micros - last;
let added_milli = self
.refill_per_sec
.saturating_mul(delta_micros)
.saturating_div(1_000);
if added_milli == 0 {
return;
}
if self
.last_refill_micros
.compare_exchange(last, now_micros, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
return; }
let cap_milli = self.capacity.saturating_mul(1000);
let mut cur = self.millitokens.load(Ordering::Acquire);
loop {
let next = cur.saturating_add(added_milli).min(cap_milli);
match self.millitokens.compare_exchange_weak(
cur,
next,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return,
Err(observed) => cur = observed,
}
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct RateLimitConfig {
pub msgs_per_sec: u64,
pub burst: u64,
pub strike_threshold: u32,
}
impl RateLimitConfig {
pub fn from_env() -> Self {
let msgs_per_sec = epics_base_rs::runtime::env::get("EPICS_CAS_RATE_LIMIT_MSGS_PER_SEC")
.and_then(|s| s.parse().ok())
.unwrap_or(0);
let burst = epics_base_rs::runtime::env::get("EPICS_CAS_RATE_LIMIT_BURST")
.and_then(|s| s.parse().ok())
.unwrap_or(if msgs_per_sec > 0 {
msgs_per_sec * 4
} else {
0
});
let strike_threshold = epics_base_rs::runtime::env::get("EPICS_CAS_RATE_LIMIT_STRIKES")
.and_then(|s| s.parse().ok())
.unwrap_or(100);
Self {
msgs_per_sec,
burst,
strike_threshold,
}
}
pub fn build(&self) -> Option<RateLimiter> {
if self.msgs_per_sec == 0 || self.burst == 0 {
return None;
}
Some(RateLimiter::new(self.burst, self.msgs_per_sec))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn disabled_always_ok() {
let rl = RateLimiter::new(0, 0);
for _ in 0..1000 {
assert!(rl.try_acquire().is_ok());
}
}
#[test]
fn empty_bucket_rejects() {
let rl = RateLimiter::new(2, 1);
assert!(rl.try_acquire().is_ok());
assert!(rl.try_acquire().is_ok());
assert!(rl.try_acquire().is_err());
}
#[test]
fn refills_over_time() {
let rl = RateLimiter::new(10, 1000); for _ in 0..10 {
rl.try_acquire().unwrap();
}
assert!(rl.try_acquire().is_err());
std::thread::sleep(Duration::from_millis(15));
for _ in 0..5 {
assert!(rl.try_acquire().is_ok());
}
}
#[test]
fn config_from_env_defaults_disabled() {
let cfg = RateLimitConfig::default();
assert!(cfg.build().is_none());
}
}