touch_ratelimit 0.1.0

A composable, extensible rate limiting crate for Rust
Documentation
//! Axum integration for `touch_ratelimit`.
//!
//! This module provides an Axum-compatible rate limiting layer that can be
//! attached to an `axum::Router` using `.layer(...)`.
//!
//! ## How it works
//!
//! - Each incoming request is identified by a key extracted from the request
//!   (by default, the `x-forwarded-for` header).
//! - The extracted key is checked against a [`RateLimitStore`].
//! - If the request exceeds the rate limit, the request is rejected with
//!   **HTTP 429 (Too Many Requests)**.
//!
//! ## Storage behavior
//!
//! This adapter does **not** store any data itself. All rate-limiting state
//! lives inside the provided [`RateLimitStore`] (e.g. `InMemoryStore`).
//!
//! ## Example
//!
//! ```rust,ignore
//! use axum::{routing::get, Router};
//! use std::sync::Arc;
//! use touch_ratelimit::{
//!     adapters::axum::axum_rate_limit_layer,
//!     storage::in_memory::InMemoryStore,
//!     bucket::token_bucket::TokenBucket,
//! };
//!
//! let store = Arc::new(InMemoryStore::new(|| TokenBucket::new(10.0, 1.0)));
//!
//! let _ = Router::new()
//!     .route("/", get(|| async { "hello" }))
//!     .layer(axum_rate_limit_layer(store));
//! ```

use crate::{middleware::KeyExtractor, storage::RateLimitStore};
use axum::{
    http::{Request, StatusCode},
    response::{IntoResponse, Response},
};
use std::{
    convert::Infallible,
    future::Future,
    pin::Pin,
    task::{Context, Poll},
};
use tower::Service;
/// Extracts a rate-limiting key from an Axum request.
///
/// This extractor uses the `x-forwarded-for` HTTP header to identify
/// the client. If the header is missing or invalid, rate limiting
/// is skipped for that request.
///
/// This is useful when running behind a reverse proxy or load balancer
/// that sets the `x-forwarded-for` header.
#[derive(Clone)]
pub struct AxumIpExtractor;

impl<B> KeyExtractor<Request<B>> for AxumIpExtractor {
    fn extract(&self, req: &Request<B>) -> Option<String> {
        req.headers()
            .get("x-forwarded-for")
            .and_then(|v| v.to_str().ok())
            .map(|s| s.to_string())
    }
}
/// Axum-compatible Tower service that enforces rate limiting
/// before forwarding requests to the inner service.
#[derive(Clone)]
pub struct AxumRateLimitService<S, Store> {
    inner: S,
    store: Store,
    extractor: AxumIpExtractor,
}

impl<S, Store, B> Service<Request<B>> for AxumRateLimitService<S, Store>
where
    S: Service<Request<B>, Response = Response, Error = Infallible> + Send,
    S::Future: Send + 'static,
    Store: RateLimitStore,
    B: Send + 'static,
{
    type Response = Response;
    type Error = Infallible;
    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: Request<B>) -> Self::Future {
        if let Some(key) = self.extractor.extract(&req) {
            if !self.store.allow(&key) {
                let response =
                    (StatusCode::TOO_MANY_REQUESTS, "rate limit exceeded").into_response();
                return Box::pin(async move { Ok(response) });
            }
        }

        let fut = self.inner.call(req);
        Box::pin(async move { fut.await })
    }
}
/// An Axum `tower::Layer` that applies rate limiting to incoming requests.
///
/// This layer can be attached to an `axum::Router` using `.layer(...)`.
/// Requests that exceed the configured rate limit are rejected with
/// **HTTP 429 (Too Many Requests)**.
#[derive(Clone)]
pub struct AxumRateLimitLayer<Store> {
    store: Store,
}

impl<S, Store: Clone> tower::Layer<S> for AxumRateLimitLayer<Store> {
    type Service = AxumRateLimitService<S, Store>;

    fn layer(&self, inner: S) -> Self::Service {
        AxumRateLimitService {
            inner,
            store: self.store.clone(),
            extractor: AxumIpExtractor,
        }
    }
}

/// Create an Axum rate-limiting layer using the provided store.
///
/// # Arguments
///
/// - `store`: A [`RateLimitStore`] implementation that holds
///   rate-limiting state (e.g. `InMemoryStore`).
///
/// # Behavior
///
/// - Requests are identified using the `x-forwarded-for` header.
/// - Requests that exceed the rate limit receive
///   **HTTP 429 (Too Many Requests)**.
///
/// # Example
///
/// ```rust
/// use std::sync::Arc;
/// use touch_ratelimit::{
///     adapters::axum::axum_rate_limit_layer,
///     storage::in_memory::InMemoryStore,
///     bucket::token_bucket::TokenBucket,
/// };
///
/// let store = Arc::new(InMemoryStore::new(|| TokenBucket::new(5.0, 1.0)));
/// let layer = axum_rate_limit_layer(store);
/// ```
pub fn axum_rate_limit_layer<S>(store: S) -> AxumRateLimitLayer<S>
where
    S: RateLimitStore + Clone,
{
    AxumRateLimitLayer { store }
}