mockforge-registry-server 0.3.127

Plugin registry server for MockForge
Documentation
//! Redis connection and utilities for rate limiting and caching

use anyhow::Result;
use chrono::Datelike;
use std::sync::Arc;

// Import redis crate types
use redis::{aio::ConnectionManager, AsyncCommands, Client};

/// Redis connection wrapper
#[derive(Clone)]
pub struct RedisPool {
    manager: Arc<ConnectionManager>,
}

impl RedisPool {
    /// Create a new Redis connection pool
    pub async fn connect(redis_url: &str) -> Result<Self> {
        let client = Client::open(redis_url)?;
        let manager = ConnectionManager::new(client).await?;

        Ok(Self {
            manager: Arc::new(manager),
        })
    }

    /// Get a connection for async operations
    /// Note: ConnectionManager is already cloneable, so we can use it directly
    pub fn get_connection(&self) -> Arc<ConnectionManager> {
        self.manager.clone()
    }

    /// Increment a counter with expiration
    /// Returns the new count after increment
    pub async fn increment_with_expiry(&self, key: &str, expiry_seconds: u64) -> Result<i64> {
        // ConnectionManager is already async-safe, we can use it directly
        let mut conn = (*self.manager).clone();

        // Use Redis pipeline for atomic increment + expiry
        let count: i64 = conn.incr(key, 1).await?;

        // Set expiry on first increment (count == 1)
        if count == 1 {
            conn.expire::<_, ()>(key, expiry_seconds as i64).await?;
        }

        Ok(count)
    }

    /// Get a counter value
    pub async fn get_counter(&self, key: &str) -> Result<i64> {
        let mut conn = (*self.manager).clone();
        let count: i64 = conn.get(key).await.unwrap_or(0);
        Ok(count)
    }

    /// Set a key with expiration
    pub async fn set_with_expiry(&self, key: &str, value: &str, expiry_seconds: u64) -> Result<()> {
        let mut conn = (*self.manager).clone();
        conn.set_ex::<_, _, ()>(key, value, expiry_seconds).await?;
        Ok(())
    }

    /// Get a key value
    pub async fn get(&self, key: &str) -> Result<Option<String>> {
        let mut conn = (*self.manager).clone();
        let value: Option<String> = conn.get(key).await?;
        Ok(value)
    }

    /// Delete a key
    pub async fn delete(&self, key: &str) -> Result<()> {
        let mut conn = (*self.manager).clone();
        conn.del::<_, ()>(key).await?;
        Ok(())
    }

    /// Scan for keys matching a glob pattern using Redis SCAN
    pub async fn scan_keys(&self, pattern: &str) -> Result<Vec<String>> {
        let mut conn = (*self.manager).clone();
        let mut cursor: u64 = 0;
        let mut keys = Vec::new();

        loop {
            let (next_cursor, batch): (u64, Vec<String>) = redis::cmd("SCAN")
                .arg(cursor)
                .arg("MATCH")
                .arg(pattern)
                .arg("COUNT")
                .arg(100)
                .query_async(&mut conn)
                .await?;

            keys.extend(batch);
            cursor = next_cursor;

            if cursor == 0 {
                break;
            }
        }

        Ok(keys)
    }

    /// Health check - verify Redis connectivity
    pub async fn ping(&self) -> Result<()> {
        let mut conn = (*self.manager).clone();
        // Use the AsyncCommands trait method directly
        let _: String = conn.get("__ping_test__").await.unwrap_or_else(|_| "PONG".to_string());
        Ok(())
    }
}

/// Generate Redis key for org usage counter
pub fn org_usage_key(org_id: &uuid::Uuid, period: &str) -> String {
    format!("usage:{}:{}", org_id, period)
}

/// Generate Redis key for org usage counter by type
pub fn org_usage_key_by_type(org_id: &uuid::Uuid, period: &str, usage_type: &str) -> String {
    format!("usage:{}:{}:{}", org_id, period, usage_type)
}

