construct/gateway/
auth_rate_limit.rs1use parking_lot::Mutex;
8use std::collections::HashMap;
9use std::time::{Duration, Instant};
10
11pub const MAX_ATTEMPTS: u32 = 10;
13pub const WINDOW_SECS: u64 = 60;
15pub const LOCKOUT_SECS: u64 = 300;
17const SWEEP_INTERVAL_SECS: u64 = 300;
19
20#[derive(Debug, Clone)]
22pub struct RateLimitError {
23 pub retry_after_secs: u64,
25}
26
27#[derive(Debug)]
29pub struct AuthRateLimiter {
30 inner: Mutex<Inner>,
31}
32
33#[derive(Debug)]
34struct Inner {
35 attempts: HashMap<String, Vec<Instant>>,
37 lockouts: HashMap<String, Instant>,
39 last_sweep: Instant,
40}
41
42impl AuthRateLimiter {
43 pub fn new() -> Self {
44 Self {
45 inner: Mutex::new(Inner {
46 attempts: HashMap::new(),
47 lockouts: HashMap::new(),
48 last_sweep: Instant::now(),
49 }),
50 }
51 }
52
53 pub fn check_rate_limit(
64 &self,
65 key: &str,
66 peer_is_loopback: bool,
67 ) -> Result<(), RateLimitError> {
68 if peer_is_loopback {
69 return Ok(());
70 }
71
72 let now = Instant::now();
73 let mut inner = self.inner.lock();
74 Self::maybe_sweep(&mut inner, now);
75
76 if let Some(&locked_at) = inner.lockouts.get(key) {
78 let elapsed = now.duration_since(locked_at).as_secs();
79 if elapsed < LOCKOUT_SECS {
80 return Err(RateLimitError {
81 retry_after_secs: LOCKOUT_SECS - elapsed,
82 });
83 }
84 inner.lockouts.remove(key);
86 inner.attempts.remove(key);
87 }
88
89 let window = Duration::from_secs(WINDOW_SECS);
91 if let Some(timestamps) = inner.attempts.get_mut(key) {
92 timestamps.retain(|t| now.duration_since(*t) < window);
93 if timestamps.len() >= MAX_ATTEMPTS as usize {
94 inner.lockouts.insert(key.to_owned(), now);
96 return Err(RateLimitError {
97 retry_after_secs: LOCKOUT_SECS,
98 });
99 }
100 }
101
102 Ok(())
103 }
104
105 pub fn record_attempt(&self, key: &str, peer_is_loopback: bool) {
110 if peer_is_loopback {
111 return;
112 }
113
114 let now = Instant::now();
115 let mut inner = self.inner.lock();
116 inner.attempts.entry(key.to_owned()).or_default().push(now);
117 }
118
119 pub fn is_locked_out(&self, key: &str, peer_is_loopback: bool) -> bool {
124 if peer_is_loopback {
125 return false;
126 }
127
128 let now = Instant::now();
129 let inner = self.inner.lock();
130 if let Some(&locked_at) = inner.lockouts.get(key) {
131 return now.duration_since(locked_at).as_secs() < LOCKOUT_SECS;
132 }
133 false
134 }
135
136 fn maybe_sweep(inner: &mut Inner, now: Instant) {
138 if inner.last_sweep.elapsed() < Duration::from_secs(SWEEP_INTERVAL_SECS) {
139 return;
140 }
141 inner.last_sweep = now;
142
143 let lockout_dur = Duration::from_secs(LOCKOUT_SECS);
144 let window_dur = Duration::from_secs(WINDOW_SECS);
145
146 inner
147 .lockouts
148 .retain(|_, locked_at| now.duration_since(*locked_at) < lockout_dur);
149
150 inner.attempts.retain(|_, timestamps| {
151 timestamps.retain(|t| now.duration_since(*t) < window_dur);
152 !timestamps.is_empty()
153 });
154 }
155}
156
157impl Default for AuthRateLimiter {
158 fn default() -> Self {
159 Self::new()
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 #[test]
168 fn loopback_peer_is_exempt() {
169 let limiter = AuthRateLimiter::new();
170 for _ in 0..20 {
171 assert!(limiter.check_rate_limit("whatever", true).is_ok());
172 limiter.record_attempt("whatever", true);
173 }
174 assert!(!limiter.is_locked_out("whatever", true));
175 }
176
177 #[test]
178 fn spoofed_loopback_key_is_not_exempt() {
179 let limiter = AuthRateLimiter::new();
183 let key = "127.0.0.1";
184 for _ in 0..MAX_ATTEMPTS {
185 assert!(limiter.check_rate_limit(key, false).is_ok());
186 limiter.record_attempt(key, false);
187 }
188 assert!(limiter.check_rate_limit(key, false).is_err());
189 assert!(limiter.is_locked_out(key, false));
190 }
191
192 #[test]
193 fn lockout_after_max_attempts() {
194 let limiter = AuthRateLimiter::new();
195 let key = "192.168.1.100";
196
197 for _ in 0..MAX_ATTEMPTS {
198 assert!(limiter.check_rate_limit(key, false).is_ok());
199 limiter.record_attempt(key, false);
200 }
201
202 let err = limiter.check_rate_limit(key, false).unwrap_err();
204 assert!(err.retry_after_secs > 0);
205 assert!(limiter.is_locked_out(key, false));
206 }
207
208 #[test]
209 fn under_limit_is_ok() {
210 let limiter = AuthRateLimiter::new();
211 let key = "10.0.0.1";
212
213 for _ in 0..(MAX_ATTEMPTS - 1) {
214 assert!(limiter.check_rate_limit(key, false).is_ok());
215 limiter.record_attempt(key, false);
216 }
217 assert!(limiter.check_rate_limit(key, false).is_ok());
219 }
220}