use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::security::audit::{AuditActor, AuditEvent, AuditEventType, AuditLogger, AuditOutcome};
#[derive(Debug, Clone, serde::Deserialize)]
pub struct BruteForceConfig {
pub attempts_before_delay: u32,
pub initial_delay_secs: u32,
pub delay_multiplier: f32,
pub max_delay_secs: u32,
pub attempts_before_block: u32,
pub block_duration_secs: u32,
pub failure_window_secs: u32,
pub captcha_threshold: u32,
pub reset_after_secs: u32,
}
impl Default for BruteForceConfig {
fn default() -> Self {
Self {
attempts_before_delay: 3,
initial_delay_secs: 1,
delay_multiplier: 2.0,
max_delay_secs: 30,
attempts_before_block: 10,
block_duration_secs: 3600, failure_window_secs: 300, captcha_threshold: 5,
reset_after_secs: 3600, }
}
}
#[derive(Debug, Clone)]
struct IpState {
failed_attempts: u32,
window_start: Instant,
last_attempt: Instant,
delay_level: u32,
blocked_until: Option<Instant>,
suspicious_patterns: Vec<SuspiciousPattern>,
}
#[derive(Debug, Clone)]
pub enum SuspiciousPattern {
PrefixEnumeration { prefix: String, count: u32 },
SequentialAttempts,
AttackSignature(String),
GeoAnomaly { expected: String, actual: String },
}
#[derive(Debug, Clone)]
pub enum BruteForceAction {
Allow,
Delay { seconds: u32 },
RequireCaptcha { challenge_id: String },
Block {
reason: BlockReason,
duration_secs: u32,
},
}
#[derive(Debug, Clone)]
pub enum BlockReason {
ExcessiveFailures,
AttackPatternDetected,
ThreatIntelMatch,
GeoAnomaly,
}
#[derive(Debug, Clone)]
struct PrefixState {
ips_attempting: HashMap<IpAddr, u32>,
last_attempt: Instant,
}
#[async_trait::async_trait]
pub trait ThreatIntelProvider: Send + Sync {
async fn check_ip(&self, ip: IpAddr) -> Option<ThreatInfo>;
}
#[derive(Debug, Clone)]
pub struct ThreatInfo {
pub category: String,
pub risk_score: u8,
pub last_seen: chrono::DateTime<chrono::Utc>,
}
#[async_trait::async_trait]
pub trait AlertService: Send + Sync {
async fn send_alert(&self, alert: Alert) -> Result<(), String>;
}
#[derive(Debug, Clone)]
pub struct Alert {
pub severity: AlertSeverity,
pub title: String,
pub message: String,
pub metadata: serde_json::Value,
}
#[derive(Debug, Clone, Copy)]
pub enum AlertSeverity {
Low,
Medium,
High,
Critical,
}
pub struct BruteForceProtection {
config: BruteForceConfig,
ip_state: Arc<RwLock<HashMap<IpAddr, IpState>>>,
prefix_state: Arc<RwLock<HashMap<String, PrefixState>>>,
threat_intel: Option<Arc<dyn ThreatIntelProvider>>,
audit: Arc<AuditLogger>,
alerter: Option<Arc<dyn AlertService>>,
}
impl BruteForceProtection {
pub fn new(config: BruteForceConfig, audit: Arc<AuditLogger>) -> Self {
Self {
config,
ip_state: Arc::new(RwLock::new(HashMap::new())),
prefix_state: Arc::new(RwLock::new(HashMap::new())),
threat_intel: None,
audit,
alerter: None,
}
}
pub fn with_threat_intel(mut self, provider: Arc<dyn ThreatIntelProvider>) -> Self {
self.threat_intel = Some(provider);
self
}
pub fn with_alerter(mut self, alerter: Arc<dyn AlertService>) -> Self {
self.alerter = Some(alerter);
self
}
pub async fn check_request(&self, ip: IpAddr, key_prefix: Option<&str>) -> BruteForceAction {
if let Some(action) = self.check_ip_block(ip).await {
return action;
}
if let Some(ref threat_intel) = self.threat_intel {
if let Some(threat) = threat_intel.check_ip(ip).await {
self.block_ip(ip, BlockReason::ThreatIntelMatch, self.config.block_duration_secs * 24)
.await;
return BruteForceAction::Block {
duration_secs: self.config.block_duration_secs * 24,
reason: BlockReason::ThreatIntelMatch,
};
}
}
if let Some(prefix) = key_prefix {
if self.detect_enumeration(ip, prefix).await {
self.block_ip(ip, BlockReason::AttackPatternDetected, self.config.block_duration_secs)
.await;
return BruteForceAction::Block {
duration_secs: self.config.block_duration_secs,
reason: BlockReason::AttackPatternDetected,
};
}
}
self.apply_progressive_penalty(ip).await
}
pub async fn record_failure(&self, ip: IpAddr, key_prefix: Option<&str>, error_type: &str) {
let now = Instant::now();
let mut state = self.ip_state.write().await;
let ip_state = state.entry(ip).or_insert_with(|| IpState {
failed_attempts: 0,
window_start: now,
last_attempt: now,
delay_level: 0,
blocked_until: None,
suspicious_patterns: Vec::new(),
});
if now.duration_since(ip_state.window_start)
> Duration::from_secs(self.config.failure_window_secs as u64)
{
ip_state.failed_attempts = 0;
ip_state.window_start = now;
}
ip_state.failed_attempts += 1;
ip_state.last_attempt = now;
if ip_state.failed_attempts >= self.config.attempts_before_delay {
ip_state.delay_level = ip_state.delay_level.saturating_add(1);
}
let failed_attempts = ip_state.failed_attempts;
let should_block = failed_attempts >= self.config.attempts_before_block;
if should_block {
ip_state.blocked_until =
Some(now + Duration::from_secs(self.config.block_duration_secs as u64));
drop(state);
let _ = self
.audit
.log(
AuditEvent::new(AuditEventType::BruteForceDetected, AuditActor::Anonymous)
.with_details(serde_json::json!({
"ip": ip.to_string(),
"failed_attempts": failed_attempts,
"block_duration_secs": self.config.block_duration_secs,
"key_prefix": key_prefix,
"error_type": error_type,
}))
.with_ip(ip)
.with_outcome(AuditOutcome::Failure {
error_code: "BRUTE_FORCE".to_string(),
error_message: format!("IP blocked after {} failures", failed_attempts),
}),
)
.await;
if let Some(ref alerter) = self.alerter {
let _ = alerter
.send_alert(Alert {
severity: AlertSeverity::High,
title: "Brute Force Attack Detected".to_string(),
message: format!(
"IP {} blocked after {} failed attempts",
ip, failed_attempts
),
metadata: serde_json::json!({
"ip": ip.to_string(),
"attempts": failed_attempts,
}),
})
.await;
}
} else {
drop(state);
}
if let Some(prefix) = key_prefix {
self.record_prefix_attempt(ip, prefix).await;
}
}
pub async fn record_success(&self, ip: IpAddr) {
let mut state = self.ip_state.write().await;
if let Some(ip_state) = state.get_mut(&ip) {
ip_state.delay_level = ip_state.delay_level.saturating_sub(2);
ip_state.failed_attempts = ip_state.failed_attempts.saturating_sub(1);
}
}
async fn check_ip_block(&self, ip: IpAddr) -> Option<BruteForceAction> {
let state = self.ip_state.read().await;
if let Some(ip_state) = state.get(&ip) {
if let Some(blocked_until) = ip_state.blocked_until {
if Instant::now() < blocked_until {
let remaining = blocked_until - Instant::now();
return Some(BruteForceAction::Block {
duration_secs: remaining.as_secs() as u32,
reason: BlockReason::ExcessiveFailures,
});
}
}
}
None
}
async fn apply_progressive_penalty(&self, ip: IpAddr) -> BruteForceAction {
let state = self.ip_state.read().await;
if let Some(ip_state) = state.get(&ip) {
if ip_state.failed_attempts >= self.config.captcha_threshold
&& ip_state.failed_attempts < self.config.attempts_before_block
{
return BruteForceAction::RequireCaptcha {
challenge_id: Uuid::new_v4().to_string(),
};
}
if ip_state.delay_level > 0 {
let delay = self.calculate_delay(ip_state.delay_level);
return BruteForceAction::Delay { seconds: delay };
}
}
BruteForceAction::Allow
}
fn calculate_delay(&self, level: u32) -> u32 {
let delay = self.config.initial_delay_secs as f32
* self.config.delay_multiplier.powi(level as i32 - 1);
(delay as u32).min(self.config.max_delay_secs)
}
async fn detect_enumeration(&self, ip: IpAddr, prefix: &str) -> bool {
let mut state = self.prefix_state.write().await;
let prefix_state = state.entry(prefix.to_string()).or_insert_with(|| PrefixState {
ips_attempting: HashMap::new(),
last_attempt: Instant::now(),
});
let count = prefix_state.ips_attempting.entry(ip).or_insert(0);
*count += 1;
prefix_state.last_attempt = Instant::now();
let unique_ips = prefix_state.ips_attempting.len();
let this_ip_attempts = *count;
unique_ips >= 5 || this_ip_attempts >= 10
}
async fn record_prefix_attempt(&self, ip: IpAddr, prefix: &str) {
let mut state = self.prefix_state.write().await;
let prefix_state = state.entry(prefix.to_string()).or_insert_with(|| PrefixState {
ips_attempting: HashMap::new(),
last_attempt: Instant::now(),
});
*prefix_state.ips_attempting.entry(ip).or_insert(0) += 1;
}
async fn block_ip(&self, ip: IpAddr, reason: BlockReason, duration_secs: u32) {
let mut state = self.ip_state.write().await;
let ip_state = state.entry(ip).or_insert_with(|| IpState {
failed_attempts: 0,
window_start: Instant::now(),
last_attempt: Instant::now(),
delay_level: 0,
blocked_until: None,
suspicious_patterns: Vec::new(),
});
ip_state.blocked_until =
Some(Instant::now() + Duration::from_secs(duration_secs as u64));
drop(state);
let _ = self
.audit
.log(
AuditEvent::new(AuditEventType::IpBlocked, AuditActor::System)
.with_details(serde_json::json!({
"ip": ip.to_string(),
"reason": format!("{:?}", reason),
"duration_secs": duration_secs,
}))
.with_ip(ip)
.with_outcome(AuditOutcome::Success),
)
.await;
}
pub async fn unblock_ip(&self, ip: IpAddr) {
let mut state = self.ip_state.write().await;
if let Some(ip_state) = state.get_mut(&ip) {
ip_state.blocked_until = None;
ip_state.failed_attempts = 0;
ip_state.delay_level = 0;
}
}
pub async fn get_ip_state(&self, ip: IpAddr) -> Option<IpStateInfo> {
let state = self.ip_state.read().await;
state.get(&ip).map(|s| IpStateInfo {
failed_attempts: s.failed_attempts,
delay_level: s.delay_level,
is_blocked: s.blocked_until.map(|b| Instant::now() < b).unwrap_or(false),
block_remaining_secs: s.blocked_until.and_then(|b| {
if Instant::now() < b {
Some((b - Instant::now()).as_secs() as u32)
} else {
None
}
}),
})
}
}
#[derive(Debug, Clone)]
pub struct IpStateInfo {
pub failed_attempts: u32,
pub delay_level: u32,
pub is_blocked: bool,
pub block_remaining_secs: Option<u32>,
}
pub struct NoOpThreatIntel;
#[async_trait::async_trait]
impl ThreatIntelProvider for NoOpThreatIntel {
async fn check_ip(&self, _ip: IpAddr) -> Option<ThreatInfo> {
None
}
}
pub struct NoOpAlertService;
#[async_trait::async_trait]
impl AlertService for NoOpAlertService {
async fn send_alert(&self, _alert: Alert) -> Result<(), String> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::security::audit::{AuditConfig, InMemoryAuditStore};
fn create_test_protection() -> BruteForceProtection {
let audit_store = Arc::new(InMemoryAuditStore::new());
let audit = Arc::new(AuditLogger::new(audit_store, AuditConfig::default()));
BruteForceProtection::new(BruteForceConfig::default(), audit)
}
#[tokio::test]
async fn test_allows_initial_request() {
let protection = create_test_protection();
let ip = "127.0.0.1".parse().unwrap();
let action = protection.check_request(ip, None).await;
assert!(matches!(action, BruteForceAction::Allow));
}
#[tokio::test]
async fn test_delays_after_threshold() {
let mut config = BruteForceConfig::default();
config.attempts_before_delay = 2;
let audit_store = Arc::new(InMemoryAuditStore::new());
let audit = Arc::new(AuditLogger::new(audit_store, AuditConfig::default()));
let protection = BruteForceProtection::new(config, audit);
let ip = "127.0.0.1".parse().unwrap();
protection.record_failure(ip, None, "invalid_key").await;
protection.record_failure(ip, None, "invalid_key").await;
protection.record_failure(ip, None, "invalid_key").await;
let action = protection.check_request(ip, None).await;
assert!(matches!(action, BruteForceAction::Delay { .. }));
}
#[tokio::test]
async fn test_blocks_after_threshold() {
let mut config = BruteForceConfig::default();
config.attempts_before_block = 3;
let audit_store = Arc::new(InMemoryAuditStore::new());
let audit = Arc::new(AuditLogger::new(audit_store, AuditConfig::default()));
let protection = BruteForceProtection::new(config, audit);
let ip = "127.0.0.1".parse().unwrap();
for _ in 0..4 {
protection.record_failure(ip, None, "invalid_key").await;
}
let action = protection.check_request(ip, None).await;
assert!(matches!(action, BruteForceAction::Block { .. }));
}
#[tokio::test]
async fn test_unblock_ip() {
let protection = create_test_protection();
let ip = "127.0.0.1".parse().unwrap();
protection
.block_ip(ip, BlockReason::ExcessiveFailures, 3600)
.await;
let action = protection.check_request(ip, None).await;
assert!(matches!(action, BruteForceAction::Block { .. }));
protection.unblock_ip(ip).await;
let action = protection.check_request(ip, None).await;
assert!(matches!(action, BruteForceAction::Allow));
}
#[test]
fn test_calculate_delay() {
let config = BruteForceConfig {
initial_delay_secs: 1,
delay_multiplier: 2.0,
max_delay_secs: 30,
..Default::default()
};
let audit_store = Arc::new(crate::security::audit::InMemoryAuditStore::new());
let audit = Arc::new(AuditLogger::new(audit_store, AuditConfig::default()));
let protection = BruteForceProtection::new(config, audit);
assert_eq!(protection.calculate_delay(1), 1);
assert_eq!(protection.calculate_delay(2), 2);
assert_eq!(protection.calculate_delay(3), 4);
assert_eq!(protection.calculate_delay(4), 8);
assert_eq!(protection.calculate_delay(10), 30); }
}