use crate::error::CanoError;
use crate::resource::Resource;
use cano_macros::resource;
use parking_lot::Mutex;
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct RateLimiterPolicy {
pub max_tokens: u32,
pub tokens_per_period: u32,
pub refill_period: Duration,
pub burst: u32,
}
impl RateLimiterPolicy {
pub fn new(tokens: u32, period: Duration) -> Self {
Self {
max_tokens: tokens,
tokens_per_period: tokens,
refill_period: period,
burst: 0,
}
}
pub fn per_second(tokens: u32) -> Self {
Self::new(tokens, Duration::from_secs(1))
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_burst(mut self, burst: u32) -> Self {
self.burst = burst;
self
}
fn refill_per_sec(&self) -> f64 {
self.tokens_per_period as f64 / self.refill_period.as_secs_f64()
}
fn capacity(&self) -> u32 {
self.max_tokens.saturating_add(self.burst)
}
}
#[derive(Debug)]
struct State {
tokens: f64,
last_refill: Instant,
}
#[derive(Debug, Clone)]
pub struct RateLimiter {
inner: Arc<Mutex<State>>,
policy: RateLimiterPolicy,
}
impl RateLimiter {
pub fn new(policy: RateLimiterPolicy) -> Self {
assert!(policy.max_tokens >= 1, "max_tokens must be >= 1");
assert!(
policy.tokens_per_period >= 1 && !policy.refill_period.is_zero(),
"refill rate must be > 0"
);
let capacity = policy.capacity() as f64;
Self {
inner: Arc::new(Mutex::new(State {
tokens: capacity,
last_refill: Instant::now(),
})),
policy,
}
}
pub fn policy(&self) -> &RateLimiterPolicy {
&self.policy
}
pub fn try_acquire(self: &Arc<Self>) -> Option<Permit> {
self.try_acquire_n(1)
}
pub fn try_acquire_n(self: &Arc<Self>, cost: u64) -> Option<Permit> {
let acquired = self.debit(cost);
#[cfg(feature = "metrics")]
if acquired {
crate::metrics::rate_limiter_acquired();
crate::metrics::rate_limiter_tokens_consumed(cost);
} else {
crate::metrics::rate_limiter_throttled("rejected");
}
acquired.then(|| Permit {
_limiter: Arc::clone(self),
})
}
pub fn try_reserve_n(self: &Arc<Self>, cost: u64) -> Option<Reservation> {
if self.debit(cost) {
let arc = Arc::clone(self);
let meter: Arc<dyn Meter> = arc;
Some(Reservation {
meter,
cost,
epoch: 0, committed: false,
})
} else {
None
}
}
pub async fn acquire(self: &Arc<Self>) -> Permit {
self.acquire_n(1).await
}
pub async fn acquire_n(self: &Arc<Self>, cost: u64) -> Permit {
#[cfg(feature = "metrics")]
let start = Instant::now();
#[cfg(feature = "metrics")]
let mut waited = false;
let cost_f = cost as f64;
let capacity = self.policy.capacity() as f64;
loop {
let wait = {
let mut st = self.inner.lock();
self.refill_locked(&mut st);
if st.tokens >= cost_f {
st.tokens -= cost_f;
None
} else if cost_f > capacity {
Some(Duration::MAX)
} else {
let deficit = cost_f - st.tokens;
Some(
Duration::try_from_secs_f64(
(deficit / self.policy.refill_per_sec()).max(0.0),
)
.unwrap_or(Duration::MAX),
)
}
};
match wait {
None => {
#[cfg(feature = "metrics")]
{
crate::metrics::rate_limiter_acquired();
crate::metrics::rate_limiter_tokens_consumed(cost);
if waited {
crate::metrics::rate_limiter_wait(start.elapsed());
}
}
return Permit {
_limiter: Arc::clone(self),
};
}
Some(dur) => {
#[cfg(feature = "metrics")]
{
crate::metrics::rate_limiter_throttled("waited");
waited = true;
}
tokio::time::sleep(dur).await;
}
}
}
}
pub fn tokens_available(&self) -> u64 {
let mut st = self.inner.lock();
self.refill_locked(&mut st);
st.tokens.floor().max(0.0) as u64
}
pub fn time_until(&self, cost: u64) -> Duration {
if cost == 0 {
return Duration::ZERO;
}
let cost_f = cost as f64;
if cost_f > self.policy.capacity() as f64 {
return Duration::MAX;
}
let mut st = self.inner.lock();
self.refill_locked(&mut st);
if st.tokens >= cost_f {
Duration::ZERO
} else {
let deficit = cost_f - st.tokens;
Duration::try_from_secs_f64((deficit / self.policy.refill_per_sec()).max(0.0))
.unwrap_or(Duration::MAX)
}
}
fn debit(&self, cost: u64) -> bool {
if cost == 0 {
return true;
}
let cost_f = cost as f64;
let mut st = self.inner.lock();
self.refill_locked(&mut st);
if st.tokens >= cost_f {
st.tokens -= cost_f;
true
} else {
false
}
}
fn refund(&self, cost: u64) {
if cost == 0 {
return;
}
let mut st = self.inner.lock();
st.tokens = (st.tokens + cost as f64).min(self.policy.capacity() as f64);
}
fn refill_locked(&self, st: &mut State) {
let now = Instant::now();
let elapsed = now.saturating_duration_since(st.last_refill);
if elapsed.is_zero() {
return;
}
let credit = elapsed.as_secs_f64() * self.policy.refill_per_sec();
st.tokens = (st.tokens + credit).min(self.policy.capacity() as f64);
st.last_refill = now;
}
}
#[resource]
impl Resource for RateLimiter {}
impl Meter for RateLimiter {
fn try_debit(&self, cost: u64) -> Option<u64> {
self.debit(cost).then_some(0)
}
fn credit(&self, cost: u64, _epoch: u64) {
self.refund(cost);
}
fn time_until(&self, cost: u64) -> Duration {
RateLimiter::time_until(self, cost)
}
fn snapshot(&self) -> MeterStatus {
let capacity = self.policy.capacity() as u64;
let available = self.tokens_available();
MeterStatus {
used: capacity.saturating_sub(available),
limit: capacity,
available_in: self.time_until(1),
resets_at: None,
}
}
}
pub trait Meter: Send + Sync + std::fmt::Debug {
fn try_debit(&self, cost: u64) -> Option<u64>;
fn credit(&self, cost: u64, epoch: u64);
fn time_until(&self, cost: u64) -> Duration;
fn snapshot(&self) -> MeterStatus;
}
#[derive(Debug, Clone)]
pub struct MeterStatus {
pub used: u64,
pub limit: u64,
pub available_in: Duration,
pub resets_at: Option<Instant>,
}
#[must_use = "a dropped, uncommitted Reservation refunds its units; commit it to keep the debit"]
pub struct Reservation {
meter: Arc<dyn Meter>,
cost: u64,
epoch: u64,
committed: bool,
}
impl Reservation {
pub fn commit(&mut self) {
self.committed = true;
}
}
impl Drop for Reservation {
fn drop(&mut self) {
if !self.committed {
self.meter.credit(self.cost, self.epoch);
}
}
}
impl std::fmt::Debug for Reservation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Reservation")
.field("cost", &self.cost)
.field("committed", &self.committed)
.finish_non_exhaustive()
}
}
#[must_use = "the permit marks the rate-limited call's scope"]
pub struct Permit {
_limiter: Arc<RateLimiter>,
}
impl std::fmt::Debug for Permit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Permit").finish_non_exhaustive()
}
}
#[derive(Debug, Clone)]
pub struct WindowPolicy {
pub limit: u64,
pub window: Duration,
}
impl WindowPolicy {
pub fn new(limit: u64, window: Duration) -> Self {
Self { limit, window }
}
pub fn per_hours(limit: u64, hours: u64) -> Self {
Self::new(limit, Duration::from_secs(hours.saturating_mul(3600)))
}
pub fn per_days(limit: u64, days: u64) -> Self {
Self::new(limit, Duration::from_secs(days.saturating_mul(86_400)))
}
}
#[derive(Debug)]
struct WindowState {
window_start: Instant,
used: u64,
generation: u64,
}
#[derive(Debug, Clone)]
pub struct WindowedRateLimiter {
inner: Arc<Mutex<WindowState>>,
policy: WindowPolicy,
}
impl WindowedRateLimiter {
pub fn new(policy: WindowPolicy) -> Self {
assert!(policy.limit >= 1, "WindowPolicy::limit must be >= 1");
assert!(!policy.window.is_zero(), "WindowPolicy::window must be > 0");
Self {
inner: Arc::new(Mutex::new(WindowState {
window_start: Instant::now(),
used: 0,
generation: 0,
})),
policy,
}
}
pub fn policy(&self) -> &WindowPolicy {
&self.policy
}
pub fn try_acquire(self: &Arc<Self>) -> Option<WindowPermit> {
self.try_acquire_n(1)
}
pub fn try_acquire_n(self: &Arc<Self>, cost: u64) -> Option<WindowPermit> {
let acquired = self.debit(cost).is_some();
#[cfg(feature = "metrics")]
if acquired {
crate::metrics::rate_limiter_acquired();
crate::metrics::rate_limiter_tokens_consumed(cost);
} else {
crate::metrics::rate_limiter_throttled("rejected");
}
acquired.then(|| WindowPermit {
_limiter: Arc::clone(self),
})
}
pub fn try_reserve_n(self: &Arc<Self>, cost: u64) -> Option<Reservation> {
self.debit(cost).map(|epoch| {
let arc = Arc::clone(self);
let meter: Arc<dyn Meter> = arc;
Reservation {
meter,
cost,
epoch,
committed: false,
}
})
}
pub async fn acquire(self: &Arc<Self>) -> WindowPermit {
self.acquire_n(1).await
}
pub async fn acquire_n(self: &Arc<Self>, cost: u64) -> WindowPermit {
#[cfg(feature = "metrics")]
let start = Instant::now();
#[cfg(feature = "metrics")]
let mut waited = false;
loop {
let wait = {
let mut st = self.inner.lock();
self.roll_locked(&mut st);
if st.used.saturating_add(cost) <= self.policy.limit {
st.used += cost;
None
} else if cost > self.policy.limit {
Some(Duration::MAX)
} else {
Some(self.time_until_reset(&st))
}
};
match wait {
None => {
#[cfg(feature = "metrics")]
{
crate::metrics::rate_limiter_acquired();
crate::metrics::rate_limiter_tokens_consumed(cost);
if waited {
crate::metrics::rate_limiter_wait(start.elapsed());
}
}
return WindowPermit {
_limiter: Arc::clone(self),
};
}
Some(dur) => {
#[cfg(feature = "metrics")]
{
crate::metrics::rate_limiter_throttled("waited");
waited = true;
}
tokio::time::sleep(dur).await;
}
}
}
}
pub fn used(&self) -> u64 {
let mut st = self.inner.lock();
self.roll_locked(&mut st);
st.used
}
pub fn remaining(&self) -> u64 {
let mut st = self.inner.lock();
self.roll_locked(&mut st);
self.policy.limit.saturating_sub(st.used)
}
pub fn resets_at(&self) -> Instant {
let mut st = self.inner.lock();
self.roll_locked(&mut st);
self.boundary(&st).unwrap_or(st.window_start)
}
pub fn time_until(&self, cost: u64) -> Duration {
if cost == 0 {
return Duration::ZERO;
}
let mut st = self.inner.lock();
self.roll_locked(&mut st);
if st.used.saturating_add(cost) <= self.policy.limit {
Duration::ZERO
} else if cost > self.policy.limit {
Duration::MAX
} else {
self.time_until_reset(&st)
}
}
fn debit(&self, cost: u64) -> Option<u64> {
let mut st = self.inner.lock();
self.roll_locked(&mut st);
if cost == 0 {
return Some(st.generation);
}
if st.used.saturating_add(cost) <= self.policy.limit {
st.used += cost;
Some(st.generation)
} else {
None
}
}
fn refund(&self, cost: u64, epoch: u64) {
if cost == 0 {
return;
}
let mut st = self.inner.lock();
self.roll_locked(&mut st);
if epoch == st.generation {
st.used = st.used.saturating_sub(cost);
}
}
fn boundary(&self, st: &WindowState) -> Option<Instant> {
st.window_start.checked_add(self.policy.window)
}
fn time_until_reset(&self, st: &WindowState) -> Duration {
self.boundary(st)
.map(|b| b.saturating_duration_since(Instant::now()))
.unwrap_or(Duration::MAX)
}
fn roll_locked(&self, st: &mut WindowState) {
let now = Instant::now();
if now.saturating_duration_since(st.window_start) >= self.policy.window {
st.window_start = now;
st.used = 0;
st.generation = st.generation.wrapping_add(1);
}
}
}
#[resource]
impl Resource for WindowedRateLimiter {}
impl Meter for WindowedRateLimiter {
fn try_debit(&self, cost: u64) -> Option<u64> {
self.debit(cost)
}
fn credit(&self, cost: u64, epoch: u64) {
self.refund(cost, epoch);
}
fn time_until(&self, cost: u64) -> Duration {
WindowedRateLimiter::time_until(self, cost)
}
fn snapshot(&self) -> MeterStatus {
let mut st = self.inner.lock();
self.roll_locked(&mut st);
let resets_at = self.boundary(&st);
let available_in = if st.used < self.policy.limit {
Duration::ZERO
} else {
self.time_until_reset(&st)
};
MeterStatus {
used: st.used,
limit: self.policy.limit,
available_in,
resets_at,
}
}
}
#[must_use = "the permit marks the rate-limited call's scope"]
pub struct WindowPermit {
_limiter: Arc<WindowedRateLimiter>,
}
impl std::fmt::Debug for WindowPermit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WindowPermit").finish_non_exhaustive()
}
}
#[derive(Debug, Clone)]
pub struct Tier {
pub name: &'static str,
pub meter: Arc<dyn Meter>,
pub cost: u64,
}
const MULTI_MIN_BACKOFF: Duration = Duration::from_millis(1);
#[derive(Debug, Clone, Default)]
pub struct MultiRateLimiter {
tiers: Vec<Tier>,
}
impl MultiRateLimiter {
pub fn new() -> Self {
Self::default()
}
pub fn with_tier(mut self, name: &'static str, meter: Arc<dyn Meter>, cost: u64) -> Self {
self.tiers.push(Tier { name, meter, cost });
self
}
pub fn tiers(&self) -> &[Tier] {
&self.tiers
}
pub fn try_acquire(&self) -> Result<MultiPermit, CanoError> {
self.try_acquire_for(&[])
}
pub fn try_acquire_for(&self, only: &[&str]) -> Result<MultiPermit, CanoError> {
self.reserve_all(only).map_err(|(tier, retry_after)| {
Self::record_throttle(tier);
CanoError::rate_limited(tier, retry_after)
})
}
pub async fn acquire(&self) -> MultiPermit {
self.acquire_for(&[]).await
}
pub async fn acquire_for(&self, only: &[&str]) -> MultiPermit {
loop {
match self.reserve_all(only) {
Ok(permit) => return permit,
Err((tier, retry_after)) => {
Self::record_throttle(tier);
tokio::time::sleep(retry_after).await;
}
}
}
}
fn record_throttle(tier: &'static str) {
#[cfg(feature = "metrics")]
crate::metrics::multi_rate_limiter_throttled(tier);
#[cfg(not(feature = "metrics"))]
let _ = tier;
}
fn reserve_all(&self, only: &[&str]) -> Result<MultiPermit, (&'static str, Duration)> {
let mut held: Vec<Reservation> = Vec::with_capacity(self.tiers.len());
for tier in &self.tiers {
if (!only.is_empty() && !only.contains(&tier.name)) || tier.cost == 0 {
continue;
}
match tier.meter.try_debit(tier.cost) {
Some(epoch) => held.push(Reservation {
meter: Arc::clone(&tier.meter),
cost: tier.cost,
epoch,
committed: false,
}),
None => {
let retry_after = tier.meter.time_until(tier.cost).max(MULTI_MIN_BACKOFF);
drop(held);
return Err((tier.name, retry_after));
}
}
}
for r in &mut held {
r.commit();
}
Ok(MultiPermit {
_reservations: held,
})
}
}
#[resource]
impl Resource for MultiRateLimiter {}
#[must_use = "the permit marks the rate-limited call's scope"]
pub struct MultiPermit {
_reservations: Vec<Reservation>,
}
impl std::fmt::Debug for MultiPermit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MultiPermit")
.field("tiers", &self._reservations.len())
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::time::Duration;
fn fast_policy() -> RateLimiterPolicy {
RateLimiterPolicy::per_second(10).with_max_tokens(5)
}
#[test]
fn try_acquire_succeeds_when_tokens_available() {
let limiter = Arc::new(RateLimiter::new(fast_policy()));
for _ in 0..5 {
assert!(limiter.try_acquire().is_some());
}
}
#[test]
fn try_acquire_returns_none_when_bucket_empty() {
let limiter = Arc::new(RateLimiter::new(fast_policy()));
for _ in 0..5 {
let _ = limiter.try_acquire().expect("first 5 succeed");
}
assert!(limiter.try_acquire().is_none());
}
#[tokio::test]
async fn acquire_parks_until_a_token_refills() {
let start = std::time::Instant::now();
let limiter = Arc::new(RateLimiter::new(
RateLimiterPolicy::per_second(100).with_max_tokens(1),
));
let _ = limiter.try_acquire().expect("initial token");
let _permit = limiter.acquire().await;
assert!(
start.elapsed() >= Duration::from_millis(5),
"acquire should have parked for a refill, waited {:?}",
start.elapsed()
);
}
#[tokio::test]
async fn acquire_returns_immediately_when_tokens_available() {
let limiter = Arc::new(RateLimiter::new(fast_policy()));
let _permit = limiter.acquire().await;
assert_eq!(limiter.tokens_available(), 4);
}
#[tokio::test]
async fn tokens_refill_over_real_time() {
let limiter = Arc::new(RateLimiter::new(
RateLimiterPolicy::per_second(50).with_max_tokens(1),
));
let _ = limiter.try_acquire().expect("initial token");
assert!(
limiter.try_acquire().is_none(),
"bucket is empty right after draining"
);
tokio::time::sleep(Duration::from_millis(40)).await;
assert!(
limiter.try_acquire().is_some(),
"a token should have refilled after >1 interval"
);
}
#[tokio::test]
async fn acquire_clamps_extreme_wait_without_panicking() {
let limiter = Arc::new(RateLimiter::new(RateLimiterPolicy {
max_tokens: 1,
tokens_per_period: 1,
refill_period: Duration::MAX,
burst: 0,
}));
let _ = limiter.try_acquire().expect("initial token");
let res = tokio::time::timeout(Duration::from_millis(50), limiter.acquire()).await;
assert!(
res.is_err(),
"acquire should still be parking on the clamped wait"
);
}
#[test]
fn burst_respected() {
let policy = RateLimiterPolicy::per_second(10)
.with_max_tokens(2)
.with_burst(5);
let limiter = Arc::new(RateLimiter::new(policy));
let mut acquired = 0;
while limiter.try_acquire().is_some() {
acquired += 1;
if acquired > 7 {
break;
}
}
assert_eq!(acquired, 7);
}
#[test]
#[should_panic(expected = "max_tokens must be >= 1")]
fn rejects_zero_max_tokens() {
let _ = RateLimiter::new(RateLimiterPolicy::per_second(1).with_max_tokens(0));
}
#[test]
#[should_panic(expected = "refill rate must be > 0")]
fn rejects_zero_tokens_per_period() {
let _ =
RateLimiter::new(RateLimiterPolicy::new(0, Duration::from_secs(1)).with_max_tokens(1));
}
#[test]
#[should_panic(expected = "refill rate must be > 0")]
fn rejects_zero_refill_period() {
let _ = RateLimiter::new(RateLimiterPolicy {
max_tokens: 1,
tokens_per_period: 1,
refill_period: Duration::ZERO,
burst: 0,
});
}
#[test]
fn try_acquire_n_debits_cost() {
let limiter = Arc::new(RateLimiter::new(
RateLimiterPolicy::per_second(10).with_max_tokens(10),
));
assert!(limiter.try_acquire_n(3).is_some());
assert_eq!(limiter.tokens_available(), 7);
assert!(limiter.try_acquire_n(7).is_some());
assert_eq!(limiter.tokens_available(), 0);
assert!(limiter.try_acquire_n(1).is_none());
}
#[test]
fn try_acquire_n_zero_cost_always_admits_without_debit() {
let limiter = Arc::new(RateLimiter::new(
RateLimiterPolicy::per_second(10).with_max_tokens(2),
));
for _ in 0..100 {
assert!(limiter.try_acquire_n(0).is_some());
}
assert_eq!(limiter.tokens_available(), 2);
}
#[test]
fn try_acquire_n_over_capacity_rejects_without_draining() {
let limiter = Arc::new(RateLimiter::new(
RateLimiterPolicy::per_second(10).with_max_tokens(5),
));
assert!(limiter.try_acquire_n(6).is_none());
assert_eq!(limiter.tokens_available(), 5);
}
#[tokio::test]
async fn acquire_n_parks_for_weighted_cost() {
let start = std::time::Instant::now();
let limiter = Arc::new(RateLimiter::new(
RateLimiterPolicy::per_second(100).with_max_tokens(5),
));
assert!(limiter.try_acquire_n(5).is_some());
let _p = limiter.acquire_n(3).await;
assert!(start.elapsed() >= Duration::from_millis(15));
}
#[test]
fn reservation_drop_refunds() {
let limiter = Arc::new(RateLimiter::new(
RateLimiterPolicy::per_second(10).with_max_tokens(10),
));
{
let _r = limiter.try_reserve_n(4).expect("reserve");
assert_eq!(limiter.tokens_available(), 6);
} assert_eq!(limiter.tokens_available(), 10);
}
#[test]
fn reservation_commit_keeps_debit() {
let limiter = Arc::new(RateLimiter::new(
RateLimiterPolicy::per_second(10).with_max_tokens(10),
));
{
let mut r = limiter.try_reserve_n(4).expect("reserve");
r.commit();
} assert_eq!(limiter.tokens_available(), 6);
}
#[test]
fn reservation_refund_is_capped_at_capacity() {
let limiter = Arc::new(RateLimiter::new(
RateLimiterPolicy::per_second(1_000_000).with_max_tokens(5),
));
let r = limiter.try_reserve_n(5).expect("reserve"); std::thread::sleep(Duration::from_millis(20));
assert_eq!(limiter.tokens_available(), 5); drop(r); assert_eq!(limiter.tokens_available(), 5);
}
#[test]
fn window_admits_up_to_limit() {
let w = Arc::new(WindowedRateLimiter::new(WindowPolicy::new(
3,
Duration::from_secs(60),
)));
for _ in 0..3 {
assert!(w.try_acquire().is_some());
}
assert!(w.try_acquire().is_none());
assert_eq!(w.used(), 3);
assert_eq!(w.remaining(), 0);
}
#[test]
fn window_weighted_admits_by_cost() {
let w = Arc::new(WindowedRateLimiter::new(WindowPolicy::new(
1000,
Duration::from_secs(60),
)));
assert!(w.try_acquire_n(1500).is_none()); assert!(w.try_acquire_n(600).is_some());
assert!(w.try_acquire_n(600).is_none()); assert!(w.try_acquire_n(400).is_some()); assert_eq!(w.used(), 1000);
}
#[tokio::test]
async fn window_resets_after_duration() {
let w = Arc::new(WindowedRateLimiter::new(WindowPolicy::new(
2,
Duration::from_millis(50),
)));
assert!(w.try_acquire().is_some());
assert!(w.try_acquire().is_some());
assert!(w.try_acquire().is_none());
tokio::time::sleep(Duration::from_millis(70)).await;
assert!(w.try_acquire().is_some()); assert_eq!(w.used(), 1);
}
#[test]
fn window_time_until_is_reset_when_full() {
let w = Arc::new(WindowedRateLimiter::new(WindowPolicy::new(
1,
Duration::from_secs(10),
)));
assert!(w.try_acquire().is_some());
let t = w.time_until(1);
assert!(t > Duration::from_secs(1), "expected ~window, got {t:?}");
assert!(t <= Duration::from_secs(10));
}
#[test]
fn meter_trait_object_works_for_both_types() {
let bucket: Arc<dyn Meter> = Arc::new(RateLimiter::new(RateLimiterPolicy::per_second(5)));
let window: Arc<dyn Meter> = Arc::new(WindowedRateLimiter::new(WindowPolicy::new(
5,
Duration::from_secs(60),
)));
assert!(bucket.try_debit(1).is_some());
assert!(window.try_debit(1).is_some());
assert!(bucket.snapshot().resets_at.is_none()); assert!(window.snapshot().resets_at.is_some()); }
#[tokio::test]
async fn window_reservation_across_boundary_does_not_corrupt_new_window() {
let w = Arc::new(WindowedRateLimiter::new(WindowPolicy::new(
5,
Duration::from_millis(50),
)));
let r = w.try_reserve_n(3).expect("reserve in window 1"); tokio::time::sleep(Duration::from_millis(70)).await; assert!(w.try_acquire_n(2).is_some()); drop(r); assert_eq!(
w.used(),
2,
"stale reservation refund corrupted the new window"
);
assert!(w.try_acquire_n(3).is_some());
assert!(w.try_acquire_n(1).is_none());
}
#[test]
fn time_until_is_max_for_cost_over_capacity() {
let bucket = Arc::new(RateLimiter::new(
RateLimiterPolicy::per_second(1_000_000).with_max_tokens(5),
));
assert_eq!(bucket.time_until(6), Duration::MAX);
let window = Arc::new(WindowedRateLimiter::new(WindowPolicy::new(
5,
Duration::from_secs(60),
)));
assert_eq!(window.time_until(6), Duration::MAX);
}
#[tokio::test]
async fn acquire_n_over_capacity_parks_rather_than_busy_looping() {
let bucket = Arc::new(RateLimiter::new(
RateLimiterPolicy::per_second(1_000_000).with_max_tokens(5),
));
let res = tokio::time::timeout(Duration::from_millis(50), bucket.acquire_n(6)).await;
assert!(
res.is_err(),
"acquire_n(cost>capacity) should park, not return"
);
}
fn rl(tokens: u32) -> Arc<dyn Meter> {
Arc::new(RateLimiter::new(
RateLimiterPolicy::per_second(tokens).with_max_tokens(tokens),
))
}
fn win(limit: u64) -> Arc<dyn Meter> {
Arc::new(WindowedRateLimiter::new(WindowPolicy::new(
limit,
Duration::from_secs(60),
)))
}
#[test]
fn multi_admits_when_all_have_capacity() {
let m = MultiRateLimiter::new()
.with_tier("a", rl(5), 1)
.with_tier("b", win(5), 1);
assert!(m.try_acquire().is_ok());
}
#[test]
fn multi_rejects_and_reports_binding_tier() {
let m = MultiRateLimiter::new()
.with_tier("a", rl(100), 1)
.with_tier("b", win(1), 1);
assert!(m.try_acquire().is_ok());
match m.try_acquire().unwrap_err() {
CanoError::RateLimited { tier, retry_after } => {
assert_eq!(tier, "b");
assert!(retry_after > Duration::ZERO);
}
other => panic!("expected RateLimited, got {other:?}"),
}
}
#[test]
fn multi_no_leak_on_partial_rejection() {
let a = win(100);
let m = MultiRateLimiter::new()
.with_tier("a", Arc::clone(&a), 1)
.with_tier("b", win(1), 1);
assert!(m.try_acquire().is_ok()); let a_used_before = a.snapshot().used;
for _ in 0..10 {
assert!(m.try_acquire().is_err()); }
assert_eq!(a.snapshot().used, a_used_before, "tier a leaked budget");
}
#[test]
fn multi_weighted_mixed_tiers() {
let m = MultiRateLimiter::new()
.with_tier("requests", win(2), 1)
.with_tier("tokens", win(1000), 600);
assert!(m.try_acquire().is_ok()); assert!(
matches!(m.try_acquire().unwrap_err(), CanoError::RateLimited { tier, .. } if tier == "tokens")
);
}
#[test]
fn multi_per_request_tier_selection() {
let m = MultiRateLimiter::new()
.with_tier("shared", win(10), 1)
.with_tier("model_x", win(1), 1);
assert!(m.try_acquire_for(&["shared", "model_x"]).is_ok());
assert!(m.try_acquire_for(&["shared"]).is_ok());
assert!(m.try_acquire_for(&["shared", "model_x"]).is_err());
}
#[test]
fn multi_zero_cost_tier_is_inert() {
let m = MultiRateLimiter::new()
.with_tier("off", win(1), 0) .with_tier("on", win(2), 1);
for _ in 0..2 {
assert!(m.try_acquire().is_ok());
}
assert!(m.try_acquire().is_err()); }
#[tokio::test]
async fn multi_acquire_parks_then_admits() {
let start = std::time::Instant::now();
let b: Arc<dyn Meter> = Arc::new(RateLimiter::new(
RateLimiterPolicy::per_second(100).with_max_tokens(1),
));
let m = MultiRateLimiter::new()
.with_tier("a", rl(1000), 1)
.with_tier("b", b, 1);
assert!(m.try_acquire().is_ok()); let _p = m.acquire().await; assert!(start.elapsed() >= Duration::from_millis(5));
}
#[test]
fn multi_commit_makes_all_tiers_permanent() {
let a = win(2);
let b = win(2);
let m = MultiRateLimiter::new()
.with_tier("a", Arc::clone(&a), 1)
.with_tier("b", Arc::clone(&b), 1);
let permit = m.try_acquire().expect("ok");
drop(permit); assert_eq!(a.snapshot().used, 1);
assert_eq!(b.snapshot().used, 1);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn multi_concurrent_never_exceeds_scarcest() {
let scarce = win(5);
let m = Arc::new(
MultiRateLimiter::new()
.with_tier("a", rl(1000), 1)
.with_tier("scarce", Arc::clone(&scarce), 1),
);
let mut handles = Vec::new();
for _ in 0..24 {
let m = Arc::clone(&m);
handles.push(tokio::spawn(async move { m.try_acquire().is_ok() }));
}
let mut ok = 0;
for h in handles {
if h.await.unwrap() {
ok += 1;
}
}
assert_eq!(
ok, 5,
"exactly the scarcest tier's capacity should be admitted"
);
assert_eq!(scarce.snapshot().used, 5);
}
#[test]
#[should_panic(expected = "WindowPolicy::limit must be >= 1")]
fn window_rejects_zero_limit() {
let _ = WindowedRateLimiter::new(WindowPolicy::new(0, Duration::from_secs(1)));
}
#[test]
#[should_panic(expected = "WindowPolicy::window must be > 0")]
fn window_rejects_zero_window() {
let _ = WindowedRateLimiter::new(WindowPolicy::new(5, Duration::ZERO));
}
#[test]
fn window_policy_constructors_scale_durations() {
assert_eq!(
WindowPolicy::per_hours(100, 5).window,
Duration::from_secs(5 * 3600)
);
assert_eq!(
WindowPolicy::per_days(100, 7).window,
Duration::from_secs(7 * 86_400)
);
assert_eq!(WindowPolicy::per_hours(42, 5).limit, 42);
}
#[test]
fn bucket_cost_equal_to_capacity_succeeds_then_empty() {
let limiter = Arc::new(RateLimiter::new(
RateLimiterPolicy::new(5, Duration::from_secs(3600)).with_max_tokens(5),
));
assert!(limiter.try_acquire_n(5).is_some()); assert!(limiter.try_acquire_n(1).is_none());
}
#[test]
fn window_cost_equal_to_limit_succeeds_then_full() {
let w = Arc::new(WindowedRateLimiter::new(WindowPolicy::new(
5,
Duration::from_secs(60),
)));
assert!(w.try_acquire_n(5).is_some()); assert!(w.try_acquire_n(1).is_none());
}
#[test]
fn window_zero_cost_is_inert() {
let w = Arc::new(WindowedRateLimiter::new(WindowPolicy::new(
1,
Duration::from_secs(60),
)));
for _ in 0..100 {
assert!(w.try_acquire_n(0).is_some());
}
assert_eq!(w.used(), 0);
}
#[tokio::test]
async fn window_acquire_n_parks_until_reset() {
let start = std::time::Instant::now();
let w = Arc::new(WindowedRateLimiter::new(WindowPolicy::new(
1,
Duration::from_millis(100),
)));
assert!(w.try_acquire().is_some());
let _p = w.acquire().await;
assert!(
start.elapsed() >= Duration::from_millis(50),
"should park until the window reset, waited {:?}",
start.elapsed()
);
}
#[test]
fn window_reservation_within_window_refunds() {
let w = Arc::new(WindowedRateLimiter::new(WindowPolicy::new(
5,
Duration::from_secs(60),
)));
{
let _r = w.try_reserve_n(3).expect("reserve");
assert_eq!(w.used(), 3);
} assert_eq!(w.used(), 0);
}
#[test]
fn window_reservation_commit_keeps_usage() {
let w = Arc::new(WindowedRateLimiter::new(WindowPolicy::new(
5,
Duration::from_secs(60),
)));
{
let mut r = w.try_reserve_n(3).expect("reserve");
r.commit();
}
assert_eq!(w.used(), 3);
}
#[test]
fn bucket_snapshot_reports_no_reset_boundary() {
let limiter = Arc::new(RateLimiter::new(
RateLimiterPolicy::per_second(10).with_max_tokens(10),
));
let _ = limiter.try_acquire_n(4);
let s = limiter.snapshot();
assert_eq!(s.limit, 10);
assert_eq!(s.used, 4);
assert_eq!(s.available_in, Duration::ZERO); assert!(s.resets_at.is_none()); }
#[test]
fn window_snapshot_reports_reset_boundary() {
let w = Arc::new(WindowedRateLimiter::new(WindowPolicy::new(
10,
Duration::from_secs(60),
)));
let _ = w.try_acquire_n(4);
let s = w.snapshot();
assert_eq!(s.limit, 10);
assert_eq!(s.used, 4);
assert_eq!(s.available_in, Duration::ZERO); assert!(s.resets_at.is_some()); }
#[test]
fn exhausted_window_snapshot_available_in_is_until_reset() {
let w = Arc::new(WindowedRateLimiter::new(WindowPolicy::new(
1,
Duration::from_secs(10),
)));
assert!(w.try_acquire().is_some());
let s = w.snapshot();
assert_eq!(s.used, 1);
assert!(
s.available_in > Duration::from_secs(1),
"exhausted window should report ~window until free, got {:?}",
s.available_in
);
}
#[test]
fn cloning_rate_limiter_shares_the_bucket() {
let limiter = Arc::new(RateLimiter::new(
RateLimiterPolicy::new(5, Duration::from_secs(3600)).with_max_tokens(5),
));
let clone = Arc::new((*limiter).clone());
assert!(limiter.try_acquire_n(5).is_some()); assert!(clone.try_acquire().is_none()); }
#[test]
fn cloning_window_shares_state() {
let w = Arc::new(WindowedRateLimiter::new(WindowPolicy::new(
2,
Duration::from_secs(60),
)));
let clone = Arc::new((*w).clone());
assert!(w.try_acquire().is_some());
assert!(w.try_acquire().is_some());
assert!(clone.try_acquire().is_none()); }
#[test]
fn empty_multi_rate_limiter_admits() {
let m = MultiRateLimiter::new();
assert!(m.try_acquire().is_ok());
}
#[test]
fn multi_acquire_for_unknown_tier_enforces_nothing() {
let m = MultiRateLimiter::new().with_tier("a", win(1), 1);
assert!(m.try_acquire().is_ok());
assert!(m.try_acquire().is_err()); assert!(m.try_acquire_for(&["nonexistent"]).is_ok());
}
#[test]
fn multi_first_tier_rejection_leaves_later_tiers_untouched() {
let first = win(1);
let second = win(100);
let m = MultiRateLimiter::new()
.with_tier("first", Arc::clone(&first), 1)
.with_tier("second", Arc::clone(&second), 1);
assert!(m.try_acquire().is_ok()); let second_used = second.snapshot().used;
for _ in 0..5 {
match m.try_acquire().unwrap_err() {
CanoError::RateLimited { tier, .. } => assert_eq!(tier, "first"),
other => panic!("expected RateLimited(first), got {other:?}"),
}
}
assert_eq!(second.snapshot().used, second_used);
}
#[test]
fn same_meter_in_two_tiers_debits_once_per_tier() {
let shared = win(10);
let m = MultiRateLimiter::new()
.with_tier("a", Arc::clone(&shared), 1)
.with_tier("b", Arc::clone(&shared), 1);
assert!(m.try_acquire().is_ok());
assert_eq!(shared.snapshot().used, 2);
}
#[tokio::test]
async fn multi_acquire_parks_on_window_tier_reset() {
let start = std::time::Instant::now();
let w: Arc<dyn Meter> = Arc::new(WindowedRateLimiter::new(WindowPolicy::new(
1,
Duration::from_millis(100),
)));
let m = MultiRateLimiter::new().with_tier("w", w, 1);
assert!(m.try_acquire().is_ok()); let _p = m.acquire().await;
assert!(
start.elapsed() >= Duration::from_millis(50),
"should park until the window tier reset, waited {:?}",
start.elapsed()
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn bucket_concurrent_admits_exactly_capacity() {
let limiter = Arc::new(RateLimiter::new(
RateLimiterPolicy::new(10, Duration::from_secs(86_400)).with_max_tokens(10),
));
let mut handles = Vec::new();
for _ in 0..40 {
let l = Arc::clone(&limiter);
handles.push(tokio::spawn(async move { l.try_acquire().is_some() }));
}
let mut ok = 0;
for h in handles {
if h.await.unwrap() {
ok += 1;
}
}
assert_eq!(ok, 10);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn window_concurrent_admits_exactly_limit() {
let w = Arc::new(WindowedRateLimiter::new(WindowPolicy::new(
10,
Duration::from_secs(60),
)));
let mut handles = Vec::new();
for _ in 0..40 {
let w = Arc::clone(&w);
handles.push(tokio::spawn(async move { w.try_acquire().is_some() }));
}
let mut ok = 0;
for h in handles {
if h.await.unwrap() {
ok += 1;
}
}
assert_eq!(ok, 10);
assert_eq!(w.used(), 10);
}
#[test]
fn multi_tier_cost_over_meter_capacity_is_permanently_rejected() {
let m = MultiRateLimiter::new().with_tier("impossible", win(5), 10);
match m.try_acquire().unwrap_err() {
CanoError::RateLimited { tier, retry_after } => {
assert_eq!(tier, "impossible");
assert_eq!(retry_after, Duration::MAX);
}
other => panic!("expected RateLimited, got {other:?}"),
}
}
#[test]
fn huge_cost_rejects_without_panicking() {
let bucket = Arc::new(RateLimiter::new(RateLimiterPolicy::per_second(5)));
assert!(bucket.try_acquire_n(u64::MAX).is_none());
assert_eq!(bucket.time_until(u64::MAX), Duration::MAX);
let w = Arc::new(WindowedRateLimiter::new(WindowPolicy::new(
5,
Duration::from_secs(60),
)));
assert!(w.try_acquire_n(u64::MAX).is_none());
assert_eq!(w.time_until(u64::MAX), Duration::MAX);
}
}
#[cfg(all(test, feature = "metrics"))]
mod metrics_tests {
use super::*;
use crate::metrics::test_support::*;
use std::sync::Arc;
use std::time::Duration;
#[test]
fn try_acquire_records_acquired_and_rejected() {
let ((), rows) = run_with_recorder(|| async {
let limiter = Arc::new(RateLimiter::new(
RateLimiterPolicy::per_second(10).with_max_tokens(2),
));
assert!(limiter.try_acquire().is_some());
assert!(limiter.try_acquire().is_some());
assert!(limiter.try_acquire().is_none());
});
assert_eq!(counter(&rows, "cano_rate_limiter_acquired_total", &[]), 2);
assert_eq!(
counter(
&rows,
"cano_rate_limiter_throttled_total",
&[("result", "rejected")]
),
1
);
}
#[test]
fn acquire_waits_then_records_acquired_and_histogram() {
let ((), rows) = run_with_recorder(|| async {
let limiter = Arc::new(RateLimiter::new(
RateLimiterPolicy::per_second(100).with_max_tokens(1),
));
let _ = limiter.try_acquire().expect("first token is available");
let _ = limiter.acquire().await;
});
assert_eq!(counter(&rows, "cano_rate_limiter_acquired_total", &[]), 2);
assert!(
counter(
&rows,
"cano_rate_limiter_throttled_total",
&[("result", "waited")]
) >= 1
);
assert_eq!(
histogram_count(&rows, "cano_rate_limiter_wait_seconds", &[]),
1
);
}
#[test]
fn immediate_acquire_does_not_record_wait_histogram() {
let ((), rows) = run_with_recorder(|| async {
let limiter = Arc::new(RateLimiter::new(
RateLimiterPolicy::per_second(10).with_max_tokens(2),
));
let _ = limiter.acquire().await;
});
assert_eq!(counter(&rows, "cano_rate_limiter_acquired_total", &[]), 1);
assert_eq!(
histogram_count_opt(&rows, "cano_rate_limiter_wait_seconds", &[]),
None,
"wait_seconds must not record a sample when acquire did not park"
);
assert_eq!(
counter_opt(
&rows,
"cano_rate_limiter_throttled_total",
&[("result", "waited")]
),
None
);
}
#[test]
fn weighted_acquire_records_tokens_consumed() {
let ((), rows) = run_with_recorder(|| async {
let limiter = Arc::new(RateLimiter::new(
RateLimiterPolicy::per_second(100).with_max_tokens(100),
));
let _ = limiter.try_acquire_n(7);
let _ = limiter.try_acquire_n(3);
});
assert_eq!(
counter(&rows, "cano_rate_limiter_tokens_consumed_total", &[]),
10
);
assert_eq!(counter(&rows, "cano_rate_limiter_acquired_total", &[]), 2);
}
#[test]
fn multi_rejection_records_throttled_with_tier() {
let ((), rows) = run_with_recorder(|| async {
let a: Arc<dyn Meter> = Arc::new(RateLimiter::new(
RateLimiterPolicy::per_second(100).with_max_tokens(100),
));
let b: Arc<dyn Meter> = Arc::new(WindowedRateLimiter::new(WindowPolicy::new(
1,
Duration::from_secs(60),
)));
let m = MultiRateLimiter::new()
.with_tier("a", a, 1)
.with_tier("b", b, 1);
assert!(m.try_acquire().is_ok());
assert!(m.try_acquire().is_err());
});
assert_eq!(
counter(
&rows,
"cano_multi_rate_limiter_throttled_total",
&[("tier", "b")]
),
1
);
}
#[test]
fn window_acquire_records_acquired_tokens_and_rejected() {
let ((), rows) = run_with_recorder(|| async {
let w = Arc::new(WindowedRateLimiter::new(WindowPolicy::new(
2,
Duration::from_secs(60),
)));
assert!(w.try_acquire_n(2).is_some());
assert!(w.try_acquire().is_none()); });
assert_eq!(counter(&rows, "cano_rate_limiter_acquired_total", &[]), 1);
assert_eq!(
counter(&rows, "cano_rate_limiter_tokens_consumed_total", &[]),
2
);
assert_eq!(
counter(
&rows,
"cano_rate_limiter_throttled_total",
&[("result", "rejected")]
),
1
);
}
#[test]
fn multi_acquire_for_records_throttle_while_parking() {
let ((), rows) = run_with_recorder(|| async {
let b: Arc<dyn Meter> = Arc::new(RateLimiter::new(
RateLimiterPolicy::per_second(100).with_max_tokens(1),
));
let m = MultiRateLimiter::new().with_tier("b", b, 1);
assert!(m.try_acquire().is_ok()); let _p = m.acquire().await; });
assert!(
counter(
&rows,
"cano_multi_rate_limiter_throttled_total",
&[("tier", "b")]
) >= 1,
"acquire() should record a throttle for each blocked attempt while parking"
);
}
}