Skip to main content

aetheris_server/auth/
rate_limit.rs

1use dashmap::DashMap;
2use std::sync::Arc;
3use std::time::Duration;
4use tokio::time::Instant;
5
6use tonic::Status;
7use tracing::{info, warn};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum RateLimitType {
11    Email,
12    Ip,
13}
14
15#[derive(Debug)]
16struct RateLimitEntry {
17    count: u32,
18    reset_at: Instant,
19}
20
21/// In-memory rate limiter for authentication attempts.
22///
23/// M10146 — Implements per-email (5/h) and per-IP (30/h) limits
24/// to prevent OTP brute-force and resource exhaustion.
25#[derive(Clone, Default)]
26pub struct InMemoryRateLimiter {
27    /// Maps (Type, Identity) -> Entry
28    state: Arc<DashMap<(RateLimitType, String), RateLimitEntry>>,
29}
30
31impl InMemoryRateLimiter {
32    #[must_use]
33    pub fn new() -> Self {
34        Self {
35            state: Arc::new(DashMap::new()),
36        }
37    }
38
39    /// Checks if a request should be rate-limited.
40    ///
41    /// returns `Ok(())` if allowed, or `Err(Status)` if limited.
42    #[allow(clippy::duration_suboptimal_units)]
43    pub fn check_limit(&self, limit_type: RateLimitType, identity: &str) -> Result<(), Status> {
44        let key = (limit_type, identity.to_string());
45        let now = Instant::now();
46
47        // 1. Get or create entry
48        let mut entry = self
49            .state
50            .entry(key.clone())
51            .or_insert_with(|| RateLimitEntry {
52                count: 0,
53                reset_at: now + Duration::from_secs(3600),
54            });
55
56        // 2. Check for reset
57        if now > entry.reset_at {
58            entry.count = 0;
59            entry.reset_at = now + Duration::from_secs(3600);
60        }
61
62        // 3. Enforce limit
63        let limit = match limit_type {
64            RateLimitType::Email => 5,
65            RateLimitType::Ip => 30,
66        };
67
68        if entry.count >= limit {
69            warn!(
70                type = ?limit_type,
71                identity = %identity,
72                count = entry.count,
73                "Rate limit exceeded"
74            );
75            return Err(Status::resource_exhausted(format!(
76                "Rate limit exceeded for {limit_type:?}: {identity}. Try again later."
77            )));
78        }
79
80        // 4. Increment
81        entry.count += 1;
82        info!(
83            type = ?limit_type,
84            identity = %identity,
85            count = entry.count,
86            "Rate limit check passed"
87        );
88
89        Ok(())
90    }
91
92    /// Periodic cleanup of expired entries (optional for MVP, but good for hygiene).
93    pub fn cleanup(&self) {
94        let now = Instant::now();
95        self.state.retain(|_, entry| entry.reset_at > now);
96    }
97}
98
99#[cfg(test)]
100#[allow(clippy::duration_suboptimal_units)]
101mod tests {
102    use super::*;
103
104    #[test]
105    fn test_email_rate_limit() {
106        let limiter = InMemoryRateLimiter::new();
107        let email = "test@example.com";
108
109        // First 5 should succeed
110        for _ in 0..5 {
111            assert!(limiter.check_limit(RateLimitType::Email, email).is_ok());
112        }
113
114        // 6th should fail
115        let result = limiter.check_limit(RateLimitType::Email, email);
116        assert!(result.is_err());
117        assert_eq!(result.unwrap_err().code(), tonic::Code::ResourceExhausted);
118    }
119
120    #[test]
121    fn test_ip_rate_limit() {
122        let limiter = InMemoryRateLimiter::new();
123        let ip = "127.0.0.1";
124
125        // First 30 should succeed
126        for _ in 0..30 {
127            assert!(limiter.check_limit(RateLimitType::Ip, ip).is_ok());
128        }
129
130        // 31st should fail
131        let result = limiter.check_limit(RateLimitType::Ip, ip);
132        assert!(result.is_err());
133        assert_eq!(result.unwrap_err().code(), tonic::Code::ResourceExhausted);
134    }
135
136    #[tokio::test]
137    async fn test_rate_limit_reset() {
138        tokio::time::pause();
139        let limiter = InMemoryRateLimiter::new();
140        let email = "reset@example.com";
141
142        // Exhaust the limit
143        for _ in 0..5 {
144            let _ = limiter.check_limit(RateLimitType::Email, email);
145        }
146        assert!(limiter.check_limit(RateLimitType::Email, email).is_err());
147
148        // Advance time by 1 hour + 1 second
149        tokio::time::advance(Duration::from_secs(3601)).await;
150
151        // Should succeed now
152        assert!(limiter.check_limit(RateLimitType::Email, email).is_ok());
153    }
154
155    #[tokio::test]
156    async fn test_rate_limit_concurrency() {
157        let limiter = Arc::new(InMemoryRateLimiter::new());
158        let ip = "192.168.1.1";
159        let mut handles = vec![];
160
161        // Spawn 100 tasks hitting the same IP
162        for _ in 0..100 {
163            let l = Arc::clone(&limiter);
164            let target = ip.to_string();
165            handles.push(tokio::spawn(async move {
166                l.check_limit(RateLimitType::Ip, &target)
167            }));
168        }
169
170        let results = futures::future::join_all(handles).await;
171        let success_count = results
172            .into_iter()
173            .filter(|r| r.as_ref().unwrap().is_ok())
174            .count();
175
176        // Exactly 30 should have succeeded
177        assert_eq!(success_count, 30);
178    }
179
180    #[test]
181    fn test_rate_limit_cleanup() {
182        let limiter = InMemoryRateLimiter::new();
183        let now = Instant::now();
184        let email_stale = "stale@example.com";
185        let email_fresh = "fresh@example.com";
186
187        // Create a stale entry
188        limiter.state.insert(
189            (RateLimitType::Email, email_stale.to_string()),
190            RateLimitEntry {
191                count: 5,
192                reset_at: now.checked_sub(Duration::from_secs(3600)).unwrap(),
193            },
194        );
195
196        // Create a fresh entry
197        limiter.state.insert(
198            (RateLimitType::Email, email_fresh.to_string()),
199            RateLimitEntry {
200                count: 1,
201                reset_at: now + Duration::from_secs(3600),
202            },
203        );
204
205        assert_eq!(limiter.state.len(), 2);
206
207        // Cleanup
208        limiter.cleanup();
209
210        // Only fresh should remain
211        assert_eq!(limiter.state.len(), 1);
212        assert!(
213            limiter
214                .state
215                .contains_key(&(RateLimitType::Email, email_fresh.to_string()))
216        );
217    }
218}