use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use uvb_storage_api::{
RateLimitConfig, RateLimitError, RateLimitResult, RateLimitScope, RateLimitStore,
};
#[derive(Clone, Debug)]
struct RateLimitEntry {
count: u32,
window_start: i64,
penalty_expires_at: Option<i64>,
}
pub struct InMemoryRateLimitStore {
entries: Arc<RwLock<HashMap<String, RateLimitEntry>>>,
}
impl InMemoryRateLimitStore {
pub fn new() -> Self {
Self {
entries: Arc::new(RwLock::new(HashMap::new())),
}
}
fn make_key(&self, scope: &RateLimitScope) -> String {
scope.to_string()
}
async fn cleanup_expired(&self) {
let mut entries = self.entries.write().await;
let now = chrono::Utc::now().timestamp();
entries.retain(|_, entry| {
if let Some(penalty_expires) = entry.penalty_expires_at {
if penalty_expires > now {
return true;
}
}
now - entry.window_start < 3600
});
}
}
impl Default for InMemoryRateLimitStore {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl RateLimitStore for InMemoryRateLimitStore {
async fn check_and_increment(
&self,
scope: &RateLimitScope,
config: &RateLimitConfig,
) -> Result<RateLimitResult, RateLimitError> {
if rand::random::<f32>() < 0.1 {
self.cleanup_expired().await;
}
let key = self.make_key(scope);
let now = chrono::Utc::now().timestamp();
let mut entries = self.entries.write().await;
let entry = entries
.entry(key.clone())
.or_insert_with(|| RateLimitEntry {
count: 0,
window_start: now,
penalty_expires_at: None,
});
if let Some(penalty_expires) = entry.penalty_expires_at {
if penalty_expires > now {
return Ok(RateLimitResult::denied(
entry.count,
config.max_attempts,
entry.window_start + config.window_secs as i64,
)
.with_penalty(penalty_expires));
} else {
entry.penalty_expires_at = None;
}
}
let window_elapsed = now - entry.window_start;
if window_elapsed >= config.window_secs as i64 {
entry.count = 0;
entry.window_start = now;
}
entry.count += 1;
let current_count = entry.count;
let reset_at = entry.window_start + config.window_secs as i64;
if current_count > config.max_attempts {
if let Some(penalty_secs) = config.penalty_secs {
entry.penalty_expires_at = Some(now + penalty_secs as i64);
}
Ok(RateLimitResult::denied(
current_count,
config.max_attempts,
reset_at,
))
} else {
Ok(RateLimitResult::allowed(
current_count,
config.max_attempts,
reset_at,
))
}
}
async fn check(
&self,
scope: &RateLimitScope,
config: &RateLimitConfig,
) -> Result<RateLimitResult, RateLimitError> {
let key = self.make_key(scope);
let now = chrono::Utc::now().timestamp();
let entries = self.entries.read().await;
let entry = match entries.get(&key) {
Some(entry) => entry,
None => {
return Ok(RateLimitResult::allowed(
0,
config.max_attempts,
now + config.window_secs as i64,
));
}
};
if let Some(penalty_expires) = entry.penalty_expires_at {
if penalty_expires > now {
return Ok(RateLimitResult::denied(
entry.count,
config.max_attempts,
entry.window_start + config.window_secs as i64,
)
.with_penalty(penalty_expires));
}
}
let window_elapsed = now - entry.window_start;
let current_count = if window_elapsed >= config.window_secs as i64 {
0 } else {
entry.count
};
let reset_at = entry.window_start + config.window_secs as i64;
if current_count >= config.max_attempts {
Ok(RateLimitResult::denied(
current_count,
config.max_attempts,
reset_at,
))
} else {
Ok(RateLimitResult::allowed(
current_count,
config.max_attempts,
reset_at,
))
}
}
async fn reset(&self, scope: &RateLimitScope) -> Result<(), RateLimitError> {
let key = self.make_key(scope);
let mut entries = self.entries.write().await;
entries.remove(&key);
Ok(())
}
async fn apply_penalty(
&self,
scope: &RateLimitScope,
penalty_secs: u64,
) -> Result<(), RateLimitError> {
let key = self.make_key(scope);
let now = chrono::Utc::now().timestamp();
let mut entries = self.entries.write().await;
let entry = entries.entry(key).or_insert_with(|| RateLimitEntry {
count: 0,
window_start: now,
penalty_expires_at: None,
});
entry.penalty_expires_at = Some(now + penalty_secs as i64);
Ok(())
}
async fn is_penalized(&self, scope: &RateLimitScope) -> Result<bool, RateLimitError> {
let key = self.make_key(scope);
let now = chrono::Utc::now().timestamp();
let entries = self.entries.read().await;
if let Some(entry) = entries.get(&key) {
if let Some(penalty_expires) = entry.penalty_expires_at {
return Ok(penalty_expires > now);
}
}
Ok(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
use uvb_core::TenantId;
#[tokio::test]
async fn test_rate_limit_basic() {
let store = InMemoryRateLimitStore::new();
let scope = RateLimitScope::Subject {
user_id: "user_1".to_string(),
tenant_id: TenantId::new("tenant_a"),
};
let config = RateLimitConfig::new(3, 60);
for i in 1..=3 {
let result = store.check_and_increment(&scope, &config).await.unwrap();
assert!(result.allowed, "attempt {} should be allowed", i);
assert_eq!(result.current_attempts, i);
}
let result = store.check_and_increment(&scope, &config).await.unwrap();
assert!(!result.allowed);
assert_eq!(result.current_attempts, 4);
}
#[tokio::test]
async fn test_rate_limit_window_reset() {
let store = InMemoryRateLimitStore::new();
let scope = RateLimitScope::IpAddress {
ip: "203.0.113.1".to_string(),
};
let config = RateLimitConfig::new(2, 1);
let result1 = store.check_and_increment(&scope, &config).await.unwrap();
assert!(result1.allowed);
let result2 = store.check_and_increment(&scope, &config).await.unwrap();
assert!(result2.allowed);
let result3 = store.check_and_increment(&scope, &config).await.unwrap();
assert!(!result3.allowed);
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
let result4 = store.check_and_increment(&scope, &config).await.unwrap();
assert!(result4.allowed);
assert_eq!(result4.current_attempts, 1);
}
#[tokio::test]
async fn test_penalty() {
let store = InMemoryRateLimitStore::new();
let scope = RateLimitScope::FactorAttempt {
user_id: "user_1".to_string(),
tenant_id: TenantId::new("tenant_a"),
factor_id: "totp".to_string(),
};
store.apply_penalty(&scope, 5).await.unwrap();
assert!(store.is_penalized(&scope).await.unwrap());
let config = RateLimitConfig::new(3, 60);
let result = store.check(&scope, &config).await.unwrap();
assert!(!result.allowed);
assert!(result.penalty_expires_at.is_some());
}
#[tokio::test]
async fn test_reset() {
let store = InMemoryRateLimitStore::new();
let scope = RateLimitScope::Subject {
user_id: "user_1".to_string(),
tenant_id: TenantId::new("tenant_a"),
};
let config = RateLimitConfig::new(3, 60);
for _ in 0..4 {
let _ = store.check_and_increment(&scope, &config).await;
}
let result = store.check(&scope, &config).await.unwrap();
assert!(!result.allowed);
store.reset(&scope).await.unwrap();
let result = store.check(&scope, &config).await.unwrap();
assert!(result.allowed);
assert_eq!(result.current_attempts, 0);
}
}