use crate::cache::Cache;
use crate::http::{HttpResponse, Request, Response};
use crate::middleware::{Middleware, Next};
use async_trait::async_trait;
use dashmap::DashMap;
use std::sync::{Arc, OnceLock};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
type LimiterFn = Arc<dyn Fn(&Request) -> Vec<Limit> + Send + Sync>;
fn limiter_registry() -> &'static DashMap<String, LimiterFn> {
static REGISTRY: OnceLock<DashMap<String, LimiterFn>> = OnceLock::new();
REGISTRY.get_or_init(DashMap::new)
}
pub struct Limit {
pub max_requests: u32,
pub window_seconds: u64,
key: Option<String>,
response_fn: Option<Arc<dyn Fn() -> HttpResponse + Send + Sync>>,
}
impl Limit {
pub fn per_second(max: u32) -> Self {
Self {
max_requests: max,
window_seconds: 1,
key: None,
response_fn: None,
}
}
pub fn per_minute(max: u32) -> Self {
Self {
max_requests: max,
window_seconds: 60,
key: None,
response_fn: None,
}
}
pub fn per_hour(max: u32) -> Self {
Self {
max_requests: max,
window_seconds: 3600,
key: None,
response_fn: None,
}
}
pub fn per_day(max: u32) -> Self {
Self {
max_requests: max,
window_seconds: 86400,
key: None,
response_fn: None,
}
}
pub fn by(mut self, key: impl Into<String>) -> Self {
self.key = Some(key.into());
self
}
pub fn response<F>(mut self, f: F) -> Self
where
F: Fn() -> HttpResponse + Send + Sync + 'static,
{
self.response_fn = Some(Arc::new(f));
self
}
}
pub enum LimiterResponse {
Single(Limit),
Multiple(Vec<Limit>),
}
impl From<Limit> for LimiterResponse {
fn from(limit: Limit) -> Self {
LimiterResponse::Single(limit)
}
}
impl From<Vec<Limit>> for LimiterResponse {
fn from(limits: Vec<Limit>) -> Self {
LimiterResponse::Multiple(limits)
}
}
impl LimiterResponse {
fn into_vec(self) -> Vec<Limit> {
match self {
LimiterResponse::Single(limit) => vec![limit],
LimiterResponse::Multiple(limits) => limits,
}
}
}
pub struct RateLimiter;
impl RateLimiter {
pub fn define<F, T>(name: &str, f: F)
where
F: Fn(&Request) -> T + Send + Sync + 'static,
T: Into<LimiterResponse>,
{
let wrapped: LimiterFn = Arc::new(move |req| {
let response: LimiterResponse = f(req).into();
response.into_vec()
});
limiter_registry().insert(name.to_string(), wrapped);
}
pub fn resolve(name: &str, req: &Request) -> Option<Vec<Limit>> {
limiter_registry().get(name).map(|f| f(req))
}
pub fn per_second(max: u32) -> Limit {
Limit::per_second(max)
}
pub fn per_minute(max: u32) -> Limit {
Limit::per_minute(max)
}
pub fn per_hour(max: u32) -> Limit {
Limit::per_hour(max)
}
pub fn per_day(max: u32) -> Limit {
Limit::per_day(max)
}
}
struct RateLimitResult {
allowed: bool,
limit: u32,
remaining: u32,
retry_after: u64,
}
fn get_client_ip(request: &Request) -> String {
request
.header("X-Forwarded-For")
.and_then(|s| s.split(',').next())
.map(|s| s.trim().to_string())
.or_else(|| request.header("X-Real-IP").map(|s| s.to_string()))
.unwrap_or_else(|| "unknown".to_string())
}
async fn check_rate_limit(limit: &Limit, name: &str, identifier: &str) -> RateLimitResult {
let now_secs = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let window_number = now_secs / limit.window_seconds;
let key = format!("rate_limit:{name}:{identifier}:{window_number}");
let count = match Cache::increment(&key, 1).await {
Ok(c) => c as u32,
Err(e) => {
eprintln!("[ferro] Rate limiter cache error (fail-open): {e}");
return RateLimitResult {
allowed: true,
limit: limit.max_requests,
remaining: limit.max_requests,
retry_after: limit.window_seconds,
};
}
};
if count == 1 {
let ttl = Duration::from_secs(limit.window_seconds + 1);
if let Err(e) = Cache::expire(&key, ttl).await {
eprintln!("[ferro] Rate limiter expire error: {e}");
}
}
let remaining = limit.max_requests.saturating_sub(count);
let retry_after = limit.window_seconds - (now_secs % limit.window_seconds);
RateLimitResult {
allowed: count <= limit.max_requests,
limit: limit.max_requests,
remaining,
retry_after,
}
}
fn add_rate_limit_headers(
response: HttpResponse,
limit: u32,
remaining: u32,
retry_after: u64,
) -> HttpResponse {
response
.header("X-RateLimit-Limit", limit.to_string())
.header("X-RateLimit-Remaining", remaining.to_string())
.header("X-RateLimit-Reset", retry_after.to_string())
}
pub struct Throttle {
name: Option<String>,
inline_limits: Vec<Limit>,
}
impl Throttle {
pub fn named(name: &str) -> Self {
Self {
name: Some(name.to_string()),
inline_limits: Vec::new(),
}
}
pub fn per_second(max: u32) -> Self {
Self {
name: None,
inline_limits: vec![Limit::per_second(max)],
}
}
pub fn per_minute(max: u32) -> Self {
Self {
name: None,
inline_limits: vec![Limit::per_minute(max)],
}
}
pub fn per_hour(max: u32) -> Self {
Self {
name: None,
inline_limits: vec![Limit::per_hour(max)],
}
}
pub fn per_day(max: u32) -> Self {
Self {
name: None,
inline_limits: vec![Limit::per_day(max)],
}
}
}
#[async_trait]
impl Middleware for Throttle {
async fn handle(&self, request: Request, next: Next) -> Response {
let (limiter_name, limits) = if let Some(ref name) = self.name {
match RateLimiter::resolve(name, &request) {
Some(limits) => (name.clone(), limits),
None => {
eprintln!(
"[ferro] Rate limiter '{name}' not registered (fail-open, allowing request)"
);
return next(request).await;
}
}
} else {
let limits: Vec<Limit> = self
.inline_limits
.iter()
.map(|l| Limit {
max_requests: l.max_requests,
window_seconds: l.window_seconds,
key: l.key.clone(),
response_fn: l.response_fn.clone(),
})
.collect();
("inline".to_string(), limits)
};
let client_ip = get_client_ip(&request);
let mut most_restrictive: Option<(
RateLimitResult,
Option<Arc<dyn Fn() -> HttpResponse + Send + Sync>>,
)> = None;
for limit in &limits {
let identifier = limit.key.as_deref().unwrap_or(&client_ip);
let result = check_rate_limit(limit, &limiter_name, identifier).await;
if !result.allowed {
let error_response = if let Some(ref response_fn) = limit.response_fn {
response_fn()
} else {
HttpResponse::json(serde_json::json!({
"error": "Too Many Requests",
"message": "Rate limit exceeded. Please try again later.",
"retry_after": result.retry_after
}))
.status(429)
};
let error_response =
add_rate_limit_headers(error_response, result.limit, 0, result.retry_after)
.header("Retry-After", result.retry_after.to_string());
return Err(error_response);
}
let is_more_restrictive = most_restrictive
.as_ref()
.map(|(prev, _)| result.remaining < prev.remaining)
.unwrap_or(true);
if is_more_restrictive {
most_restrictive = Some((result, limit.response_fn.clone()));
}
}
let response = next(request).await;
if let Some((result, _)) = most_restrictive {
match response {
Ok(http_response) => Ok(add_rate_limit_headers(
http_response,
result.limit,
result.remaining,
result.retry_after,
)),
Err(http_response) => Err(add_rate_limit_headers(
http_response,
result.limit,
result.remaining,
result.retry_after,
)),
}
} else {
response
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cache::{CacheStore, InMemoryCache};
use crate::container::App;
use serial_test::serial;
use std::sync::Arc;
fn setup_test_cache() {
App::bind::<dyn CacheStore>(Arc::new(InMemoryCache::new()));
}
async fn test_request() -> Request {
use hyper_util::rt::TokioIo;
use std::sync::Mutex;
use tokio::sync::oneshot;
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (tx, rx) = oneshot::channel();
let tx_holder = Arc::new(Mutex::new(Some(tx)));
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let io = TokioIo::new(stream);
let tx_holder = tx_holder.clone();
let service =
hyper::service::service_fn(move |req: hyper::Request<hyper::body::Incoming>| {
let tx_holder = tx_holder.clone();
async move {
if let Some(tx) = tx_holder.lock().unwrap().take() {
let _ = tx.send(Request::new(req));
}
Ok::<_, hyper::Error>(hyper::Response::new(http_body_util::Empty::<
bytes::Bytes,
>::new(
)))
}
});
hyper::server::conn::http1::Builder::new()
.serve_connection(io, service)
.await
.ok();
});
let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
let io = TokioIo::new(stream);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
tokio::spawn(async move {
conn.await.ok();
});
let req = hyper::Request::builder()
.uri("/test")
.body(http_body_util::Empty::<bytes::Bytes>::new())
.unwrap();
let _ = sender.send_request(req).await;
rx.await.unwrap()
}
#[test]
fn test_limit_per_minute() {
let limit = Limit::per_minute(60);
assert_eq!(limit.max_requests, 60);
assert_eq!(limit.window_seconds, 60);
assert!(limit.key.is_none());
assert!(limit.response_fn.is_none());
}
#[test]
fn test_limit_per_hour() {
let limit = Limit::per_hour(1000);
assert_eq!(limit.max_requests, 1000);
assert_eq!(limit.window_seconds, 3600);
}
#[test]
fn test_limit_per_second() {
let limit = Limit::per_second(10);
assert_eq!(limit.max_requests, 10);
assert_eq!(limit.window_seconds, 1);
}
#[test]
fn test_limit_per_day() {
let limit = Limit::per_day(10000);
assert_eq!(limit.max_requests, 10000);
assert_eq!(limit.window_seconds, 86400);
}
#[test]
fn test_limit_by_key() {
let limit = Limit::per_minute(60).by("user:1");
assert_eq!(limit.key, Some("user:1".to_string()));
}
#[test]
fn test_limit_response_factory() {
let limit = Limit::per_minute(60)
.response(|| HttpResponse::json(serde_json::json!({"error": "custom"})).status(429));
assert!(limit.response_fn.is_some());
}
#[tokio::test]
#[serial]
async fn test_define_and_resolve() {
limiter_registry().clear();
RateLimiter::define("test", |_req| Limit::per_minute(100));
let req = test_request().await;
let limits = RateLimiter::resolve("test", &req);
assert!(limits.is_some(), "defined limiter should resolve");
let limits = limits.unwrap();
assert_eq!(limits.len(), 1);
assert_eq!(limits[0].max_requests, 100);
assert_eq!(limits[0].window_seconds, 60);
}
#[tokio::test]
#[serial]
async fn test_resolve_undefined() {
limiter_registry().clear();
let req = test_request().await;
let result = RateLimiter::resolve("nonexistent", &req);
assert!(result.is_none(), "undefined limiter should resolve to None");
}
#[tokio::test]
#[serial]
async fn test_define_multiple_limits() {
limiter_registry().clear();
RateLimiter::define("login", |_req| {
vec![Limit::per_minute(500), Limit::per_minute(5).by("email")]
});
let req = test_request().await;
let limits = RateLimiter::resolve("login", &req).unwrap();
assert_eq!(limits.len(), 2);
assert_eq!(limits[0].max_requests, 500);
assert!(limits[0].key.is_none());
assert_eq!(limits[1].max_requests, 5);
assert_eq!(limits[1].key, Some("email".to_string()));
}
#[tokio::test]
#[serial]
async fn test_allows_within_limit() {
setup_test_cache();
let limit = Limit::per_minute(10);
for i in 1..=5 {
let result = check_rate_limit(&limit, "test_allow", "ip:127.0.0.1").await;
assert!(result.allowed, "request {i} should be allowed");
assert_eq!(result.remaining, 10 - i);
assert_eq!(result.limit, 10);
}
}
#[tokio::test]
#[serial]
async fn test_exceeds_limit() {
setup_test_cache();
let limit = Limit::per_minute(3);
for i in 1..=3 {
let result = check_rate_limit(&limit, "test_exceed", "ip:10.0.0.1").await;
assert!(result.allowed, "request {i} should be allowed");
}
let result = check_rate_limit(&limit, "test_exceed", "ip:10.0.0.1").await;
assert!(!result.allowed, "request 4 should be rate limited");
assert_eq!(result.remaining, 0);
}
#[tokio::test]
#[serial]
async fn test_separate_keys_independent() {
setup_test_cache();
let limit = Limit::per_minute(2);
for _ in 0..2 {
check_rate_limit(&limit, "test_sep", "key_a").await;
}
let result_a = check_rate_limit(&limit, "test_sep", "key_a").await;
assert!(!result_a.allowed, "key_a should be exhausted");
let result_b = check_rate_limit(&limit, "test_sep", "key_b").await;
assert!(result_b.allowed, "key_b should still be allowed");
assert_eq!(result_b.remaining, 1);
}
#[tokio::test]
#[serial]
async fn test_cache_failure_allows_request() {
let limit = Limit::per_minute(5);
let result = check_rate_limit(&limit, "failopen", "test").await;
assert!(result.allowed);
}
#[test]
fn test_throttle_per_minute() {
let throttle = Throttle::per_minute(60);
assert!(throttle.name.is_none());
assert_eq!(throttle.inline_limits.len(), 1);
assert_eq!(throttle.inline_limits[0].max_requests, 60);
assert_eq!(throttle.inline_limits[0].window_seconds, 60);
}
#[test]
fn test_throttle_per_second() {
let throttle = Throttle::per_second(10);
assert_eq!(throttle.inline_limits[0].max_requests, 10);
assert_eq!(throttle.inline_limits[0].window_seconds, 1);
}
#[test]
fn test_throttle_per_hour() {
let throttle = Throttle::per_hour(1000);
assert_eq!(throttle.inline_limits[0].max_requests, 1000);
assert_eq!(throttle.inline_limits[0].window_seconds, 3600);
}
#[test]
fn test_throttle_per_day() {
let throttle = Throttle::per_day(5000);
assert_eq!(throttle.inline_limits[0].max_requests, 5000);
assert_eq!(throttle.inline_limits[0].window_seconds, 86400);
}
#[test]
fn test_throttle_named() {
let throttle = Throttle::named("api");
assert_eq!(throttle.name, Some("api".to_string()));
assert!(throttle.inline_limits.is_empty());
}
#[test]
fn test_limiter_response_single() {
let response: LimiterResponse = Limit::per_minute(60).into();
let limits = response.into_vec();
assert_eq!(limits.len(), 1);
assert_eq!(limits[0].max_requests, 60);
}
#[test]
fn test_limiter_response_multiple() {
let response: LimiterResponse = vec![Limit::per_minute(60), Limit::per_hour(1000)].into();
let limits = response.into_vec();
assert_eq!(limits.len(), 2);
assert_eq!(limits[0].max_requests, 60);
assert_eq!(limits[1].max_requests, 1000);
}
}