1use std::collections::HashMap;
20use std::sync::Mutex;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum AuthBucket {
25 Login,
28 Send,
33 Verify,
37}
38
39impl AuthBucket {
40 fn caps(&self) -> (u32, u32) {
42 match self {
43 Self::Login => (5, 30),
45 Self::Send => (3, 10),
47 Self::Verify => (30, 100),
49 }
50 }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum RateLimitDecision {
55 Allow,
56 Deny {
59 retry_after_secs: u64,
60 },
61}
62
63#[derive(Debug, Clone, Copy)]
67struct Counter {
68 window_start: u64,
69 count: u32,
70}
71
72pub struct AuthRateLimiter {
73 per_ip: Mutex<HashMap<(AuthBucket, String), Counter>>,
74 per_account: Mutex<HashMap<(AuthBucket, String), Counter>>,
75}
76
77impl Default for AuthRateLimiter {
78 fn default() -> Self {
79 Self {
80 per_ip: Mutex::new(HashMap::new()),
81 per_account: Mutex::new(HashMap::new()),
82 }
83 }
84}
85
86impl AuthRateLimiter {
87 pub fn new() -> Self {
88 Self::default()
89 }
90
91 pub fn shared() -> &'static AuthRateLimiter {
95 static CELL: std::sync::OnceLock<AuthRateLimiter> = std::sync::OnceLock::new();
96 CELL.get_or_init(AuthRateLimiter::default)
97 }
98
99 pub fn check(
103 &self,
104 bucket: AuthBucket,
105 ip: &str,
106 account_key: Option<&str>,
107 ) -> RateLimitDecision {
108 let (ip_cap, acct_cap) = bucket.caps();
109 let now = now_secs();
110 if let Some(retry) = bump(&self.per_ip, (bucket, ip.to_string()), 60, ip_cap, now) {
112 return RateLimitDecision::Deny {
113 retry_after_secs: retry,
114 };
115 }
116 if let Some(key) = account_key {
117 if let Some(retry) = bump(
118 &self.per_account,
119 (bucket, key.to_ascii_lowercase()),
120 3600,
121 acct_cap,
122 now,
123 ) {
124 return RateLimitDecision::Deny {
125 retry_after_secs: retry,
126 };
127 }
128 }
129 RateLimitDecision::Allow
130 }
131}
132
133fn bump(
134 map: &Mutex<HashMap<(AuthBucket, String), Counter>>,
135 key: (AuthBucket, String),
136 window_secs: u64,
137 cap: u32,
138 now: u64,
139) -> Option<u64> {
140 let mut g = map.lock().unwrap();
141 let entry = g.entry(key).or_insert(Counter {
142 window_start: now,
143 count: 0,
144 });
145 if now >= entry.window_start + window_secs {
146 entry.window_start = now;
147 entry.count = 0;
148 }
149 if entry.count >= cap {
150 return Some(entry.window_start + window_secs - now);
151 }
152 entry.count += 1;
153 None
154}
155
156fn now_secs() -> u64 {
157 use std::time::{SystemTime, UNIX_EPOCH};
158 SystemTime::now()
159 .duration_since(UNIX_EPOCH)
160 .unwrap_or_default()
161 .as_secs()
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 #[test]
169 fn allows_within_cap() {
170 let rl = AuthRateLimiter::new();
171 for _ in 0..5 {
172 assert_eq!(
173 rl.check(AuthBucket::Login, "1.2.3.4", Some("a@b.com")),
174 RateLimitDecision::Allow
175 );
176 }
177 }
178
179 #[test]
180 fn denies_after_per_ip_cap() {
181 let rl = AuthRateLimiter::new();
182 let bucket = AuthBucket::Login;
183 let (ip_cap, _) = bucket.caps();
184 for _ in 0..ip_cap {
185 assert_eq!(rl.check(bucket, "1.2.3.4", None), RateLimitDecision::Allow);
186 }
187 match rl.check(bucket, "1.2.3.4", None) {
188 RateLimitDecision::Deny { retry_after_secs } => assert!(retry_after_secs <= 60),
189 _ => panic!("expected Deny"),
190 }
191 }
192
193 #[test]
194 fn per_account_cap_independent_of_ip() {
195 let rl = AuthRateLimiter::new();
196 let bucket = AuthBucket::Send;
197 let (_, acct_cap) = bucket.caps();
198 for i in 0..acct_cap {
200 let ip = format!("10.0.0.{i}");
201 assert_eq!(
202 rl.check(bucket, &ip, Some("victim@x.com")),
203 RateLimitDecision::Allow
204 );
205 }
206 let result = rl.check(bucket, "10.0.0.99", Some("victim@x.com"));
207 assert!(matches!(result, RateLimitDecision::Deny { .. }));
208 }
209
210 #[test]
211 fn account_key_lowercased() {
212 let rl = AuthRateLimiter::new();
213 let bucket = AuthBucket::Send;
214 let (_, acct_cap) = bucket.caps();
215 for i in 0..acct_cap {
218 let ip = format!("10.0.0.{i}");
219 let _ = rl.check(bucket, &ip, Some("a@b.com"));
220 }
221 let result = rl.check(bucket, "172.16.0.1", Some("A@B.COM"));
224 assert!(matches!(result, RateLimitDecision::Deny { .. }));
225 }
226}