use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::fmt;
use thiserror::Error;
use uvb_core::TenantId;
#[derive(Debug, Error)]
pub enum RateLimitError {
#[error("storage error: {0}")]
Storage(String),
#[error("serialization error: {0}")]
Serialization(String),
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum RateLimitScope {
Subject {
user_id: String,
tenant_id: TenantId,
},
IpAddress { ip: String },
FactorAttempt {
user_id: String,
tenant_id: TenantId,
factor_id: String,
},
Endpoint { path: String, method: String },
Custom { key: String },
}
impl fmt::Display for RateLimitScope {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
RateLimitScope::Subject { user_id, tenant_id } => {
write!(f, "subject:{}:{}", tenant_id, user_id)
}
RateLimitScope::IpAddress { ip } => write!(f, "ip:{}", ip),
RateLimitScope::FactorAttempt {
user_id,
tenant_id,
factor_id,
} => write!(f, "factor:{}:{}:{}", tenant_id, user_id, factor_id),
RateLimitScope::Endpoint { path, method } => write!(f, "endpoint:{}:{}", method, path),
RateLimitScope::Custom { key } => write!(f, "custom:{}", key),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RateLimitConfig {
pub max_attempts: u32,
pub window_secs: u64,
pub penalty_secs: Option<u64>,
}
impl RateLimitConfig {
pub fn new(max_attempts: u32, window_secs: u64) -> Self {
Self {
max_attempts,
window_secs,
penalty_secs: None,
}
}
pub fn with_penalty(mut self, penalty_secs: u64) -> Self {
self.penalty_secs = Some(penalty_secs);
self
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RateLimitResult {
pub allowed: bool,
pub current_attempts: u32,
pub max_attempts: u32,
pub remaining_attempts: u32,
pub reset_at: i64,
pub penalty_expires_at: Option<i64>,
}
impl RateLimitResult {
pub fn allowed(current_attempts: u32, max_attempts: u32, reset_at: i64) -> Self {
Self {
allowed: true,
current_attempts,
max_attempts,
remaining_attempts: max_attempts.saturating_sub(current_attempts),
reset_at,
penalty_expires_at: None,
}
}
pub fn denied(current_attempts: u32, max_attempts: u32, reset_at: i64) -> Self {
Self {
allowed: false,
current_attempts,
max_attempts,
remaining_attempts: 0,
reset_at,
penalty_expires_at: None,
}
}
pub fn with_penalty(mut self, penalty_expires_at: i64) -> Self {
self.penalty_expires_at = Some(penalty_expires_at);
self
}
}
#[async_trait]
pub trait RateLimitStore: Send + Sync {
async fn check_and_increment(
&self,
scope: &RateLimitScope,
config: &RateLimitConfig,
) -> Result<RateLimitResult, RateLimitError>;
async fn check(
&self,
scope: &RateLimitScope,
config: &RateLimitConfig,
) -> Result<RateLimitResult, RateLimitError>;
async fn reset(&self, scope: &RateLimitScope) -> Result<(), RateLimitError>;
async fn apply_penalty(
&self,
scope: &RateLimitScope,
penalty_secs: u64,
) -> Result<(), RateLimitError>;
async fn is_penalized(&self, scope: &RateLimitScope) -> Result<bool, RateLimitError>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limit_scope_display() {
let scope = RateLimitScope::Subject {
user_id: "user_1".to_string(),
tenant_id: TenantId::new("tenant_a"),
};
assert_eq!(scope.to_string(), "subject:tenant_a:user_1");
let scope = RateLimitScope::IpAddress {
ip: "203.0.113.1".to_string(),
};
assert_eq!(scope.to_string(), "ip:203.0.113.1");
let scope = RateLimitScope::FactorAttempt {
user_id: "user_1".to_string(),
tenant_id: TenantId::new("tenant_a"),
factor_id: "totp".to_string(),
};
assert_eq!(scope.to_string(), "factor:tenant_a:user_1:totp");
}
#[test]
fn test_rate_limit_result() {
let result = RateLimitResult::allowed(3, 10, 1234567890);
assert!(result.allowed);
assert_eq!(result.current_attempts, 3);
assert_eq!(result.remaining_attempts, 7);
let result = RateLimitResult::denied(10, 10, 1234567890);
assert!(!result.allowed);
assert_eq!(result.remaining_attempts, 0);
}
}