Skip to main content

pylon_auth/
rate_limit.rs

1//! In-process token-bucket rate limiter for auth endpoints.
2//!
3//! Sized for the auth surface specifically: small key space (per-IP
4//! and per-account, not per-(IP, route)), short windows, fixed
5//! defaults that match Better-Auth's posture. Apps that need
6//! cluster-wide rate limits across multiple replicas should put a
7//! reverse proxy in front (Cloudflare / Caddy / nginx limit_req).
8//!
9//! Two scopes:
10//!   - **per-IP**: blanket cap on auth attempts from a single client.
11//!     Stops trivial credential-stuffing from one box.
12//!   - **per-account**: caps attempts against a single
13//!     email/user_id/phone — slower than per-IP but harder to bypass
14//!     (an attacker who rotates IPs still hits the per-account cap).
15//!
16//! Limits are tuned to be invisible to humans (1 retry/s leaves you
17//! plenty of headroom) but make brute force impractical.
18
19use std::collections::HashMap;
20use std::sync::Mutex;
21
22/// Auth endpoint families with distinct rate limits.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum AuthBucket {
25    /// `/api/auth/password/login`, `/api/auth/totp/verify` — credential
26    /// guesses. Strictest cap.
27    Login,
28    /// `/api/auth/password/register`, `/api/auth/magic-link/send`,
29    /// `/api/auth/magic/send`, `/api/auth/password/reset/request`,
30    /// `/api/auth/phone/send-code` — sends an email/SMS or creates a
31    /// user. Caps email-bombing + signup spam.
32    Send,
33    /// `/api/auth/passkey/login/finish`, `/api/auth/siwe/verify` —
34    /// public verify endpoints with cryptographic gates. Caps the
35    /// signature-fuzzing class.
36    Verify,
37}
38
39impl AuthBucket {
40    /// `(per_ip_limit_per_min, per_account_limit_per_hour)`.
41    fn caps(&self) -> (u32, u32) {
42        match self {
43            // 5 logins/min/IP, 30/hr/account — Better-Auth-equivalent.
44            Self::Login => (5, 30),
45            // 3 sends/min/IP, 10/hr/email — protects SMS/email spend.
46            Self::Send => (3, 10),
47            // 30/min/IP — generous because legitimate flows can retry.
48            Self::Verify => (30, 100),
49        }
50    }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum RateLimitDecision {
55    Allow,
56    /// Caller exceeded the cap. `retry_after_secs` is a hint for the
57    /// 429 `Retry-After` header.
58    Deny {
59        retry_after_secs: u64,
60    },
61}
62
63/// Token-bucket counter — for each `(bucket, key)` we track the
64/// epoch-second window start + count. When the window rolls over,
65/// the count resets. Cheap O(1) per check.
66#[derive(Debug, Clone, Copy)]
67struct Counter {
68    window_start: u64,
69    count: u32,
70}
71
72pub struct AuthRateLimiter {
73    per_ip: Mutex<HashMap<(AuthBucket, String), Counter>>,
74    per_account: Mutex<HashMap<(AuthBucket, String), Counter>>,
75}
76
77impl Default for AuthRateLimiter {
78    fn default() -> Self {
79        Self {
80            per_ip: Mutex::new(HashMap::new()),
81            per_account: Mutex::new(HashMap::new()),
82        }
83    }
84}
85
86impl AuthRateLimiter {
87    pub fn new() -> Self {
88        Self::default()
89    }
90
91    /// Process-wide singleton. Auth endpoints use this so per-IP
92    /// counters survive across requests without plumbing a store
93    /// through every call site.
94    pub fn shared() -> &'static AuthRateLimiter {
95        static CELL: std::sync::OnceLock<AuthRateLimiter> = std::sync::OnceLock::new();
96        CELL.get_or_init(AuthRateLimiter::default)
97    }
98
99    /// Check + bump. `account_key` is the email/user_id/phone — pass
100    /// `None` for endpoints with no pre-auth account binding (e.g.
101    /// passkey/login/begin).
102    pub fn check(
103        &self,
104        bucket: AuthBucket,
105        ip: &str,
106        account_key: Option<&str>,
107    ) -> RateLimitDecision {
108        let (ip_cap, acct_cap) = bucket.caps();
109        let now = now_secs();
110        // 1-minute window for IP, 1-hour window for account.
111        if let Some(retry) = bump(&self.per_ip, (bucket, ip.to_string()), 60, ip_cap, now) {
112            return RateLimitDecision::Deny {
113                retry_after_secs: retry,
114            };
115        }
116        if let Some(key) = account_key {
117            if let Some(retry) = bump(
118                &self.per_account,
119                (bucket, key.to_ascii_lowercase()),
120                3600,
121                acct_cap,
122                now,
123            ) {
124                return RateLimitDecision::Deny {
125                    retry_after_secs: retry,
126                };
127            }
128        }
129        RateLimitDecision::Allow
130    }
131}
132
133fn bump(
134    map: &Mutex<HashMap<(AuthBucket, String), Counter>>,
135    key: (AuthBucket, String),
136    window_secs: u64,
137    cap: u32,
138    now: u64,
139) -> Option<u64> {
140    let mut g = map.lock().unwrap();
141    let entry = g.entry(key).or_insert(Counter {
142        window_start: now,
143        count: 0,
144    });
145    if now >= entry.window_start + window_secs {
146        entry.window_start = now;
147        entry.count = 0;
148    }
149    if entry.count >= cap {
150        return Some(entry.window_start + window_secs - now);
151    }
152    entry.count += 1;
153    None
154}
155
156fn now_secs() -> u64 {
157    use std::time::{SystemTime, UNIX_EPOCH};
158    SystemTime::now()
159        .duration_since(UNIX_EPOCH)
160        .unwrap_or_default()
161        .as_secs()
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    #[test]
169    fn allows_within_cap() {
170        let rl = AuthRateLimiter::new();
171        for _ in 0..5 {
172            assert_eq!(
173                rl.check(AuthBucket::Login, "1.2.3.4", Some("a@b.com")),
174                RateLimitDecision::Allow
175            );
176        }
177    }
178
179    #[test]
180    fn denies_after_per_ip_cap() {
181        let rl = AuthRateLimiter::new();
182        let bucket = AuthBucket::Login;
183        let (ip_cap, _) = bucket.caps();
184        for _ in 0..ip_cap {
185            assert_eq!(rl.check(bucket, "1.2.3.4", None), RateLimitDecision::Allow);
186        }
187        match rl.check(bucket, "1.2.3.4", None) {
188            RateLimitDecision::Deny { retry_after_secs } => assert!(retry_after_secs <= 60),
189            _ => panic!("expected Deny"),
190        }
191    }
192
193    #[test]
194    fn per_account_cap_independent_of_ip() {
195        let rl = AuthRateLimiter::new();
196        let bucket = AuthBucket::Send;
197        let (_, acct_cap) = bucket.caps();
198        // Rotate IPs to exhaust per-account before per-IP.
199        for i in 0..acct_cap {
200            let ip = format!("10.0.0.{i}");
201            assert_eq!(
202                rl.check(bucket, &ip, Some("victim@x.com")),
203                RateLimitDecision::Allow
204            );
205        }
206        let result = rl.check(bucket, "10.0.0.99", Some("victim@x.com"));
207        assert!(matches!(result, RateLimitDecision::Deny { .. }));
208    }
209
210    #[test]
211    fn account_key_lowercased() {
212        let rl = AuthRateLimiter::new();
213        let bucket = AuthBucket::Send;
214        let (_, acct_cap) = bucket.caps();
215        // Rotate IPs so we exhaust the per-account counter before
216        // any single IP hits its own per-minute cap.
217        for i in 0..acct_cap {
218            let ip = format!("10.0.0.{i}");
219            let _ = rl.check(bucket, &ip, Some("a@b.com"));
220        }
221        // Capitalized variant of the same email must hit the same
222        // (now-exhausted) per-account bucket from a fresh IP.
223        let result = rl.check(bucket, "172.16.0.1", Some("A@B.COM"));
224        assert!(matches!(result, RateLimitDecision::Deny { .. }));
225    }
226}