1use std::collections::HashMap;
29use std::net::IpAddr;
30use std::sync::Mutex;
31use std::time::{SystemTime, UNIX_EPOCH};
32
33use anyhow::{Context, Result};
34use tracing::warn;
35
36use crate::config::{parse_rate, RateLimitCfg};
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum StoreMode {
41 Local,
43 Memory,
45 Redis,
47}
48
49impl StoreMode {
50 pub fn parse(s: &str) -> Result<StoreMode> {
51 match s.trim().to_ascii_lowercase().as_str() {
52 "local" | "governor" | "" => Ok(StoreMode::Local),
53 "memory" | "in-memory" => Ok(StoreMode::Memory),
54 "redis" => Ok(StoreMode::Redis),
55 other => {
56 anyhow::bail!("invalid ratelimit.store {other:?} (expected local|memory|redis)")
57 }
58 }
59 }
60
61 pub fn is_distributed(self) -> bool {
64 matches!(self, StoreMode::Memory | StoreMode::Redis)
65 }
66}
67
68#[derive(Debug, Clone, Copy)]
72pub struct Gcra {
73 emission_interval: u64,
74 tolerance: u64,
75}
76
77impl Gcra {
78 pub fn from_rate(rate: &str, burst: u32) -> Result<Gcra> {
81 let (count, period) = parse_rate(rate)?;
82 anyhow::ensure!(count > 0, "rate count must be > 0 (got {rate:?})");
83 anyhow::ensure!(burst > 0, "burst must be > 0 (rate {rate:?})");
84 let period_us = period.as_micros() as u64;
85 let emission_interval = period_us / count as u64;
86 anyhow::ensure!(
87 emission_interval > 0,
88 "rate too high for a usable sub-microsecond interval: {rate:?}"
89 );
90 let tolerance = emission_interval.saturating_mul(burst as u64);
91 Ok(Gcra {
92 emission_interval,
93 tolerance,
94 })
95 }
96}
97
98fn gcra_admit(stored_tat: Option<u64>, now: u64, g: &Gcra) -> Option<u64> {
104 let tat = stored_tat.unwrap_or(now).max(now);
106 let new_tat = tat + g.emission_interval;
107 let allow_at = new_tat.saturating_sub(g.tolerance);
108 if now < allow_at {
109 None
110 } else {
111 Some(new_tat)
112 }
113}
114
115enum Store {
118 Memory(MemoryStore),
119 Redis(Box<RedisStore>),
120}
121
122impl Store {
123 async fn admit(&self, key: &str, g: &Gcra, now: u64) -> Result<bool> {
125 match self {
126 Store::Memory(s) => Ok(s.admit(key, g, now)),
127 Store::Redis(s) => s.admit(key, g, now).await,
128 }
129 }
130}
131
132#[derive(Default)]
136struct MemoryStore {
137 tats: Mutex<HashMap<String, u64>>,
138}
139
140impl MemoryStore {
141 fn admit(&self, key: &str, g: &Gcra, now: u64) -> bool {
142 let mut map = self.tats.lock().expect("limiter store mutex poisoned");
143 match gcra_admit(map.get(key).copied(), now, g) {
144 Some(new_tat) => {
145 map.insert(key.to_string(), new_tat);
146 true
147 }
148 None => false,
149 }
150 }
151}
152
153const GCRA_LUA: &str = r#"
158local tat = redis.call('GET', KEYS[1])
159local now = tonumber(ARGV[1])
160local interval = tonumber(ARGV[2])
161local tolerance = tonumber(ARGV[3])
162if tat == false then
163 tat = now
164else
165 tat = tonumber(tat)
166 if tat < now then tat = now end
167end
168local new_tat = tat + interval
169local allow_at = new_tat - tolerance
170if now < allow_at then
171 return 0
172end
173local ttl_ms = math.ceil((new_tat - now) / 1000)
174if ttl_ms < 1 then ttl_ms = 1 end
175redis.call('SET', KEYS[1], new_tat, 'PX', ttl_ms)
176return 1
177"#;
178
179struct RedisStore {
183 client: redis::Client,
184 conn: tokio::sync::OnceCell<redis::aio::ConnectionManager>,
185 script: redis::Script,
186}
187
188impl RedisStore {
189 fn new(url: &str) -> Result<RedisStore> {
190 anyhow::ensure!(
191 !url.trim().is_empty(),
192 "ratelimit.redis_url is required when ratelimit.store = \"redis\""
193 );
194 let client = redis::Client::open(url)
195 .with_context(|| format!("opening redis client for {url:?} (ratelimit.redis_url)"))?;
196 Ok(RedisStore {
197 client,
198 conn: tokio::sync::OnceCell::new(),
199 script: redis::Script::new(GCRA_LUA),
200 })
201 }
202
203 async fn admit(&self, key: &str, g: &Gcra, now: u64) -> Result<bool> {
204 let manager = self
205 .conn
206 .get_or_try_init(|| redis::aio::ConnectionManager::new(self.client.clone()))
207 .await
208 .context("connecting to redis rate-limit store")?;
209 let mut conn = manager.clone();
210 let admitted: i64 = self
211 .script
212 .key(key)
213 .arg(now)
214 .arg(g.emission_interval)
215 .arg(g.tolerance)
216 .invoke_async(&mut conn)
217 .await
218 .context("evaluating redis GCRA script")?;
219 Ok(admitted == 1)
220 }
221}
222
223#[derive(Debug, Clone, Copy, PartialEq, Eq)]
225pub enum Admit {
226 Allowed,
228 Limited(&'static str),
230 Error,
232}
233
234struct RouteGcra {
236 prefix: String,
237 gcra: Gcra,
238}
239
240pub struct DistributedLimiter {
244 store: Store,
245 key_prefix: String,
246 fail_open: bool,
247 global: Gcra,
248 routes: Vec<RouteGcra>,
249 per_key: Option<Gcra>,
250}
251
252impl DistributedLimiter {
253 pub fn build(rl: &RateLimitCfg, mode: StoreMode) -> Result<DistributedLimiter> {
257 let store = match mode {
258 StoreMode::Memory => Store::Memory(MemoryStore::default()),
259 StoreMode::Redis => Store::Redis(Box::new(RedisStore::new(&rl.redis_url)?)),
260 StoreMode::Local => {
261 anyhow::bail!("DistributedLimiter::build called for the local store")
262 }
263 };
264
265 let global = Gcra::from_rate(&rl.rate, rl.burst)?;
266 let mut routes = Vec::new();
267 for route in &rl.routes {
268 anyhow::ensure!(
269 !route.path.is_empty(),
270 "ratelimit.routes[].path must not be empty"
271 );
272 routes.push(RouteGcra {
273 prefix: route.path.clone(),
274 gcra: Gcra::from_rate(&route.rate, route.burst)?,
275 });
276 }
277 let per_key = if rl.per_key.enabled {
278 Some(Gcra::from_rate(&rl.per_key.rate, rl.per_key.burst)?)
279 } else {
280 None
281 };
282
283 Ok(DistributedLimiter {
284 store,
285 key_prefix: rl.redis_prefix.clone(),
286 fail_open: rl.fail_open,
287 global,
288 routes,
289 per_key,
290 })
291 }
292
293 pub async fn check_ip_route(&self, ip: IpAddr, path: &str) -> Admit {
296 let now = now_micros();
297 if let Some(route) = self
298 .routes
299 .iter()
300 .filter(|r| path.starts_with(&r.prefix))
301 .max_by_key(|r| r.prefix.len())
302 {
303 let key = format!("{}:route:{}:{}", self.key_prefix, route.prefix, ip);
304 self.admit(&key, &route.gcra, now, "route").await
305 } else {
306 let key = format!("{}:ip:{}", self.key_prefix, ip);
307 self.admit(&key, &self.global, now, "ip").await
308 }
309 }
310
311 pub async fn check_key(&self, principal: &str) -> Admit {
314 match &self.per_key {
315 Some(gcra) => {
316 let now = now_micros();
317 let key = format!("{}:key:{}", self.key_prefix, principal);
318 self.admit(&key, gcra, now, "key").await
319 }
320 None => Admit::Allowed,
321 }
322 }
323
324 async fn admit(&self, key: &str, g: &Gcra, now: u64, scope: &'static str) -> Admit {
325 match self.store.admit(key, g, now).await {
326 Ok(true) => Admit::Allowed,
327 Ok(false) => Admit::Limited(scope),
328 Err(e) => {
329 if self.fail_open {
330 warn!(error = %format!("{e:#}"), scope, "rate-limit store error; failing open (allowing request)");
331 Admit::Allowed
332 } else {
333 warn!(error = %format!("{e:#}"), scope, "rate-limit store error; failing closed (503)");
334 Admit::Error
335 }
336 }
337 }
338 }
339}
340
341fn now_micros() -> u64 {
343 SystemTime::now()
344 .duration_since(UNIX_EPOCH)
345 .map(|d| u64::try_from(d.as_micros()).unwrap_or(u64::MAX))
347 .unwrap_or(0)
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353 use crate::config::{PerKeyRateLimit, RouteRateLimit};
354
355 fn gcra(rate: &str, burst: u32) -> Gcra {
356 Gcra::from_rate(rate, burst).unwrap()
357 }
358
359 #[test]
360 fn store_mode_parses_and_classifies() {
361 assert_eq!(StoreMode::parse("local").unwrap(), StoreMode::Local);
362 assert_eq!(StoreMode::parse("").unwrap(), StoreMode::Local);
363 assert_eq!(StoreMode::parse("REDIS").unwrap(), StoreMode::Redis);
364 assert_eq!(StoreMode::parse(" memory ").unwrap(), StoreMode::Memory);
365 assert!(StoreMode::parse("dynamo").is_err());
366 assert!(!StoreMode::parse("local").unwrap().is_distributed());
367 assert!(StoreMode::parse("redis").unwrap().is_distributed());
368 assert!(StoreMode::parse("memory").unwrap().is_distributed());
369 }
370
371 #[test]
372 fn gcra_from_rate_rejects_degenerate_input() {
373 assert!(Gcra::from_rate("0/sec", 5).is_err()); assert!(Gcra::from_rate("10/sec", 0).is_err()); assert!(Gcra::from_rate("nonsense", 5).is_err());
376 }
377
378 #[test]
379 fn gcra_admit_allows_burst_then_rejects_at_same_instant() {
380 let g = gcra("1/sec", 3);
382 let now = 1_000_000_000;
383 let mut tat = None;
384 for _ in 0..3 {
385 let next = gcra_admit(tat, now, &g);
386 assert!(next.is_some(), "within-burst request should be admitted");
387 tat = next;
388 }
389 assert!(
390 gcra_admit(tat, now, &g).is_none(),
391 "the request past the burst must be rejected"
392 );
393 }
394
395 #[test]
396 fn gcra_admit_recovers_after_emission_interval() {
397 let g = gcra("1/sec", 1);
399 let t0 = 5_000_000_000;
400 let tat = gcra_admit(None, t0, &g).expect("first admitted");
401 assert!(
402 gcra_admit(Some(tat), t0, &g).is_none(),
403 "immediate second rejected"
404 );
405 assert!(
407 gcra_admit(Some(tat), t0 + 1_000_000, &g).is_some(),
408 "request after the interval admitted"
409 );
410 }
411
412 #[test]
413 fn gcra_admit_does_not_advance_tat_on_rejection() {
414 let g = gcra("1/min", 1);
415 let now = 2_000_000_000;
416 let tat = gcra_admit(None, now, &g).unwrap();
417 assert!(gcra_admit(Some(tat), now, &g).is_none());
420 assert!(gcra_admit(Some(tat), now, &g).is_none());
421 }
422
423 #[tokio::test]
424 async fn memory_store_enforces_global_limit() {
425 let rl = RateLimitCfg {
426 enabled: true,
427 rate: "1/min".into(),
428 burst: 1,
429 store: "memory".into(),
430 ..Default::default()
431 };
432 let limiter = DistributedLimiter::build(&rl, StoreMode::Memory).unwrap();
433 let ip: IpAddr = "203.0.113.7".parse().unwrap();
434
435 assert_eq!(limiter.check_ip_route(ip, "/").await, Admit::Allowed);
437 assert_eq!(limiter.check_ip_route(ip, "/").await, Admit::Limited("ip"));
438 let ip2: IpAddr = "203.0.113.8".parse().unwrap();
440 assert_eq!(limiter.check_ip_route(ip2, "/").await, Admit::Allowed);
441 }
442
443 #[tokio::test]
444 async fn memory_store_applies_per_route_override() {
445 let rl = RateLimitCfg {
446 enabled: true,
447 rate: "1000/min".into(), burst: 1000,
449 routes: vec![RouteRateLimit {
450 path: "/api/".into(),
451 rate: "1/min".into(),
452 burst: 1,
453 }],
454 store: "memory".into(),
455 ..Default::default()
456 };
457 let limiter = DistributedLimiter::build(&rl, StoreMode::Memory).unwrap();
458 let ip: IpAddr = "198.51.100.4".parse().unwrap();
459
460 assert_eq!(limiter.check_ip_route(ip, "/api/x").await, Admit::Allowed);
462 assert_eq!(
463 limiter.check_ip_route(ip, "/api/x").await,
464 Admit::Limited("route")
465 );
466 assert_eq!(limiter.check_ip_route(ip, "/public").await, Admit::Allowed);
467 }
468
469 #[tokio::test]
470 async fn memory_store_per_key_limit() {
471 let rl = RateLimitCfg {
472 enabled: true,
473 rate: "1000/min".into(),
474 burst: 1000,
475 per_key: PerKeyRateLimit {
476 enabled: true,
477 rate: "1/min".into(),
478 burst: 1,
479 },
480 store: "memory".into(),
481 ..Default::default()
482 };
483 let limiter = DistributedLimiter::build(&rl, StoreMode::Memory).unwrap();
484
485 assert_eq!(limiter.check_key("apikey:abc").await, Admit::Allowed);
486 assert_eq!(limiter.check_key("apikey:abc").await, Admit::Limited("key"));
487 assert_eq!(limiter.check_key("apikey:def").await, Admit::Allowed);
489 }
490
491 #[tokio::test]
492 async fn per_key_disabled_always_allows() {
493 let rl = RateLimitCfg {
494 enabled: true,
495 store: "memory".into(),
496 ..Default::default()
497 };
498 let limiter = DistributedLimiter::build(&rl, StoreMode::Memory).unwrap();
499 assert_eq!(limiter.check_key("whoever").await, Admit::Allowed);
500 }
501
502 #[test]
503 fn redis_store_requires_a_url() {
504 let rl = RateLimitCfg {
505 enabled: true,
506 store: "redis".into(),
507 redis_url: "".into(),
508 ..Default::default()
509 };
510 assert!(DistributedLimiter::build(&rl, StoreMode::Redis).is_err());
511 let bad = RateLimitCfg {
513 enabled: true,
514 store: "redis".into(),
515 redis_url: "not-a-redis-url".into(),
516 ..Default::default()
517 };
518 assert!(DistributedLimiter::build(&bad, StoreMode::Redis).is_err());
519 }
520}