uvb-storage-api 0.2.1

Storage backend trait abstractions for UVB data persistence
Documentation
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::fmt;
use thiserror::Error;
use uvb_core::TenantId;

#[derive(Debug, Error)]
pub enum RateLimitError {
    #[error("storage error: {0}")]
    Storage(String),

    #[error("serialization error: {0}")]
    Serialization(String),
}

/// Rate limit scope for different types of operations
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum RateLimitScope {
    /// Rate limit per subject (user)
    Subject {
        user_id: String,
        tenant_id: TenantId,
    },

    /// Rate limit per IP address
    IpAddress { ip: String },

    /// Rate limit per factor and subject
    FactorAttempt {
        user_id: String,
        tenant_id: TenantId,
        factor_id: String,
    },

    /// Rate limit per API endpoint
    Endpoint { path: String, method: String },

    /// Custom rate limit key
    Custom { key: String },
}

impl fmt::Display for RateLimitScope {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            RateLimitScope::Subject { user_id, tenant_id } => {
                write!(f, "subject:{}:{}", tenant_id, user_id)
            }
            RateLimitScope::IpAddress { ip } => write!(f, "ip:{}", ip),
            RateLimitScope::FactorAttempt {
                user_id,
                tenant_id,
                factor_id,
            } => write!(f, "factor:{}:{}:{}", tenant_id, user_id, factor_id),
            RateLimitScope::Endpoint { path, method } => write!(f, "endpoint:{}:{}", method, path),
            RateLimitScope::Custom { key } => write!(f, "custom:{}", key),
        }
    }
}

/// Rate limit configuration
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RateLimitConfig {
    /// Maximum number of attempts allowed
    pub max_attempts: u32,

    /// Time window in seconds
    pub window_secs: u64,

    /// Optional penalty duration in seconds after exceeding limit
    pub penalty_secs: Option<u64>,
}

impl RateLimitConfig {
    pub fn new(max_attempts: u32, window_secs: u64) -> Self {
        Self {
            max_attempts,
            window_secs,
            penalty_secs: None,
        }
    }

    pub fn with_penalty(mut self, penalty_secs: u64) -> Self {
        self.penalty_secs = Some(penalty_secs);
        self
    }
}

/// Rate limit check result
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RateLimitResult {
    /// Whether the request is allowed
    pub allowed: bool,

    /// Current attempt count in the window
    pub current_attempts: u32,

    /// Maximum attempts allowed
    pub max_attempts: u32,

    /// Remaining attempts before hitting the limit
    pub remaining_attempts: u32,

    /// Unix timestamp when the rate limit window resets
    pub reset_at: i64,

    /// Optional: Unix timestamp when penalty expires (if in penalty period)
    pub penalty_expires_at: Option<i64>,
}

impl RateLimitResult {
    pub fn allowed(current_attempts: u32, max_attempts: u32, reset_at: i64) -> Self {
        Self {
            allowed: true,
            current_attempts,
            max_attempts,
            remaining_attempts: max_attempts.saturating_sub(current_attempts),
            reset_at,
            penalty_expires_at: None,
        }
    }

    pub fn denied(current_attempts: u32, max_attempts: u32, reset_at: i64) -> Self {
        Self {
            allowed: false,
            current_attempts,
            max_attempts,
            remaining_attempts: 0,
            reset_at,
            penalty_expires_at: None,
        }
    }

    pub fn with_penalty(mut self, penalty_expires_at: i64) -> Self {
        self.penalty_expires_at = Some(penalty_expires_at);
        self
    }
}

/// Trait for rate limiting storage
///
/// Implementations should use efficient counters with TTL where possible.
/// Redis is ideal for this, but in-memory and SQL implementations are also provided.
#[async_trait]
pub trait RateLimitStore: Send + Sync {
    /// Check if an operation is allowed under the rate limit and increment counter
    ///
    /// This is an atomic "check and increment" operation:
    /// - If under the limit: increment and return allowed=true
    /// - If at or over the limit: return allowed=false
    /// - Automatically handles window expiration
    async fn check_and_increment(
        &self,
        scope: &RateLimitScope,
        config: &RateLimitConfig,
    ) -> Result<RateLimitResult, RateLimitError>;

    /// Get current rate limit status without incrementing
    async fn check(
        &self,
        scope: &RateLimitScope,
        config: &RateLimitConfig,
    ) -> Result<RateLimitResult, RateLimitError>;

    /// Reset rate limit for a scope (useful for admin overrides)
    async fn reset(&self, scope: &RateLimitScope) -> Result<(), RateLimitError>;

    /// Apply a penalty (temporary ban) for a scope
    async fn apply_penalty(
        &self,
        scope: &RateLimitScope,
        penalty_secs: u64,
    ) -> Result<(), RateLimitError>;

    /// Check if a scope is currently in penalty period
    async fn is_penalized(&self, scope: &RateLimitScope) -> Result<bool, RateLimitError>;
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_rate_limit_scope_display() {
        let scope = RateLimitScope::Subject {
            user_id: "user_1".to_string(),
            tenant_id: TenantId::new("tenant_a"),
        };
        assert_eq!(scope.to_string(), "subject:tenant_a:user_1");

        let scope = RateLimitScope::IpAddress {
            ip: "203.0.113.1".to_string(),
        };
        assert_eq!(scope.to_string(), "ip:203.0.113.1");

        let scope = RateLimitScope::FactorAttempt {
            user_id: "user_1".to_string(),
            tenant_id: TenantId::new("tenant_a"),
            factor_id: "totp".to_string(),
        };
        assert_eq!(scope.to_string(), "factor:tenant_a:user_1:totp");
    }

    #[test]
    fn test_rate_limit_result() {
        let result = RateLimitResult::allowed(3, 10, 1234567890);
        assert!(result.allowed);
        assert_eq!(result.current_attempts, 3);
        assert_eq!(result.remaining_attempts, 7);

        let result = RateLimitResult::denied(10, 10, 1234567890);
        assert!(!result.allowed);
        assert_eq!(result.remaining_attempts, 0);
    }
}