use clap::Parser;
use trillium_ratelimit::{Quota, RateLimiter};
#[derive(Parser, Debug, Clone, Copy)]
pub struct RateLimit {
#[arg(long = "rate-limit", value_name = "RATE", value_parser = parse_quota, verbatim_doc_comment, help_heading = "Rate limit")]
quota: Option<Quota>,
#[arg(
long = "rate-limit-burst",
requires = "quota",
help_heading = "Rate limit"
)]
burst: Option<u64>,
}
impl RateLimit {
#[cfg(any(feature = "serve", feature = "proxy"))]
pub fn limiter(self) -> Option<impl trillium::Handler> {
self.quota.map(|quota| {
let quota = match self.burst {
Some(burst) => quota.allow_burst(burst),
None => quota,
};
RateLimiter::by_network(quota)
})
}
}
#[cfg(feature = "gateway")]
pub(crate) fn limiter_for(
rate: &str,
burst: Option<u64>,
) -> Result<impl trillium::Handler, String> {
let quota = parse_quota(rate)?;
let quota = match burst {
Some(burst) => quota.allow_burst(burst),
None => quota,
};
Ok(RateLimiter::by_network(quota))
}
fn parse_quota(s: &str) -> Result<Quota, String> {
let (count, window) = s
.split_once('/')
.ok_or_else(|| format!("expected COUNT/WINDOW, e.g. 100/min (got {s:?})"))?;
let count = count
.trim()
.parse::<u64>()
.map_err(|_| format!("invalid request count {:?}", count.trim()))?;
match window.trim().to_ascii_lowercase().as_str() {
"s" | "sec" | "secs" | "second" | "seconds" => Ok(Quota::per_second(count)),
"m" | "min" | "mins" | "minute" | "minutes" => Ok(Quota::per_minute(count)),
"h" | "hr" | "hour" | "hours" => Ok(Quota::per_hour(count)),
other => Err(format!(
"unknown window {other:?}; use s, min, or h (e.g. 100/min)"
)),
}
}