#[cfg(feature = "rate-limit")]
use std::time::Duration;
use async_trait::async_trait;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RateLimitSubject<'a> {
Ip(std::net::IpAddr),
WebId(&'a str),
Custom(&'a str),
}
#[cfg(feature = "rate-limit")]
impl RateLimitSubject<'_> {
fn canonical(&self) -> String {
match self {
RateLimitSubject::Ip(ip) => format!("ip:{ip}"),
RateLimitSubject::WebId(w) => format!("webid:{w}"),
RateLimitSubject::Custom(c) => format!("custom:{c}"),
}
}
}
#[derive(Debug, Clone)]
pub struct RateLimitKey<'a> {
pub route: &'a str,
pub subject: RateLimitSubject<'a>,
}
#[cfg(feature = "rate-limit")]
impl RateLimitKey<'_> {
fn canonical(&self) -> String {
format!("{}|{}", self.route, self.subject.canonical())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RateLimitDecision {
Allow,
Deny {
retry_after_secs: u64,
limit: u32,
window_secs: u64,
},
}
#[async_trait]
pub trait RateLimiter: Send + Sync + 'static {
async fn check(&self, key: &RateLimitKey<'_>) -> RateLimitDecision;
}
#[cfg(feature = "rate-limit")]
mod lru_impl {
use super::*;
use std::num::NonZeroUsize;
use std::time::Instant;
use lru::LruCache;
use parking_lot::Mutex;
pub const DEFAULT_LRU_CAPACITY: usize = 4096;
const DEFAULT_MAX: u32 = 60;
const DEFAULT_WINDOW: Duration = Duration::from_secs(60);
#[derive(Debug, Default)]
struct SlidingWindow {
hits: Vec<Instant>,
}
impl SlidingWindow {
fn prune(&mut self, now: Instant, window: Duration) {
let cutoff = now.checked_sub(window);
match cutoff {
Some(c) => self.hits.retain(|t| *t > c),
None => self.hits.clear(),
}
}
}
pub struct LruRateLimiter {
buckets: Mutex<LruCache<String, SlidingWindow>>,
policies: Vec<RoutePolicy>,
default_policy: RoutePolicy,
}
#[derive(Debug, Clone)]
struct RoutePolicy {
route: String,
max: u32,
window: Duration,
}
impl LruRateLimiter {
pub fn new() -> Self {
Self::with_capacity_and_policies(DEFAULT_LRU_CAPACITY, Vec::new())
}
pub fn with_policy(policies: Vec<(String, u32, Duration)>) -> Self {
Self::with_capacity_and_policies(DEFAULT_LRU_CAPACITY, policies)
}
pub fn with_capacity_and_policies(
capacity: usize,
policies: Vec<(String, u32, Duration)>,
) -> Self {
let capacity =
NonZeroUsize::new(capacity.max(1)).unwrap_or(NonZeroUsize::new(1).unwrap());
let policies = policies
.into_iter()
.map(|(route, max, window)| {
assert!(max > 0, "rate-limit max must be non-zero");
assert!(!window.is_zero(), "rate-limit window must be non-zero");
RoutePolicy {
route,
max,
window,
}
})
.collect();
Self {
buckets: Mutex::new(LruCache::new(capacity)),
policies,
default_policy: RoutePolicy {
route: String::new(),
max: DEFAULT_MAX,
window: DEFAULT_WINDOW,
},
}
}
fn policy_for(&self, route: &str) -> &RoutePolicy {
self.policies
.iter()
.find(|p| p.route == route)
.unwrap_or(&self.default_policy)
}
fn check_sync(&self, key: &RateLimitKey<'_>, now: Instant) -> RateLimitDecision {
let policy = self.policy_for(key.route);
let canonical = key.canonical();
let mut buckets = self.buckets.lock();
let bucket = buckets.get_or_insert_mut(canonical, SlidingWindow::default);
bucket.prune(now, policy.window);
let window_secs = policy.window.as_secs().max(1);
if bucket.hits.len() as u32 >= policy.max {
let oldest = bucket.hits.first().copied().unwrap_or(now);
let elapsed = now.saturating_duration_since(oldest);
let remaining = policy.window.saturating_sub(elapsed);
let retry_after_secs = ceil_secs(remaining).max(1);
return RateLimitDecision::Deny {
retry_after_secs,
limit: policy.max,
window_secs,
};
}
bucket.hits.push(now);
RateLimitDecision::Allow
}
}
impl Default for LruRateLimiter {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl RateLimiter for LruRateLimiter {
async fn check(&self, key: &RateLimitKey<'_>) -> RateLimitDecision {
self.check_sync(key, Instant::now())
}
}
fn ceil_secs(d: Duration) -> u64 {
let whole = d.as_secs();
if d.subsec_nanos() > 0 {
whole.saturating_add(1)
} else {
whole
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
fn ip() -> RateLimitSubject<'static> {
RateLimitSubject::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)))
}
#[test]
fn ceil_secs_rounds_up_fractional() {
assert_eq!(ceil_secs(Duration::from_millis(500)), 1);
assert_eq!(ceil_secs(Duration::from_secs(1)), 1);
assert_eq!(ceil_secs(Duration::from_millis(1500)), 2);
assert_eq!(ceil_secs(Duration::from_secs(0)), 0);
}
#[test]
fn default_policy_used_when_route_unknown() {
let limiter =
LruRateLimiter::with_policy(vec![("foo".into(), 1, Duration::from_secs(5))]);
let key = RateLimitKey {
route: "bar",
subject: ip(),
};
let now = Instant::now();
for _ in 0..60 {
assert_eq!(limiter.check_sync(&key, now), RateLimitDecision::Allow);
}
let d = limiter.check_sync(&key, now);
assert!(matches!(d, RateLimitDecision::Deny { .. }));
}
#[test]
fn canonical_keys_separate_subjects() {
let a = RateLimitKey {
route: "r",
subject: RateLimitSubject::Ip(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4))),
};
let b = RateLimitKey {
route: "r",
subject: RateLimitSubject::Ip(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 5))),
};
assert_ne!(a.canonical(), b.canonical());
}
#[test]
fn canonical_keys_separate_routes() {
let a = RateLimitKey {
route: "r1",
subject: ip(),
};
let b = RateLimitKey {
route: "r2",
subject: ip(),
};
assert_ne!(a.canonical(), b.canonical());
}
}
}
#[cfg(feature = "rate-limit")]
pub use lru_impl::{LruRateLimiter, DEFAULT_LRU_CAPACITY};