Skip to main content

pylon_auth/
rate_limit.rs

1//! In-process token-bucket rate limiter for auth endpoints.
2//!
3//! Sized for the auth surface specifically: small key space (per-IP
4//! and per-account, not per-(IP, route)), short windows, fixed
5//! defaults that match Better-Auth's posture. Apps that need
6//! cluster-wide rate limits across multiple replicas should put a
7//! reverse proxy in front (Cloudflare / Caddy / nginx limit_req).
8//!
9//! Two scopes:
10//!   - **per-IP**: blanket cap on auth attempts from a single client.
11//!     Stops trivial credential-stuffing from one box.
12//!   - **per-account**: caps attempts against a single
13//!     email/user_id/phone — slower than per-IP but harder to bypass
14//!     (an attacker who rotates IPs still hits the per-account cap).
15//!
16//! Limits are tuned to be invisible to humans (1 retry/s leaves you
17//! plenty of headroom) but make brute force impractical.
18
19use std::collections::HashMap;
20use std::sync::Mutex;
21
22/// Auth endpoint families with distinct rate limits.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum AuthBucket {
25    /// `/api/auth/password/login`, `/api/auth/totp/verify` — credential
26    /// guesses. Strictest cap.
27    Login,
28    /// `/api/auth/password/register`, `/api/auth/magic-link/send`,
29    /// `/api/auth/magic/send`, `/api/auth/password/reset/request`,
30    /// `/api/auth/phone/send-code` — sends an email/SMS or creates a
31    /// user. Caps email-bombing + signup spam.
32    Send,
33    /// `/api/auth/passkey/login/finish`, `/api/auth/siwe/verify` —
34    /// public verify endpoints with cryptographic gates. Caps the
35    /// signature-fuzzing class.
36    Verify,
37}
38
39impl AuthBucket {
40    /// `(per_ip_limit_per_min, per_account_limit_per_hour)`.
41    fn caps(&self) -> (u32, u32) {
42        match self {
43            // 5 logins/min/IP, 30/hr/account — Better-Auth-equivalent.
44            Self::Login => (5, 30),
45            // 3 sends/min/IP, 10/hr/email — protects SMS/email spend.
46            Self::Send => (3, 10),
47            // 30/min/IP — generous because legitimate flows can retry.
48            Self::Verify => (30, 100),
49        }
50    }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum RateLimitDecision {
55    Allow,
56    /// Caller exceeded the cap. `retry_after_secs` is a hint for the
57    /// 429 `Retry-After` header.
58    Deny { retry_after_secs: u64 },
59}
60
61/// Token-bucket counter — for each `(bucket, key)` we track the
62/// epoch-second window start + count. When the window rolls over,
63/// the count resets. Cheap O(1) per check.
64#[derive(Debug, Clone, Copy)]
65struct Counter {
66    window_start: u64,
67    count: u32,
68}
69
70pub struct AuthRateLimiter {
71    per_ip: Mutex<HashMap<(AuthBucket, String), Counter>>,
72    per_account: Mutex<HashMap<(AuthBucket, String), Counter>>,
73}
74
75impl Default for AuthRateLimiter {
76    fn default() -> Self {
77        Self {
78            per_ip: Mutex::new(HashMap::new()),
79            per_account: Mutex::new(HashMap::new()),
80        }
81    }
82}
83
84impl AuthRateLimiter {
85    pub fn new() -> Self {
86        Self::default()
87    }
88
89    /// Process-wide singleton. Auth endpoints use this so per-IP
90    /// counters survive across requests without plumbing a store
91    /// through every call site.
92    pub fn shared() -> &'static AuthRateLimiter {
93        static CELL: std::sync::OnceLock<AuthRateLimiter> = std::sync::OnceLock::new();
94        CELL.get_or_init(AuthRateLimiter::default)
95    }
96
97    /// Check + bump. `account_key` is the email/user_id/phone — pass
98    /// `None` for endpoints with no pre-auth account binding (e.g.
99    /// passkey/login/begin).
100    pub fn check(
101        &self,
102        bucket: AuthBucket,
103        ip: &str,
104        account_key: Option<&str>,
105    ) -> RateLimitDecision {
106        let (ip_cap, acct_cap) = bucket.caps();
107        let now = now_secs();
108        // 1-minute window for IP, 1-hour window for account.
109        if let Some(retry) = bump(&self.per_ip, (bucket, ip.to_string()), 60, ip_cap, now) {
110            return RateLimitDecision::Deny {
111                retry_after_secs: retry,
112            };
113        }
114        if let Some(key) = account_key {
115            if let Some(retry) = bump(
116                &self.per_account,
117                (bucket, key.to_ascii_lowercase()),
118                3600,
119                acct_cap,
120                now,
121            ) {
122                return RateLimitDecision::Deny {
123                    retry_after_secs: retry,
124                };
125            }
126        }
127        RateLimitDecision::Allow
128    }
129}
130
131fn bump(
132    map: &Mutex<HashMap<(AuthBucket, String), Counter>>,
133    key: (AuthBucket, String),
134    window_secs: u64,
135    cap: u32,
136    now: u64,
137) -> Option<u64> {
138    let mut g = map.lock().unwrap();
139    let entry = g.entry(key).or_insert(Counter {
140        window_start: now,
141        count: 0,
142    });
143    if now >= entry.window_start + window_secs {
144        entry.window_start = now;
145        entry.count = 0;
146    }
147    if entry.count >= cap {
148        return Some(entry.window_start + window_secs - now);
149    }
150    entry.count += 1;
151    None
152}
153
154fn now_secs() -> u64 {
155    use std::time::{SystemTime, UNIX_EPOCH};
156    SystemTime::now()
157        .duration_since(UNIX_EPOCH)
158        .unwrap_or_default()
159        .as_secs()
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    #[test]
167    fn allows_within_cap() {
168        let rl = AuthRateLimiter::new();
169        for _ in 0..5 {
170            assert_eq!(
171                rl.check(AuthBucket::Login, "1.2.3.4", Some("a@b.com")),
172                RateLimitDecision::Allow
173            );
174        }
175    }
176
177    #[test]
178    fn denies_after_per_ip_cap() {
179        let rl = AuthRateLimiter::new();
180        let bucket = AuthBucket::Login;
181        let (ip_cap, _) = bucket.caps();
182        for _ in 0..ip_cap {
183            assert_eq!(rl.check(bucket, "1.2.3.4", None), RateLimitDecision::Allow);
184        }
185        match rl.check(bucket, "1.2.3.4", None) {
186            RateLimitDecision::Deny { retry_after_secs } => assert!(retry_after_secs <= 60),
187            _ => panic!("expected Deny"),
188        }
189    }
190
191    #[test]
192    fn per_account_cap_independent_of_ip() {
193        let rl = AuthRateLimiter::new();
194        let bucket = AuthBucket::Send;
195        let (_, acct_cap) = bucket.caps();
196        // Rotate IPs to exhaust per-account before per-IP.
197        for i in 0..acct_cap {
198            let ip = format!("10.0.0.{i}");
199            assert_eq!(rl.check(bucket, &ip, Some("victim@x.com")), RateLimitDecision::Allow);
200        }
201        let result = rl.check(bucket, "10.0.0.99", Some("victim@x.com"));
202        assert!(matches!(result, RateLimitDecision::Deny { .. }));
203    }
204
205    #[test]
206    fn account_key_lowercased() {
207        let rl = AuthRateLimiter::new();
208        let bucket = AuthBucket::Send;
209        let (_, acct_cap) = bucket.caps();
210        // Rotate IPs so we exhaust the per-account counter before
211        // any single IP hits its own per-minute cap.
212        for i in 0..acct_cap {
213            let ip = format!("10.0.0.{i}");
214            let _ = rl.check(bucket, &ip, Some("a@b.com"));
215        }
216        // Capitalized variant of the same email must hit the same
217        // (now-exhausted) per-account bucket from a fresh IP.
218        let result = rl.check(bucket, "172.16.0.1", Some("A@B.COM"));
219        assert!(matches!(result, RateLimitDecision::Deny { .. }));
220    }
221}