use actix_web::{HttpRequest, HttpResponse};
use governor::clock::DefaultClock;
use governor::middleware::NoOpMiddleware;
use governor::state::keyed::DefaultKeyedStateStore;
use governor::{Quota, RateLimiter};
use serde_json::json;
use std::net::IpAddr;
use std::num::NonZeroU32;
use std::sync::Arc;
pub type KeyedStringRateLimiter =
RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock, NoOpMiddleware>;
pub fn build_keyed_limiter(per_second: u32, burst: u32) -> Arc<KeyedStringRateLimiter> {
let per_sec: NonZeroU32 = NonZeroU32::new(per_second.max(1)).expect("per_second");
let burst_nz: NonZeroU32 = NonZeroU32::new(burst.max(1)).expect("burst");
let quota: Quota = Quota::per_second(per_sec).allow_burst(burst_nz);
Arc::new(RateLimiter::keyed(quota))
}
pub fn rate_limit_client_key(req: &HttpRequest, trust_x_forwarded_for: bool) -> String {
if trust_x_forwarded_for {
if let Some(raw) = req
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
{
let first: &str = raw.split(',').next().map(str::trim).unwrap_or("");
if !first.is_empty() && looks_like_ip_or_host(first) {
return first.to_string();
}
}
}
req.connection_info()
.peer_addr()
.map(|addr| addr.to_string())
.unwrap_or_else(|| "unknown".to_string())
}
fn looks_like_ip_or_host(s: &str) -> bool {
s.parse::<IpAddr>().is_ok() || !s.chars().any(|c| c.is_whitespace())
}
pub fn check_keyed(limiter: &KeyedStringRateLimiter, key: &str) -> Result<(), HttpResponse> {
if limiter.check_key(&key.to_string()).is_err() {
return Err(HttpResponse::TooManyRequests()
.insert_header(("Retry-After", "1"))
.json(json!({
"error": "Rate limit exceeded",
"message": "Too many requests; try again later."
})));
}
Ok(())
}
pub fn check_inbound_optional(
limiter: &Option<Arc<KeyedStringRateLimiter>>,
trust_x_forwarded_for: bool,
req: &HttpRequest,
) -> Result<(), HttpResponse> {
let Some(lim) = limiter.as_ref() else {
return Ok(());
};
let key: String = rate_limit_client_key(req, trust_x_forwarded_for);
check_keyed(lim.as_ref(), &key)
}
pub fn check_outbound_supabase_optional(limiter: &Option<Arc<KeyedStringRateLimiter>>) -> bool {
let Some(lim) = limiter.as_ref() else {
return true;
};
lim.check_key(&"supabase_outbound".to_string()).is_ok()
}