use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use crate::error::{Error, Result};
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub max_requests: u32,
pub window: Duration,
pub sliding_window: bool,
pub ban_duration: Option<Duration>,
pub ban_threshold: u32,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_requests: 100,
window: Duration::from_secs(60),
sliding_window: true,
ban_duration: None,
ban_threshold: 3,
}
}
}
impl RateLimitConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_requests(mut self, max: u32) -> Self {
self.max_requests = max;
self
}
pub fn with_window(mut self, window: Duration) -> Self {
self.window = window;
self
}
pub fn with_sliding_window(mut self, enabled: bool) -> Self {
self.sliding_window = enabled;
self
}
pub fn with_ban_duration(mut self, duration: Duration) -> Self {
self.ban_duration = Some(duration);
self
}
pub fn with_ban_threshold(mut self, threshold: u32) -> Self {
self.ban_threshold = threshold;
self
}
pub fn for_login() -> Self {
Self {
max_requests: 5,
window: Duration::from_secs(60),
sliding_window: true,
ban_duration: Some(Duration::from_secs(900)), ban_threshold: 3,
}
}
pub fn for_api() -> Self {
Self {
max_requests: 60,
window: Duration::from_secs(60),
sliding_window: true,
ban_duration: None,
ban_threshold: 5,
}
}
pub fn for_password_reset() -> Self {
Self {
max_requests: 3,
window: Duration::from_secs(3600),
sliding_window: true,
ban_duration: Some(Duration::from_secs(3600)),
ban_threshold: 2,
}
}
}
#[derive(Debug, Clone)]
pub struct RateLimitInfo {
pub remaining: u32,
pub limit: u32,
pub reset_after: Duration,
pub is_banned: bool,
pub ban_remaining: Option<Duration>,
}
impl RateLimitInfo {
pub fn allowed(remaining: u32, limit: u32, reset_after: Duration) -> Self {
Self {
remaining,
limit,
reset_after,
is_banned: false,
ban_remaining: None,
}
}
pub fn limited(limit: u32, reset_after: Duration) -> Self {
Self {
remaining: 0,
limit,
reset_after,
is_banned: false,
ban_remaining: None,
}
}
pub fn banned(limit: u32, ban_remaining: Duration) -> Self {
Self {
remaining: 0,
limit,
reset_after: ban_remaining,
is_banned: true,
ban_remaining: Some(ban_remaining),
}
}
}
#[derive(Debug, Clone)]
struct RequestRecord {
timestamps: Vec<Instant>,
violation_count: u32,
banned_until: Option<Instant>,
}
impl RequestRecord {
fn new() -> Self {
Self {
timestamps: Vec::new(),
violation_count: 0,
banned_until: None,
}
}
fn cleanup(&mut self, window: Duration) {
let cutoff = Instant::now() - window;
self.timestamps.retain(|&ts| ts > cutoff);
}
fn is_banned(&self) -> bool {
if let Some(until) = self.banned_until {
Instant::now() < until
} else {
false
}
}
fn ban_remaining(&self) -> Option<Duration> {
self.banned_until.and_then(|until| {
let now = Instant::now();
if now < until { Some(until - now) } else { None }
})
}
}
#[derive(Debug, Clone)]
struct FixedWindowRecord {
count: u32,
window_start: Instant,
violation_count: u32,
banned_until: Option<Instant>,
}
impl FixedWindowRecord {
fn new() -> Self {
Self {
count: 0,
window_start: Instant::now(),
violation_count: 0,
banned_until: None,
}
}
fn check_reset(&mut self, window: Duration) {
if self.window_start.elapsed() >= window {
self.count = 0;
self.window_start = Instant::now();
}
}
fn is_banned(&self) -> bool {
if let Some(until) = self.banned_until {
Instant::now() < until
} else {
false
}
}
fn ban_remaining(&self) -> Option<Duration> {
self.banned_until.and_then(|until| {
let now = Instant::now();
if now < until { Some(until - now) } else { None }
})
}
}
#[async_trait]
pub trait RateLimitStore: Send + Sync {
async fn check_and_record(&self, key: &str, config: &RateLimitConfig) -> Result<RateLimitInfo>;
async fn reset(&self, key: &str);
async fn get_status(&self, key: &str, config: &RateLimitConfig) -> RateLimitInfo;
async fn ban(&self, key: &str, duration: Duration);
async fn unban(&self, key: &str);
async fn cleanup(&self);
}
#[derive(Debug)]
pub struct InMemorySlidingWindowStore {
records: RwLock<HashMap<String, RequestRecord>>,
}
impl Default for InMemorySlidingWindowStore {
fn default() -> Self {
Self::new()
}
}
impl InMemorySlidingWindowStore {
pub fn new() -> Self {
Self {
records: RwLock::new(HashMap::new()),
}
}
}
#[async_trait]
impl RateLimitStore for InMemorySlidingWindowStore {
async fn check_and_record(&self, key: &str, config: &RateLimitConfig) -> Result<RateLimitInfo> {
let mut records = self
.records
.write()
.map_err(|_| Error::internal("Failed to acquire lock"))?;
let record = records
.entry(key.to_string())
.or_insert_with(RequestRecord::new);
if record.is_banned() {
let ban_remaining = record.ban_remaining().unwrap_or(Duration::ZERO);
return Err(Error::rate_limited(ban_remaining));
}
if record.banned_until.is_some() && !record.is_banned() {
record.banned_until = None;
record.violation_count = 0;
}
record.cleanup(config.window);
let current_count = record.timestamps.len() as u32;
if current_count >= config.max_requests {
record.violation_count += 1;
if let Some(ban_duration) = config.ban_duration
&& record.violation_count >= config.ban_threshold
{
record.banned_until = Some(Instant::now() + ban_duration);
return Err(Error::rate_limited(ban_duration));
}
let reset_after = record
.timestamps
.first()
.map(|ts| config.window.saturating_sub(ts.elapsed()))
.unwrap_or(config.window);
return Err(Error::rate_limited(reset_after));
}
record.timestamps.push(Instant::now());
let remaining = config.max_requests - current_count - 1;
let reset_after = record
.timestamps
.first()
.map(|ts| config.window.saturating_sub(ts.elapsed()))
.unwrap_or(config.window);
Ok(RateLimitInfo::allowed(
remaining,
config.max_requests,
reset_after,
))
}
async fn reset(&self, key: &str) {
if let Ok(mut records) = self.records.write() {
records.remove(key);
}
}
async fn get_status(&self, key: &str, config: &RateLimitConfig) -> RateLimitInfo {
let records = match self.records.read() {
Ok(r) => r,
Err(_) => {
return RateLimitInfo::allowed(
config.max_requests,
config.max_requests,
config.window,
);
}
};
match records.get(key) {
Some(record) => {
if record.is_banned() {
let ban_remaining = record.ban_remaining().unwrap_or(Duration::ZERO);
RateLimitInfo::banned(config.max_requests, ban_remaining)
} else {
let cutoff = Instant::now() - config.window;
let count = record.timestamps.iter().filter(|&&ts| ts > cutoff).count() as u32;
let remaining = config.max_requests.saturating_sub(count);
let reset_after = record
.timestamps
.iter()
.find(|&&ts| ts > cutoff)
.map(|ts| config.window.saturating_sub(ts.elapsed()))
.unwrap_or(config.window);
RateLimitInfo::allowed(remaining, config.max_requests, reset_after)
}
}
None => RateLimitInfo::allowed(config.max_requests, config.max_requests, config.window),
}
}
async fn ban(&self, key: &str, duration: Duration) {
if let Ok(mut records) = self.records.write() {
let record = records
.entry(key.to_string())
.or_insert_with(RequestRecord::new);
record.banned_until = Some(Instant::now() + duration);
}
}
async fn unban(&self, key: &str) {
if let Ok(mut records) = self.records.write()
&& let Some(record) = records.get_mut(key)
{
record.banned_until = None;
record.violation_count = 0;
}
}
async fn cleanup(&self) {
if let Ok(mut records) = self.records.write() {
let now = Instant::now();
records.retain(|_, record| {
if let Some(until) = record.banned_until
&& now < until
{
return true;
}
if let Some(last) = record.timestamps.last() {
last.elapsed() < Duration::from_secs(3600)
} else {
false
}
});
}
}
}
#[derive(Debug)]
pub struct InMemoryFixedWindowStore {
records: RwLock<HashMap<String, FixedWindowRecord>>,
}
impl Default for InMemoryFixedWindowStore {
fn default() -> Self {
Self::new()
}
}
impl InMemoryFixedWindowStore {
pub fn new() -> Self {
Self {
records: RwLock::new(HashMap::new()),
}
}
}
#[async_trait]
impl RateLimitStore for InMemoryFixedWindowStore {
async fn check_and_record(&self, key: &str, config: &RateLimitConfig) -> Result<RateLimitInfo> {
let mut records = self
.records
.write()
.map_err(|_| Error::internal("Failed to acquire lock"))?;
let record = records
.entry(key.to_string())
.or_insert_with(FixedWindowRecord::new);
if record.is_banned() {
let ban_remaining = record.ban_remaining().unwrap_or(Duration::ZERO);
return Err(Error::rate_limited(ban_remaining));
}
if record.banned_until.is_some() && !record.is_banned() {
record.banned_until = None;
record.violation_count = 0;
}
record.check_reset(config.window);
if record.count >= config.max_requests {
record.violation_count += 1;
if let Some(ban_duration) = config.ban_duration
&& record.violation_count >= config.ban_threshold
{
record.banned_until = Some(Instant::now() + ban_duration);
return Err(Error::rate_limited(ban_duration));
}
let reset_after = config.window.saturating_sub(record.window_start.elapsed());
return Err(Error::rate_limited(reset_after));
}
record.count += 1;
let remaining = config.max_requests - record.count;
let reset_after = config.window.saturating_sub(record.window_start.elapsed());
Ok(RateLimitInfo::allowed(
remaining,
config.max_requests,
reset_after,
))
}
async fn reset(&self, key: &str) {
if let Ok(mut records) = self.records.write() {
records.remove(key);
}
}
async fn get_status(&self, key: &str, config: &RateLimitConfig) -> RateLimitInfo {
let records = match self.records.read() {
Ok(r) => r,
Err(_) => {
return RateLimitInfo::allowed(
config.max_requests,
config.max_requests,
config.window,
);
}
};
match records.get(key) {
Some(record) => {
if record.is_banned() {
let ban_remaining = record.ban_remaining().unwrap_or(Duration::ZERO);
RateLimitInfo::banned(config.max_requests, ban_remaining)
} else {
let elapsed = record.window_start.elapsed();
if elapsed >= config.window {
RateLimitInfo::allowed(
config.max_requests,
config.max_requests,
config.window,
)
} else {
let remaining = config.max_requests.saturating_sub(record.count);
let reset_after = config.window.saturating_sub(elapsed);
RateLimitInfo::allowed(remaining, config.max_requests, reset_after)
}
}
}
None => RateLimitInfo::allowed(config.max_requests, config.max_requests, config.window),
}
}
async fn ban(&self, key: &str, duration: Duration) {
if let Ok(mut records) = self.records.write() {
let record = records
.entry(key.to_string())
.or_insert_with(FixedWindowRecord::new);
record.banned_until = Some(Instant::now() + duration);
}
}
async fn unban(&self, key: &str) {
if let Ok(mut records) = self.records.write()
&& let Some(record) = records.get_mut(key)
{
record.banned_until = None;
record.violation_count = 0;
}
}
async fn cleanup(&self) {
if let Ok(mut records) = self.records.write() {
let now = Instant::now();
records.retain(|_, record| {
if let Some(until) = record.banned_until
&& now < until
{
return true;
}
record.window_start.elapsed() < Duration::from_secs(3600)
});
}
}
}
pub struct RateLimiter {
config: RateLimitConfig,
store: Arc<dyn RateLimitStore>,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
let store: Arc<dyn RateLimitStore> = if config.sliding_window {
Arc::new(InMemorySlidingWindowStore::new())
} else {
Arc::new(InMemoryFixedWindowStore::new())
};
Self { config, store }
}
pub fn with_config(config: RateLimitConfig) -> Self {
Self::new(config)
}
pub fn with_store<S: RateLimitStore + 'static>(config: RateLimitConfig, store: S) -> Self {
Self {
config,
store: Arc::new(store),
}
}
pub async fn check(&self, key: &str) -> Result<RateLimitInfo> {
self.store.check_and_record(key, &self.config).await
}
pub async fn status(&self, key: &str) -> RateLimitInfo {
self.store.get_status(key, &self.config).await
}
pub async fn reset(&self, key: &str) {
self.store.reset(key).await;
}
pub async fn ban(&self, key: &str, duration: Duration) {
self.store.ban(key, duration).await;
}
pub async fn unban(&self, key: &str) {
self.store.unban(key).await;
}
pub async fn cleanup(&self) {
self.store.cleanup().await;
}
pub fn config(&self) -> &RateLimitConfig {
&self.config
}
}
#[derive(Debug, Clone)]
pub struct TokenBucketConfig {
pub capacity: u32,
pub refill_rate: f64,
pub tokens_per_request: u32,
}
impl Default for TokenBucketConfig {
fn default() -> Self {
Self {
capacity: 100,
refill_rate: 10.0, tokens_per_request: 1,
}
}
}
impl TokenBucketConfig {
pub fn new(capacity: u32, refill_rate: f64) -> Self {
Self {
capacity,
refill_rate,
tokens_per_request: 1,
}
}
pub fn with_tokens_per_request(mut self, tokens: u32) -> Self {
self.tokens_per_request = tokens;
self
}
}
#[derive(Debug, Clone)]
struct TokenBucket {
tokens: f64,
last_update: Instant,
}
impl TokenBucket {
fn new(capacity: u32) -> Self {
Self {
tokens: capacity as f64,
last_update: Instant::now(),
}
}
fn refill(&mut self, config: &TokenBucketConfig) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_update).as_secs_f64();
let refilled = elapsed * config.refill_rate;
self.tokens = (self.tokens + refilled).min(config.capacity as f64);
self.last_update = now;
}
fn try_consume(&mut self, tokens: u32, config: &TokenBucketConfig) -> bool {
self.refill(config);
let tokens = tokens as f64;
if self.tokens >= tokens {
self.tokens -= tokens;
true
} else {
false
}
}
fn current_tokens(&mut self, config: &TokenBucketConfig) -> f64 {
self.refill(config);
self.tokens
}
}
pub struct TokenBucketLimiter {
config: TokenBucketConfig,
buckets: RwLock<HashMap<String, TokenBucket>>,
}
impl TokenBucketLimiter {
pub fn new(config: TokenBucketConfig) -> Self {
Self {
config,
buckets: RwLock::new(HashMap::new()),
}
}
pub fn check(&self, key: &str) -> bool {
self.try_consume(key, self.config.tokens_per_request)
}
pub fn try_consume(&self, key: &str, tokens: u32) -> bool {
let mut buckets = match self.buckets.write() {
Ok(b) => b,
Err(_) => return false,
};
let bucket = buckets
.entry(key.to_string())
.or_insert_with(|| TokenBucket::new(self.config.capacity));
bucket.try_consume(tokens, &self.config)
}
pub fn available_tokens(&self, key: &str) -> f64 {
let mut buckets = match self.buckets.write() {
Ok(b) => b,
Err(_) => return 0.0,
};
let bucket = buckets
.entry(key.to_string())
.or_insert_with(|| TokenBucket::new(self.config.capacity));
bucket.current_tokens(&self.config)
}
pub fn reset(&self, key: &str) {
if let Ok(mut buckets) = self.buckets.write() {
buckets.remove(key);
}
}
pub fn cleanup(&self) {
if let Ok(mut buckets) = self.buckets.write() {
buckets.retain(|_, bucket| {
bucket.last_update.elapsed() < Duration::from_secs(3600)
});
}
}
}
pub struct CompositeRateLimiter {
limiters: Vec<Arc<RateLimiter>>,
}
impl CompositeRateLimiter {
pub fn new(limiters: Vec<Arc<RateLimiter>>) -> Self {
Self { limiters }
}
pub async fn check(&self, key: &str) -> Result<Vec<RateLimitInfo>> {
let mut infos = Vec::with_capacity(self.limiters.len());
for limiter in &self.limiters {
let info = limiter.check(key).await?;
infos.push(info);
}
Ok(infos)
}
pub async fn status(&self, key: &str) -> Vec<RateLimitInfo> {
let mut results = Vec::new();
for limiter in &self.limiters {
results.push(limiter.status(key).await);
}
results
}
pub async fn reset(&self, key: &str) {
for limiter in &self.limiters {
limiter.reset(key).await;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[tokio::test]
async fn test_rate_limiter_basic() {
let config = RateLimitConfig::new()
.with_max_requests(3)
.with_window(Duration::from_secs(60));
let limiter = RateLimiter::new(config);
let key = "test:user";
assert!(limiter.check(key).await.is_ok());
assert!(limiter.check(key).await.is_ok());
assert!(limiter.check(key).await.is_ok());
assert!(limiter.check(key).await.is_err());
}
#[tokio::test]
async fn test_rate_limiter_reset() {
let config = RateLimitConfig::new()
.with_max_requests(2)
.with_window(Duration::from_secs(60));
let limiter = RateLimiter::new(config);
let key = "test:reset";
assert!(limiter.check(key).await.is_ok());
assert!(limiter.check(key).await.is_ok());
assert!(limiter.check(key).await.is_err());
limiter.reset(key).await;
assert!(limiter.check(key).await.is_ok());
}
#[tokio::test]
async fn test_rate_limiter_status() {
let config = RateLimitConfig::new()
.with_max_requests(5)
.with_window(Duration::from_secs(60));
let limiter = RateLimiter::new(config);
let key = "test:status";
let status = limiter.status(key).await;
assert_eq!(status.remaining, 5);
assert_eq!(status.limit, 5);
limiter.check(key).await.unwrap();
limiter.check(key).await.unwrap();
let status = limiter.status(key).await;
assert_eq!(status.remaining, 3);
}
#[tokio::test]
async fn test_rate_limiter_ban() {
let config = RateLimitConfig::new()
.with_max_requests(5)
.with_window(Duration::from_secs(60));
let limiter = RateLimiter::new(config);
let key = "test:ban";
limiter.ban(key, Duration::from_secs(60)).await;
assert!(limiter.check(key).await.is_err());
limiter.unban(key).await;
assert!(limiter.check(key).await.is_ok());
}
#[tokio::test]
async fn test_fixed_window_limiter() {
let config = RateLimitConfig::new()
.with_max_requests(3)
.with_window(Duration::from_secs(60))
.with_sliding_window(false);
let limiter = RateLimiter::new(config);
let key = "test:fixed";
assert!(limiter.check(key).await.is_ok());
assert!(limiter.check(key).await.is_ok());
assert!(limiter.check(key).await.is_ok());
assert!(limiter.check(key).await.is_err());
}
#[test]
fn test_token_bucket_basic() {
let config = TokenBucketConfig::new(5, 1.0);
let limiter = TokenBucketLimiter::new(config);
let key = "test:bucket";
for _ in 0..5 {
assert!(limiter.check(key));
}
assert!(!limiter.check(key));
}
#[test]
fn test_token_bucket_refill() {
let config = TokenBucketConfig::new(10, 100.0); let limiter = TokenBucketLimiter::new(config);
let key = "test:refill";
for _ in 0..10 {
limiter.check(key);
}
thread::sleep(Duration::from_millis(50));
let tokens = limiter.available_tokens(key);
assert!(tokens > 0.0);
}
#[tokio::test]
async fn test_composite_limiter() {
let limiter1 = Arc::new(RateLimiter::new(
RateLimitConfig::new()
.with_max_requests(2)
.with_window(Duration::from_secs(60)),
));
let limiter2 = Arc::new(RateLimiter::new(
RateLimitConfig::new()
.with_max_requests(5)
.with_window(Duration::from_secs(60)),
));
let composite = CompositeRateLimiter::new(vec![limiter1, limiter2]);
let key = "test:composite";
assert!(composite.check(key).await.is_ok());
assert!(composite.check(key).await.is_ok());
assert!(composite.check(key).await.is_err());
}
#[test]
fn test_login_config() {
let config = RateLimitConfig::for_login();
assert_eq!(config.max_requests, 5);
assert_eq!(config.window, Duration::from_secs(60));
assert!(config.ban_duration.is_some());
}
#[test]
fn test_api_config() {
let config = RateLimitConfig::for_api();
assert_eq!(config.max_requests, 60);
assert_eq!(config.window, Duration::from_secs(60));
}
#[tokio::test]
async fn test_remaining_count() {
let config = RateLimitConfig::new()
.with_max_requests(5)
.with_window(Duration::from_secs(60));
let limiter = RateLimiter::new(config);
let key = "test:remaining";
let info = limiter.check(key).await.unwrap();
assert_eq!(info.remaining, 4);
assert_eq!(info.limit, 5);
let info = limiter.check(key).await.unwrap();
assert_eq!(info.remaining, 3);
let info = limiter.check(key).await.unwrap();
assert_eq!(info.remaining, 2);
}
}