use std::num::NonZeroU32;
use std::path::Path;
use std::sync::Arc;
use dashmap::DashMap;
use governor::clock::DefaultClock;
use governor::state::{InMemoryState, NotKeyed};
use governor::{Quota, RateLimiter};
use serde::Deserialize;
use crate::policy;
#[derive(Debug, Clone, Deserialize)]
pub struct Rule {
pub principal: String,
pub bucket: String,
pub rps: u32,
pub burst: u32,
}
#[derive(Clone)]
pub struct RateLimits {
rules: Arc<Vec<Rule>>,
limiters: Arc<DashMap<(usize, String, String), Arc<KeyLimiter>>>,
}
type KeyLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
impl RateLimits {
pub fn from_json_str(s: &str) -> Result<Self, String> {
let rules: Vec<Rule> =
serde_json::from_str(s).map_err(|e| format!("rate-limit JSON parse error: {e}"))?;
for r in &rules {
if r.rps == 0 || r.burst == 0 {
return Err(format!(
"rate-limit rule has rps=0 or burst=0 (would deny everything): {r:?}"
));
}
}
Ok(Self {
rules: Arc::new(rules),
limiters: Arc::new(DashMap::new()),
})
}
pub fn from_path(path: &Path) -> Result<Self, String> {
let txt = std::fs::read_to_string(path)
.map_err(|e| format!("failed to read {}: {e}", path.display()))?;
Self::from_json_str(&txt)
}
pub fn check(&self, principal_id: Option<&str>, bucket: &str) -> bool {
let principal = principal_id.unwrap_or("");
for (idx, rule) in self.rules.iter().enumerate() {
if !glob_match(&rule.principal, principal) {
continue;
}
if !glob_match(&rule.bucket, bucket) {
continue;
}
let key = (idx, principal.to_owned(), bucket.to_owned());
let limiter = self
.limiters
.entry(key)
.or_insert_with(|| {
let burst = NonZeroU32::new(rule.burst).expect("burst > 0 (validated)");
let rps = NonZeroU32::new(rule.rps).expect("rps > 0 (validated)");
let quota = Quota::per_second(rps).allow_burst(burst);
Arc::new(RateLimiter::direct(quota))
})
.clone();
return limiter.check().is_ok();
}
true
}
}
impl std::fmt::Debug for RateLimits {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RateLimits")
.field("rules", &self.rules.len())
.field("active_limiters", &self.limiters.len())
.finish()
}
}
pub type SharedRateLimits = Arc<RateLimits>;
fn glob_match(pattern: &str, s: &str) -> bool {
glob_match_bytes(pattern.as_bytes(), s.as_bytes())
}
fn glob_match_bytes(p: &[u8], s: &[u8]) -> bool {
let mut pi = 0;
let mut si = 0;
let mut star: Option<(usize, usize)> = None;
while si < s.len() {
if pi < p.len() && (p[pi] == b'?' || p[pi] == s[si]) {
pi += 1;
si += 1;
} else if pi < p.len() && p[pi] == b'*' {
star = Some((pi, si));
pi += 1;
} else if let Some((sp, ss)) = star {
pi = sp + 1;
si = ss + 1;
star = Some((sp, si));
} else {
return false;
}
}
while pi < p.len() && p[pi] == b'*' {
pi += 1;
}
pi == p.len()
}
#[allow(dead_code)]
fn _link() -> Option<policy::Effect> {
None
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
fn rl(s: &str) -> RateLimits {
RateLimits::from_json_str(s).expect("rate-limit parse")
}
#[test]
fn parse_rejects_zero_rps_or_burst() {
let err = RateLimits::from_json_str(
r#"[{"principal": "*", "bucket": "*", "rps": 0, "burst": 10}]"#,
)
.unwrap_err();
assert!(err.contains("rps=0"));
let err = RateLimits::from_json_str(
r#"[{"principal": "*", "bucket": "*", "rps": 1, "burst": 0}]"#,
)
.unwrap_err();
assert!(err.contains("burst=0"));
}
#[test]
fn match_principal_and_bucket_globs() {
let r = rl(r#"[
{"principal": "AKIA*", "bucket": "tenant-a-*", "rps": 1000, "burst": 1000},
{"principal": "*", "bucket": "*", "rps": 1, "burst": 1}
]"#);
assert!(r.check(Some("AKIATEST"), "tenant-a-foo"));
assert!(r.check(Some("anonymous"), "any"));
assert!(!r.check(Some("anonymous"), "any"));
}
#[test]
fn no_rule_means_no_limit() {
let r = rl(r#"[{"principal": "AKIATENANT", "bucket": "*", "rps": 1, "burst": 1}]"#);
for _ in 0..100 {
assert!(r.check(Some("AKIAOTHER"), "anything"));
}
}
#[test]
fn refill_after_wait() {
let r = rl(r#"[{"principal": "*", "bucket": "*", "rps": 100, "burst": 1}]"#);
assert!(r.check(None, "b"));
assert!(!r.check(None, "b"));
std::thread::sleep(Duration::from_millis(15)); assert!(r.check(None, "b"));
}
}