use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Mutex;
use std::time::{SystemTime, UNIX_EPOCH};
use anyhow::{Context, Result};
use tracing::warn;
use crate::config::{parse_rate, RateLimitCfg};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StoreMode {
Local,
Memory,
Redis,
}
impl StoreMode {
pub fn parse(s: &str) -> Result<StoreMode> {
match s.trim().to_ascii_lowercase().as_str() {
"local" | "governor" | "" => Ok(StoreMode::Local),
"memory" | "in-memory" => Ok(StoreMode::Memory),
"redis" => Ok(StoreMode::Redis),
other => {
anyhow::bail!("invalid ratelimit.store {other:?} (expected local|memory|redis)")
}
}
}
pub fn is_distributed(self) -> bool {
matches!(self, StoreMode::Memory | StoreMode::Redis)
}
}
#[derive(Debug, Clone, Copy)]
pub struct Gcra {
emission_interval: u64,
tolerance: u64,
}
impl Gcra {
pub fn from_rate(rate: &str, burst: u32) -> Result<Gcra> {
let (count, period) = parse_rate(rate)?;
anyhow::ensure!(count > 0, "rate count must be > 0 (got {rate:?})");
anyhow::ensure!(burst > 0, "burst must be > 0 (rate {rate:?})");
let period_us = period.as_micros() as u64;
let emission_interval = period_us / count as u64;
anyhow::ensure!(
emission_interval > 0,
"rate too high for a usable sub-microsecond interval: {rate:?}"
);
let tolerance = emission_interval.saturating_mul(burst as u64);
Ok(Gcra {
emission_interval,
tolerance,
})
}
}
fn gcra_admit(stored_tat: Option<u64>, now: u64, g: &Gcra) -> Option<u64> {
let tat = stored_tat.unwrap_or(now).max(now);
let new_tat = tat + g.emission_interval;
let allow_at = new_tat.saturating_sub(g.tolerance);
if now < allow_at {
None
} else {
Some(new_tat)
}
}
enum Store {
Memory(MemoryStore),
Redis(Box<RedisStore>),
}
impl Store {
async fn admit(&self, key: &str, g: &Gcra, now: u64) -> Result<bool> {
match self {
Store::Memory(s) => Ok(s.admit(key, g, now)),
Store::Redis(s) => s.admit(key, g, now).await,
}
}
}
#[derive(Default)]
struct MemoryStore {
tats: Mutex<HashMap<String, u64>>,
}
impl MemoryStore {
fn admit(&self, key: &str, g: &Gcra, now: u64) -> bool {
let mut map = self.tats.lock().expect("limiter store mutex poisoned");
match gcra_admit(map.get(key).copied(), now, g) {
Some(new_tat) => {
map.insert(key.to_string(), new_tat);
true
}
None => false,
}
}
}
const GCRA_LUA: &str = r#"
local tat = redis.call('GET', KEYS[1])
local now = tonumber(ARGV[1])
local interval = tonumber(ARGV[2])
local tolerance = tonumber(ARGV[3])
if tat == false then
tat = now
else
tat = tonumber(tat)
if tat < now then tat = now end
end
local new_tat = tat + interval
local allow_at = new_tat - tolerance
if now < allow_at then
return 0
end
local ttl_ms = math.ceil((new_tat - now) / 1000)
if ttl_ms < 1 then ttl_ms = 1 end
redis.call('SET', KEYS[1], new_tat, 'PX', ttl_ms)
return 1
"#;
struct RedisStore {
client: redis::Client,
conn: tokio::sync::OnceCell<redis::aio::ConnectionManager>,
script: redis::Script,
}
impl RedisStore {
fn new(url: &str) -> Result<RedisStore> {
anyhow::ensure!(
!url.trim().is_empty(),
"ratelimit.redis_url is required when ratelimit.store = \"redis\""
);
let client = redis::Client::open(url)
.with_context(|| format!("opening redis client for {url:?} (ratelimit.redis_url)"))?;
Ok(RedisStore {
client,
conn: tokio::sync::OnceCell::new(),
script: redis::Script::new(GCRA_LUA),
})
}
async fn admit(&self, key: &str, g: &Gcra, now: u64) -> Result<bool> {
let manager = self
.conn
.get_or_try_init(|| redis::aio::ConnectionManager::new(self.client.clone()))
.await
.context("connecting to redis rate-limit store")?;
let mut conn = manager.clone();
let admitted: i64 = self
.script
.key(key)
.arg(now)
.arg(g.emission_interval)
.arg(g.tolerance)
.invoke_async(&mut conn)
.await
.context("evaluating redis GCRA script")?;
Ok(admitted == 1)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Admit {
Allowed,
Limited(&'static str),
Error,
}
struct RouteGcra {
prefix: String,
gcra: Gcra,
}
pub struct DistributedLimiter {
store: Store,
key_prefix: String,
fail_open: bool,
global: Gcra,
routes: Vec<RouteGcra>,
per_key: Option<Gcra>,
}
impl DistributedLimiter {
pub fn build(rl: &RateLimitCfg, mode: StoreMode) -> Result<DistributedLimiter> {
let store = match mode {
StoreMode::Memory => Store::Memory(MemoryStore::default()),
StoreMode::Redis => Store::Redis(Box::new(RedisStore::new(&rl.redis_url)?)),
StoreMode::Local => {
anyhow::bail!("DistributedLimiter::build called for the local store")
}
};
let global = Gcra::from_rate(&rl.rate, rl.burst)?;
let mut routes = Vec::new();
for route in &rl.routes {
anyhow::ensure!(
!route.path.is_empty(),
"ratelimit.routes[].path must not be empty"
);
routes.push(RouteGcra {
prefix: route.path.clone(),
gcra: Gcra::from_rate(&route.rate, route.burst)?,
});
}
let per_key = if rl.per_key.enabled {
Some(Gcra::from_rate(&rl.per_key.rate, rl.per_key.burst)?)
} else {
None
};
Ok(DistributedLimiter {
store,
key_prefix: rl.redis_prefix.clone(),
fail_open: rl.fail_open,
global,
routes,
per_key,
})
}
pub async fn check_ip_route(&self, ip: IpAddr, path: &str) -> Admit {
let now = now_micros();
if let Some(route) = self
.routes
.iter()
.filter(|r| path.starts_with(&r.prefix))
.max_by_key(|r| r.prefix.len())
{
let key = format!("{}:route:{}:{}", self.key_prefix, route.prefix, ip);
self.admit(&key, &route.gcra, now, "route").await
} else {
let key = format!("{}:ip:{}", self.key_prefix, ip);
self.admit(&key, &self.global, now, "ip").await
}
}
pub async fn check_key(&self, principal: &str) -> Admit {
match &self.per_key {
Some(gcra) => {
let now = now_micros();
let key = format!("{}:key:{}", self.key_prefix, principal);
self.admit(&key, gcra, now, "key").await
}
None => Admit::Allowed,
}
}
async fn admit(&self, key: &str, g: &Gcra, now: u64, scope: &'static str) -> Admit {
match self.store.admit(key, g, now).await {
Ok(true) => Admit::Allowed,
Ok(false) => Admit::Limited(scope),
Err(e) => {
if self.fail_open {
warn!(error = %format!("{e:#}"), scope, "rate-limit store error; failing open (allowing request)");
Admit::Allowed
} else {
warn!(error = %format!("{e:#}"), scope, "rate-limit store error; failing closed (503)");
Admit::Error
}
}
}
}
}
fn now_micros() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| u64::try_from(d.as_micros()).unwrap_or(u64::MAX))
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{PerKeyRateLimit, RouteRateLimit};
fn gcra(rate: &str, burst: u32) -> Gcra {
Gcra::from_rate(rate, burst).unwrap()
}
#[test]
fn store_mode_parses_and_classifies() {
assert_eq!(StoreMode::parse("local").unwrap(), StoreMode::Local);
assert_eq!(StoreMode::parse("").unwrap(), StoreMode::Local);
assert_eq!(StoreMode::parse("REDIS").unwrap(), StoreMode::Redis);
assert_eq!(StoreMode::parse(" memory ").unwrap(), StoreMode::Memory);
assert!(StoreMode::parse("dynamo").is_err());
assert!(!StoreMode::parse("local").unwrap().is_distributed());
assert!(StoreMode::parse("redis").unwrap().is_distributed());
assert!(StoreMode::parse("memory").unwrap().is_distributed());
}
#[test]
fn gcra_from_rate_rejects_degenerate_input() {
assert!(Gcra::from_rate("0/sec", 5).is_err()); assert!(Gcra::from_rate("10/sec", 0).is_err()); assert!(Gcra::from_rate("nonsense", 5).is_err());
}
#[test]
fn gcra_admit_allows_burst_then_rejects_at_same_instant() {
let g = gcra("1/sec", 3);
let now = 1_000_000_000;
let mut tat = None;
for _ in 0..3 {
let next = gcra_admit(tat, now, &g);
assert!(next.is_some(), "within-burst request should be admitted");
tat = next;
}
assert!(
gcra_admit(tat, now, &g).is_none(),
"the request past the burst must be rejected"
);
}
#[test]
fn gcra_admit_recovers_after_emission_interval() {
let g = gcra("1/sec", 1);
let t0 = 5_000_000_000;
let tat = gcra_admit(None, t0, &g).expect("first admitted");
assert!(
gcra_admit(Some(tat), t0, &g).is_none(),
"immediate second rejected"
);
assert!(
gcra_admit(Some(tat), t0 + 1_000_000, &g).is_some(),
"request after the interval admitted"
);
}
#[test]
fn gcra_admit_does_not_advance_tat_on_rejection() {
let g = gcra("1/min", 1);
let now = 2_000_000_000;
let tat = gcra_admit(None, now, &g).unwrap();
assert!(gcra_admit(Some(tat), now, &g).is_none());
assert!(gcra_admit(Some(tat), now, &g).is_none());
}
#[tokio::test]
async fn memory_store_enforces_global_limit() {
let rl = RateLimitCfg {
enabled: true,
rate: "1/min".into(),
burst: 1,
store: "memory".into(),
..Default::default()
};
let limiter = DistributedLimiter::build(&rl, StoreMode::Memory).unwrap();
let ip: IpAddr = "203.0.113.7".parse().unwrap();
assert_eq!(limiter.check_ip_route(ip, "/").await, Admit::Allowed);
assert_eq!(limiter.check_ip_route(ip, "/").await, Admit::Limited("ip"));
let ip2: IpAddr = "203.0.113.8".parse().unwrap();
assert_eq!(limiter.check_ip_route(ip2, "/").await, Admit::Allowed);
}
#[tokio::test]
async fn memory_store_applies_per_route_override() {
let rl = RateLimitCfg {
enabled: true,
rate: "1000/min".into(), burst: 1000,
routes: vec![RouteRateLimit {
path: "/api/".into(),
rate: "1/min".into(),
burst: 1,
}],
store: "memory".into(),
..Default::default()
};
let limiter = DistributedLimiter::build(&rl, StoreMode::Memory).unwrap();
let ip: IpAddr = "198.51.100.4".parse().unwrap();
assert_eq!(limiter.check_ip_route(ip, "/api/x").await, Admit::Allowed);
assert_eq!(
limiter.check_ip_route(ip, "/api/x").await,
Admit::Limited("route")
);
assert_eq!(limiter.check_ip_route(ip, "/public").await, Admit::Allowed);
}
#[tokio::test]
async fn memory_store_per_key_limit() {
let rl = RateLimitCfg {
enabled: true,
rate: "1000/min".into(),
burst: 1000,
per_key: PerKeyRateLimit {
enabled: true,
rate: "1/min".into(),
burst: 1,
},
store: "memory".into(),
..Default::default()
};
let limiter = DistributedLimiter::build(&rl, StoreMode::Memory).unwrap();
assert_eq!(limiter.check_key("apikey:abc").await, Admit::Allowed);
assert_eq!(limiter.check_key("apikey:abc").await, Admit::Limited("key"));
assert_eq!(limiter.check_key("apikey:def").await, Admit::Allowed);
}
#[tokio::test]
async fn per_key_disabled_always_allows() {
let rl = RateLimitCfg {
enabled: true,
store: "memory".into(),
..Default::default()
};
let limiter = DistributedLimiter::build(&rl, StoreMode::Memory).unwrap();
assert_eq!(limiter.check_key("whoever").await, Admit::Allowed);
}
#[test]
fn redis_store_requires_a_url() {
let rl = RateLimitCfg {
enabled: true,
store: "redis".into(),
redis_url: "".into(),
..Default::default()
};
assert!(DistributedLimiter::build(&rl, StoreMode::Redis).is_err());
let bad = RateLimitCfg {
enabled: true,
store: "redis".into(),
redis_url: "not-a-redis-url".into(),
..Default::default()
};
assert!(DistributedLimiter::build(&bad, StoreMode::Redis).is_err());
}
}