use std::net::IpAddr;
use futures::future::BoxFuture;
use crate::web::{Error, RequestContext};
pub enum RateDecision {
Allow {
remaining: u32,
},
Deny {
retry_after_secs: u32,
},
Unavailable,
}
pub trait RateLimitBackend: Send + Sync + 'static {
fn hit<'a>(&'a self, key: &'a str, max: u32, window_secs: u32) -> BoxFuture<'a, RateDecision>;
}
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum FailurePolicy {
FailOpen,
FailClosed,
}
pub struct DistributedRateLimit {
pub name: &'static str,
pub max: u32,
pub window_secs: u32,
pub policy: FailurePolicy,
}
impl DistributedRateLimit {
pub const fn new(name: &'static str, max: u32, window_secs: u32) -> Self {
Self {
name,
max,
window_secs,
policy: FailurePolicy::FailOpen,
}
}
pub const fn fail_closed(mut self) -> Self {
self.policy = FailurePolicy::FailClosed;
self
}
fn principal(ctx: &RequestContext) -> String {
if let Some(sub) = ctx
.claims()
.and_then(|c| c.get("sub"))
.and_then(|v| v.as_str())
{
return format!("sub:{sub}");
}
ctx.header("x-forwarded-for")
.and_then(|h| h.split(',').next())
.and_then(|s| s.trim().parse::<IpAddr>().ok())
.map(|ip| format!("ip:{ip}"))
.unwrap_or_else(|| "anon".to_owned())
}
pub async fn check(&self, ctx: &RequestContext) -> Result<(), Error> {
let Some(backend) = ctx.try_inject::<Box<dyn RateLimitBackend>>() else {
return Ok(());
};
let key = format!("rl:{}:{}", self.name, Self::principal(ctx));
match backend.hit(&key, self.max, self.window_secs).await {
RateDecision::Allow { .. } => Ok(()),
RateDecision::Deny { .. } => Err(Error::TooManyRequests),
RateDecision::Unavailable => match self.policy {
FailurePolicy::FailOpen => Ok(()),
FailurePolicy::FailClosed => {
Err(Error::ServiceUnavailable("rate limit backend unavailable"))
}
},
}
}
}