use dashmap::DashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::Instant;
use tonic::Status;
use tracing::{info, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum RateLimitType {
Email,
Ip,
}
#[derive(Debug)]
struct RateLimitEntry {
count: u32,
reset_at: Instant,
}
#[derive(Clone, Default)]
pub struct InMemoryRateLimiter {
state: Arc<DashMap<(RateLimitType, String), RateLimitEntry>>,
}
impl InMemoryRateLimiter {
#[must_use]
pub fn new() -> Self {
Self {
state: Arc::new(DashMap::new()),
}
}
#[allow(clippy::duration_suboptimal_units)]
pub fn check_limit(&self, limit_type: RateLimitType, identity: &str) -> Result<(), Status> {
let key = (limit_type, identity.to_string());
let now = Instant::now();
let mut entry = self
.state
.entry(key.clone())
.or_insert_with(|| RateLimitEntry {
count: 0,
reset_at: now + Duration::from_secs(3600),
});
if now > entry.reset_at {
entry.count = 0;
entry.reset_at = now + Duration::from_secs(3600);
}
let limit = match limit_type {
RateLimitType::Email => 5,
RateLimitType::Ip => 30,
};
if entry.count >= limit {
warn!(
type = ?limit_type,
identity = %identity,
count = entry.count,
"Rate limit exceeded"
);
return Err(Status::resource_exhausted(format!(
"Rate limit exceeded for {limit_type:?}: {identity}. Try again later."
)));
}
entry.count += 1;
info!(
type = ?limit_type,
identity = %identity,
count = entry.count,
"Rate limit check passed"
);
Ok(())
}
pub fn cleanup(&self) {
let now = Instant::now();
self.state.retain(|_, entry| entry.reset_at > now);
}
}
#[cfg(test)]
#[allow(clippy::duration_suboptimal_units)]
mod tests {
use super::*;
#[test]
fn test_email_rate_limit() {
let limiter = InMemoryRateLimiter::new();
let email = "test@example.com";
for _ in 0..5 {
assert!(limiter.check_limit(RateLimitType::Email, email).is_ok());
}
let result = limiter.check_limit(RateLimitType::Email, email);
assert!(result.is_err());
assert_eq!(result.unwrap_err().code(), tonic::Code::ResourceExhausted);
}
#[test]
fn test_ip_rate_limit() {
let limiter = InMemoryRateLimiter::new();
let ip = "127.0.0.1";
for _ in 0..30 {
assert!(limiter.check_limit(RateLimitType::Ip, ip).is_ok());
}
let result = limiter.check_limit(RateLimitType::Ip, ip);
assert!(result.is_err());
assert_eq!(result.unwrap_err().code(), tonic::Code::ResourceExhausted);
}
#[tokio::test]
async fn test_rate_limit_reset() {
tokio::time::pause();
let limiter = InMemoryRateLimiter::new();
let email = "reset@example.com";
for _ in 0..5 {
let _ = limiter.check_limit(RateLimitType::Email, email);
}
assert!(limiter.check_limit(RateLimitType::Email, email).is_err());
tokio::time::advance(Duration::from_secs(3601)).await;
assert!(limiter.check_limit(RateLimitType::Email, email).is_ok());
}
#[tokio::test]
async fn test_rate_limit_concurrency() {
let limiter = Arc::new(InMemoryRateLimiter::new());
let ip = "192.168.1.1";
let mut handles = vec![];
for _ in 0..100 {
let l = Arc::clone(&limiter);
let target = ip.to_string();
handles.push(tokio::spawn(async move {
l.check_limit(RateLimitType::Ip, &target)
}));
}
let results = futures::future::join_all(handles).await;
let success_count = results
.into_iter()
.filter(|r| r.as_ref().unwrap().is_ok())
.count();
assert_eq!(success_count, 30);
}
#[test]
fn test_rate_limit_cleanup() {
let limiter = InMemoryRateLimiter::new();
let now = Instant::now();
let email_stale = "stale@example.com";
let email_fresh = "fresh@example.com";
limiter.state.insert(
(RateLimitType::Email, email_stale.to_string()),
RateLimitEntry {
count: 5,
reset_at: now.checked_sub(Duration::from_secs(3600)).unwrap(),
},
);
limiter.state.insert(
(RateLimitType::Email, email_fresh.to_string()),
RateLimitEntry {
count: 1,
reset_at: now + Duration::from_secs(3600),
},
);
assert_eq!(limiter.state.len(), 2);
limiter.cleanup();
assert_eq!(limiter.state.len(), 1);
assert!(
limiter
.state
.contains_key(&(RateLimitType::Email, email_fresh.to_string()))
);
}
}