rok-rate-limit 0.3.0

Rate limiting Tower middleware and programmatic Limiter API for the rok ecosystem
Documentation
#![cfg(feature = "axum")]

use axum::{
    body::Body,
    http::{Request, StatusCode},
    routing::get,
    Router,
};
use rok_rate_limit::{Limiter, ThrottleLayer, ThrottleRule};
use tower::ServiceExt;

// ── helpers ───────────────────────────────────────────────────────────────────

fn ok_handler() -> Router {
    Router::new().route("/", get(|| async { "ok" }))
}

// ── ThrottleLayer::global ────────────────────────────────────────────────────

#[tokio::test]
async fn global_layer_allows_within_limit() {
    // Use a unique limiter per test to avoid inter-test state pollution.
    let limiter = Limiter::memory();
    let layer = ThrottleLayer::new(limiter, vec![ThrottleRule::global("test:allow", 10, 60)]);

    let app = ok_handler().layer(layer);
    let resp = app
        .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
        .await
        .unwrap();

    assert_eq!(resp.status(), StatusCode::OK);
}

#[tokio::test]
async fn global_layer_returns_429_when_exceeded() {
    let limiter = Limiter::memory();
    // Limit of 1, then we make 2 requests.
    let layer = ThrottleLayer::new(
        limiter.clone(),
        vec![ThrottleRule::global("test:exceed-gl", 1, 60)],
    );

    // First request — allowed (consumes the 1 slot)
    let app = ok_handler().layer(layer.clone());
    let resp = app
        .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
        .await
        .unwrap();
    assert_eq!(resp.status(), StatusCode::OK);

    // Second request — exceeded
    let app = ok_handler().layer(layer);
    let resp = app
        .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
        .await
        .unwrap();
    assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
}

#[tokio::test]
async fn response_contains_ratelimit_headers_when_allowed() {
    let limiter = Limiter::memory();
    let layer = ThrottleLayer::new(limiter, vec![ThrottleRule::global("test:headers", 5, 60)]);

    let app = ok_handler().layer(layer);
    let resp = app
        .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
        .await
        .unwrap();

    assert_eq!(resp.status(), StatusCode::OK);
    assert!(
        resp.headers().contains_key("x-ratelimit-limit"),
        "should have x-ratelimit-limit"
    );
    assert!(
        resp.headers().contains_key("x-ratelimit-remaining"),
        "should have x-ratelimit-remaining"
    );
    assert!(
        resp.headers().contains_key("x-ratelimit-reset"),
        "should have x-ratelimit-reset"
    );
}

#[tokio::test]
async fn response_contains_retry_after_when_exceeded() {
    let limiter = Limiter::memory();
    // Exhaust the limiter before attaching it to the router
    for _ in 0..=1 {
        limiter
            .for_key("test:retry-hdr")
            .requests(1)
            .per(std::time::Duration::from_secs(60))
            .check();
    }
    let layer = ThrottleLayer::new(limiter, vec![ThrottleRule::global("test:retry-hdr", 1, 60)]);

    let app = ok_handler().layer(layer);
    let resp = app
        .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
        .await
        .unwrap();

    assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
    assert!(
        resp.headers().contains_key("retry-after"),
        "should have Retry-After header"
    );
}

// ── ThrottleLayer::by_ip ────────────────────────────────────────────────────

#[tokio::test]
async fn by_ip_layer_uses_forwarded_for_header() {
    let limiter = Limiter::memory();
    let layer = ThrottleLayer::new(limiter.clone(), vec![ThrottleRule::by_ip("test:ip", 1, 60)]);

    // First request from 1.1.1.1 — allowed
    let app = ok_handler().layer(layer.clone());
    let resp = app
        .oneshot(
            Request::builder()
                .uri("/")
                .header("x-forwarded-for", "1.1.1.1")
                .body(Body::empty())
                .unwrap(),
        )
        .await
        .unwrap();
    assert_eq!(resp.status(), StatusCode::OK);

    // First request from 2.2.2.2 — also allowed (different IP bucket)
    let app = ok_handler().layer(layer.clone());
    let resp = app
        .oneshot(
            Request::builder()
                .uri("/")
                .header("x-forwarded-for", "2.2.2.2")
                .body(Body::empty())
                .unwrap(),
        )
        .await
        .unwrap();
    assert_eq!(resp.status(), StatusCode::OK);

    // Second request from 1.1.1.1 — exceeded
    let app = ok_handler().layer(layer);
    let resp = app
        .oneshot(
            Request::builder()
                .uri("/")
                .header("x-forwarded-for", "1.1.1.1")
                .body(Body::empty())
                .unwrap(),
        )
        .await
        .unwrap();
    assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
}