/// Generate Redis key for org rate limit
pub fn org_rate_limit_key(org_id: &uuid::Uuid) -> String {
    format!("ratelimit:{}", org_id)
}

/// Get current month period string (YYYY-MM)
pub fn current_month_period() -> String {
    let now = chrono::Utc::now();
    format!("{}-{:02}", now.year(), now.month())
}

/// Generate Redis key for 2FA setup secret
pub fn two_factor_setup_key(user_id: &uuid::Uuid) -> String {
    format!("2fa_setup:{}", user_id)
}

/// Generate Redis key for 2FA backup codes (stored during setup)
pub fn two_factor_backup_codes_key(user_id: &uuid::Uuid) -> String {
    format!("2fa_backup_codes:{}", user_id)
}

/// TTL for 2FA setup secrets (5 minutes)
pub const TWO_FACTOR_SETUP_TTL_SECONDS: u64 = 300;

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_org_usage_key() {
        let org_id = uuid::Uuid::new_v4();
        let period = "2025-01";

        let key = org_usage_key(&org_id, period);

        assert!(key.starts_with("usage:"));
        assert!(key.contains(&org_id.to_string()));
        assert!(key.contains(period));
        assert_eq!(key, format!("usage:{}:{}", org_id, period));
    }

    #[test]
    fn test_org_usage_key_by_type() {
        let org_id = uuid::Uuid::new_v4();
        let period = "2025-01";
        let usage_type = "api_calls";

        let key = org_usage_key_by_type(&org_id, period, usage_type);

        assert!(key.starts_with("usage:"));
        assert!(key.contains(&org_id.to_string()));
        assert!(key.contains(period));
        assert!(key.contains(usage_type));
        assert_eq!(key, format!("usage:{}:{}:{}", org_id, period, usage_type));
    }

    #[test]
    fn test_org_rate_limit_key() {
        let org_id = uuid::Uuid::new_v4();

        let key = org_rate_limit_key(&org_id);

        assert!(key.starts_with("ratelimit:"));
        assert!(key.contains(&org_id.to_string()));
        assert_eq!(key, format!("ratelimit:{}", org_id));
    }

    #[test]
    fn test_current_month_period_format() {
        let period = current_month_period();

        // Should be in YYYY-MM format
        assert_eq!(period.len(), 7); // "YYYY-MM" is 7 characters
        assert!(period.contains('-'));

        // Parse and validate format
        let parts: Vec<&str> = period.split('-').collect();
        assert_eq!(parts.len(), 2);

        // Year should be 4 digits
        assert_eq!(parts[0].len(), 4);
        let year: i32 = parts[0].parse().expect("Year should be numeric");
        assert!(year >= 2025); // Should be current year or later

        // Month should be 2 digits
        assert_eq!(parts[1].len(), 2);
        let month: u32 = parts[1].parse().expect("Month should be numeric");
        assert!((1..=12).contains(&month));
    }

    #[test]
    fn test_current_month_period_consistency() {
        // Call multiple times in quick succession, should return same value
        let period1 = current_month_period();
        let period2 = current_month_period();

        assert_eq!(period1, period2);
    }

    #[test]
    fn test_org_usage_key_different_periods() {
        let org_id = uuid::Uuid::new_v4();

        let key1 = org_usage_key(&org_id, "2025-01");
        let key2 = org_usage_key(&org_id, "2025-02");

        assert_ne!(key1, key2);
        assert!(key1.contains("2025-01"));
        assert!(key2.contains("2025-02"));
    }

    #[test]
    fn test_org_usage_key_different_orgs() {
        let org_id1 = uuid::Uuid::new_v4();
        let org_id2 = uuid::Uuid::new_v4();
        let period = "2025-01";

        let key1 = org_usage_key(&org_id1, period);
        let key2 = org_usage_key(&org_id2, period);

        assert_ne!(key1, key2);
        assert!(key1.contains(&org_id1.to_string()));
        assert!(key2.contains(&org_id2.to_string()));
    }

    #[test]
    fn test_org_usage_key_by_type_different_types() {
        let org_id = uuid::Uuid::new_v4();
        let period = "2025-01";

        let key1 = org_usage_key_by_type(&org_id, period, "api_calls");
        let key2 = org_usage_key_by_type(&org_id, period, "storage");
        let key3 = org_usage_key_by_type(&org_id, period, "bandwidth");

        assert_ne!(key1, key2);
        assert_ne!(key2, key3);
        assert!(key1.contains("api_calls"));
        assert!(key2.contains("storage"));
        assert!(key3.contains("bandwidth"));
    }

    #[test]
    fn test_org_rate_limit_key_different_orgs() {
        let org_id1 = uuid::Uuid::new_v4();
        let org_id2 = uuid::Uuid::new_v4();

        let key1 = org_rate_limit_key(&org_id1);
        let key2 = org_rate_limit_key(&org_id2);

        assert_ne!(key1, key2);
    }

    #[test]
    fn test_key_format_no_spaces() {
        let org_id = uuid::Uuid::new_v4();

        let key1 = org_usage_key(&org_id, "2025-01");
        let key2 = org_usage_key_by_type(&org_id, "2025-01", "api_calls");
        let key3 = org_rate_limit_key(&org_id);

        assert!(!key1.contains(' '));
        assert!(!key2.contains(' '));
        assert!(!key3.contains(' '));
    }

    #[test]
    fn test_key_format_no_special_chars() {
        let org_id = uuid::Uuid::new_v4();

        let key1 = org_usage_key(&org_id, "2025-01");
        let key2 = org_usage_key_by_type(&org_id, "2025-01", "api_calls");
        let key3 = org_rate_limit_key(&org_id);

        // Keys should only contain alphanumeric, hyphens, underscores, and colons
        let valid_chars =
            |s: &str| s.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == ':');

        assert!(valid_chars(&key1));
        assert!(valid_chars(&key2));
        assert!(valid_chars(&key3));
    }

    #[test]
    fn test_usage_key_with_special_period_formats() {
        let org_id = uuid::Uuid::new_v4();

        // Test various period formats
        let key1 = org_usage_key(&org_id, "2025-01");
        let key2 = org_usage_key(&org_id, "2025-12");
        let key3 = org_usage_key(&org_id, "2024-06");

        assert!(key1.contains("2025-01"));
        assert!(key2.contains("2025-12"));
        assert!(key3.contains("2024-06"));
    }

    #[test]
    fn test_usage_key_by_type_with_special_types() {
        let org_id = uuid::Uuid::new_v4();
        let period = "2025-01";

        // Test various usage types
        let key1 = org_usage_key_by_type(&org_id, period, "api_calls");
        let key2 = org_usage_key_by_type(&org_id, period, "storage_gb");
        let key3 = org_usage_key_by_type(&org_id, period, "bandwidth_mb");

        assert!(key1.ends_with("api_calls"));
        assert!(key2.ends_with("storage_gb"));
        assert!(key3.ends_with("bandwidth_mb"));
    }

    #[test]
    fn test_redis_pool_clone() {
        // This tests the Clone trait on RedisPool
        // We can't actually create a RedisPool without a Redis server,
        // but we can verify the trait is implemented via compilation
        fn requires_clone<T: Clone>() {}
        requires_clone::<RedisPool>();
    }

    #[test]
    fn test_current_month_period_matches_chrono() {
        let period = current_month_period();
        let now = chrono::Utc::now();
        let expected = format!("{}-{:02}", now.year(), now.month());

        assert_eq!(period, expected);
    }

    // Mock-based tests would require a Redis server, so we focus on
    // testing the key generation functions which don't require external services
}