gatekpr-rate-limiter 0.2.3

Reusable rate limiting with multiple backend support
Documentation
//! Axum middleware integration for rate limiting
//!
//! Provides Tower layer and service for integrating rate limiting
//! with Axum applications.
//!
//! # Example
//!
//! ```rust,ignore
//! use gatekpr_rate_limiter::axum_layer::RateLimitLayer;
//! use axum::{Router, routing::get};
//!
//! async fn handler() -> &'static str {
//!     "Hello, World!"
//! }
//!
//! let app = Router::new()
//!     .route("/", get(handler))
//!     .layer(RateLimitLayer::new());
//! ```

use crate::config::RateLimitConfig;
use crate::state::{RateLimitResult, RateLimitStore};
use axum::{
    body::Body,
    http::{Request, Response, StatusCode},
};
use std::{
    future::Future,
    pin::Pin,
    task::{Context, Poll},
};
use tower::{Layer, Service};

/// Rate limiting layer for Axum
///
/// This layer extracts the tenant/user key from request headers and
/// applies rate limiting based on the configured plan.
///
/// Headers used:
/// - `X-Tenant-ID`: Primary key for rate limiting
/// - `X-Forwarded-For`: Fallback if no tenant ID (uses first IP)
/// - `X-Tenant-Plan`: Plan tier for limit configuration (free, pro, enterprise)
#[derive(Clone)]
pub struct RateLimitLayer {
    store: RateLimitStore,
}

impl RateLimitLayer {
    /// Create a new rate limit layer
    pub fn new() -> Self {
        Self {
            store: RateLimitStore::new(),
        }
    }

    /// Create a layer with a custom store
    pub fn with_store(store: RateLimitStore) -> Self {
        Self { store }
    }

    /// Clean up expired entries
    ///
    /// Call this periodically to prevent memory growth.
    pub fn cleanup_expired(&self) {
        self.store.cleanup_expired();
    }

    /// Get the underlying store for direct access
    pub fn store(&self) -> &RateLimitStore {
        &self.store
    }
}

impl Default for RateLimitLayer {
    fn default() -> Self {
        Self::new()
    }
}

impl<S> Layer<S> for RateLimitLayer {
    type Service = RateLimitService<S>;

    fn layer(&self, inner: S) -> Self::Service {
        RateLimitService {
            inner,
            store: self.store.clone(),
        }
    }
}

/// Rate limiting service for Axum
#[derive(Clone)]
pub struct RateLimitService<S> {
    inner: S,
    store: RateLimitStore,
}

impl<S> RateLimitService<S> {
    /// Create a new rate limit service
    pub fn new(inner: S, store: RateLimitStore) -> Self {
        Self { inner, store }
    }
}

impl<S> Service<Request<Body>> for RateLimitService<S>
where
    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
    S::Future: Send,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: Request<Body>) -> Self::Future {
        // Extract tenant key from headers
        let tenant_key = extract_rate_limit_key(&req);

        // Get plan from header (typically set by auth middleware)
        let plan = req
            .headers()
            .get("X-Tenant-Plan")
            .and_then(|v| v.to_str().ok())
            .unwrap_or("free");

        let config = RateLimitConfig::for_plan(plan);

        // Check rate limit
        let result = self.store.check(&tenant_key, &config);

        match result {
            RateLimitResult::Exceeded {
                retry_after,
                limit_type,
            } => {
                tracing::warn!(
                    tenant = %tenant_key,
                    limit_type = %limit_type,
                    retry_after = %retry_after,
                    "Rate limit exceeded"
                );

                Box::pin(async move {
                    let response = Response::builder()
                        .status(StatusCode::TOO_MANY_REQUESTS)
                        .header("Retry-After", retry_after.to_string())
                        .header("X-RateLimit-Limit-Type", limit_type)
                        .header("Content-Type", "application/json")
                        .body(Body::from(format!(
                            r#"{{"error":"Rate limit exceeded","error_code":"RATE_LIMIT_EXCEEDED","retry_after":{},"limit_type":"{}"}}"#,
                            retry_after, limit_type
                        )))
                        .expect("Failed to build rate limit response");
                    Ok(response)
                })
            }
            RateLimitResult::Allowed {
                remaining_minute,
                remaining_hour,
            } => {
                let mut inner = self.inner.clone();
                Box::pin(async move {
                    let mut response = inner.call(req).await?;

                    // Add rate limit headers to response
                    let headers = response.headers_mut();
                    if let Ok(v) = remaining_minute.to_string().parse() {
                        headers.insert("X-RateLimit-Remaining-Minute", v);
                    }
                    if let Ok(v) = remaining_hour.to_string().parse() {
                        headers.insert("X-RateLimit-Remaining-Hour", v);
                    }

                    Ok(response)
                })
            }
        }
    }
}

/// Extract the rate limit key from a request
///
/// Priority:
/// 1. X-Tenant-ID header
/// 2. First IP from X-Forwarded-For header (validated format)
/// 3. "anonymous" as fallback
///
/// # Security Note
/// IP addresses are validated for format. Configure your reverse proxy to set
/// X-Forwarded-For correctly and strip client-provided values.
fn extract_rate_limit_key(req: &Request<Body>) -> String {
    req.headers()
        .get("X-Tenant-ID")
        .and_then(|v| v.to_str().ok())
        .map(String::from)
        .unwrap_or_else(|| {
            req.headers()
                .get("X-Forwarded-For")
                .and_then(|v| v.to_str().ok())
                .and_then(|s| s.split(',').next())
                .map(|s| s.trim())
                .and_then(validate_ip_key)
                .unwrap_or_else(|| "anonymous".to_string())
        })
}

/// Validate IP format and return as rate limit key
/// Returns None for invalid IPs to fall back to anonymous
#[inline]
fn validate_ip_key(ip: &str) -> Option<String> {
    use std::net::IpAddr;
    ip.parse::<IpAddr>().ok().map(|addr| addr.to_string())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_layer_creation() {
        let layer = RateLimitLayer::new();
        assert!(layer.store().is_empty());
    }

    #[test]
    fn test_layer_with_store() {
        let store = RateLimitStore::with_capacity(500);
        let layer = RateLimitLayer::with_store(store);
        assert!(layer.store().is_empty());
    }
}