gatekpr-rate-limiter 0.2.3

Reusable rate limiting with multiple backend support
Documentation
//! Rate limit state tracking
//!
//! Provides per-key rate limit tracking with sliding window counters.

use crate::config::RateLimitConfig;
use dashmap::DashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};

/// Result of a rate limit check
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RateLimitResult {
    /// Request is allowed
    Allowed {
        /// Remaining requests in the minute window
        remaining_minute: u32,
        /// Remaining requests in the hour window
        remaining_hour: u32,
    },
    /// Request is rate limited
    Exceeded {
        /// Seconds until the limit resets
        retry_after: u64,
        /// Which limit was exceeded ("minute" or "hour")
        limit_type: &'static str,
    },
}

impl RateLimitResult {
    /// Check if the request was allowed
    pub fn is_allowed(&self) -> bool {
        matches!(self, RateLimitResult::Allowed { .. })
    }

    /// Check if the request was rate limited
    pub fn is_exceeded(&self) -> bool {
        matches!(self, RateLimitResult::Exceeded { .. })
    }

    /// Get the retry-after value if rate limited
    pub fn retry_after(&self) -> Option<u64> {
        match self {
            RateLimitResult::Exceeded { retry_after, .. } => Some(*retry_after),
            _ => None,
        }
    }
}

/// Rate limit state for a single key
#[derive(Debug)]
pub struct KeyRateLimit {
    minute_count: u32,
    hour_count: u32,
    minute_reset: Instant,
    hour_reset: Instant,
}

impl KeyRateLimit {
    /// Create a new rate limit state
    pub fn new() -> Self {
        let now = Instant::now();
        Self {
            minute_count: 0,
            hour_count: 0,
            minute_reset: now + Duration::from_secs(60),
            hour_reset: now + Duration::from_secs(3600),
        }
    }

    /// Check the rate limit and increment counters if allowed
    ///
    /// Returns the result of the rate limit check.
    pub fn check_and_increment(&mut self, config: &RateLimitConfig) -> RateLimitResult {
        let now = Instant::now();

        // Reset minute counter if window has passed
        if now >= self.minute_reset {
            self.minute_count = 0;
            self.minute_reset = now + Duration::from_secs(60);
        }

        // Reset hour counter if window has passed
        if now >= self.hour_reset {
            self.hour_count = 0;
            self.hour_reset = now + Duration::from_secs(3600);
        }

        // Check minute limit
        if self.minute_count >= config.requests_per_minute {
            let retry_after = self.minute_reset.duration_since(now).as_secs().max(1);
            return RateLimitResult::Exceeded {
                retry_after,
                limit_type: "minute",
            };
        }

        // Check hour limit
        if self.hour_count >= config.requests_per_hour {
            let retry_after = self.hour_reset.duration_since(now).as_secs().max(1);
            return RateLimitResult::Exceeded {
                retry_after,
                limit_type: "hour",
            };
        }

        // Increment counters
        self.minute_count += 1;
        self.hour_count += 1;

        RateLimitResult::Allowed {
            remaining_minute: config.requests_per_minute - self.minute_count,
            remaining_hour: config.requests_per_hour - self.hour_count,
        }
    }

    /// Check if this rate limit state has expired (inactive for over an hour)
    pub fn is_expired(&self) -> bool {
        Instant::now() >= self.hour_reset
    }

    /// Get the current minute count
    pub fn minute_count(&self) -> u32 {
        self.minute_count
    }

    /// Get the current hour count
    pub fn hour_count(&self) -> u32 {
        self.hour_count
    }
}

impl Default for KeyRateLimit {
    fn default() -> Self {
        Self::new()
    }
}

/// Thread-safe rate limit store using DashMap
///
/// Provides lock-free concurrent access to rate limit states.
#[derive(Clone)]
pub struct RateLimitStore {
    state: Arc<DashMap<String, KeyRateLimit>>,
}

impl RateLimitStore {
    /// Create a new rate limit store
    pub fn new() -> Self {
        Self {
            state: Arc::new(DashMap::with_capacity(1000)),
        }
    }

    /// Create a store with pre-allocated capacity
    pub fn with_capacity(capacity: usize) -> Self {
        Self {
            state: Arc::new(DashMap::with_capacity(capacity)),
        }
    }

    /// Check and increment the rate limit for a key
    pub fn check(&self, key: &str, config: &RateLimitConfig) -> RateLimitResult {
        let mut entry = self.state.entry(key.to_string()).or_default();
        entry.check_and_increment(config)
    }

