use std::{sync::Arc, time::Duration};
use crate::backend::MemoryBackend;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LimitResult {
Allowed {
remaining: u64,
reset_epoch: u64,
},
Exceeded {
retry_after_secs: u64,
},
}
impl LimitResult {
pub fn is_allowed(&self) -> bool {
matches!(self, Self::Allowed { .. })
}
pub fn is_exceeded(&self) -> bool {
matches!(self, Self::Exceeded { .. })
}
}
#[derive(Clone)]
pub struct Limiter {
backend: Arc<MemoryBackend>,
block_backend: Arc<MemoryBackend>,
}
impl Default for Limiter {
fn default() -> Self {
Self::memory()
}
}
impl Limiter {
pub fn memory() -> Self {
Self {
backend: Arc::new(MemoryBackend::new()),
block_backend: Arc::new(MemoryBackend::new()),
}
}
pub fn for_key(&self, key: impl Into<String>) -> LimitBuilder {
LimitBuilder {
limiter: self.clone(),
key: key.into(),
requests: 60,
window_secs: 60,
block_after: None,
}
}
}
pub struct LimitBuilder {
limiter: Limiter,
key: String,
requests: u64,
window_secs: u64,
block_after: Option<u64>,
}
impl LimitBuilder {
pub fn requests(mut self, n: u64) -> Self {
self.requests = n;
self
}
pub fn per(mut self, d: Duration) -> Self {
self.window_secs = d.as_secs().max(1);
self
}
pub fn block_in_memory_after(mut self, n: u64) -> Self {
self.block_after = Some(n);
self
}
pub fn check(self) -> LimitResult {
let now_epoch = epoch_secs();
if let Some(block_limit) = self.block_after {
let (block_count, reset_epoch) = self
.limiter
.block_backend
.increment(&format!("block:{}", self.key), self.window_secs);
if block_count > block_limit {
let retry_after_secs = reset_epoch.saturating_sub(now_epoch).max(1);
return LimitResult::Exceeded { retry_after_secs };
}
}
let (count, reset_epoch) = self.limiter.backend.increment(&self.key, self.window_secs);
if count > self.requests {
let retry_after_secs = reset_epoch.saturating_sub(now_epoch).max(1);
LimitResult::Exceeded { retry_after_secs }
} else {
LimitResult::Allowed {
remaining: self.requests.saturating_sub(count),
reset_epoch,
}
}
}
}
fn epoch_secs() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}