use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BruteForceConfig {
pub max_attempts: u32,
pub lockout_duration_secs: u64,
pub attempt_window_secs: u64,
pub progressive_lockout: bool,
}
impl Default for BruteForceConfig {
fn default() -> Self {
Self {
max_attempts: 5,
lockout_duration_secs: 900, attempt_window_secs: 300, progressive_lockout: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PasswordStrengthConfig {
pub min_length: usize,
pub require_uppercase: bool,
pub require_lowercase: bool,
pub require_digit: bool,
pub require_special: bool,
pub min_entropy_bits: f64,
pub banned_passwords: Vec<String>,
}
impl Default for PasswordStrengthConfig {
fn default() -> Self {
Self {
min_length: 8,
require_uppercase: true,
require_lowercase: true,
require_digit: true,
require_special: true,
min_entropy_bits: 3.0,
banned_passwords: vec![
"password".to_string(),
"123456".to_string(),
"qwerty".to_string(),
"admin".to_string(),
"letmein".to_string(),
],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
pub max_requests: u32,
pub window_secs: u64,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_requests: 10,
window_secs: 60, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecurityConfig {
pub brute_force: BruteForceConfig,
pub password_strength: PasswordStrengthConfig,
pub rate_limit: RateLimitConfig,
pub enable_audit_log: bool,
}
impl Default for SecurityConfig {
fn default() -> Self {
Self {
brute_force: BruteForceConfig::default(),
password_strength: PasswordStrengthConfig::default(),
rate_limit: RateLimitConfig::default(),
enable_audit_log: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AuthEvent {
Success,
Failure,
AccountLocked,
RateLimitExceeded,
PasswordChanged,
UserCreated,
UserDeleted,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditLogEntry {
pub timestamp: u64,
pub event: AuthEvent,
pub username: String,
pub ip_address: Option<IpAddr>,
pub details: Option<String>,
}
pub struct AuditLogger {
entries: Arc<RwLock<Vec<AuditLogEntry>>>,
max_entries: usize,
}
impl AuditLogger {
pub fn new(max_entries: usize) -> Self {
Self {
entries: Arc::new(RwLock::new(Vec::new())),
max_entries,
}
}
pub async fn log(
&self,
event: AuthEvent,
username: String,
ip_address: Option<IpAddr>,
details: Option<String>,
) {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let entry = AuditLogEntry {
timestamp,
event,
username,
ip_address,
details,
};
let mut entries = self.entries.write().await;
entries.push(entry);
if entries.len() > self.max_entries {
let start = entries.len() - self.max_entries;
*entries = entries[start..].to_vec();
}
}
pub async fn get_recent(&self, count: usize) -> Vec<AuditLogEntry> {
let entries = self.entries.read().await;
let start = entries.len().saturating_sub(count);
entries[start..].to_vec()
}
pub async fn get_for_user(&self, username: &str) -> Vec<AuditLogEntry> {
let entries = self.entries.read().await;
entries
.iter()
.filter(|e| e.username == username)
.cloned()
.collect()
}
}
#[derive(Debug, Clone)]
struct FailedAttempt {
timestamp: u64,
}
#[derive(Debug, Clone)]
struct LockoutInfo {
locked_at: u64,
duration_secs: u64,
lockout_count: u32,
}
pub struct BruteForceProtector {
config: BruteForceConfig,
user_attempts: Arc<RwLock<HashMap<String, Vec<FailedAttempt>>>>,
ip_attempts: Arc<RwLock<HashMap<IpAddr, Vec<FailedAttempt>>>>,
locked_accounts: Arc<RwLock<HashMap<String, LockoutInfo>>>,
locked_ips: Arc<RwLock<HashMap<IpAddr, LockoutInfo>>>,
}
impl BruteForceProtector {
pub fn new(config: BruteForceConfig) -> Self {
Self {
config,
user_attempts: Arc::new(RwLock::new(HashMap::new())),
ip_attempts: Arc::new(RwLock::new(HashMap::new())),
locked_accounts: Arc::new(RwLock::new(HashMap::new())),
locked_ips: Arc::new(RwLock::new(HashMap::new())),
}
}
fn current_timestamp() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
pub async fn is_user_locked(&self, username: &str) -> bool {
let locked = self.locked_accounts.read().await;
if let Some(lockout) = locked.get(username) {
let now = Self::current_timestamp();
let unlock_time = lockout.locked_at + lockout.duration_secs;
if now < unlock_time {
return true;
}
}
false
}
pub async fn is_ip_locked(&self, ip: &IpAddr) -> bool {
let locked = self.locked_ips.read().await;
if let Some(lockout) = locked.get(ip) {
let now = Self::current_timestamp();
let unlock_time = lockout.locked_at + lockout.duration_secs;
if now < unlock_time {
return true;
}
}
false
}
pub async fn record_failed_attempt(&self, username: &str, ip: Option<IpAddr>) {
let now = Self::current_timestamp();
let cutoff = now - self.config.attempt_window_secs;
{
let mut attempts = self.user_attempts.write().await;
let user_attempts = attempts
.entry(username.to_string())
.or_insert_with(Vec::new);
user_attempts.retain(|a| a.timestamp > cutoff);
user_attempts.push(FailedAttempt { timestamp: now });
if user_attempts.len() as u32 >= self.config.max_attempts {
drop(attempts); self.lock_user(username).await;
}
}
if let Some(ip_addr) = ip {
let mut attempts = self.ip_attempts.write().await;
let ip_attempts = attempts.entry(ip_addr).or_insert_with(Vec::new);
ip_attempts.retain(|a| a.timestamp > cutoff);
ip_attempts.push(FailedAttempt { timestamp: now });
if ip_attempts.len() as u32 >= self.config.max_attempts {
drop(attempts); self.lock_ip(&ip_addr).await;
}
}
}
async fn lock_user(&self, username: &str) {
let now = Self::current_timestamp();
let mut locked = self.locked_accounts.write().await;
let lockout_count = locked
.get(username)
.map(|l| l.lockout_count + 1)
.unwrap_or(1);
let duration_secs = if self.config.progressive_lockout {
self.config.lockout_duration_secs * (2_u64.pow(lockout_count.saturating_sub(1)))
} else {
self.config.lockout_duration_secs
};
locked.insert(
username.to_string(),
LockoutInfo {
locked_at: now,
duration_secs,
lockout_count,
},
);
let mut attempts = self.user_attempts.write().await;
attempts.remove(username);
}
async fn lock_ip(&self, ip: &IpAddr) {
let now = Self::current_timestamp();
let mut locked = self.locked_ips.write().await;
let lockout_count = locked.get(ip).map(|l| l.lockout_count + 1).unwrap_or(1);
let duration_secs = if self.config.progressive_lockout {
self.config.lockout_duration_secs * (2_u64.pow(lockout_count.saturating_sub(1)))
} else {
self.config.lockout_duration_secs
};
locked.insert(
*ip,
LockoutInfo {
locked_at: now,
duration_secs,
lockout_count,
},
);
let mut attempts = self.ip_attempts.write().await;
attempts.remove(ip);
}
pub async fn clear_user_attempts(&self, username: &str) {
let mut attempts = self.user_attempts.write().await;
attempts.remove(username);
}
pub async fn clear_ip_attempts(&self, ip: &IpAddr) {
let mut attempts = self.ip_attempts.write().await;
attempts.remove(ip);
}
pub async fn unlock_user(&self, username: &str) {
let mut locked = self.locked_accounts.write().await;
locked.remove(username);
let mut attempts = self.user_attempts.write().await;
attempts.remove(username);
}
pub async fn unlock_ip(&self, ip: &IpAddr) {
let mut locked = self.locked_ips.write().await;
locked.remove(ip);
let mut attempts = self.ip_attempts.write().await;
attempts.remove(ip);
}
pub async fn get_unlock_time(&self, username: &str) -> Option<u64> {
let locked = self.locked_accounts.read().await;
if let Some(lockout) = locked.get(username) {
let now = Self::current_timestamp();
let unlock_time = lockout.locked_at + lockout.duration_secs;
if now < unlock_time {
return Some(unlock_time - now);
}
}
None
}
}
#[derive(Debug, Clone)]
pub struct PasswordStrengthResult {
pub valid: bool,
pub errors: Vec<String>,
pub entropy_bits: f64,
}
pub struct PasswordStrengthValidator {
config: PasswordStrengthConfig,
}
impl PasswordStrengthValidator {
pub fn new(config: PasswordStrengthConfig) -> Self {
Self { config }
}
pub fn validate(&self, password: &str) -> PasswordStrengthResult {
let mut errors = Vec::new();
if password.len() < self.config.min_length {
errors.push(format!(
"Password must be at least {} characters long",
self.config.min_length
));
}
if self.config.require_uppercase && !password.chars().any(|c| c.is_uppercase()) {
errors.push("Password must contain at least one uppercase letter".to_string());
}
if self.config.require_lowercase && !password.chars().any(|c| c.is_lowercase()) {
errors.push("Password must contain at least one lowercase letter".to_string());
}
if self.config.require_digit && !password.chars().any(|c| c.is_ascii_digit()) {
errors.push("Password must contain at least one digit".to_string());
}
if self.config.require_special && !password.chars().any(|c| !c.is_alphanumeric()) {
errors.push("Password must contain at least one special character".to_string());
}
let password_lower = password.to_lowercase();
for banned in &self.config.banned_passwords {
if password_lower.contains(&banned.to_lowercase()) {
errors.push("Password contains a commonly used word or pattern".to_string());
break;
}
}
let entropy = self.calculate_entropy(password);
if entropy < self.config.min_entropy_bits {
errors.push(format!(
"Password entropy too low ({:.2} bits, minimum {:.2} bits required)",
entropy, self.config.min_entropy_bits
));
}
PasswordStrengthResult {
valid: errors.is_empty(),
errors,
entropy_bits: entropy,
}
}
fn calculate_entropy(&self, password: &str) -> f64 {
if password.is_empty() {
return 0.0;
}
let mut char_counts: HashMap<char, usize> = HashMap::new();
for c in password.chars() {
*char_counts.entry(c).or_insert(0) += 1;
}
let len = password.len() as f64;
let mut entropy = 0.0;
for count in char_counts.values() {
let probability = *count as f64 / len;
entropy -= probability * probability.log2();
}
entropy * len
}
}
pub struct RateLimiter {
config: RateLimitConfig,
request_counts: Arc<RwLock<HashMap<IpAddr, Vec<u64>>>>,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
config,
request_counts: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn check_rate_limit(&self, ip: &IpAddr) -> bool {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let cutoff = now - self.config.window_secs;
let mut counts = self.request_counts.write().await;
let ip_counts = counts.entry(*ip).or_insert_with(Vec::new);
ip_counts.retain(|×tamp| timestamp > cutoff);
if ip_counts.len() >= self.config.max_requests as usize {
return false;
}
ip_counts.push(now);
true
}
pub async fn reset_limit(&self, ip: &IpAddr) {
let mut counts = self.request_counts.write().await;
counts.remove(ip);
}
}
pub struct AuthSecurity {
config: SecurityConfig,
brute_force: BruteForceProtector,
password_validator: PasswordStrengthValidator,
rate_limiter: RateLimiter,
audit_logger: Option<AuditLogger>,
}
impl AuthSecurity {
pub fn new(config: SecurityConfig) -> Self {
let audit_logger = if config.enable_audit_log {
Some(AuditLogger::new(10000)) } else {
None
};
Self {
brute_force: BruteForceProtector::new(config.brute_force.clone()),
password_validator: PasswordStrengthValidator::new(config.password_strength.clone()),
rate_limiter: RateLimiter::new(config.rate_limit.clone()),
audit_logger,
config,
}
}
pub fn brute_force(&self) -> &BruteForceProtector {
&self.brute_force
}
pub async fn check_auth_attempt(&self, username: &str, ip: Option<IpAddr>) -> Result<()> {
if let Some(ip_addr) = ip {
if !self.rate_limiter.check_rate_limit(&ip_addr).await {
if let Some(logger) = &self.audit_logger {
logger
.log(
AuthEvent::RateLimitExceeded,
username.to_string(),
Some(ip_addr),
None,
)
.await;
}
return Err(anyhow!("Rate limit exceeded for IP address"));
}
if self.brute_force.is_ip_locked(&ip_addr).await {
if let Some(logger) = &self.audit_logger {
logger
.log(
AuthEvent::AccountLocked,
username.to_string(),
Some(ip_addr),
Some("IP address locked".to_string()),
)
.await;
}
return Err(anyhow!("IP address is temporarily locked"));
}
}
if self.brute_force.is_user_locked(username).await {
if let Some(remaining) = self.brute_force.get_unlock_time(username).await {
if let Some(logger) = &self.audit_logger {
logger
.log(
AuthEvent::AccountLocked,
username.to_string(),
ip,
Some(format!("Account locked for {} seconds", remaining)),
)
.await;
}
return Err(anyhow!(
"Account is temporarily locked. Try again in {} seconds",
remaining
));
}
}
Ok(())
}
pub async fn record_auth_success(&self, username: &str, ip: Option<IpAddr>) {
self.brute_force.clear_user_attempts(username).await;
if let Some(ip_addr) = ip {
self.brute_force.clear_ip_attempts(&ip_addr).await;
}
if let Some(logger) = &self.audit_logger {
logger
.log(AuthEvent::Success, username.to_string(), ip, None)
.await;
}
}
pub async fn record_auth_failure(&self, username: &str, ip: Option<IpAddr>) {
self.brute_force.record_failed_attempt(username, ip).await;
if let Some(logger) = &self.audit_logger {
logger
.log(AuthEvent::Failure, username.to_string(), ip, None)
.await;
}
}
pub fn validate_password(&self, password: &str) -> PasswordStrengthResult {
self.password_validator.validate(password)
}
pub fn check_password_strength(&self, password: &str) -> Result<()> {
let result = self.password_validator.validate(password);
if result.valid {
Ok(())
} else {
Err(anyhow!(
"Password strength validation failed: {}",
result.errors.join(", ")
))
}
}
pub async fn log_password_change(&self, username: &str, ip: Option<IpAddr>) {
if let Some(logger) = &self.audit_logger {
logger
.log(AuthEvent::PasswordChanged, username.to_string(), ip, None)
.await;
}
}
pub async fn log_user_created(&self, username: &str, ip: Option<IpAddr>) {
if let Some(logger) = &self.audit_logger {
logger
.log(AuthEvent::UserCreated, username.to_string(), ip, None)
.await;
}
}
pub async fn log_user_deleted(&self, username: &str, ip: Option<IpAddr>) {
if let Some(logger) = &self.audit_logger {
logger
.log(AuthEvent::UserDeleted, username.to_string(), ip, None)
.await;
}
}
pub async fn get_audit_log(&self, count: usize) -> Option<Vec<AuditLogEntry>> {
if let Some(logger) = &self.audit_logger {
Some(logger.get_recent(count).await)
} else {
None
}
}
pub async fn get_user_audit_log(&self, username: &str) -> Option<Vec<AuditLogEntry>> {
if let Some(logger) = &self.audit_logger {
Some(logger.get_for_user(username).await)
} else {
None
}
}
pub async fn unlock_user(&self, username: &str) {
self.brute_force.unlock_user(username).await;
}
pub async fn unlock_ip(&self, ip: &IpAddr) {
self.brute_force.unlock_ip(ip).await;
}
pub async fn reset_rate_limit(&self, ip: &IpAddr) {
self.rate_limiter.reset_limit(ip).await;
}
pub fn config(&self) -> &SecurityConfig {
&self.config
}
}
pub struct AuthSecurityBuilder {
config: SecurityConfig,
}
impl AuthSecurityBuilder {
pub fn new() -> Self {
Self {
config: SecurityConfig::default(),
}
}
pub fn brute_force_config(mut self, config: BruteForceConfig) -> Self {
self.config.brute_force = config;
self
}
pub fn password_strength_config(mut self, config: PasswordStrengthConfig) -> Self {
self.config.password_strength = config;
self
}
pub fn rate_limit_config(mut self, config: RateLimitConfig) -> Self {
self.config.rate_limit = config;
self
}
pub fn enable_audit_log(mut self, enable: bool) -> Self {
self.config.enable_audit_log = enable;
self
}
pub fn build(self) -> AuthSecurity {
AuthSecurity::new(self.config)
}
}
impl Default for AuthSecurityBuilder {
fn default() -> Self {
Self::new()
}
}