use super::{
config::{MemoryOrdering, RateLimiterConfig, MAX_REFILL_PERIODS},
metrics::RateLimiterMetrics,
utils::{cpu_relax, current_time_ms, CacheAligned},
};
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use tracing::debug;
const MAX_CAS_RETRIES: usize = 16;
const CAS_BACKOFF_THRESHOLD: usize = 4;
const MAX_REPEAT_COUNT: usize = 3;
const LAST_ACCESS_UPDATE_INTERVAL_MS: u64 = 100;
#[cfg(target_pointer_width = "64")]
type TokenCounter = AtomicU64;
#[cfg(not(target_pointer_width = "64"))]
type TokenCounter = AtomicU32;
#[cfg(not(target_pointer_width = "64"))]
trait TokenOps {
fn new(val: u64) -> Self;
fn load(&self, ordering: Ordering) -> u64;
fn store(&self, val: u64, ordering: Ordering);
fn compare_exchange(
&self,
current: u64,
new: u64,
success: Ordering,
failure: Ordering,
) -> Result<u64, u64>;
fn compare_exchange_weak(
&self,
current: u64,
new: u64,
success: Ordering,
failure: Ordering,
) -> Result<u64, u64>;
}
#[cfg(not(target_pointer_width = "64"))]
impl TokenOps for AtomicU32 {
#[inline(always)]
fn new(val: u64) -> Self {
debug_assert!(val <= u32::MAX as u64, "Token value exceeds u32::MAX");
AtomicU32::new(val.min(u32::MAX as u64) as u32)
}
#[inline(always)]
fn load(&self, ordering: Ordering) -> u64 {
self.load(ordering) as u64
}
#[inline(always)]
fn store(&self, val: u64, ordering: Ordering) {
debug_assert!(val <= u32::MAX as u64, "Token value exceeds u32::MAX");
self.store(val.min(u32::MAX as u64) as u32, ordering)
}
#[inline(always)]
fn compare_exchange(
&self,
current: u64,
new: u64,
success: Ordering,
failure: Ordering,
) -> Result<u64, u64> {
debug_assert!(new <= u32::MAX as u64, "Token value exceeds u32::MAX");
self.compare_exchange(
current as u32,
new.min(u32::MAX as u64) as u32,
success,
failure,
)
.map(|v| v as u64)
.map_err(|v| v as u64)
}
#[inline(always)]
fn compare_exchange_weak(
&self,
current: u64,
new: u64,
success: Ordering,
failure: Ordering,
) -> Result<u64, u64> {
debug_assert!(new <= u32::MAX as u64, "Token value exceeds u32::MAX");
self.compare_exchange_weak(
current as u32,
new.min(u32::MAX as u64) as u32,
success,
failure,
)
.map(|v| v as u64)
.map_err(|v| v as u64)
}
}
pub struct RateLimiter {
tokens: CacheAligned<TokenCounter>,
last_refill_ms: CacheAligned<AtomicU64>,
last_access_ms: CacheAligned<AtomicU64>,
max_tokens: u64,
refill_rate: u32,
refill_interval_ms: u64,
ordering: MemoryOrdering,
ordering_load: Ordering,
ordering_rmw: Ordering,
ordering_store: Ordering,
consecutive_rejections: AtomicU32,
max_wait_time_ns: AtomicU64,
total_acquired: AtomicU64,
total_rejected: AtomicU64,
total_refills: AtomicU64,
}
impl RateLimiter {
#[inline]
pub fn new(max_tokens: u64, refill_rate: u32) -> Self {
Self::with_config(RateLimiterConfig {
max_tokens,
refill_rate,
..Default::default()
})
}
pub fn with_config(config: RateLimiterConfig) -> Self {
config
.validate()
.expect("Invalid rate limiter configuration");
let now_ms = current_time_ms();
Self {
tokens: CacheAligned::new(TokenCounter::new(config.max_tokens)),
last_refill_ms: CacheAligned::new(AtomicU64::new(now_ms)),
last_access_ms: CacheAligned::new(AtomicU64::new(now_ms)),
consecutive_rejections: AtomicU32::new(0),
max_wait_time_ns: AtomicU64::new(0),
max_tokens: config.max_tokens,
refill_rate: config.refill_rate,
refill_interval_ms: config.refill_interval_ms,
ordering_load: config.ordering.load(),
ordering_rmw: config.ordering.rmw(),
ordering_store: config.ordering.store(),
ordering: config.ordering,
total_acquired: AtomicU64::new(0),
total_rejected: AtomicU64::new(0),
total_refills: AtomicU64::new(0),
}
}
#[inline(always)]
pub fn try_acquire(&self) -> bool {
let current = self.tokens.0.load(self.ordering_load);
if current > 0
&& self
.tokens
.0
.compare_exchange_weak(current, current - 1, self.ordering_rmw, Ordering::Relaxed)
.is_ok()
{
self.on_acquisition(1);
self.touch_last_access_lazy();
return true;
}
self.try_acquire_full()
}
#[cold]
#[inline(never)]
fn try_acquire_full(&self) -> bool {
let now_ms = current_time_ms();
self.touch_last_access(now_ms);
let last_refill = self.last_refill_ms.0.load(Ordering::Relaxed);
if now_ms.wrapping_sub(last_refill) >= self.refill_interval_ms {
self.refill_if_needed(now_ms);
}
let mut retries = 0;
loop {
let current = self.tokens.0.load(self.ordering_load);
if current == 0 {
self.on_rejection(1);
return false;
}
match self.tokens.0.compare_exchange_weak(
current,
current - 1,
self.ordering_rmw,
Ordering::Relaxed,
) {
Ok(_) => {
self.on_acquisition(1);
return true;
}
Err(0) => {
self.on_rejection(1);
return false;
}
Err(_) => {
retries += 1;
if retries >= MAX_CAS_RETRIES {
self.on_rejection(1);
return false;
}
Self::backoff(retries);
}
}
}
}
#[inline(always)]
fn touch_last_access_lazy(&self) {
let now = current_time_ms();
let last = self.last_access_ms.0.load(Ordering::Relaxed);
if now.wrapping_sub(last) > LAST_ACCESS_UPDATE_INTERVAL_MS {
self.last_access_ms.0.store(now, Ordering::Relaxed);
}
}
#[inline(always)]
fn touch_last_access(&self, now_ms: u64) {
let last = self.last_access_ms.0.load(Ordering::Relaxed);
if now_ms.wrapping_sub(last) > LAST_ACCESS_UPDATE_INTERVAL_MS {
self.last_access_ms.0.store(now_ms, Ordering::Relaxed);
}
}
#[inline(always)]
fn backoff(retries: usize) {
if retries > CAS_BACKOFF_THRESHOLD {
for _ in 0..(1 << (retries - CAS_BACKOFF_THRESHOLD).min(4)) {
cpu_relax();
}
} else {
cpu_relax();
}
}
#[inline]
pub fn try_acquire_n(&self, n: u64) -> bool {
if n == 0 {
return true;
}
if n == 1 {
return self.try_acquire();
}
if n > self.max_tokens {
self.on_rejection(n);
return false;
}
let now_ms = current_time_ms();
self.touch_last_access(now_ms);
self.refill_if_needed(now_ms);
self.try_acquire_with_bounded_cas(n)
}
#[inline(always)]
fn try_acquire_with_bounded_cas(&self, n: u64) -> bool {
let mut retries = 0;
let mut last_seen = u64::MAX;
let mut repeat_count: usize = 0;
loop {
let current = self.tokens.0.load(self.ordering_load);
if current == last_seen {
repeat_count += 1;
if repeat_count >= MAX_REPEAT_COUNT {
if current < n {
self.on_rejection(n);
return false;
}
match self.tokens.0.compare_exchange(
current,
current - n,
self.ordering_rmw,
Ordering::Relaxed,
) {
Ok(_) => {
self.on_acquisition(n);
return true;
}
Err(_) => {
self.on_rejection(n);
return false;
}
}
}
} else {
last_seen = current;
repeat_count = 0;
}
if current < n {
self.on_rejection(n);
return false;
}
match self.tokens.0.compare_exchange_weak(
current,
current - n,
self.ordering_rmw,
Ordering::Relaxed,
) {
Ok(_) => {
self.on_acquisition(n);
return true;
}
Err(actual) => {
if actual < n {
self.on_rejection(n);
return false;
}
retries += 1;
if retries >= MAX_CAS_RETRIES {
self.on_rejection(n);
return false;
}
Self::backoff(retries);
}
}
}
}
#[inline(always)]
fn on_acquisition(&self, _n: u64) {
self.total_acquired.fetch_add(1, Ordering::Relaxed);
if self.consecutive_rejections.load(Ordering::Relaxed) != 0 {
self.consecutive_rejections.store(0, Ordering::Relaxed);
}
}
#[inline(always)]
fn on_rejection(&self, _n: u64) {
self.total_rejected.fetch_add(1, Ordering::Relaxed);
self.consecutive_rejections.fetch_add(1, Ordering::Relaxed);
}
#[inline]
fn refill_if_needed(&self, now_ms: u64) {
let last_refill = self.last_refill_ms.0.load(Ordering::Relaxed);
let elapsed = now_ms.wrapping_sub(last_refill);
if elapsed < self.refill_interval_ms {
return;
}
let periods = (elapsed / self.refill_interval_ms).min(MAX_REFILL_PERIODS);
if periods == 0 {
return;
}
let new_refill_time = last_refill.wrapping_add(periods * self.refill_interval_ms);
if self
.last_refill_ms
.0
.compare_exchange(
last_refill,
new_refill_time,
self.ordering_rmw,
Ordering::Relaxed,
)
.is_ok()
{
self.perform_refill(periods);
}
}
#[inline]
fn perform_refill(&self, periods: u64) {
let refill_rate = if self.is_under_sustained_pressure() {
self.adaptive_refill_rate()
} else {
self.refill_rate
};
let tokens_to_add = (refill_rate as u64)
.saturating_mul(periods)
.min(self.max_tokens);
let mut retries = 0;
let mut current = self.tokens.0.load(self.ordering_load);
loop {
let new_tokens = current.saturating_add(tokens_to_add).min(self.max_tokens);
if new_tokens == current {
self.total_refills.fetch_add(1, Ordering::Relaxed);
break;
}
match self.tokens.0.compare_exchange_weak(
current,
new_tokens,
self.ordering_rmw,
Ordering::Relaxed,
) {
Ok(_) => {
self.total_refills.fetch_add(1, Ordering::Relaxed);
debug!(
"Refilled {} tokens (periods: {})",
new_tokens - current,
periods
);
break;
}
Err(actual) => {
current = actual;
retries += 1;
if retries >= MAX_CAS_RETRIES {
break;
}
Self::backoff(retries);
}
}
}
}
#[inline]
fn is_under_sustained_pressure(&self) -> bool {
self.consecutive_rejections.load(Ordering::Relaxed) > 10
}
#[inline]
fn adaptive_refill_rate(&self) -> u32 {
let total_rejected = self.total_rejected.load(Ordering::Relaxed);
let total_acquired = self.total_acquired.load(Ordering::Relaxed);
let total = total_acquired + total_rejected;
if total == 0 {
return self.refill_rate;
}
if total_rejected * 2 > total {
self.refill_rate * 4 / 5
} else if total_rejected * 10 > total * 3 {
self.refill_rate * 9 / 10
} else {
self.refill_rate
}
}
#[inline]
fn calculate_pressure_ratio(&self) -> f64 {
let total_rejected = self.total_rejected.load(Ordering::Relaxed);
let total_acquired = self.total_acquired.load(Ordering::Relaxed);
let total = total_acquired + total_rejected;
if total == 0 {
0.0
} else {
total_rejected as f64 / total as f64
}
}
#[inline(always)]
pub fn available_tokens(&self) -> u64 {
self.refill_if_needed(current_time_ms());
self.tokens.0.load(self.ordering_load)
}
#[inline(always)]
pub fn is_inactive(&self, inactive_duration_ms: u64) -> bool {
let now_ms = current_time_ms();
let last_ms = self.last_access_ms.0.load(Ordering::Relaxed);
now_ms.wrapping_sub(last_ms) > inactive_duration_ms
}
pub fn metrics(&self) -> RateLimiterMetrics {
RateLimiterMetrics {
total_acquired: self.total_acquired.load(Ordering::Relaxed),
total_rejected: self.total_rejected.load(Ordering::Relaxed),
total_refills: self.total_refills.load(Ordering::Relaxed),
current_tokens: self.tokens.0.load(Ordering::Relaxed),
max_tokens: self.max_tokens,
consecutive_rejections: self.consecutive_rejections.load(Ordering::Relaxed),
max_wait_time_ns: self.max_wait_time_ns.load(Ordering::Relaxed),
pressure_ratio: self.calculate_pressure_ratio(),
}
}
#[inline]
pub fn add_tokens(&self, n: u64) {
let mut retries = 0;
let mut current = self.tokens.0.load(self.ordering_load);
loop {
let new_tokens = current.saturating_add(n).min(self.max_tokens);
if new_tokens == current {
break;
}
match self.tokens.0.compare_exchange_weak(
current,
new_tokens,
self.ordering_rmw,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(actual) => {
current = actual;
retries += 1;
if retries >= MAX_CAS_RETRIES {
break;
}
Self::backoff(retries);
}
}
}
}
pub fn reset(&self) {
let now_ms = current_time_ms();
self.tokens.0.store(self.max_tokens, self.ordering_store);
self.last_refill_ms.0.store(now_ms, self.ordering_store);
self.last_access_ms.0.store(now_ms, self.ordering_store);
self.consecutive_rejections.store(0, self.ordering_store);
self.max_wait_time_ns.store(0, self.ordering_store);
self.total_acquired.store(0, self.ordering_store);
self.total_rejected.store(0, self.ordering_store);
self.total_refills.store(0, self.ordering_store);
}
#[inline]
pub fn get_max_tokens(&self) -> u64 {
self.max_tokens
}
#[inline]
pub fn get_last_access_ms(&self) -> u64 {
self.last_access_ms.0.load(Ordering::Relaxed)
}
}
impl std::fmt::Debug for RateLimiter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RateLimiter")
.field("max_tokens", &self.max_tokens)
.field("refill_rate", &self.refill_rate)
.field("refill_interval_ms", &self.refill_interval_ms)
.field("ordering", &self.ordering)
.field("current_tokens", &self.available_tokens())
.finish()
}
}
unsafe impl Send for RateLimiter {}
unsafe impl Sync for RateLimiter {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg_attr(miri, ignore)]
fn test_basic_acquisition() {
let config = RateLimiterConfig {
max_tokens: 10,
refill_rate: 1,
refill_interval_ms: 600_000,
ordering: MemoryOrdering::AcquireRelease,
};
let limiter = RateLimiter::with_config(config);
for _ in 0..10 {
assert!(limiter.try_acquire());
}
assert!(!limiter.try_acquire());
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_bulk_acquisition() {
let config = RateLimiterConfig {
max_tokens: 10,
refill_rate: 1,
refill_interval_ms: 600_000,
ordering: MemoryOrdering::AcquireRelease,
};
let limiter = RateLimiter::with_config(config);
assert!(limiter.try_acquire_n(5));
assert!(limiter.try_acquire_n(3));
assert!(!limiter.try_acquire_n(5));
assert!(limiter.try_acquire_n(2));
}
#[test]
fn test_overflow_protection() {
let limiter = RateLimiter::new(u64::MAX, u32::MAX);
limiter.add_tokens(u64::MAX);
assert_eq!(limiter.available_tokens(), u64::MAX);
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_refill_mechanism() {
let limiter = RateLimiter::new(10, 5);
for _ in 0..10 {
assert!(limiter.try_acquire());
}
assert!(!limiter.try_acquire());
std::thread::sleep(std::time::Duration::from_millis(1100));
assert!(limiter.try_acquire_n(5));
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_multiple_refill_periods() {
let config = RateLimiterConfig {
max_tokens: 20,
refill_rate: 5,
refill_interval_ms: 100,
ordering: MemoryOrdering::AcquireRelease,
};
let limiter = RateLimiter::with_config(config);
assert!(limiter.try_acquire_n(20));
assert!(!limiter.try_acquire());
std::thread::sleep(std::time::Duration::from_millis(450));
assert_eq!(limiter.available_tokens(), 20);
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_add_tokens() {
let config = RateLimiterConfig {
max_tokens: 10,
refill_rate: 1,
refill_interval_ms: 600_000,
ordering: MemoryOrdering::AcquireRelease,
};
let limiter = RateLimiter::with_config(config);
for _ in 0..5 {
assert!(limiter.try_acquire());
}
limiter.add_tokens(3);
assert_eq!(limiter.available_tokens(), 8);
limiter.add_tokens(20);
assert_eq!(limiter.available_tokens(), 10);
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_reset() {
let config = RateLimiterConfig {
max_tokens: 10,
refill_rate: 1,
refill_interval_ms: 600_000,
ordering: MemoryOrdering::AcquireRelease,
};
let limiter = RateLimiter::with_config(config);
for _ in 0..5 {
assert!(limiter.try_acquire());
}
for _ in 0..3 {
assert!(!limiter.try_acquire_n(10));
}
let metrics_before = limiter.metrics();
assert!(metrics_before.total_acquired > 0);
assert!(metrics_before.total_rejected > 0);
limiter.reset();
assert_eq!(limiter.available_tokens(), 10);
let metrics_after = limiter.metrics();
assert_eq!(metrics_after.total_acquired, 0);
assert_eq!(metrics_after.total_rejected, 0);
assert_eq!(metrics_after.consecutive_rejections, 0);
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_is_inactive() {
let limiter = RateLimiter::new(10, 1);
assert!(!limiter.is_inactive(1000));
let _ = limiter.available_tokens();
std::thread::sleep(std::time::Duration::from_millis(200));
assert!(limiter.is_inactive(100));
assert!(!limiter.is_inactive(1000));
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_sustained_pressure() {
let config = RateLimiterConfig {
max_tokens: 5,
refill_rate: 1,
refill_interval_ms: 600_000,
ordering: MemoryOrdering::AcquireRelease,
};
let limiter = RateLimiter::with_config(config);
for _ in 0..5 {
assert!(limiter.try_acquire());
}
for _ in 0..15 {
assert!(!limiter.try_acquire());
}
let metrics = limiter.metrics();
assert!(metrics.consecutive_rejections > 10);
assert!(limiter.is_under_sustained_pressure());
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_adaptive_refill_under_pressure() {
let limiter = RateLimiter::new(10, 10);
for _ in 0..10 {
limiter.try_acquire();
}
for _ in 0..20 {
limiter.try_acquire(); }
std::thread::sleep(std::time::Duration::from_millis(1100));
let tokens = limiter.available_tokens();
assert!(tokens <= 10); }
#[test]
fn test_max_wait_time_tracking() {
let limiter = RateLimiter::new(100, 10);
assert!(limiter.try_acquire_n(50));
assert!(limiter.try_acquire_n(30));
let _metrics = limiter.metrics();
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_cas_retry_exhaustion() {
use std::sync::Arc;
use std::thread;
let limiter = Arc::new(RateLimiter::new(100, 10));
let mut handles = vec![];
for _ in 0..50 {
let limiter_clone = limiter.clone();
handles.push(thread::spawn(move || {
for _ in 0..100 {
limiter_clone.try_acquire_n(2);
}
}));
}
for handle in handles {
handle.join().unwrap();
}
let metrics = limiter.metrics();
assert!(metrics.total_acquired > 0 || metrics.total_rejected > 0);
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_acquire_zero_tokens() {
let config = RateLimiterConfig {
max_tokens: 10,
refill_rate: 1,
refill_interval_ms: 600_000,
ordering: MemoryOrdering::AcquireRelease,
};
let limiter = RateLimiter::with_config(config);
assert!(limiter.try_acquire_n(0));
for _ in 0..10 {
assert!(limiter.try_acquire());
}
assert!(limiter.try_acquire_n(0));
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_acquire_more_than_max() {
let config = RateLimiterConfig {
max_tokens: 10,
refill_rate: 1,
refill_interval_ms: 600_000,
ordering: MemoryOrdering::AcquireRelease,
};
let limiter = RateLimiter::with_config(config);
assert!(!limiter.try_acquire_n(11));
let metrics = limiter.metrics();
assert_eq!(metrics.total_rejected, 1);
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_last_access_update_throttling() {
let limiter = RateLimiter::new(100, 10);
let start_access = limiter.last_access_ms.0.load(Ordering::Relaxed);
for _ in 0..50 {
assert!(limiter.try_acquire());
std::thread::sleep(std::time::Duration::from_millis(1));
}
let end_access = limiter.last_access_ms.0.load(Ordering::Relaxed);
assert!(end_access >= start_access);
}
#[cfg(not(target_pointer_width = "64"))]
#[test]
fn test_32bit_token_ops() {
use super::TokenOps;
let counter = TokenCounter::new(u32::MAX as u64 + 1);
assert_eq!(counter.load(Ordering::Relaxed), u32::MAX as u64);
counter.store(u32::MAX as u64 + 100, Ordering::Relaxed);
assert_eq!(counter.load(Ordering::Relaxed), u32::MAX as u64);
}
#[test]
fn test_debug_impl() {
let limiter = RateLimiter::new(10, 5);
let debug_str = format!("{:?}", limiter);
assert!(debug_str.contains("RateLimiter"));
assert!(debug_str.contains("max_tokens: 10"));
assert!(debug_str.contains("refill_rate: 5"));
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_aba_fix_no_silent_drain() {
let config = RateLimiterConfig {
max_tokens: 3,
refill_rate: 1,
refill_interval_ms: 600_000,
ordering: MemoryOrdering::AcquireRelease,
};
let limiter = RateLimiter::with_config(config);
assert!(limiter.try_acquire_n(2));
assert!(!limiter.try_acquire_n(3));
assert_eq!(limiter.available_tokens(), 1);
assert!(limiter.try_acquire());
let metrics = limiter.metrics();
assert_eq!(metrics.total_acquired, 2); assert_eq!(metrics.total_rejected, 1); }
#[test]
#[cfg_attr(miri, ignore)]
fn test_metrics_count_attempts_not_tokens() {
let limiter = RateLimiter::new(100, 10);
assert!(limiter.try_acquire_n(50));
assert!(limiter.try_acquire_n(30));
assert!(limiter.try_acquire());
let metrics = limiter.metrics();
assert_eq!(metrics.total_acquired, 3);
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_rejection_resets_consecutive_on_acquire() {
let limiter = RateLimiter::new(5, 1);
for _ in 0..5 {
limiter.try_acquire();
}
for _ in 0..8 {
assert!(!limiter.try_acquire());
}
assert!(limiter.metrics().consecutive_rejections >= 8);
std::thread::sleep(std::time::Duration::from_millis(1100));
assert!(limiter.try_acquire());
assert_eq!(limiter.metrics().consecutive_rejections, 0);
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_relaxed_ordering_mode() {
let config = RateLimiterConfig {
max_tokens: 20,
refill_rate: 5,
refill_interval_ms: 1000,
ordering: MemoryOrdering::Relaxed,
};
let limiter = RateLimiter::with_config(config);
for _ in 0..20 {
assert!(limiter.try_acquire());
}
assert!(!limiter.try_acquire());
let metrics = limiter.metrics();
assert_eq!(metrics.total_acquired, 20);
assert_eq!(metrics.total_rejected, 1);
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_sequential_ordering_mode() {
let config = RateLimiterConfig {
max_tokens: 10,
refill_rate: 5,
refill_interval_ms: 1000,
ordering: MemoryOrdering::Sequential,
};
let limiter = RateLimiter::with_config(config);
for _ in 0..10 {
assert!(limiter.try_acquire());
}
assert!(!limiter.try_acquire());
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_concurrent_single_and_multi_acquire() {
use std::sync::Arc;
use std::thread;
let limiter = Arc::new(RateLimiter::new(500, 10));
let mut handles = vec![];
for i in 0..20 {
let l = limiter.clone();
handles.push(thread::spawn(move || {
let mut acquired = 0u64;
for _ in 0..50 {
if i % 2 == 0 {
if l.try_acquire() {
acquired += 1;
}
} else if l.try_acquire_n(2) {
acquired += 1;
}
}
acquired
}));
}
let total: u64 = handles.into_iter().map(|h| h.join().unwrap()).sum();
let metrics = limiter.metrics();
assert_eq!(metrics.total_acquired, total);
assert_eq!(metrics.total_requests(), 20 * 50);
}
#[test]
fn test_get_max_tokens() {
let limiter = RateLimiter::new(42, 5);
assert_eq!(limiter.get_max_tokens(), 42);
}
#[test]
fn test_get_last_access_ms_updates() {
let limiter = RateLimiter::new(100, 10);
let before = limiter.get_last_access_ms();
assert!(limiter.try_acquire_n(1));
let after = limiter.get_last_access_ms();
assert!(after >= before);
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_add_tokens_when_empty() {
let config = RateLimiterConfig {
max_tokens: 10,
refill_rate: 1,
refill_interval_ms: 600_000,
ordering: MemoryOrdering::AcquireRelease,
};
let limiter = RateLimiter::with_config(config);
for _ in 0..10 {
limiter.try_acquire();
}
assert_eq!(limiter.available_tokens(), 0);
limiter.add_tokens(5);
assert_eq!(limiter.available_tokens(), 5);
}
#[test]
fn test_add_tokens_zero() {
let limiter = RateLimiter::new(10, 1);
let before = limiter.available_tokens();
limiter.add_tokens(0);
assert_eq!(limiter.available_tokens(), before);
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_pressure_ratio_calculation() {
let config = RateLimiterConfig {
max_tokens: 5,
refill_rate: 1,
refill_interval_ms: 600_000,
ordering: MemoryOrdering::AcquireRelease,
};
let limiter = RateLimiter::with_config(config);
for _ in 0..5 {
limiter.try_acquire();
}
for _ in 0..5 {
limiter.try_acquire();
}
let metrics = limiter.metrics();
let ratio = metrics.pressure_ratio;
assert!(ratio > 0.4 && ratio < 0.6, "Expected ~0.5, got {}", ratio);
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_refill_caps_at_max() {
let config = RateLimiterConfig {
max_tokens: 10,
refill_rate: 10,
refill_interval_ms: 50,
ordering: MemoryOrdering::AcquireRelease,
};
let limiter = RateLimiter::with_config(config);
for _ in 0..5 {
limiter.try_acquire();
}
std::thread::sleep(std::time::Duration::from_millis(200));
assert_eq!(limiter.available_tokens(), 10);
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_reset_clears_all_state() {
let config = RateLimiterConfig {
max_tokens: 10,
refill_rate: 5,
refill_interval_ms: 600_000,
ordering: MemoryOrdering::AcquireRelease,
};
let limiter = RateLimiter::with_config(config);
for _ in 0..10 {
limiter.try_acquire();
}
for _ in 0..5 {
limiter.try_acquire();
}
limiter.reset();
assert_eq!(limiter.available_tokens(), 10);
let m = limiter.metrics();
assert_eq!(m.total_acquired, 0);
assert_eq!(m.total_rejected, 0);
assert_eq!(m.total_refills, 0);
assert_eq!(m.consecutive_rejections, 0);
assert_eq!(m.max_wait_time_ns, 0);
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_try_acquire_n_exactly_max() {
let config = RateLimiterConfig {
max_tokens: 10,
refill_rate: 1,
refill_interval_ms: 600_000,
ordering: MemoryOrdering::AcquireRelease,
};
let limiter = RateLimiter::with_config(config);
assert!(limiter.try_acquire_n(10));
assert_eq!(limiter.available_tokens(), 0);
assert!(!limiter.try_acquire());
}
}