reasonkit-web 0.1.7

High-performance MCP server for browser automation, web capture, and content extraction. Rust-powered CDP client for AI agents.
Documentation
//! Rate limiting middleware for API keys in the ReasonKit portal.
//!
//! This module provides a distributed rate limiting system using a sliding window
//! algorithm with Redis backend (falling back to in-memory when Redis is unavailable).
//!
//! # Features
//!
//! - Sliding window rate limiting per API key
//! - Distributed rate limiting via Redis
//! - Automatic fallback to in-memory cache
//! - Standard HTTP rate limit headers
//! - Configurable limits per API key
//!
//! # Example
//!
//! ```rust,no_run
//! use axum::Router;
//! use reasonkit_web::portal::rate_limiting::{RateLimiter, RateLimitConfig};
//!
//! #[tokio::main]
//! async fn main() {
//!     let config = RateLimitConfig {
//!         default_limit: 60,
//!         window_size_secs: 60,
//!         redis_url: Some("redis://localhost:6379".to_string()),
//!     };
//!
//!     let rate_limiter = RateLimiter::new(config).await;
//!
//!     let app = Router::new()
//!         .layer(axum::middleware::from_fn_with_state(
//!             rate_limiter.clone(),
//!             require_rate_limit
//!         ));
//! }
//! ```

#[cfg(feature = "portal")]
use axum::{
    extract::{Request, State},
    http::{HeaderMap, HeaderValue, StatusCode},
    middleware::Next,
    response::{IntoResponse, Response},
};

#[cfg(feature = "portal")]
use serde::{Deserialize, Serialize};

#[cfg(feature = "portal")]
use std::{
    collections::HashMap,
    sync::Arc,
    time::{SystemTime, UNIX_EPOCH},
};

#[cfg(feature = "portal")]
use tokio::sync::RwLock;

#[cfg(feature = "portal")]
use tracing::{debug, error, warn};

/// Configuration for the rate limiter.
#[cfg(feature = "portal")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
    /// Default number of requests allowed per minute if not specified in API key record.
    pub default_limit: u32,

    /// Size of the sliding window in seconds.
    pub window_size_secs: u64,

    /// Optional Redis connection URL for distributed rate limiting.
    /// If None, uses in-memory cache only.
    pub redis_url: Option<String>,
}

#[cfg(feature = "portal")]
impl Default for RateLimitConfig {
    fn default() -> Self {
        Self {
            default_limit: 60,
            window_size_secs: 60,
            redis_url: None,
        }
    }
}

/// Tracks request timestamps for a single API key.
#[cfg(feature = "portal")]
#[derive(Debug, Clone, Default)]
struct RateLimitEntry {
    /// Timestamps of requests within the current window.
    timestamps: Vec<u64>,
}

#[cfg(feature = "portal")]
impl RateLimitEntry {
    /// Remove timestamps outside the current window.
    fn prune(&mut self, window_start: u64) {
        self.timestamps.retain(|&ts| ts >= window_start);
    }

    /// Get the number of requests in the current window.
    fn count(&self) -> usize {
        self.timestamps.len()
    }

    /// Add a new request timestamp.
    fn add(&mut self, timestamp: u64) {
        self.timestamps.push(timestamp);
    }
}

/// Backend storage for rate limit data.
#[cfg(feature = "portal")]
#[derive(Clone)]
enum RateLimitBackend {
    /// In-memory backend for single-instance deployments.
    Memory(Arc<RwLock<HashMap<String, RateLimitEntry>>>),
}

/// Rate limiter that enforces per-API-key request limits.
#[cfg(feature = "portal")]
#[derive(Clone)]
pub struct RateLimiter {
    config: RateLimitConfig,
    backend: RateLimitBackend,
}

