athena_rs 3.22.1

Hyper performant polyglot Database driver
Documentation
//! Config-driven inbound rate limits (per client IP, optional X-Forwarded-For).

use actix_web::{HttpRequest, HttpResponse};
use governor::clock::DefaultClock;
use governor::middleware::NoOpMiddleware;
use governor::state::keyed::DefaultKeyedStateStore;
use governor::{Quota, RateLimiter};
use serde_json::json;
use std::net::IpAddr;
use std::num::NonZeroU32;
use std::sync::Arc;

/// Keyed limiter: one bucket per string key (typically client IP).
pub type KeyedStringRateLimiter =
    RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock, NoOpMiddleware>;

/// Build a keyed rate limiter: sustained `per_second` with optional burst capacity.
pub fn build_keyed_limiter(per_second: u32, burst: u32) -> Arc<KeyedStringRateLimiter> {
    let per_sec: NonZeroU32 = NonZeroU32::new(per_second.max(1)).expect("per_second");
    let burst_nz: NonZeroU32 = NonZeroU32::new(burst.max(1)).expect("burst");
    let quota: Quota = Quota::per_second(per_sec).allow_burst(burst_nz);
    Arc::new(RateLimiter::keyed(quota))
}

/// Returns the client key for the request.
pub fn rate_limit_client_key(req: &HttpRequest, trust_x_forwarded_for: bool) -> String {
    if trust_x_forwarded_for {
        if let Some(raw) = req
            .headers()
            .get("x-forwarded-for")
            .and_then(|v| v.to_str().ok())
        {
            let first: &str = raw.split(',').next().map(str::trim).unwrap_or("");
            if !first.is_empty() && looks_like_ip_or_host(first) {
                return first.to_string();
            }
        }
    }
    req.connection_info()
        .peer_addr()
        .map(|addr| addr.to_string())
        .unwrap_or_else(|| "unknown".to_string())
}

/// Returns true if the string looks like an IP address or host name.
fn looks_like_ip_or_host(s: &str) -> bool {
    s.parse::<IpAddr>().is_ok() || !s.chars().any(|c| c.is_whitespace())
}

/// Returns [`HttpResponse`] with status 429 when the key is over quota.
pub fn check_keyed(limiter: &KeyedStringRateLimiter, key: &str) -> Result<(), HttpResponse> {
    if limiter.check_key(&key.to_string()).is_err() {
        return Err(HttpResponse::TooManyRequests()
            .insert_header(("Retry-After", "1"))
            .json(json!({
                "error": "Rate limit exceeded",
                "message": "Too many requests; try again later."
            })));
    }
    Ok(())
}

/// When `limiter` is set, rate-limit by client IP (or `X-Forwarded-For` when trusted).
pub fn check_inbound_optional(
    limiter: &Option<Arc<KeyedStringRateLimiter>>,
    trust_x_forwarded_for: bool,
    req: &HttpRequest,
) -> Result<(), HttpResponse> {
    let Some(lim) = limiter.as_ref() else {
        return Ok(());
    };
    let key: String = rate_limit_client_key(req, trust_x_forwarded_for);
    check_keyed(lim.as_ref(), &key)
}

/// Best-effort global bucket for outbound Supabase HTTP from this process.
pub fn check_outbound_supabase_optional(limiter: &Option<Arc<KeyedStringRateLimiter>>) -> bool {
    let Some(lim) = limiter.as_ref() else {
        return true;
    };
    lim.check_key(&"supabase_outbound".to_string()).is_ok()
}