rok-rate-limit 0.3.0

Rate limiting Tower middleware and programmatic Limiter API for the rok ecosystem
Documentation
use std::{
    future::Future,
    pin::Pin,
    sync::Arc,
    task::{Context, Poll},
};

use axum::{
    http::{header, HeaderValue, Request, Response, StatusCode},
    response::IntoResponse,
};
use tower::{Layer, Service};

use crate::limiter::{LimitResult, Limiter};

// ── ThrottleRule ──────────────────────────────────────────────────────────────

/// A single rate-limit rule applied by [`ThrottleLayer`].
pub struct ThrottleRule {
    /// Derives the per-request bucket key from request headers/extensions.
    pub key_fn: Arc<dyn Fn(&axum::http::request::Parts) -> String + Send + Sync>,
    /// Max requests per window.
    pub requests: u64,
    /// Window in seconds.
    pub window_secs: u64,
}

impl ThrottleRule {
    /// Build a rule that limits by a static key (global handler budget).
    pub fn global(key: impl Into<String>, requests: u64, window_secs: u64) -> Self {
        let key: Arc<str> = key.into().into();
        Self {
            key_fn: Arc::new(move |_| key.as_ref().to_string()),
            requests,
            window_secs,
        }
    }

    /// Build a rule that limits per remote IP (reads `X-Forwarded-For` first).
    pub fn by_ip(prefix: impl Into<String>, requests: u64, window_secs: u64) -> Self {
        let prefix: Arc<str> = prefix.into().into();
        Self {
            key_fn: Arc::new(move |parts| {
                let ip = parts
                    .headers
                    .get("x-forwarded-for")
                    .and_then(|v| v.to_str().ok())
                    .and_then(|s| s.split(',').next())
                    .map(|s| s.trim().to_string())
                    .unwrap_or_else(|| "unknown".to_string());
                format!("{}:ip:{}", prefix, ip)
            }),
            requests,
            window_secs,
        }
    }
}

// ── ThrottleLayer ─────────────────────────────────────────────────────────────

/// Tower [`Layer`] that applies one or more rate-limit rules to every request.
///
/// On the first exceeded rule the middleware immediately returns `429 Too Many
/// Requests` with `Retry-After`, `X-RateLimit-Limit`, and
/// `X-RateLimit-Remaining` headers.  Allowed requests pass through with the
/// rate-limit headers appended to the *response*.
///
/// # Example
///
/// ```rust,ignore
/// use rok_rate_limit::ThrottleLayer;
/// use std::time::Duration;
///
/// // 100 req/min globally for this router
/// let layer = ThrottleLayer::global("api", 100, 60);
///
/// let app = Router::new()
///     .route("/", get(handler))
///     .layer(layer);
/// ```
#[derive(Clone)]
pub struct ThrottleLayer {
    limiter: Arc<Limiter>,
    rules: Arc<Vec<ThrottleRule>>,
}

impl ThrottleLayer {
    /// One global rule using the built-in shared limiter.
    pub fn global(key: impl Into<String>, requests: u64, window_secs: u64) -> Self {
        Self::new(
            crate::global_limiter().clone(),
            vec![ThrottleRule::global(key, requests, window_secs)],
        )
    }

    /// One per-IP rule using the built-in shared limiter.
    pub fn by_ip(prefix: impl Into<String>, requests: u64, window_secs: u64) -> Self {
        Self::new(
            crate::global_limiter().clone(),
            vec![ThrottleRule::by_ip(prefix, requests, window_secs)],
        )
    }

    /// Full control: supply your own `Limiter` and a set of rules.
    pub fn new(limiter: Limiter, rules: Vec<ThrottleRule>) -> Self {
        Self {
            limiter: Arc::new(limiter),
            rules: Arc::new(rules),
        }
    }
}

impl<S> Layer<S> for ThrottleLayer {
    type Service = ThrottleService<S>;

    fn layer(&self, inner: S) -> Self::Service {
        ThrottleService {
            inner,
            limiter: Arc::clone(&self.limiter),
            rules: Arc::clone(&self.rules),
        }
    }
}

// ── ThrottleService ───────────────────────────────────────────────────────────

#[derive(Clone)]
pub struct ThrottleService<S> {
    inner: S,
    limiter: Arc<Limiter>,
    rules: Arc<Vec<ThrottleRule>>,
}

impl<S, B> Service<Request<B>> for ThrottleService<S>
where
    S: Service<Request<B>, Response = Response<axum::body::Body>> + Clone + Send + 'static,
    S::Future: Send + 'static,
    S::Error: Send + 'static,
    B: Send + 'static,
{
    type Response = Response<axum::body::Body>;
    type Error = S::Error;
    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, S::Error>> + Send>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: Request<B>) -> Self::Future {
        let limiter = Arc::clone(&self.limiter);
        let rules = Arc::clone(&self.rules);
        let mut inner = self.inner.clone();

        Box::pin(async move {
            let (parts, body) = req.into_parts();

            // Check all rules; fail on the first exceeded one.
            let mut last_allowed: Option<(u64, u64, u64)> = None; // (remaining, reset, limit)
            for rule in rules.iter() {
                let key = (rule.key_fn)(&parts);
                match limiter
                    .for_key(key)
                    .requests(rule.requests)
                    .per(std::time::Duration::from_secs(rule.window_secs))
                    .check()
                {
                    LimitResult::Exceeded { retry_after_secs } => {
                        return Ok(
                            rate_limit_exceeded(retry_after_secs, rule.requests).into_response()
                        );
                    }
                    LimitResult::Allowed {
                        remaining,
                        reset_epoch,
                    } => {
                        last_allowed = Some((remaining, reset_epoch, rule.requests));
                    }
                }
            }

            let req = Request::from_parts(parts, body);
            let mut resp = inner.call(req).await?;

            // Inject rate-limit headers from the last checked rule.
            if let Some((remaining, reset_epoch, limit)) = last_allowed {
                let headers = resp.headers_mut();
                set_header(headers, "x-ratelimit-limit", &limit.to_string());
                set_header(headers, "x-ratelimit-remaining", &remaining.to_string());
                set_header(headers, "x-ratelimit-reset", &reset_epoch.to_string());
            }

            Ok(resp)
        })
    }
}

// ── helpers ───────────────────────────────────────────────────────────────────

fn rate_limit_exceeded(retry_after_secs: u64, limit: u64) -> impl IntoResponse {
    let body = axum::Json(serde_json::json!({
        "error": "too_many_requests",
        "message": "rate limit exceeded",
        "retry_after": retry_after_secs,
    }));
    let mut resp = (StatusCode::TOO_MANY_REQUESTS, body).into_response();
    let headers = resp.headers_mut();
    set_header(
        headers,
        header::RETRY_AFTER.as_str(),
        &retry_after_secs.to_string(),
    );
    set_header(headers, "x-ratelimit-limit", &limit.to_string());
    set_header(headers, "x-ratelimit-remaining", "0");
    resp
}

fn set_header(headers: &mut axum::http::HeaderMap, name: &str, value: &str) {
    if let (Ok(name), Ok(val)) = (
        name.parse::<header::HeaderName>(),
        HeaderValue::from_str(value),
    ) {
        headers.insert(name, val);
    }
}