uvb-storage-memory 0.2.1

In-memory storage backend for UVB testing and development
Documentation
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;

use uvb_storage_api::{
    RateLimitConfig, RateLimitError, RateLimitResult, RateLimitScope, RateLimitStore,
};

#[derive(Clone, Debug)]
struct RateLimitEntry {
    count: u32,
    window_start: i64,
    penalty_expires_at: Option<i64>,
}

/// In-memory rate limit store for testing
///
/// Not recommended for production use across multiple instances.
/// Use Redis for distributed rate limiting.
pub struct InMemoryRateLimitStore {
    entries: Arc<RwLock<HashMap<String, RateLimitEntry>>>,
}

impl InMemoryRateLimitStore {
    pub fn new() -> Self {
        Self {
            entries: Arc::new(RwLock::new(HashMap::new())),
        }
    }

    fn make_key(&self, scope: &RateLimitScope) -> String {
        scope.to_string()
    }

    async fn cleanup_expired(&self) {
        let mut entries = self.entries.write().await;
        let now = chrono::Utc::now().timestamp();

        entries.retain(|_, entry| {
            // Keep if penalty is still active
            if let Some(penalty_expires) = entry.penalty_expires_at {
                if penalty_expires > now {
                    return true;
                }
            }
            // Keep if within window (allow some grace period)
            now - entry.window_start < 3600
        });
    }
}

impl Default for InMemoryRateLimitStore {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl RateLimitStore for InMemoryRateLimitStore {
    async fn check_and_increment(
        &self,
        scope: &RateLimitScope,
        config: &RateLimitConfig,
    ) -> Result<RateLimitResult, RateLimitError> {
        // Periodic cleanup
        if rand::random::<f32>() < 0.1 {
            self.cleanup_expired().await;
        }

        let key = self.make_key(scope);
        let now = chrono::Utc::now().timestamp();
        let mut entries = self.entries.write().await;

        let entry = entries
            .entry(key.clone())
            .or_insert_with(|| RateLimitEntry {
                count: 0,
                window_start: now,
                penalty_expires_at: None,
            });

        // Check penalty
        if let Some(penalty_expires) = entry.penalty_expires_at {
            if penalty_expires > now {
                return Ok(RateLimitResult::denied(
                    entry.count,
                    config.max_attempts,
                    entry.window_start + config.window_secs as i64,
                )
                .with_penalty(penalty_expires));
            } else {
                // Penalty expired, clear it
                entry.penalty_expires_at = None;
            }
        }

        // Check if window expired
        let window_elapsed = now - entry.window_start;
        if window_elapsed >= config.window_secs as i64 {
            // Reset window
            entry.count = 0;
            entry.window_start = now;
        }

        // Increment
        entry.count += 1;
        let current_count = entry.count;
        let reset_at = entry.window_start + config.window_secs as i64;

        if current_count > config.max_attempts {
            // Apply penalty if configured
            if let Some(penalty_secs) = config.penalty_secs {
                entry.penalty_expires_at = Some(now + penalty_secs as i64);
            }

            Ok(RateLimitResult::denied(
                current_count,
                config.max_attempts,
                reset_at,
            ))
        } else {
            Ok(RateLimitResult::allowed(
                current_count,
                config.max_attempts,
                reset_at,
            ))
        }
    }

    async fn check(
        &self,
        scope: &RateLimitScope,
        config: &RateLimitConfig,
    ) -> Result<RateLimitResult, RateLimitError> {
        let key = self.make_key(scope);
        let now = chrono::Utc::now().timestamp();
        let entries = self.entries.read().await;

        let entry = match entries.get(&key) {
            Some(entry) => entry,
            None => {
                // No entry means no attempts yet
                return Ok(RateLimitResult::allowed(
                    0,
                    config.max_attempts,
                    now + config.window_secs as i64,
                ));
            }
        };

        // Check penalty
        if let Some(penalty_expires) = entry.penalty_expires_at {
            if penalty_expires > now {
                return Ok(RateLimitResult::denied(
                    entry.count,
                    config.max_attempts,
                    entry.window_start + config.window_secs as i64,
                )
                .with_penalty(penalty_expires));
            }
        }

        // Check if window expired
        let window_elapsed = now - entry.window_start;
        let current_count = if window_elapsed >= config.window_secs as i64 {
            0 // Window expired, count is effectively 0
        } else {
            entry.count
        };

        let reset_at = entry.window_start + config.window_secs as i64;

        if current_count >= config.max_attempts {
            Ok(RateLimitResult::denied(
                current_count,
                config.max_attempts,
                reset_at,
            ))
        } else {
            Ok(RateLimitResult::allowed(
                current_count,
                config.max_attempts,
                reset_at,
            ))
        }
    }

