Skip to main content

authx_core/
brute_force.rs

1use std::{
2    collections::HashMap,
3    sync::{Arc, Mutex},
4    time::{Duration, Instant},
5};
6
7/// In-memory brute-force / account lockout tracker.
8///
9/// Tracks consecutive failed sign-in attempts per key (typically an email
10/// address). After `max_failures` failed attempts within `window`, the
11/// account is considered locked until the window elapses.
12///
13/// Successful sign-ins reset the failure counter for that key.
14///
15/// The tracker is `Clone + Send + Sync` — share a single instance across the
16/// application via `Arc` or embed it directly in service state.
17#[derive(Clone)]
18pub struct LoginAttemptTracker {
19    inner: Arc<Mutex<HashMap<String, FailureRecord>>>,
20    cfg: LockoutConfig,
21}
22
23#[derive(Clone, Copy)]
24pub struct LockoutConfig {
25    /// Maximum consecutive failures before locking.
26    pub max_failures: u32,
27    /// How long a lock (or the failure window) lasts.
28    pub window: Duration,
29}
30
31impl LockoutConfig {
32    pub fn new(max_failures: u32, window: Duration) -> Self {
33        Self {
34            max_failures,
35            window,
36        }
37    }
38}
39
40struct FailureRecord {
41    count: u32,
42    window_start: Instant,
43}
44
45impl LoginAttemptTracker {
46    pub fn new(cfg: LockoutConfig) -> Self {
47        Self {
48            inner: Arc::new(Mutex::new(HashMap::new())),
49            cfg,
50        }
51    }
52
53    /// Returns `true` if the key is currently locked out.
54    pub fn is_locked(&self, key: &str) -> bool {
55        let now = Instant::now();
56        let map = match self.inner.lock() {
57            Ok(g) => g,
58            Err(e) => {
59                tracing::error!("lockout tracker mutex poisoned — recovering");
60                e.into_inner()
61            }
62        };
63        match map.get(key) {
64            None => false,
65            Some(rec) => {
66                if now.duration_since(rec.window_start) >= self.cfg.window {
67                    false // window expired
68                } else {
69                    rec.count >= self.cfg.max_failures
70                }
71            }
72        }
73    }
74
75    /// Record a failed attempt. Call this when credentials are wrong.
76    pub fn record_failure(&self, key: &str) {
77        let now = Instant::now();
78        let mut map = match self.inner.lock() {
79            Ok(g) => g,
80            Err(e) => {
81                tracing::error!("lockout tracker mutex poisoned — recovering");
82                e.into_inner()
83            }
84        };
85        let rec = map.entry(key.to_owned()).or_insert(FailureRecord {
86            count: 0,
87            window_start: now,
88        });
89
90        if now.duration_since(rec.window_start) >= self.cfg.window {
91            // New window — reset count.
92            rec.window_start = now;
93            rec.count = 1;
94        } else {
95            rec.count += 1;
96        }
97
98        tracing::warn!(
99            key = key,
100            failures = rec.count,
101            "failed login attempt recorded"
102        );
103    }
104
105    /// Reset the failure counter on successful sign-in.
106    pub fn record_success(&self, key: &str) {
107        let mut map = match self.inner.lock() {
108            Ok(g) => g,
109            Err(e) => {
110                tracing::error!("lockout tracker mutex poisoned — recovering");
111                e.into_inner()
112            }
113        };
114        map.remove(key);
115        tracing::debug!(key = key, "login success — failure counter cleared");
116    }
117}
118
119// ── Per-key request rate limiter ──────────────────────────────────────────────
120
121/// Sliding-window rate limiter keyed by an arbitrary string (e.g. email or IP).
122///
123/// Returns `true` from [`Self::check_and_record`] when the request is allowed,
124/// `false` when the caller has exceeded `max_requests` within `window`.
125#[derive(Clone)]
126pub struct KeyedRateLimiter {
127    inner: Arc<Mutex<HashMap<String, RateRecord>>>,
128    max_requests: u32,
129    window: Duration,
130}
131
132struct RateRecord {
133    count: u32,
134    window_start: Instant,
135}
136
137impl KeyedRateLimiter {
138    pub fn new(max_requests: u32, window: Duration) -> Self {
139        Self {
140            inner: Arc::new(Mutex::new(HashMap::new())),
141            max_requests,
142            window,
143        }
144    }
145
146    /// Returns `true` if the request should be allowed, `false` if rate-limited.
147    pub fn check_and_record(&self, key: &str) -> bool {
148        let now = Instant::now();
149        let mut map = match self.inner.lock() {
150            Ok(g) => g,
151            Err(e) => {
152                tracing::error!("rate-limiter mutex poisoned — recovering");
153                e.into_inner()
154            }
155        };
156        let rec = map.entry(key.to_owned()).or_insert(RateRecord {
157            count: 0,
158            window_start: now,
159        });
160
161        if now.duration_since(rec.window_start) >= self.window {
162            rec.window_start = now;
163            rec.count = 1;
164            return true;
165        }
166
167        if rec.count >= self.max_requests {
168            tracing::warn!(key = key, count = rec.count, "rate limit exceeded");
169            return false;
170        }
171
172        rec.count += 1;
173        true
174    }
175}