use super::{CellError, Rate, store::Store};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone)]
pub struct RateLimitResult {
pub limit: i64,
pub remaining: i64,
pub reset_after: Duration,
pub retry_after: Duration,
}
pub struct RateLimiter<S: Store> {
store: S,
}
impl<S: Store> RateLimiter<S> {
pub fn new(store: S) -> Self {
RateLimiter { store }
}
pub fn rate_limit(
&mut self,
key: &str,
max_burst: i64,
count_per_period: i64,
period: i64,
quantity: i64,
now: SystemTime,
) -> Result<(bool, RateLimitResult), CellError> {
if quantity < 0 {
return Err(CellError::NegativeQuantity(quantity));
}
if max_burst <= 0 || count_per_period <= 0 || period <= 0 {
return Err(CellError::InvalidRateLimit);
}
let rate = Rate::from_count_and_period(count_per_period, period);
let emission_interval = rate.period();
let delay_variation_tolerance = emission_interval * (max_burst - 1) as u32;
let limit = max_burst;
let now_ns = match now.duration_since(UNIX_EPOCH) {
Ok(duration) => duration.as_nanos() as i64,
Err(e) => {
match SystemTime::now().duration_since(UNIX_EPOCH) {
Ok(current) => {
let period_ns = (period as u64).saturating_mul(1_000_000_000);
current.as_nanos().saturating_sub(period_ns as u128) as i64
}
Err(_) => {
return Err(CellError::Internal(format!("System time error: {e}")));
}
}
}
};
const MAX_RETRIES: u32 = 10;
let mut retries = 0;
loop {
let tat_val = self.store.get(key, now).map_err(CellError::Internal)?;
let emission_interval_ns = emission_interval.as_nanos() as i64;
let delay_variation_tolerance_ns = delay_variation_tolerance.as_nanos() as i64;
let tat = if let Some(stored_tat) = tat_val {
let min_tat = now_ns.saturating_sub(delay_variation_tolerance_ns);
stored_tat.max(min_tat)
} else {
now_ns.saturating_sub(emission_interval_ns)
};
let increment = emission_interval_ns.saturating_mul(quantity);
let new_tat = tat.saturating_add(increment);
let allow_at = new_tat.saturating_sub(delay_variation_tolerance_ns);
let allowed = now_ns >= allow_at;
if allowed {
let ttl = Duration::from_nanos(
new_tat
.saturating_sub(now_ns)
.saturating_add(delay_variation_tolerance_ns) as u64,
);
let success = if let Some(old_tat) = tat_val {
self.store
.compare_and_swap_with_ttl(key, old_tat, new_tat, ttl, now)
.map_err(CellError::Internal)?
} else {
self.store
.set_if_not_exists_with_ttl(key, new_tat, ttl, now)
.map_err(CellError::Internal)?
};
if !success {
retries += 1;
if retries >= MAX_RETRIES {
return Err(CellError::Internal("Max retries exceeded".into()));
}
continue;
}
}
let current_tat = if allowed { new_tat } else { tat };
let burst_limit = now_ns + delay_variation_tolerance_ns;
let room_until_limit = burst_limit.saturating_sub(current_tat);
let remaining = if emission_interval_ns > 0 {
(room_until_limit / emission_interval_ns).max(0)
} else {
0
};
let reset_after = Duration::from_nanos(
current_tat
.saturating_sub(now_ns)
.saturating_add(delay_variation_tolerance_ns)
.max(0) as u64,
);
let retry_after = if allowed {
Duration::ZERO
} else {
Duration::from_nanos(allow_at.saturating_sub(now_ns).max(0) as u64)
};
return Ok((
allowed,
RateLimitResult {
limit,
remaining,
reset_after,
retry_after,
},
));
}
}
}