use crate::errors::{AuthError, Result};
use rand::Rng;
pub use rate_limit::RateLimiter;
use ring::digest;
use std::time::{SystemTime, UNIX_EPOCH};
pub mod password {
use super::*;
pub fn hash_password(password: &str) -> Result<String> {
bcrypt::hash(password, bcrypt::DEFAULT_COST)
.map_err(|e| AuthError::crypto(format!("Password hashing failed: {e}")))
}
pub fn verify_password(password: &str, hash: &str) -> Result<bool> {
bcrypt::verify(password, hash)
.map_err(|e| AuthError::crypto(format!("Password verification failed: {e}")))
}
pub fn generate_password(length: usize) -> String {
const CHARSET: &[u8] =
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!@#$%^&*";
let mut rng = rand::rng();
(0..length)
.map(|_| CHARSET[rng.random_range(0..CHARSET.len())] as char)
.collect()
}
pub fn check_password_strength(password: &str) -> PasswordStrength {
let mut score = 0;
let mut feedback = Vec::new();
if password.len() >= 8 {
score += 1;
} else {
feedback.push("Password should be at least 8 characters long".to_string());
}
if password.len() >= 12 {
score += 1;
}
if password.len() >= 16 {
score += 1; }
if password.chars().any(|c| c.is_lowercase()) {
score += 1;
} else {
feedback.push("Password should contain lowercase letters".to_string());
}
if password.chars().any(|c| c.is_uppercase()) {
score += 1;
} else {
feedback.push("Password should contain uppercase letters".to_string());
}
if password.chars().any(|c| c.is_ascii_digit()) {
score += 1;
} else {
feedback.push("Password should contain numbers".to_string());
}
if password
.chars()
.any(|c| "!@#$%^&*()_+-=[]{}|;:,.<>?".contains(c))
{
score += 1;
} else {
feedback.push("Password should contain special characters".to_string());
}
let common_passwords = ["password", "123456", "password123", "admin", "letmein"];
if common_passwords.contains(&password.to_lowercase().as_str()) {
score = 0;
feedback.push("Password is too common".to_string());
}
let strength = match score {
0..=2 => PasswordStrengthLevel::Weak,
3..=4 => PasswordStrengthLevel::Medium,
5..=6 => PasswordStrengthLevel::Strong,
_ => PasswordStrengthLevel::VeryStrong,
};
PasswordStrength {
level: strength,
score,
feedback,
}
}
#[derive(Debug, Clone)]
pub struct PasswordStrength {
pub level: PasswordStrengthLevel,
pub score: u8,
pub feedback: Vec<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum PasswordStrengthLevel {
Weak,
Medium,
Strong,
VeryStrong,
}
}
pub mod crypto {
use super::*;
pub fn generate_random_string(length: usize) -> String {
const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
let mut rng = rand::rng();
(0..length)
.map(|_| {
let idx = rng.random_range(0..CHARSET.len());
CHARSET[idx] as char
})
.collect()
}
pub fn generate_random_bytes(length: usize) -> Vec<u8> {
use rand::RngCore;
let mut bytes = vec![0u8; length];
rand::rng().fill_bytes(&mut bytes);
bytes
}
pub fn sha256(data: &[u8]) -> Vec<u8> {
let digest = digest::digest(&digest::SHA256, data);
digest.as_ref().to_vec()
}
pub fn sha256_hex(data: &[u8]) -> String {
hex::encode(sha256(data))
}
pub fn generate_token(length: usize) -> String {
use base64::Engine;
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(generate_random_bytes(length))
}
pub fn constant_time_eq(a: &str, b: &str) -> bool {
if a.len() != b.len() {
return false;
}
let mut result = 0u8;
for (byte_a, byte_b) in a.as_bytes().iter().zip(b.as_bytes().iter()) {
result |= byte_a ^ byte_b;
}
result == 0
}
}
pub mod time {
use super::*;
use std::time::Duration;
pub fn current_timestamp() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
}
pub fn current_timestamp_millis() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64
}
pub fn duration_to_seconds(duration: Duration) -> u64 {
duration.as_secs()
}
pub fn seconds_to_duration(seconds: u64) -> Duration {
Duration::from_secs(seconds)
}
pub fn is_expired(expires_at: u64) -> bool {
current_timestamp() > expires_at
}
pub fn time_until_expiry(expires_at: u64) -> Option<Duration> {
let now = current_timestamp();
if expires_at > now {
Some(Duration::from_secs(expires_at - now))
} else {
None
}
}
}
pub mod string {
pub fn mask_string(input: &str, visible_chars: usize) -> String {
if input.is_empty() {
return String::new();
}
if visible_chars >= input.len() {
return input.to_string();
}
if input.len() <= visible_chars * 2 {
"*".repeat(input.len().min(8))
} else {
format!(
"{}{}{}",
&input[..visible_chars],
"*".repeat(input.len() - visible_chars * 2),
&input[input.len() - visible_chars..]
)
}
}
pub fn truncate(input: &str, max_length: usize) -> String {
if input.len() <= max_length {
input.to_string()
} else {
format!("{}...", &input[..max_length.saturating_sub(3)])
}
}
pub fn is_valid_email(email: &str) -> bool {
if email.len() <= 5 || !email.contains('@') || !email.contains('.') {
return false;
}
if email.starts_with('@') || email.ends_with('@') {
return false;
}
if email.contains(' ') {
return false;
}
if email.matches('@').count() != 1 {
return false;
}
let parts: Vec<&str> = email.split('@').collect();
let local = parts[0];
let domain = parts[1];
if local.is_empty() {
return false;
}
if domain.is_empty() || !domain.contains('.') {
return false;
}
if domain.starts_with('.') || domain.ends_with('.') {
return false;
}
if domain.contains("..") {
return false;
}
true
}
pub fn normalize_email(email: &str) -> String {
email.trim().to_lowercase()
}
pub fn generate_id(prefix: Option<&str>) -> String {
let id = uuid::Uuid::new_v4().to_string();
match prefix {
Some(prefix) => format!("{prefix}_{id}"),
None => id,
}
}
}
pub mod validation {
use super::*;
pub fn validate_username(username: &str) -> Result<()> {
if username.is_empty() {
return Err(AuthError::validation("Username cannot be empty"));
}
if username.len() < 3 {
return Err(AuthError::validation(
"Username must be at least 3 characters long",
));
}
if username.len() > 50 {
return Err(AuthError::validation(
"Username cannot be longer than 50 characters",
));
}
if !username
.chars()
.all(|c| c.is_alphanumeric() || c == '_' || c == '-')
{
return Err(AuthError::validation(
"Username can only contain letters, numbers, underscores, and hyphens",
));
}
if username.starts_with('_')
|| username.starts_with('-')
|| username.ends_with('_')
|| username.ends_with('-')
{
return Err(AuthError::validation(
"Username cannot start or end with underscore or hyphen",
));
}
Ok(())
}
pub fn validate_email(email: &str) -> Result<()> {
use crate::security::secure_utils::SecureValidation;
SecureValidation::validate_email(email).map(|_| ())
}
pub fn validate_password(
password: &str,
min_length: usize,
require_complexity: bool,
) -> Result<()> {
if password.is_empty() {
return Err(AuthError::validation("Password cannot be empty"));
}
if password.len() < min_length {
return Err(AuthError::validation(format!(
"Password must be at least {min_length} characters long"
)));
}
if require_complexity {
let strength = password::check_password_strength(password);
if matches!(strength.level, password::PasswordStrengthLevel::Weak) {
return Err(AuthError::validation(format!(
"Password is too weak: {}",
strength.feedback.join(", ")
)));
}
}
Ok(())
}
pub fn validate_api_key(api_key: &str, expected_prefix: Option<&str>) -> Result<()> {
if api_key.is_empty() {
return Err(AuthError::validation("API key cannot be empty"));
}
if let Some(prefix) = expected_prefix
&& !api_key.starts_with(prefix)
{
return Err(AuthError::validation(format!(
"API key must start with '{prefix}'"
)));
}
if api_key.len() < 16 {
return Err(AuthError::validation("API key is too short"));
}
if api_key.len() > 128 {
return Err(AuthError::validation("API key is too long"));
}
Ok(())
}
}
pub mod rate_limit {
use dashmap::DashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug)]
pub struct RateLimiter {
buckets: Arc<DashMap<String, Bucket>>,
max_requests: u32,
window: Duration,
}
#[derive(Debug)]
struct Bucket {
count: u32,
window_start: Instant,
}
impl RateLimiter {
pub fn new(max_requests: u32, window: Duration) -> Self {
Self {
buckets: Arc::new(DashMap::new()),
max_requests,
window,
}
}
pub fn is_allowed(&self, key: &str) -> bool {
let now = Instant::now();
let mut bucket = self.buckets.entry(key.to_string()).or_insert(Bucket {
count: 0,
window_start: now,
});
if now.duration_since(bucket.window_start) >= self.window {
bucket.count = 0;
bucket.window_start = now;
}
if bucket.count < self.max_requests {
bucket.count += 1;
true
} else {
false
}
}
pub fn remaining_requests(&self, key: &str) -> u32 {
if let Some(bucket_ref) = self.buckets.get(key) {
let bucket = bucket_ref.value();
let now = Instant::now();
if now.duration_since(bucket.window_start) >= self.window {
self.max_requests
} else {
self.max_requests.saturating_sub(bucket.count)
}
} else {
self.max_requests
}
}
pub fn cleanup(&self) {
let now = Instant::now();
self.buckets
.retain(|_, bucket| now.duration_since(bucket.window_start) < self.window);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_password_hashing() {
let password = "test_password_123";
let hash = password::hash_password(password).unwrap();
assert!(password::verify_password(password, &hash).unwrap());
assert!(!password::verify_password("wrong_password", &hash).unwrap());
}
#[test]
fn test_password_strength() {
let weak = password::check_password_strength("123");
assert!(matches!(weak.level, password::PasswordStrengthLevel::Weak));
let strong = password::check_password_strength("MySecureP@ssw0rd!");
assert!(matches!(
strong.level,
password::PasswordStrengthLevel::Strong | password::PasswordStrengthLevel::VeryStrong
));
}
#[test]
fn test_crypto_utils() {
let random_string = crypto::generate_random_string(16);
assert_eq!(random_string.len(), 16);
let data = b"test data";
let hash = crypto::sha256_hex(data);
assert_eq!(hash.len(), 64); }
#[test]
fn test_string_utils() {
let masked = string::mask_string("secret123456", 2);
assert!(masked.starts_with("se"));
assert!(masked.ends_with("56"));
assert!(masked.contains("*"));
assert!(string::is_valid_email("test@example.com"));
assert!(!string::is_valid_email("invalid_email"));
}
#[test]
fn test_validation() {
assert!(validation::validate_username("test_user").is_ok());
assert!(validation::validate_username("").is_err());
assert!(validation::validate_username("ab").is_err());
assert!(validation::validate_username("_invalid").is_err());
assert!(validation::validate_username("invalid@").is_err());
assert!(validation::validate_email("test@example.com").is_ok());
assert!(validation::validate_email("").is_err());
assert!(validation::validate_email("invalid").is_err());
}
#[test]
fn test_rate_limiter() {
let limiter = rate_limit::RateLimiter::new(3, std::time::Duration::from_secs(1));
assert!(limiter.is_allowed("user1"));
assert!(limiter.is_allowed("user1"));
assert!(limiter.is_allowed("user1"));
assert!(!limiter.is_allowed("user1"));
assert!(limiter.is_allowed("user2"));
}
#[test]
fn test_password_hashing_edge_cases() {
let long_password = "a".repeat(1000);
let hash = password::hash_password(&long_password).unwrap();
assert!(password::verify_password(&long_password, &hash).unwrap());
let special_password = "!@#$%^&*()_+-=[]{}|;:,.<>?";
let hash = password::hash_password(special_password).unwrap();
assert!(password::verify_password(special_password, &hash).unwrap());
let unicode_password = "пароль测试🔒";
let hash = password::hash_password(unicode_password).unwrap();
assert!(password::verify_password(unicode_password, &hash).unwrap());
let password1 = "password123";
let password2 = "password124";
let hash1 = password::hash_password(password1).unwrap();
let hash2 = password::hash_password(password2).unwrap();
assert_ne!(hash1, hash2);
}
#[test]
fn test_password_strength_comprehensive() {
let test_cases = vec![
("", password::PasswordStrengthLevel::Weak),
("a", password::PasswordStrengthLevel::Weak),
("password", password::PasswordStrengthLevel::Weak),
("password123", password::PasswordStrengthLevel::Weak), ("mypassword123", password::PasswordStrengthLevel::Medium), ("MyPassword123", password::PasswordStrengthLevel::Medium),
("MyPassword123!", password::PasswordStrengthLevel::Strong),
(
"VerySecureP@ssw0rd2024!",
password::PasswordStrengthLevel::VeryStrong,
),
];
for (password, expected_min_level) in test_cases {
let strength = password::check_password_strength(password);
match expected_min_level {
password::PasswordStrengthLevel::Weak => {
}
password::PasswordStrengthLevel::Medium => {
assert!(
!matches!(strength.level, password::PasswordStrengthLevel::Weak),
"Password '{}' should be at least Medium strength",
password
);
}
password::PasswordStrengthLevel::Strong => {
assert!(
matches!(
strength.level,
password::PasswordStrengthLevel::Strong
| password::PasswordStrengthLevel::VeryStrong
),
"Password '{}' should be at least Strong",
password
);
}
password::PasswordStrengthLevel::VeryStrong => {
assert!(
matches!(strength.level, password::PasswordStrengthLevel::VeryStrong),
"Password '{}' should be VeryStrong",
password
);
}
}
}
}
#[test]
fn test_crypto_utils_edge_cases() {
let lengths = vec![0, 1, 8, 16, 32, 64, 128];
for length in lengths {
let random_string = crypto::generate_random_string(length);
assert_eq!(
random_string.len(),
length,
"Generated string should have requested length"
);
if length > 0 {
let another_string = crypto::generate_random_string(length);
if length > 4 {
assert_ne!(
random_string, another_string,
"Random strings should be different"
);
}
}
}
let test_data = vec![
b"".as_slice(),
b"a",
b"hello world",
&[0u8; 1000], "unicode: 测试 🔒".as_bytes(),
];
for data in test_data {
let hash = crypto::sha256_hex(data);
assert_eq!(hash.len(), 64, "SHA256 hex should always be 64 characters");
let hash2 = crypto::sha256_hex(data);
assert_eq!(hash, hash2, "Same input should produce same hash");
}
}
#[test]
fn test_string_utils_comprehensive() {
let masking_tests = vec![
("", 0),
("a", 1),
("ab", 1),
("secret", 2),
("verylongsecret", 3),
("short", 10), ];
for (input, reveal_chars) in masking_tests {
let masked = string::mask_string(input, reveal_chars);
if input.is_empty() {
assert_eq!(masked, "");
} else if reveal_chars >= input.len() {
assert_eq!(masked, input, "Should not mask if reveal_chars >= length");
} else if input.len() > reveal_chars * 2 {
assert!(
masked.starts_with(&input[..reveal_chars]),
"Should preserve first {} characters",
reveal_chars
);
assert!(masked.contains("*"), "Should contain masking characters");
} else {
assert!(
masked.contains("*"),
"Should contain masking characters for short strings"
);
}
}
let valid_emails = vec![
"user@example.com",
"user.name@example.com",
"user+tag@example.co.uk",
"user123@example-domain.com",
"a@b.co",
"test_email@domain.info",
];
let invalid_emails = vec![
"",
"user",
"@example.com",
"user@",
"user@@example.com",
"user@example",
"user @example.com",
"user@exam ple.com",
"user@.example.com",
"user@example..com",
];
for email in valid_emails {
assert!(
string::is_valid_email(email),
"Should accept valid email: {}",
email
);
}
for email in invalid_emails {
assert!(
!string::is_valid_email(email),
"Should reject invalid email: {}",
email
);
}
}
#[test]
fn test_validation_comprehensive() {
let valid_usernames = vec!["user", "user123", "user_name", "user-name", "abc"];
let invalid_usernames = vec![
"",
"us", "a", "user name", "user@domain", "user\0name", "_invalid", ];
for username in valid_usernames {
assert!(
validation::validate_username(username).is_ok(),
"Should accept valid username: {}",
username
);
}
for username in invalid_usernames {
assert!(
validation::validate_username(username).is_err(),
"Should reject invalid username: {}",
username
);
}
let valid_emails = vec![
"test@example.com",
"user.name@domain.co.uk",
"user+tag@example.org",
];
let invalid_emails = vec!["", "invalid", "@example.com", "user@", "user@@example.com"];
for email in valid_emails {
assert!(
validation::validate_email(email).is_ok(),
"Should accept valid email: {}",
email
);
}
for email in invalid_emails {
assert!(
validation::validate_email(email).is_err(),
"Should reject invalid email: {}",
email
);
}
}
#[test]
fn test_rate_limiter_edge_cases() {
let zero_limiter = rate_limit::RateLimiter::new(0, std::time::Duration::from_secs(60));
assert!(!zero_limiter.is_allowed("user1"));
let short_limiter = rate_limit::RateLimiter::new(1, std::time::Duration::from_millis(10));
assert!(short_limiter.is_allowed("user1"));
assert!(!short_limiter.is_allowed("user1"));
std::thread::sleep(std::time::Duration::from_millis(20));
assert!(short_limiter.is_allowed("user1")); }
#[test]
fn test_rate_limiter_multiple_users() {
let limiter = rate_limit::RateLimiter::new(2, std::time::Duration::from_secs(60));
assert!(limiter.is_allowed("user1"));
assert!(limiter.is_allowed("user1"));
assert!(!limiter.is_allowed("user1"));
assert!(limiter.is_allowed("user2"));
assert!(limiter.is_allowed("user2"));
assert!(!limiter.is_allowed("user2"));
assert!(limiter.is_allowed("user3"));
assert!(limiter.is_allowed("user3"));
assert!(!limiter.is_allowed("user3")); }
#[test]
fn test_crypto_random_uniqueness() {
let mut strings = std::collections::HashSet::new();
for _ in 0..1000 {
let random_string = crypto::generate_random_string(16);
assert!(
!strings.contains(&random_string),
"Generated duplicate random string"
);
strings.insert(random_string);
}
}
#[test]
fn test_password_hash_uniqueness() {
let password = "test_password_123";
let mut hashes = std::collections::HashSet::new();
for _ in 0..10 {
let hash = password::hash_password(password).unwrap();
assert!(
!hashes.contains(&hash),
"Password hashes should be unique due to salt"
);
hashes.insert(hash.clone());
assert!(password::verify_password(password, &hash).unwrap());
}
}
}