    /// Clean up expired entries
    ///
    /// Should be called periodically to prevent memory growth.
    pub fn cleanup_expired(&self) {
        self.state.retain(|_, limit| !limit.is_expired());
    }

    /// Get the number of tracked keys
    pub fn len(&self) -> usize {
        self.state.len()
    }

    /// Check if the store is empty
    pub fn is_empty(&self) -> bool {
        self.state.is_empty()
    }

    /// Remove a specific key from tracking
    pub fn remove(&self, key: &str) {
        self.state.remove(key);
    }

    /// Clear all tracked keys
    pub fn clear(&self) {
        self.state.clear();
    }

    /// Spawn a background task to periodically clean up expired entries
    ///
    /// Returns a join handle that can be used to abort the task if needed.
    /// The task logs cleanup activity at debug level.
    ///
    /// # Arguments
    /// * `interval` - How often to run cleanup (recommended: 1 hour)
    ///
    /// # Example
    /// ```ignore
    /// let store = Arc::new(RateLimitStore::new());
    /// let handle = store.clone().spawn_cleanup_task(Duration::from_secs(3600));
    /// // Later: handle.abort() to stop cleanup
    /// ```
    #[cfg(feature = "cleanup-task")]
    pub fn spawn_cleanup_task(
        self: Arc<Self>,
        interval: std::time::Duration,
    ) -> tokio::task::JoinHandle<()> {
        tokio::spawn(async move {
            let mut ticker = tokio::time::interval(interval);
            // Skip the immediate first tick
            ticker.tick().await;
            loop {
                ticker.tick().await;
                let before = self.len();
                self.cleanup_expired();
                let after = self.len();
                if before != after {
                    tracing::debug!(
                        before = before,
                        after = after,
                        removed = before - after,
                        "Rate limiter cleanup completed"
                    );
                }
            }
        })
    }
}

impl Default for RateLimitStore {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_key_rate_limit_new() {
        let limit = KeyRateLimit::new();
        assert_eq!(limit.minute_count, 0);
        assert_eq!(limit.hour_count, 0);
    }

    #[test]
    fn test_check_and_increment_allowed() {
        let mut limit = KeyRateLimit::new();
        let config = RateLimitConfig::for_plan("free");

        let result = limit.check_and_increment(&config);
        assert!(result.is_allowed());
        assert_eq!(limit.minute_count, 1);
        assert_eq!(limit.hour_count, 1);
    }

    #[test]
    fn test_check_and_increment_exceeded_minute() {
        let mut limit = KeyRateLimit::new();
        let config = RateLimitConfig::for_plan("free"); // 20/min

        // Use up all minute requests
        for _ in 0..20 {
            let result = limit.check_and_increment(&config);
            assert!(result.is_allowed());
        }

        // Next request should be rate limited
        let result = limit.check_and_increment(&config);
        assert!(result.is_exceeded());
        assert!(result.retry_after().unwrap() > 0);
    }

    #[test]
    fn test_rate_limit_result_methods() {
        let allowed = RateLimitResult::Allowed {
            remaining_minute: 10,
            remaining_hour: 100,
        };
        assert!(allowed.is_allowed());
        assert!(!allowed.is_exceeded());
        assert!(allowed.retry_after().is_none());

        let exceeded = RateLimitResult::Exceeded {
            retry_after: 30,
            limit_type: "minute",
        };
        assert!(!exceeded.is_allowed());
        assert!(exceeded.is_exceeded());
        assert_eq!(exceeded.retry_after(), Some(30));
    }

    #[test]
    fn test_store_basic() {
        let store = RateLimitStore::new();
        let config = RateLimitConfig::for_plan("free");

        let result = store.check("user1", &config);
        assert!(result.is_allowed());
        assert_eq!(store.len(), 1);

        let result = store.check("user2", &config);
        assert!(result.is_allowed());
        assert_eq!(store.len(), 2);
    }

    #[test]
    fn test_store_remove() {
        let store = RateLimitStore::new();
        let config = RateLimitConfig::for_plan("free");

        store.check("user1", &config);
        assert_eq!(store.len(), 1);

        store.remove("user1");
        assert_eq!(store.len(), 0);
    }

    #[test]
    fn test_store_clear() {
        let store = RateLimitStore::new();
        let config = RateLimitConfig::for_plan("free");

        store.check("user1", &config);
        store.check("user2", &config);
        assert_eq!(store.len(), 2);

        store.clear();
        assert!(store.is_empty());
    }
}