#[cfg(feature = "portal")]
impl RateLimiter {
    /// Create a new rate limiter with the given configuration.
    ///
    /// Uses in-memory storage. For Redis support, add redis crate dependency.
    pub async fn new(config: RateLimitConfig) -> Self {
        let backend = if config.redis_url.is_some() {
            warn!("Redis URL configured but redis feature not enabled. Using in-memory cache.");
            RateLimitBackend::Memory(Arc::new(RwLock::new(HashMap::new())))
        } else {
            debug!("Using in-memory rate limiting.");
            RateLimitBackend::Memory(Arc::new(RwLock::new(HashMap::new())))
        };

        Self { config, backend }
    }

    /// Check if a request is allowed for the given API key.
    ///
    /// Returns `Ok(remaining)` if allowed, `Err(retry_after)` if rate limit exceeded.
    pub async fn check_rate_limit(&self, api_key: &str, limit: u32) -> Result<u32, u64> {
        let now = Self::current_timestamp();
        let window_start = now.saturating_sub(self.config.window_size_secs);

        match &self.backend {
            RateLimitBackend::Memory(cache) => {
                self.check_memory(cache.clone(), api_key, limit, now, window_start)
                    .await
            }
        }
    }

    /// Check rate limit using in-memory backend.
    async fn check_memory(
        &self,
        cache: Arc<RwLock<HashMap<String, RateLimitEntry>>>,
        api_key: &str,
        limit: u32,
        now: u64,
        window_start: u64,
    ) -> Result<u32, u64> {
        let mut cache = cache.write().await;
        let entry = cache.entry(api_key.to_string()).or_default();

        // Prune old timestamps
        entry.prune(window_start);

        let count = entry.count();

        if count >= limit as usize {
            let retry_after = self.config.window_size_secs;
            debug!(
                "Rate limit exceeded for API key {}. Retry after {} seconds.",
                api_key, retry_after
            );
            return Err(retry_after);
        }

        // Add current request
        entry.add(now);

        Ok(limit.saturating_sub(count as u32 + 1))
    }

    /// Get the current Unix timestamp in seconds.
    fn current_timestamp() -> u64 {
        SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .unwrap_or_default()
            .as_secs()
    }

    /// Get the configured default limit.
    pub fn default_limit(&self) -> u32 {
        self.config.default_limit
    }

    /// Get the window size in seconds.
    pub fn window_size_secs(&self) -> u64 {
        self.config.window_size_secs
    }
}

/// Rate limit error response
#[cfg(feature = "portal")]
#[derive(Debug)]
pub struct RateLimitError {
    pub status: StatusCode,
    pub headers: HeaderMap,
    pub message: String,
}

#[cfg(feature = "portal")]
impl IntoResponse for RateLimitError {
    fn into_response(self) -> Response {
        let body = serde_json::json!({ "error": self.message });
        let mut response = (self.status, axum::Json(body)).into_response();
        response.headers_mut().extend(self.headers);
        response
    }
}