    async fn reset(&self, scope: &RateLimitScope) -> Result<(), RateLimitError> {
        let key = self.make_key(scope);
        let mut entries = self.entries.write().await;
        entries.remove(&key);
        Ok(())
    }

    async fn apply_penalty(
        &self,
        scope: &RateLimitScope,
        penalty_secs: u64,
    ) -> Result<(), RateLimitError> {
        let key = self.make_key(scope);
        let now = chrono::Utc::now().timestamp();
        let mut entries = self.entries.write().await;

        let entry = entries.entry(key).or_insert_with(|| RateLimitEntry {
            count: 0,
            window_start: now,
            penalty_expires_at: None,
        });

        entry.penalty_expires_at = Some(now + penalty_secs as i64);
        Ok(())
    }

    async fn is_penalized(&self, scope: &RateLimitScope) -> Result<bool, RateLimitError> {
        let key = self.make_key(scope);
        let now = chrono::Utc::now().timestamp();
        let entries = self.entries.read().await;

        if let Some(entry) = entries.get(&key) {
            if let Some(penalty_expires) = entry.penalty_expires_at {
                return Ok(penalty_expires > now);
            }
        }

        Ok(false)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use uvb_core::TenantId;

    #[tokio::test]
    async fn test_rate_limit_basic() {
        let store = InMemoryRateLimitStore::new();
        let scope = RateLimitScope::Subject {
            user_id: "user_1".to_string(),
            tenant_id: TenantId::new("tenant_a"),
        };
        let config = RateLimitConfig::new(3, 60);

        // First 3 attempts should succeed
        for i in 1..=3 {
            let result = store.check_and_increment(&scope, &config).await.unwrap();
            assert!(result.allowed, "attempt {} should be allowed", i);
            assert_eq!(result.current_attempts, i);
        }

        // 4th attempt should be denied
        let result = store.check_and_increment(&scope, &config).await.unwrap();
        assert!(!result.allowed);
        assert_eq!(result.current_attempts, 4);
    }

    #[tokio::test]
    async fn test_rate_limit_window_reset() {
        let store = InMemoryRateLimitStore::new();
        let scope = RateLimitScope::IpAddress {
            ip: "203.0.113.1".to_string(),
        };
        let config = RateLimitConfig::new(2, 1); // 2 attempts per second

        // Use up the limit
        let result1 = store.check_and_increment(&scope, &config).await.unwrap();
        assert!(result1.allowed);

        let result2 = store.check_and_increment(&scope, &config).await.unwrap();
        assert!(result2.allowed);

        let result3 = store.check_and_increment(&scope, &config).await.unwrap();
        assert!(!result3.allowed);

        // Wait for window to expire
        tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;

        // Should be allowed again
        let result4 = store.check_and_increment(&scope, &config).await.unwrap();
        assert!(result4.allowed);
        assert_eq!(result4.current_attempts, 1);
    }

    #[tokio::test]
    async fn test_penalty() {
        let store = InMemoryRateLimitStore::new();
        let scope = RateLimitScope::FactorAttempt {
            user_id: "user_1".to_string(),
            tenant_id: TenantId::new("tenant_a"),
            factor_id: "totp".to_string(),
        };

        // Apply penalty
        store.apply_penalty(&scope, 5).await.unwrap();

        // Should be penalized
        assert!(store.is_penalized(&scope).await.unwrap());

        // Check should return denied with penalty
        let config = RateLimitConfig::new(3, 60);
        let result = store.check(&scope, &config).await.unwrap();
        assert!(!result.allowed);
        assert!(result.penalty_expires_at.is_some());
    }

    #[tokio::test]
    async fn test_reset() {
        let store = InMemoryRateLimitStore::new();
        let scope = RateLimitScope::Subject {
            user_id: "user_1".to_string(),
            tenant_id: TenantId::new("tenant_a"),
        };
        let config = RateLimitConfig::new(3, 60);

        // Use up the limit
        for _ in 0..4 {
            let _ = store.check_and_increment(&scope, &config).await;
        }

        // Should be denied
        let result = store.check(&scope, &config).await.unwrap();
        assert!(!result.allowed);

        // Reset
        store.reset(&scope).await.unwrap();

        // Should be allowed again
        let result = store.check(&scope, &config).await.unwrap();
        assert!(result.allowed);
        assert_eq!(result.current_attempts, 0);
    }
}