auth_framework/
utils.rs

1//! Utility functions for the authentication framework.
2
3use crate::errors::{AuthError, Result};
4use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
5use argon2::password_hash::{rand_core::OsRng, SaltString};
6use ring::digest;
7use std::time::{SystemTime, UNIX_EPOCH};
8
9/// Password hashing utilities.
10pub mod password {
11    use super::*;
12
13    /// Hash a password using Argon2.
14    pub fn hash_password(password: &str) -> Result<String> {
15        let salt = SaltString::generate(&mut OsRng);
16        let argon2 = Argon2::default();
17        
18        let password_hash = argon2.hash_password(password.as_bytes(), &salt)
19            .map_err(|e| AuthError::crypto(format!("Password hashing failed: {e}")))?;
20        
21        Ok(password_hash.to_string())
22    }
23
24    /// Verify a password against a hash.
25    pub fn verify_password(password: &str, hash: &str) -> Result<bool> {
26        let parsed_hash = PasswordHash::new(hash)
27            .map_err(|e| AuthError::crypto(format!("Invalid password hash: {e}")))?;
28        
29        match Argon2::default().verify_password(password.as_bytes(), &parsed_hash) {
30            Ok(()) => Ok(true),
31            Err(_) => Ok(false),
32        }
33    }
34
35    /// Hash a password using bcrypt (alternative implementation).
36    pub fn hash_password_bcrypt(password: &str) -> Result<String> {
37        bcrypt::hash(password, bcrypt::DEFAULT_COST)
38            .map_err(|e| AuthError::crypto(format!("BCrypt hashing failed: {e}")))
39    }
40
41    /// Verify a password against a bcrypt hash.
42    pub fn verify_password_bcrypt(password: &str, hash: &str) -> Result<bool> {
43        bcrypt::verify(password, hash)
44            .map_err(|e| AuthError::crypto(format!("BCrypt verification failed: {e}")))
45    }
46
47    /// Generate a secure random password.
48    pub fn generate_password(length: usize) -> String {
49        use rand::Rng;
50        const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!@#$%^&*";
51        
52        let mut rng = rand::thread_rng();
53        (0..length)
54            .map(|_| {
55                let idx = rng.gen_range(0..CHARSET.len());
56                CHARSET[idx] as char
57            })
58            .collect()
59    }
60
61    /// Check password strength.
62    pub fn check_password_strength(password: &str) -> PasswordStrength {
63        let mut score = 0;
64        let mut feedback = Vec::new();
65
66        // Length check
67        if password.len() >= 8 {
68            score += 1;
69        } else {
70            feedback.push("Password should be at least 8 characters long".to_string());
71        }
72
73        if password.len() >= 12 {
74            score += 1;
75        }
76
77        // Character variety checks
78        if password.chars().any(|c| c.is_lowercase()) {
79            score += 1;
80        } else {
81            feedback.push("Password should contain lowercase letters".to_string());
82        }
83
84        if password.chars().any(|c| c.is_uppercase()) {
85            score += 1;
86        } else {
87            feedback.push("Password should contain uppercase letters".to_string());
88        }
89
90        if password.chars().any(|c| c.is_ascii_digit()) {
91            score += 1;
92        } else {
93            feedback.push("Password should contain numbers".to_string());
94        }
95
96        if password.chars().any(|c| "!@#$%^&*()_+-=[]{}|;:,.<>?".contains(c)) {
97            score += 1;
98        } else {
99            feedback.push("Password should contain special characters".to_string());
100        }
101
102        // Common password check (basic)
103        let common_passwords = ["password", "123456", "password123", "admin", "letmein"];
104        if common_passwords.contains(&password.to_lowercase().as_str()) {
105            score = 0;
106            feedback.push("Password is too common".to_string());
107        }
108
109        let strength = match score {
110            0..=2 => PasswordStrengthLevel::Weak,
111            3..=4 => PasswordStrengthLevel::Medium,
112            5..=6 => PasswordStrengthLevel::Strong,
113            _ => PasswordStrengthLevel::VeryStrong,
114        };
115
116        PasswordStrength {
117            level: strength,
118            score,
119            feedback,
120        }
121    }
122
123    /// Password strength assessment.
124    #[derive(Debug, Clone)]
125    pub struct PasswordStrength {
126        pub level: PasswordStrengthLevel,
127        pub score: u8,
128        pub feedback: Vec<String>,
129    }
130
131    /// Password strength levels.
132    #[derive(Debug, Clone, PartialEq)]
133    pub enum PasswordStrengthLevel {
134        Weak,
135        Medium,
136        Strong,
137        VeryStrong,
138    }
139}
140
141/// Cryptographic utilities.
142pub mod crypto {
143    use super::*;
144
145    /// Generate a secure random string.
146    pub fn generate_random_string(length: usize) -> String {
147        use rand::Rng;
148        const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
149        
150        let mut rng = rand::thread_rng();
151        (0..length)
152            .map(|_| {
153                let idx = rng.gen_range(0..CHARSET.len());
154                CHARSET[idx] as char
155            })
156            .collect()
157    }
158
159    /// Generate a secure random byte array.
160    pub fn generate_random_bytes(length: usize) -> Vec<u8> {
161        use rand::RngCore;
162        let mut bytes = vec![0u8; length];
163        rand::thread_rng().fill_bytes(&mut bytes);
164        bytes
165    }
166
167    /// Compute SHA256 hash.
168    pub fn sha256(data: &[u8]) -> Vec<u8> {
169        let digest = digest::digest(&digest::SHA256, data);
170        digest.as_ref().to_vec()
171    }
172
173    /// Compute SHA256 hash and return as hex string.
174    pub fn sha256_hex(data: &[u8]) -> String {
175        hex::encode(sha256(data))
176    }
177
178    /// Generate a secure token.
179    pub fn generate_token(length: usize) -> String {
180        use base64::Engine;
181        base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(generate_random_bytes(length))
182    }
183
184    /// Constant-time string comparison.
185    pub fn constant_time_eq(a: &str, b: &str) -> bool {
186        if a.len() != b.len() {
187            return false;
188        }
189        
190        // Use a simple constant-time comparison
191        let mut result = 0u8;
192        for (byte_a, byte_b) in a.as_bytes().iter().zip(b.as_bytes().iter()) {
193            result |= byte_a ^ byte_b;
194        }
195        result == 0
196    }
197}
198
199/// Time utilities.
200pub mod time {
201    use super::*;
202    use std::time::Duration;
203
204    /// Get current Unix timestamp.
205    pub fn current_timestamp() -> u64 {
206        SystemTime::now()
207            .duration_since(UNIX_EPOCH)
208            .unwrap()
209            .as_secs()
210    }
211
212    /// Get current Unix timestamp in milliseconds.
213    pub fn current_timestamp_millis() -> u64 {
214        SystemTime::now()
215            .duration_since(UNIX_EPOCH)
216            .unwrap()
217            .as_millis() as u64
218    }
219
220    /// Convert duration to seconds.
221    pub fn duration_to_seconds(duration: Duration) -> u64 {
222        duration.as_secs()
223    }
224
225    /// Convert seconds to duration.
226    pub fn seconds_to_duration(seconds: u64) -> Duration {
227        Duration::from_secs(seconds)
228    }
229
230    /// Check if a timestamp is expired.
231    pub fn is_expired(expires_at: u64) -> bool {
232        current_timestamp() > expires_at
233    }
234
235    /// Get time remaining until expiration.
236    pub fn time_until_expiry(expires_at: u64) -> Option<Duration> {
237        let now = current_timestamp();
238        if expires_at > now {
239            Some(Duration::from_secs(expires_at - now))
240        } else {
241            None
242        }
243    }
244}
245
246/// String utilities.
247pub mod string {
248    /// Mask a string for safe logging.
249    pub fn mask_string(input: &str, visible_chars: usize) -> String {
250        if input.len() <= visible_chars * 2 {
251            "*".repeat(input.len().min(8))
252        } else {
253            format!(
254                "{}{}{}",
255                &input[..visible_chars],
256                "*".repeat(input.len() - visible_chars * 2),
257                &input[input.len() - visible_chars..]
258            )
259        }
260    }
261
262    /// Truncate a string to a maximum length.
263    pub fn truncate(input: &str, max_length: usize) -> String {
264        if input.len() <= max_length {
265            input.to_string()
266        } else {
267            format!("{}...", &input[..max_length.saturating_sub(3)])
268        }
269    }
270
271    /// Check if a string is a valid email address (basic check).
272    pub fn is_valid_email(email: &str) -> bool {
273        email.contains('@') && email.contains('.') && email.len() > 5
274    }
275
276    /// Normalize an email address.
277    pub fn normalize_email(email: &str) -> String {
278        email.trim().to_lowercase()
279    }
280
281    /// Generate a random identifier.
282    pub fn generate_id(prefix: Option<&str>) -> String {
283        let id = uuid::Uuid::new_v4().to_string();
284        match prefix {
285            Some(prefix) => format!("{prefix}_{id}"),
286            None => id,
287        }
288    }
289}
290
291/// Validation utilities.
292pub mod validation {
293    use super::*;
294
295    /// Validate username format.
296    pub fn validate_username(username: &str) -> Result<()> {
297        if username.is_empty() {
298            return Err(AuthError::validation("Username cannot be empty"));
299        }
300
301        if username.len() < 3 {
302            return Err(AuthError::validation("Username must be at least 3 characters long"));
303        }
304
305        if username.len() > 50 {
306            return Err(AuthError::validation("Username cannot be longer than 50 characters"));
307        }
308
309        // Check for valid characters (alphanumeric, underscore, hyphen)
310        if !username.chars().all(|c| c.is_alphanumeric() || c == '_' || c == '-') {
311            return Err(AuthError::validation(
312                "Username can only contain letters, numbers, underscores, and hyphens"
313            ));
314        }
315
316        // Cannot start or end with special characters
317        if username.starts_with('_') || username.starts_with('-') ||
318           username.ends_with('_') || username.ends_with('-') {
319            return Err(AuthError::validation(
320                "Username cannot start or end with underscore or hyphen"
321            ));
322        }
323
324        Ok(())
325    }
326
327    /// Validate email format.
328    pub fn validate_email(email: &str) -> Result<()> {
329        if email.is_empty() {
330            return Err(AuthError::validation("Email cannot be empty"));
331        }
332
333        if !string::is_valid_email(email) {
334            return Err(AuthError::validation("Invalid email format"));
335        }
336
337        if email.len() > 254 {
338            return Err(AuthError::validation("Email address is too long"));
339        }
340
341        Ok(())
342    }
343
344    /// Validate password according to policy.
345    pub fn validate_password(password: &str, min_length: usize, require_complexity: bool) -> Result<()> {
346        if password.is_empty() {
347            return Err(AuthError::validation("Password cannot be empty"));
348        }
349
350        if password.len() < min_length {
351            return Err(AuthError::validation(
352                format!("Password must be at least {min_length} characters long")
353            ));
354        }
355
356        if require_complexity {
357            let strength = password::check_password_strength(password);
358            if matches!(strength.level, password::PasswordStrengthLevel::Weak) {
359                return Err(AuthError::validation(
360                    format!("Password is too weak: {}", strength.feedback.join(", "))
361                ));
362            }
363        }
364
365        Ok(())
366    }
367
368    /// Validate API key format.
369    pub fn validate_api_key(api_key: &str, expected_prefix: Option<&str>) -> Result<()> {
370        if api_key.is_empty() {
371            return Err(AuthError::validation("API key cannot be empty"));
372        }
373
374        if let Some(prefix) = expected_prefix {
375            if !api_key.starts_with(prefix) {
376                return Err(AuthError::validation(
377                    format!("API key must start with '{prefix}'")
378                ));
379            }
380        }
381
382        // Basic length check
383        if api_key.len() < 16 {
384            return Err(AuthError::validation("API key is too short"));
385        }
386
387        if api_key.len() > 128 {
388            return Err(AuthError::validation("API key is too long"));
389        }
390
391        Ok(())
392    }
393}
394
395/// Rate limiting utilities.
396pub mod rate_limit {
397    use std::collections::HashMap;
398    use std::sync::{Arc, Mutex};
399    use std::time::{Duration, Instant};
400
401    /// Simple in-memory rate limiter.
402    #[derive(Debug)]
403    pub struct RateLimiter {
404        buckets: Arc<Mutex<HashMap<String, Bucket>>>,
405        max_requests: u32,
406        window: Duration,
407    }
408
409    #[derive(Debug)]
410    struct Bucket {
411        count: u32,
412        window_start: Instant,
413    }
414
415    impl RateLimiter {
416        /// Create a new rate limiter.
417        pub fn new(max_requests: u32, window: Duration) -> Self {
418            Self {
419                buckets: Arc::new(Mutex::new(HashMap::new())),
420                max_requests,
421                window,
422            }
423        }
424
425        /// Check if a request is allowed for the given key.
426        pub fn is_allowed(&self, key: &str) -> bool {
427            let mut buckets = self.buckets.lock().unwrap();
428            let now = Instant::now();
429
430            let bucket = buckets.entry(key.to_string()).or_insert(Bucket {
431                count: 0,
432                window_start: now,
433            });
434
435            // Reset bucket if window has passed
436            if now.duration_since(bucket.window_start) >= self.window {
437                bucket.count = 0;
438                bucket.window_start = now;
439            }
440
441            // Check if under limit
442            if bucket.count < self.max_requests {
443                bucket.count += 1;
444                true
445            } else {
446                false
447            }
448        }
449
450        /// Get remaining requests for a key.
451        pub fn remaining_requests(&self, key: &str) -> u32 {
452            let buckets = self.buckets.lock().unwrap();
453            if let Some(bucket) = buckets.get(key) {
454                let now = Instant::now();
455                if now.duration_since(bucket.window_start) >= self.window {
456                    self.max_requests
457                } else {
458                    self.max_requests.saturating_sub(bucket.count)
459                }
460            } else {
461                self.max_requests
462            }
463        }
464
465        /// Clean up expired buckets.
466        pub fn cleanup(&self) {
467            let mut buckets = self.buckets.lock().unwrap();
468            let now = Instant::now();
469            
470            buckets.retain(|_, bucket| {
471                now.duration_since(bucket.window_start) < self.window
472            });
473        }
474    }
475}
476
477#[cfg(test)]
478mod tests {
479    use super::*;
480
481    #[test]
482    fn test_password_hashing() {
483        let password = "test_password_123";
484        let hash = password::hash_password(password).unwrap();
485        
486        assert!(password::verify_password(password, &hash).unwrap());
487        assert!(!password::verify_password("wrong_password", &hash).unwrap());
488    }
489
490    #[test]
491    fn test_password_strength() {
492        let weak = password::check_password_strength("123");
493        assert!(matches!(weak.level, password::PasswordStrengthLevel::Weak));
494        
495        let strong = password::check_password_strength("MySecureP@ssw0rd!");
496        assert!(matches!(strong.level, password::PasswordStrengthLevel::Strong | password::PasswordStrengthLevel::VeryStrong));
497    }
498
499    #[test]
500    fn test_crypto_utils() {
501        let random_string = crypto::generate_random_string(16);
502        assert_eq!(random_string.len(), 16);
503        
504        let data = b"test data";
505        let hash = crypto::sha256_hex(data);
506        assert_eq!(hash.len(), 64); // SHA256 hex is 64 characters
507    }
508
509    #[test]
510    fn test_string_utils() {
511        let masked = string::mask_string("secret123456", 2);
512        assert!(masked.starts_with("se"));
513        assert!(masked.ends_with("56"));
514        assert!(masked.contains("*"));
515        
516        assert!(string::is_valid_email("test@example.com"));
517        assert!(!string::is_valid_email("invalid_email"));
518    }
519
520    #[test]
521    fn test_validation() {
522        // Valid username
523        assert!(validation::validate_username("test_user").is_ok());
524        
525        // Invalid usernames
526        assert!(validation::validate_username("").is_err());
527        assert!(validation::validate_username("ab").is_err());
528        assert!(validation::validate_username("_invalid").is_err());
529        assert!(validation::validate_username("invalid@").is_err());
530        
531        // Valid email
532        assert!(validation::validate_email("test@example.com").is_ok());
533        
534        // Invalid emails
535        assert!(validation::validate_email("").is_err());
536        assert!(validation::validate_email("invalid").is_err());
537    }
538
539    #[test]
540    fn test_rate_limiter() {
541        let limiter = rate_limit::RateLimiter::new(3, std::time::Duration::from_secs(1));
542        
543        // First 3 requests should be allowed
544        assert!(limiter.is_allowed("user1"));
545        assert!(limiter.is_allowed("user1"));
546        assert!(limiter.is_allowed("user1"));
547        
548        // 4th request should be blocked
549        assert!(!limiter.is_allowed("user1"));
550        
551        // Different user should still be allowed
552        assert!(limiter.is_allowed("user2"));
553    }
554}