sockudo 2.8.0

A simple, fast, and secure WebSocket server for real-time applications.
Documentation
#![allow(unused_assignments)]
#![allow(unused_variables)]
#![allow(dead_code)]

// src/rate_limiter/redis_limiter.rs
use super::{RateLimitConfig, RateLimitResult, RateLimiter};
use crate::error::{Error, Result};
use async_trait::async_trait;
use redis::{AsyncCommands, Client};
use std::time::{Duration, SystemTime, UNIX_EPOCH};

/// Redis-based rate limiter implementation
pub struct RedisRateLimiter {
    /// Redis client
    client: Client,
    /// Redis connection with automatic reconnection
    connection: redis::aio::ConnectionManager,
    /// Prefix for Redis keys
    prefix: String,
    /// Configuration for rate limiting
    config: RateLimitConfig,
}

impl RedisRateLimiter {
    /// Create a new Redis-based rate limiter
    pub async fn new(
        client: Client,
        prefix: String,
        max_requests: u32,
        window_secs: u64,
    ) -> Result<Self> {
        // Create ConnectionManager with same config as RedisAdapter for consistency
        let connection_manager_config = redis::aio::ConnectionManagerConfig::new()
            .set_number_of_retries(5)
            .set_exponent_base(2)
            .set_factor(500)
            .set_max_delay(5000);

        let connection = client
            .get_connection_manager_with_config(connection_manager_config)
            .await
            .map_err(|e| Error::Redis(format!("Failed to connect to Redis: {e}")))?;

        let config = RateLimitConfig {
            max_requests,
            window_secs,
            identifier: Some("redis".to_string()),
        };

        Ok(Self {
            client,
            connection,
            prefix,
            config,
        })
    }

    /// Create a new Redis-based rate limiter with a specific configuration
    pub async fn with_config(
        client: Client,
        prefix: String,
        config: RateLimitConfig,
    ) -> Result<Self> {
        // Create ConnectionManager with same config as RedisAdapter for consistency
        let connection_manager_config = redis::aio::ConnectionManagerConfig::new()
            .set_number_of_retries(5)
            .set_exponent_base(2)
            .set_factor(500)
            .set_max_delay(5000);

        let connection = client
            .get_connection_manager_with_config(connection_manager_config)
            .await
            .map_err(|e| Error::Redis(format!("Failed to connect to Redis: {e}")))?;

        Ok(Self {
            client,
            connection,
            prefix,
            config,
        })
    }

    /// Get a key formatted with the prefix
    fn get_key(&self, key: &str) -> String {
        format!("{}:rl:{}", self.prefix, key)
    }

    /// Get the Unix timestamp for the current time
    fn get_current_time() -> u64 {
        SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .unwrap_or_else(|_| Duration::from_secs(0))
            .as_secs()
    }

    /// Run sliding window rate limiting using Redis
    /// This uses a sorted set with scores as timestamps
    async fn run_sliding_window_check(
        &self,
        key: &str,
        increment: bool,
    ) -> Result<RateLimitResult> {
        let redis_key = self.get_key(key);
        let now = Self::get_current_time();
        let window_start = now - self.config.window_secs;

        // Get a cloned connection
        let mut conn = self.connection.clone();

        // Remove all elements older than our window
        let _: () = conn
            .zrevrangebyscore(&redis_key, 0, window_start as i64)
            .await
            .map_err(|e| Error::Redis(format!("Failed to clean up Redis sorted set: {e}")))?;

        // Count current elements in the window
        let count: u32 = conn
            .zcard(&redis_key)
            .await
            .map_err(|e| Error::Redis(format!("Failed to count Redis sorted set: {e}")))?;

        // Set expiry on the key for automatic cleanup
        let _: () = conn
            .expire(&redis_key, self.config.window_secs as usize as i64)
            .await
            .map_err(|e| Error::Redis(format!("Failed to set expiry on Redis key: {e}")))?;

        let remaining = self.config.max_requests.saturating_sub(count);
        let allowed = remaining > 0;

        // If we should increment and we're allowed, add the current timestamp
        if increment && allowed {
            let _: () = conn
                .zadd(&redis_key, now, now)
                .await
                .map_err(|e| Error::Redis(format!("Failed to increment Redis counter: {e}")))?;

            // Recalculate remaining after increment
            let new_remaining = remaining.saturating_sub(1);

            return Ok(RateLimitResult {
                allowed,
                remaining: new_remaining,
                reset_after: self.config.window_secs,
                limit: self.config.max_requests,
            });
        }

        Ok(RateLimitResult {
            allowed,
            remaining,
            reset_after: self.config.window_secs,
            limit: self.config.max_requests,
        })
    }
}

#[async_trait]
impl RateLimiter for RedisRateLimiter {
    async fn check(&self, key: &str) -> Result<RateLimitResult> {
        self.run_sliding_window_check(key, false).await
    }

    async fn increment(&self, key: &str) -> Result<RateLimitResult> {
        self.run_sliding_window_check(key, true).await
    }

    async fn reset(&self, key: &str) -> Result<()> {
        let redis_key = self.get_key(key);
        let mut conn = self.connection.clone();

        let _: () = conn
            .del(&redis_key)
            .await
            .map_err(|e| Error::Redis(format!("Failed to delete Redis key: {e}")))?;

        Ok(())
    }

    async fn get_remaining(&self, key: &str) -> Result<u32> {
        let result = self.check(key).await?;
        Ok(result.remaining)
    }
}