aetheris_server/auth/
rate_limit.rs1use dashmap::DashMap;
2use std::sync::Arc;
3use std::time::Duration;
4use tokio::time::Instant;
5
6use tonic::Status;
7use tracing::{info, warn};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum RateLimitType {
11 Email,
12 Ip,
13}
14
15#[derive(Debug)]
16struct RateLimitEntry {
17 count: u32,
18 reset_at: Instant,
19}
20
21#[derive(Clone, Default)]
26pub struct InMemoryRateLimiter {
27 state: Arc<DashMap<(RateLimitType, String), RateLimitEntry>>,
29}
30
31impl InMemoryRateLimiter {
32 #[must_use]
33 pub fn new() -> Self {
34 Self {
35 state: Arc::new(DashMap::new()),
36 }
37 }
38
39 #[allow(clippy::duration_suboptimal_units)]
43 pub fn check_limit(&self, limit_type: RateLimitType, identity: &str) -> Result<(), Status> {
44 let key = (limit_type, identity.to_string());
45 let now = Instant::now();
46
47 let mut entry = self
49 .state
50 .entry(key.clone())
51 .or_insert_with(|| RateLimitEntry {
52 count: 0,
53 reset_at: now + Duration::from_secs(3600),
54 });
55
56 if now > entry.reset_at {
58 entry.count = 0;
59 entry.reset_at = now + Duration::from_secs(3600);
60 }
61
62 let limit = match limit_type {
64 RateLimitType::Email => 5,
65 RateLimitType::Ip => 30,
66 };
67
68 if entry.count >= limit {
69 warn!(
70 type = ?limit_type,
71 identity = %identity,
72 count = entry.count,
73 "Rate limit exceeded"
74 );
75 return Err(Status::resource_exhausted(format!(
76 "Rate limit exceeded for {limit_type:?}: {identity}. Try again later."
77 )));
78 }
79
80 entry.count += 1;
82 info!(
83 type = ?limit_type,
84 identity = %identity,
85 count = entry.count,
86 "Rate limit check passed"
87 );
88
89 Ok(())
90 }
91
92 pub fn cleanup(&self) {
94 let now = Instant::now();
95 self.state.retain(|_, entry| entry.reset_at > now);
96 }
97}
98
99#[cfg(test)]
100#[allow(clippy::duration_suboptimal_units)]
101mod tests {
102 use super::*;
103
104 #[test]
105 fn test_email_rate_limit() {
106 let limiter = InMemoryRateLimiter::new();
107 let email = "test@example.com";
108
109 for _ in 0..5 {
111 assert!(limiter.check_limit(RateLimitType::Email, email).is_ok());
112 }
113
114 let result = limiter.check_limit(RateLimitType::Email, email);
116 assert!(result.is_err());
117 assert_eq!(result.unwrap_err().code(), tonic::Code::ResourceExhausted);
118 }
119
120 #[test]
121 fn test_ip_rate_limit() {
122 let limiter = InMemoryRateLimiter::new();
123 let ip = "127.0.0.1";
124
125 for _ in 0..30 {
127 assert!(limiter.check_limit(RateLimitType::Ip, ip).is_ok());
128 }
129
130 let result = limiter.check_limit(RateLimitType::Ip, ip);
132 assert!(result.is_err());
133 assert_eq!(result.unwrap_err().code(), tonic::Code::ResourceExhausted);
134 }
135
136 #[tokio::test]
137 async fn test_rate_limit_reset() {
138 tokio::time::pause();
139 let limiter = InMemoryRateLimiter::new();
140 let email = "reset@example.com";
141
142 for _ in 0..5 {
144 let _ = limiter.check_limit(RateLimitType::Email, email);
145 }
146 assert!(limiter.check_limit(RateLimitType::Email, email).is_err());
147
148 tokio::time::advance(Duration::from_secs(3601)).await;
150
151 assert!(limiter.check_limit(RateLimitType::Email, email).is_ok());
153 }
154
155 #[tokio::test]
156 async fn test_rate_limit_concurrency() {
157 let limiter = Arc::new(InMemoryRateLimiter::new());
158 let ip = "192.168.1.1";
159 let mut handles = vec![];
160
161 for _ in 0..100 {
163 let l = Arc::clone(&limiter);
164 let target = ip.to_string();
165 handles.push(tokio::spawn(async move {
166 l.check_limit(RateLimitType::Ip, &target)
167 }));
168 }
169
170 let results = futures::future::join_all(handles).await;
171 let success_count = results
172 .into_iter()
173 .filter(|r| r.as_ref().unwrap().is_ok())
174 .count();
175
176 assert_eq!(success_count, 30);
178 }
179
180 #[test]
181 fn test_rate_limit_cleanup() {
182 let limiter = InMemoryRateLimiter::new();
183 let now = Instant::now();
184 let email_stale = "stale@example.com";
185 let email_fresh = "fresh@example.com";
186
187 limiter.state.insert(
189 (RateLimitType::Email, email_stale.to_string()),
190 RateLimitEntry {
191 count: 5,
192 reset_at: now.checked_sub(Duration::from_secs(3600)).unwrap(),
193 },
194 );
195
196 limiter.state.insert(
198 (RateLimitType::Email, email_fresh.to_string()),
199 RateLimitEntry {
200 count: 1,
201 reset_at: now + Duration::from_secs(3600),
202 },
203 );
204
205 assert_eq!(limiter.state.len(), 2);
206
207 limiter.cleanup();
209
210 assert_eq!(limiter.state.len(), 1);
212 assert!(
213 limiter
214 .state
215 .contains_key(&(RateLimitType::Email, email_fresh.to_string()))
216 );
217 }
218}