1#![doc = include_str!("../README.md")]
2#![deny(missing_docs)]
3#![deny(rustdoc::broken_intra_doc_links)]
4
5use std::net::IpAddr;
6use std::time::{Duration, Instant};
7
8use dashmap::DashMap;
9
10pub struct AuthGuardConfig {
17 pub max_failures_account: u32,
20 pub account_window_secs: u64,
22 pub base_lockout_secs: u64,
25 pub max_failures_ip: u32,
28 pub ip_window_secs: u64,
30 pub ip_base_lockout_secs: u64,
32 pub backoff_multiplier: f64,
35 pub max_lockout_secs: u64,
37}
38
39impl Default for AuthGuardConfig {
40 fn default() -> Self {
41 Self {
42 max_failures_account: 5,
43 account_window_secs: 900,
44 base_lockout_secs: 1800,
45 max_failures_ip: 20,
46 ip_window_secs: 3600,
47 ip_base_lockout_secs: 3600,
48 backoff_multiplier: 2.0,
49 max_lockout_secs: 86400,
50 }
51 }
52}
53
54struct FailureRecord {
55 failures: Vec<Instant>,
56 lockout_until: Option<Instant>,
57 consecutive_lockouts: u32,
58}
59
60pub enum AuthCheck {
62 Allowed,
65 LockedOut {
69 remaining_secs: u64,
71 },
72}
73
74pub struct AuthGuard {
82 config: AuthGuardConfig,
83 account_failures: DashMap<(IpAddr, String), FailureRecord>,
84 ip_failures: DashMap<IpAddr, FailureRecord>,
85}
86
87pub fn lockout_duration(
96 base_secs: u64,
97 consecutive_lockouts: u32,
98 multiplier: f64,
99 max_secs: u64,
100) -> u64 {
101 let backoff = mailrs_backoff::Backoff {
102 initial: Duration::from_secs(base_secs),
103 multiplier,
104 max: Duration::from_secs(max_secs),
105 jitter: mailrs_backoff::Jitter::None,
106 };
107 backoff.base_delay(consecutive_lockouts).as_secs()
108}
109
110fn normalize_ip(ip: IpAddr) -> IpAddr {
112 match ip {
113 IpAddr::V6(v6) => {
114 let segments = v6.segments();
115 let masked = std::net::Ipv6Addr::new(
116 segments[0],
117 segments[1],
118 segments[2],
119 segments[3],
120 0,
121 0,
122 0,
123 0,
124 );
125 IpAddr::V6(masked)
126 }
127 ip => ip,
128 }
129}
130
131impl AuthGuard {
132 pub fn new(config: AuthGuardConfig) -> Self {
135 Self {
136 config,
137 account_failures: DashMap::new(),
138 ip_failures: DashMap::new(),
139 }
140 }
141
142 pub fn check(&self, ip: IpAddr, username: &str) -> AuthCheck {
152 let ip = normalize_ip(ip);
153 let now = Instant::now();
154
155 if let Some(rec) = self.ip_failures.get(&ip)
156 && let Some(until) = rec.lockout_until
157 && now < until {
158 let remaining = until.duration_since(now).as_secs();
159 return AuthCheck::LockedOut {
160 remaining_secs: remaining,
161 };
162 }
163
164 let key = (ip, username.to_string());
165 if let Some(rec) = self.account_failures.get(&key)
166 && let Some(until) = rec.lockout_until
167 && now < until {
168 let remaining = until.duration_since(now).as_secs();
169 return AuthCheck::LockedOut {
170 remaining_secs: remaining,
171 };
172 }
173
174 AuthCheck::Allowed
175 }
176
177 pub fn record_failure(&self, ip: IpAddr, username: &str) {
184 let ip = normalize_ip(ip);
185 let now = Instant::now();
186
187 tracing::warn!(
188 event = "auth_failure",
189 ip = %ip,
190 username = username,
191 );
192
193 let key = (ip, username.to_string());
195 let mut entry = self
196 .account_failures
197 .entry(key)
198 .or_insert_with(|| FailureRecord {
199 failures: Vec::new(),
200 lockout_until: None,
201 consecutive_lockouts: 0,
202 });
203
204 let window_start = now - Duration::from_secs(self.config.account_window_secs);
205 entry.failures.retain(|t| *t > window_start);
206 entry.failures.push(now);
207
208 if entry.failures.len() as u32 >= self.config.max_failures_account {
209 let duration = lockout_duration(
210 self.config.base_lockout_secs,
211 entry.consecutive_lockouts,
212 self.config.backoff_multiplier,
213 self.config.max_lockout_secs,
214 );
215 entry.lockout_until = Some(now + Duration::from_secs(duration));
216 entry.consecutive_lockouts += 1;
217 entry.failures.clear();
218
219 tracing::warn!(
220 event = "auth_lockout",
221 ip = %ip,
222 username = username,
223 scope = "account",
224 duration_secs = duration,
225 );
226 }
227
228 let mut entry = self.ip_failures.entry(ip).or_insert_with(|| FailureRecord {
230 failures: Vec::new(),
231 lockout_until: None,
232 consecutive_lockouts: 0,
233 });
234
235 let window_start = now - Duration::from_secs(self.config.ip_window_secs);
236 entry.failures.retain(|t| *t > window_start);
237 entry.failures.push(now);
238
239 if entry.failures.len() as u32 >= self.config.max_failures_ip {
240 let duration = lockout_duration(
241 self.config.ip_base_lockout_secs,
242 entry.consecutive_lockouts,
243 self.config.backoff_multiplier,
244 self.config.max_lockout_secs,
245 );
246 entry.lockout_until = Some(now + Duration::from_secs(duration));
247 entry.consecutive_lockouts += 1;
248 entry.failures.clear();
249
250 tracing::warn!(
251 event = "auth_lockout",
252 ip = %ip,
253 scope = "ip",
254 duration_secs = duration,
255 );
256 }
257 }
258
259 pub fn record_success(&self, ip: IpAddr, username: &str) {
267 let ip = normalize_ip(ip);
268 let key = (ip, username.to_string());
269 self.account_failures.remove(&key);
270 }
271
272 pub fn cleanup_stale(&self, before: Instant) {
280 self.account_failures.retain(|_, rec| {
281 if let Some(until) = rec.lockout_until
282 && until < before {
283 return false;
284 }
285 !rec.failures.is_empty() || rec.lockout_until.is_some()
286 });
287 self.ip_failures.retain(|_, rec| {
288 if let Some(until) = rec.lockout_until
289 && until < before {
290 return false;
291 }
292 !rec.failures.is_empty() || rec.lockout_until.is_some()
293 });
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 #[test]
302 fn lockout_duration_base() {
303 assert_eq!(lockout_duration(1800, 0, 2.0, 86400), 1800);
304 }
305
306 #[test]
307 fn lockout_duration_exponential() {
308 assert_eq!(lockout_duration(1800, 2, 2.0, 86400), 7200);
309 }
310
311 #[test]
312 fn lockout_duration_capped() {
313 assert_eq!(lockout_duration(1800, 10, 2.0, 86400), 86400);
314 }
315
316 #[test]
317 fn allowed_below_threshold() {
318 let guard = AuthGuard::new(AuthGuardConfig {
319 max_failures_account: 5,
320 ..Default::default()
321 });
322 let ip: IpAddr = "127.0.0.1".parse().unwrap();
323 for _ in 0..4 {
324 guard.record_failure(ip, "alice");
325 }
326 assert!(matches!(guard.check(ip, "alice"), AuthCheck::Allowed));
327 }
328
329 #[test]
330 fn locked_at_threshold() {
331 let guard = AuthGuard::new(AuthGuardConfig {
332 max_failures_account: 5,
333 ..Default::default()
334 });
335 let ip: IpAddr = "127.0.0.1".parse().unwrap();
336 for _ in 0..5 {
337 guard.record_failure(ip, "alice");
338 }
339 assert!(matches!(
340 guard.check(ip, "alice"),
341 AuthCheck::LockedOut { .. }
342 ));
343 }
344
345 #[test]
346 fn success_resets_account() {
347 let guard = AuthGuard::new(AuthGuardConfig {
348 max_failures_account: 5,
349 ..Default::default()
350 });
351 let ip: IpAddr = "127.0.0.1".parse().unwrap();
352 for _ in 0..4 {
353 guard.record_failure(ip, "alice");
354 }
355 guard.record_success(ip, "alice");
356 guard.record_failure(ip, "alice");
358 assert!(matches!(guard.check(ip, "alice"), AuthCheck::Allowed));
359 }
360
361 #[test]
362 fn ipv6_normalized_to_64() {
363 let ip1: IpAddr = "2001:db8::1".parse().unwrap();
364 let ip2: IpAddr = "2001:db8::ffff".parse().unwrap();
365 assert_eq!(normalize_ip(ip1), normalize_ip(ip2));
366 }
367
368 #[test]
369 fn ipv4_unchanged() {
370 let ip: IpAddr = "192.168.1.1".parse().unwrap();
371 assert_eq!(normalize_ip(ip), ip);
372 }
373
374 #[test]
375 fn ipv6_different_subnets_not_merged() {
376 let ip1: IpAddr = "2001:db8:aaaa:bbbb::1".parse().unwrap();
377 let ip2: IpAddr = "2001:db8:cccc:dddd::1".parse().unwrap();
378 assert_ne!(normalize_ip(ip1), normalize_ip(ip2));
379 }
380
381 #[test]
382 fn ip_lockout_at_threshold() {
383 let guard = AuthGuard::new(AuthGuardConfig {
384 max_failures_ip: 3,
385 max_failures_account: 100, ..Default::default()
387 });
388 let ip: IpAddr = "10.0.0.1".parse().unwrap();
389 for _ in 0..3 {
390 guard.record_failure(ip, "user1");
391 }
392 assert!(matches!(
393 guard.check(ip, "any_user"),
394 AuthCheck::LockedOut { .. }
395 ));
396 }
397
398 #[test]
399 fn lockout_expires_after_duration() {
400 let guard = AuthGuard::new(AuthGuardConfig {
404 max_failures_account: 2,
405 base_lockout_secs: 1,
406 max_lockout_secs: 1,
407 backoff_multiplier: 1.0,
408 ..Default::default()
409 });
410 let ip: IpAddr = "127.0.0.1".parse().unwrap();
411
412 guard.record_failure(ip, "bob");
414 guard.record_failure(ip, "bob");
415 assert!(matches!(
416 guard.check(ip, "bob"),
417 AuthCheck::LockedOut { remaining_secs }
418 if remaining_secs <= 1
419 ));
420
421 std::thread::sleep(std::time::Duration::from_millis(1100));
423 assert!(matches!(guard.check(ip, "bob"), AuthCheck::Allowed));
424 }
425
426 #[test]
427 fn cleanup_stale_removes_expired_lockouts() {
428 let guard = AuthGuard::new(AuthGuardConfig {
429 max_failures_account: 2,
430 base_lockout_secs: 1,
431 max_lockout_secs: 1,
432 backoff_multiplier: 1.0,
433 max_failures_ip: 2,
434 ip_base_lockout_secs: 1,
435 ..Default::default()
436 });
437 let ip: IpAddr = "127.0.0.1".parse().unwrap();
438
439 guard.record_failure(ip, "carol");
441 guard.record_failure(ip, "carol");
442 assert!(!guard.account_failures.is_empty());
443 assert!(!guard.ip_failures.is_empty());
444
445 let future = Instant::now() + std::time::Duration::from_secs(3600);
447 guard.cleanup_stale(future);
448 assert!(guard.account_failures.is_empty());
449 assert!(guard.ip_failures.is_empty());
450 }
451
452 #[test]
453 fn cleanup_stale_preserves_active_records() {
454 let guard = AuthGuard::new(AuthGuardConfig {
455 max_failures_account: 10,
456 max_failures_ip: 10,
457 ..Default::default()
458 });
459 let ip: IpAddr = "127.0.0.1".parse().unwrap();
460
461 guard.record_failure(ip, "dave");
463 assert_eq!(guard.account_failures.len(), 1);
464 assert_eq!(guard.ip_failures.len(), 1);
465
466 guard.cleanup_stale(Instant::now());
468 assert_eq!(guard.account_failures.len(), 1);
469 assert_eq!(guard.ip_failures.len(), 1);
470 }
471
472 #[test]
473 fn normal_login_not_blocked() {
474 let guard = AuthGuard::new(AuthGuardConfig::default());
475 let ip: IpAddr = "192.168.1.100".parse().unwrap();
476 assert!(matches!(guard.check(ip, "admin"), AuthCheck::Allowed));
478 }
479
480 #[test]
481 fn exponential_backoff_increases_lockout() {
482 let guard = AuthGuard::new(AuthGuardConfig {
483 max_failures_account: 1,
484 base_lockout_secs: 10,
485 backoff_multiplier: 2.0,
486 max_lockout_secs: 86400,
487 account_window_secs: 1, ..Default::default()
489 });
490 let ip: IpAddr = "127.0.0.1".parse().unwrap();
491
492 guard.record_failure(ip, "eve");
494 if let AuthCheck::LockedOut { remaining_secs } = guard.check(ip, "eve") {
495 assert!(remaining_secs <= 10);
496 } else {
497 panic!("expected lockout after first failure");
498 }
499
500 std::thread::sleep(std::time::Duration::from_millis(100));
502 if let Some(mut rec) = guard.account_failures.get_mut(&(ip, "eve".to_string())) {
504 rec.lockout_until = None;
505 }
506
507 guard.record_failure(ip, "eve");
509 if let AuthCheck::LockedOut { remaining_secs } = guard.check(ip, "eve") {
510 assert!(
511 remaining_secs > 10,
512 "second lockout should be longer than first, got {remaining_secs}"
513 );
514 } else {
515 panic!("expected lockout after second round of failures");
516 }
517 }
518
519 #[test]
520 fn ipv6_lockout_applies_to_same_subnet() {
521 let guard = AuthGuard::new(AuthGuardConfig {
522 max_failures_account: 2,
523 ..Default::default()
524 });
525 let ip1: IpAddr = "2001:db8:1:2::aaaa".parse().unwrap();
527 let ip2: IpAddr = "2001:db8:1:2::bbbb".parse().unwrap();
528
529 guard.record_failure(ip1, "frank");
531 guard.record_failure(ip1, "frank");
532
533 assert!(matches!(
535 guard.check(ip2, "frank"),
536 AuthCheck::LockedOut { .. }
537 ));
538 }
539
540 #[test]
543 fn ipv6_different_subnets_not_blocked_together() {
544 let guard = AuthGuard::new(AuthGuardConfig {
545 max_failures_account: 2,
546 ..Default::default()
547 });
548 let ip1: IpAddr = "2001:db8:1:2::aaaa".parse().unwrap();
549 let ip2: IpAddr = "2001:db8:3:4::bbbb".parse().unwrap(); guard.record_failure(ip1, "alice");
551 guard.record_failure(ip1, "alice");
552 assert!(matches!(guard.check(ip2, "alice"), AuthCheck::Allowed));
554 }
555
556 #[test]
557 fn different_usernames_track_independently() {
558 let guard = AuthGuard::new(AuthGuardConfig {
559 max_failures_account: 2,
560 max_failures_ip: 100, ..Default::default()
562 });
563 let ip: IpAddr = "192.0.2.1".parse().unwrap();
564 guard.record_failure(ip, "alice");
566 guard.record_failure(ip, "alice");
567 assert!(matches!(guard.check(ip, "alice"), AuthCheck::LockedOut { .. }));
568 assert!(matches!(guard.check(ip, "bob"), AuthCheck::Allowed));
570 }
571
572 #[test]
573 fn record_failure_during_lockout_does_not_panic() {
574 let guard = AuthGuard::new(AuthGuardConfig {
577 max_failures_account: 2,
578 ..Default::default()
579 });
580 let ip: IpAddr = "192.0.2.10".parse().unwrap();
581 guard.record_failure(ip, "alice");
582 guard.record_failure(ip, "alice"); for _ in 0..10 {
585 guard.record_failure(ip, "alice");
586 }
587 assert!(matches!(guard.check(ip, "alice"), AuthCheck::LockedOut { .. }));
589 }
590
591 #[test]
592 fn record_success_does_not_clear_ip_counter() {
593 let guard = AuthGuard::new(AuthGuardConfig {
596 max_failures_account: 100,
597 max_failures_ip: 3,
598 ..Default::default()
599 });
600 let ip: IpAddr = "192.0.2.20".parse().unwrap();
601 guard.record_failure(ip, "user1");
603 guard.record_failure(ip, "user2");
604 guard.record_failure(ip, "user3"); guard.record_success(ip, "user1");
607 assert!(matches!(guard.check(ip, "anyone"), AuthCheck::LockedOut { .. }));
609 }
610
611 #[test]
612 fn cleanup_stale_handles_empty_maps() {
613 let guard = AuthGuard::new(AuthGuardConfig::default());
615 guard.cleanup_stale(Instant::now());
616 guard.cleanup_stale(Instant::now() + Duration::from_secs(3600));
617 }
618
619 #[test]
620 fn zero_max_failures_locks_immediately() {
621 let guard = AuthGuard::new(AuthGuardConfig {
626 max_failures_account: 1,
627 ..Default::default()
628 });
629 let ip: IpAddr = "192.0.2.30".parse().unwrap();
630 guard.record_failure(ip, "alice");
631 assert!(matches!(guard.check(ip, "alice"), AuthCheck::LockedOut { .. }));
632 }
633
634 #[test]
635 fn high_max_lockout_secs_caps_at_max() {
636 let d = lockout_duration(1800, 100, 2.0, 86400);
639 assert_eq!(d, 86400);
640 }
641
642 #[test]
643 fn backoff_multiplier_below_one_does_not_explode() {
644 let d0 = lockout_duration(1800, 0, 0.5, 86400);
648 let d1 = lockout_duration(1800, 1, 0.5, 86400);
649 let d2 = lockout_duration(1800, 2, 0.5, 86400);
650 assert_eq!(d0, 1800);
651 assert!(d1 <= d0);
652 assert!(d2 <= d1);
653 }
654
655 #[test]
656 fn concurrent_record_failures_dont_panic() {
657 use std::sync::Arc;
658 use std::thread;
659
660 let guard = Arc::new(AuthGuard::new(AuthGuardConfig::default()));
661 let ip: IpAddr = "192.0.2.40".parse().unwrap();
662 let mut handles = vec![];
663 for _ in 0..8 {
664 let g = guard.clone();
665 handles.push(thread::spawn(move || {
666 for _ in 0..50 {
667 g.record_failure(ip, "alice");
668 }
669 }));
670 }
671 for h in handles {
672 h.join().unwrap();
673 }
674 assert!(matches!(guard.check(ip, "alice"), AuthCheck::LockedOut { .. }));
677 }
678
679 #[test]
680 fn ipv4_loopback_treated_separately_from_ipv6_loopback() {
681 let guard = AuthGuard::new(AuthGuardConfig {
683 max_failures_account: 2,
684 ..Default::default()
685 });
686 let v4: IpAddr = "127.0.0.1".parse().unwrap();
687 let v6: IpAddr = "::1".parse().unwrap();
688 guard.record_failure(v4, "alice");
689 guard.record_failure(v4, "alice");
690 assert!(matches!(guard.check(v4, "alice"), AuthCheck::LockedOut { .. }));
691 assert!(matches!(guard.check(v6, "alice"), AuthCheck::Allowed));
693 }
694}