1use std::collections::HashMap;
20use std::sync::Mutex;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum AuthBucket {
25 Login,
28 Send,
33 Verify,
37}
38
39impl AuthBucket {
40 fn caps(&self) -> (u32, u32) {
42 match self {
43 Self::Login => (5, 30),
45 Self::Send => (3, 10),
47 Self::Verify => (30, 100),
49 }
50 }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum RateLimitDecision {
55 Allow,
56 Deny { retry_after_secs: u64 },
59}
60
61#[derive(Debug, Clone, Copy)]
65struct Counter {
66 window_start: u64,
67 count: u32,
68}
69
70pub struct AuthRateLimiter {
71 per_ip: Mutex<HashMap<(AuthBucket, String), Counter>>,
72 per_account: Mutex<HashMap<(AuthBucket, String), Counter>>,
73}
74
75impl Default for AuthRateLimiter {
76 fn default() -> Self {
77 Self {
78 per_ip: Mutex::new(HashMap::new()),
79 per_account: Mutex::new(HashMap::new()),
80 }
81 }
82}
83
84impl AuthRateLimiter {
85 pub fn new() -> Self {
86 Self::default()
87 }
88
89 pub fn shared() -> &'static AuthRateLimiter {
93 static CELL: std::sync::OnceLock<AuthRateLimiter> = std::sync::OnceLock::new();
94 CELL.get_or_init(AuthRateLimiter::default)
95 }
96
97 pub fn check(
101 &self,
102 bucket: AuthBucket,
103 ip: &str,
104 account_key: Option<&str>,
105 ) -> RateLimitDecision {
106 let (ip_cap, acct_cap) = bucket.caps();
107 let now = now_secs();
108 if let Some(retry) = bump(&self.per_ip, (bucket, ip.to_string()), 60, ip_cap, now) {
110 return RateLimitDecision::Deny {
111 retry_after_secs: retry,
112 };
113 }
114 if let Some(key) = account_key {
115 if let Some(retry) = bump(
116 &self.per_account,
117 (bucket, key.to_ascii_lowercase()),
118 3600,
119 acct_cap,
120 now,
121 ) {
122 return RateLimitDecision::Deny {
123 retry_after_secs: retry,
124 };
125 }
126 }
127 RateLimitDecision::Allow
128 }
129}
130
131fn bump(
132 map: &Mutex<HashMap<(AuthBucket, String), Counter>>,
133 key: (AuthBucket, String),
134 window_secs: u64,
135 cap: u32,
136 now: u64,
137) -> Option<u64> {
138 let mut g = map.lock().unwrap();
139 let entry = g.entry(key).or_insert(Counter {
140 window_start: now,
141 count: 0,
142 });
143 if now >= entry.window_start + window_secs {
144 entry.window_start = now;
145 entry.count = 0;
146 }
147 if entry.count >= cap {
148 return Some(entry.window_start + window_secs - now);
149 }
150 entry.count += 1;
151 None
152}
153
154fn now_secs() -> u64 {
155 use std::time::{SystemTime, UNIX_EPOCH};
156 SystemTime::now()
157 .duration_since(UNIX_EPOCH)
158 .unwrap_or_default()
159 .as_secs()
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 #[test]
167 fn allows_within_cap() {
168 let rl = AuthRateLimiter::new();
169 for _ in 0..5 {
170 assert_eq!(
171 rl.check(AuthBucket::Login, "1.2.3.4", Some("a@b.com")),
172 RateLimitDecision::Allow
173 );
174 }
175 }
176
177 #[test]
178 fn denies_after_per_ip_cap() {
179 let rl = AuthRateLimiter::new();
180 let bucket = AuthBucket::Login;
181 let (ip_cap, _) = bucket.caps();
182 for _ in 0..ip_cap {
183 assert_eq!(rl.check(bucket, "1.2.3.4", None), RateLimitDecision::Allow);
184 }
185 match rl.check(bucket, "1.2.3.4", None) {
186 RateLimitDecision::Deny { retry_after_secs } => assert!(retry_after_secs <= 60),
187 _ => panic!("expected Deny"),
188 }
189 }
190
191 #[test]
192 fn per_account_cap_independent_of_ip() {
193 let rl = AuthRateLimiter::new();
194 let bucket = AuthBucket::Send;
195 let (_, acct_cap) = bucket.caps();
196 for i in 0..acct_cap {
198 let ip = format!("10.0.0.{i}");
199 assert_eq!(rl.check(bucket, &ip, Some("victim@x.com")), RateLimitDecision::Allow);
200 }
201 let result = rl.check(bucket, "10.0.0.99", Some("victim@x.com"));
202 assert!(matches!(result, RateLimitDecision::Deny { .. }));
203 }
204
205 #[test]
206 fn account_key_lowercased() {
207 let rl = AuthRateLimiter::new();
208 let bucket = AuthBucket::Send;
209 let (_, acct_cap) = bucket.caps();
210 for i in 0..acct_cap {
213 let ip = format!("10.0.0.{i}");
214 let _ = rl.check(bucket, &ip, Some("a@b.com"));
215 }
216 let result = rl.check(bucket, "172.16.0.1", Some("A@B.COM"));
219 assert!(matches!(result, RateLimitDecision::Deny { .. }));
220 }
221}