/// Middleware function that enforces rate limiting for API keys.
///
/// # Headers
///
/// - **Request**: Requires `X-API-Key` header
/// - **Response**: Includes `X-RateLimit-Limit`, `X-RateLimit-Remaining`, `X-RateLimit-Reset`
/// - **Response (429)**: Includes `Retry-After` header
///
/// # Errors
///
/// Returns `401 Unauthorized` if `X-API-Key` header is missing.
/// Returns `429 Too Many Requests` if rate limit is exceeded.
#[cfg(feature = "portal")]
pub async fn require_rate_limit(
    State(rate_limiter): State<RateLimiter>,
    headers: HeaderMap,
    request: Request,
    next: Next,
) -> Result<Response, RateLimitError> {
    // Extract API key from header
    let api_key = headers
        .get("X-API-Key")
        .and_then(|v| v.to_str().ok())
        .ok_or_else(|| RateLimitError {
            status: StatusCode::UNAUTHORIZED,
            headers: HeaderMap::new(),
            message: "Missing or invalid X-API-Key header".to_string(),
        })?;

    // TODO: Fetch the actual rate limit from the database for this API key
    // For now, use the default limit from config
    // In production, this should query the api_keys table to get rate_limit_per_minute
    let limit = rate_limiter.default_limit();

    // Check rate limit
    match rate_limiter.check_rate_limit(api_key, limit).await {
        Ok(remaining) => {
            // Rate limit not exceeded, proceed with request
            let mut response = next.run(request).await;

            // Add rate limit headers
            let headers = response.headers_mut();
            headers.insert(
                "X-RateLimit-Limit",
                HeaderValue::from_str(&limit.to_string()).unwrap(),
            );
            headers.insert(
                "X-RateLimit-Remaining",
                HeaderValue::from_str(&remaining.to_string()).unwrap(),
            );

            let reset_time = RateLimiter::current_timestamp() + rate_limiter.window_size_secs();
            headers.insert(
                "X-RateLimit-Reset",
                HeaderValue::from_str(&reset_time.to_string()).unwrap(),
            );

            Ok(response)
        }
        Err(retry_after) => {
            // Rate limit exceeded
            let mut headers = HeaderMap::new();
            headers.insert(
                "Retry-After",
                HeaderValue::from_str(&retry_after.to_string()).unwrap(),
            );
            headers.insert(
                "X-RateLimit-Limit",
                HeaderValue::from_str(&limit.to_string()).unwrap(),
            );
            headers.insert("X-RateLimit-Remaining", HeaderValue::from_str("0").unwrap());

            let reset_time = RateLimiter::current_timestamp() + retry_after;
            headers.insert(
                "X-RateLimit-Reset",
                HeaderValue::from_str(&reset_time.to_string()).unwrap(),
            );

            Err(RateLimitError {
                status: StatusCode::TOO_MANY_REQUESTS,
                headers,
                message: "Rate limit exceeded. Please try again later.".to_string(),
            })
        }
    }
}

#[cfg(all(test, feature = "portal"))]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_memory_backend_basic() {
        let config = RateLimitConfig {
            default_limit: 5,
            window_size_secs: 60,
            redis_url: None,
        };

        let limiter = RateLimiter::new(config).await;

        // Should allow first 5 requests
        for i in 0..5 {
            let result = limiter.check_rate_limit("test_key", 5).await;
            assert!(result.is_ok(), "Request {} should be allowed", i + 1);
        }

        // 6th request should be denied
        let result = limiter.check_rate_limit("test_key", 5).await;
        assert!(result.is_err(), "6th request should be rate limited");
    }

    #[tokio::test]
    async fn test_different_keys_independent() {
        let config = RateLimitConfig {
            default_limit: 2,
            window_size_secs: 60,
            redis_url: None,
        };

        let limiter = RateLimiter::new(config).await;

        // Key 1: use up limit
        assert!(limiter.check_rate_limit("key1", 2).await.is_ok());
        assert!(limiter.check_rate_limit("key1", 2).await.is_ok());
        assert!(limiter.check_rate_limit("key1", 2).await.is_err());

        // Key 2: should still work
        assert!(limiter.check_rate_limit("key2", 2).await.is_ok());
        assert!(limiter.check_rate_limit("key2", 2).await.is_ok());
    }

    #[tokio::test]
    async fn test_window_pruning() {
        let config = RateLimitConfig {
            default_limit: 2,
            window_size_secs: 1, // 1 second window for testing
            redis_url: None,
        };

        let limiter = RateLimiter::new(config).await;

        // Fill the limit
        assert!(limiter.check_rate_limit("test_key", 2).await.is_ok());
        assert!(limiter.check_rate_limit("test_key", 2).await.is_ok());
        assert!(limiter.check_rate_limit("test_key", 2).await.is_err());

        // Wait for window to expire
        tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;

        // Should be allowed again
        assert!(limiter.check_rate_limit("test_key", 2).await.is_ok());
    }
}