1use crate::errors::{AuthError, Result};
3use rand::Rng;
4pub use rate_limit::RateLimiter;
5use ring::digest;
6use std::time::{SystemTime, UNIX_EPOCH};
7
8pub mod password {
10 use super::*;
11
12 pub fn hash_password(password: &str) -> Result<String> {
14 bcrypt::hash(password, bcrypt::DEFAULT_COST)
15 .map_err(|e| AuthError::crypto(format!("Password hashing failed: {e}")))
16 }
17
18 pub fn verify_password(password: &str, hash: &str) -> Result<bool> {
20 bcrypt::verify(password, hash)
21 .map_err(|e| AuthError::crypto(format!("Password verification failed: {e}")))
22 }
23
24 pub fn generate_password(length: usize) -> String {
26 const CHARSET: &[u8] =
27 b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!@#$%^&*";
28 let mut rng = rand::rng();
29 (0..length)
30 .map(|_| CHARSET[rng.random_range(0..CHARSET.len())] as char)
31 .collect()
32 }
33
34 pub fn check_password_strength(password: &str) -> PasswordStrength {
36 let mut score = 0;
37 let mut feedback = Vec::new();
38
39 if password.len() >= 8 {
41 score += 1;
42 } else {
43 feedback.push("Password should be at least 8 characters long".to_string());
44 }
45
46 if password.len() >= 12 {
47 score += 1;
48 }
49
50 if password.len() >= 16 {
51 score += 1; }
53
54 if password.chars().any(|c| c.is_lowercase()) {
56 score += 1;
57 } else {
58 feedback.push("Password should contain lowercase letters".to_string());
59 }
60
61 if password.chars().any(|c| c.is_uppercase()) {
62 score += 1;
63 } else {
64 feedback.push("Password should contain uppercase letters".to_string());
65 }
66
67 if password.chars().any(|c| c.is_ascii_digit()) {
68 score += 1;
69 } else {
70 feedback.push("Password should contain numbers".to_string());
71 }
72
73 if password
74 .chars()
75 .any(|c| "!@#$%^&*()_+-=[]{}|;:,.<>?".contains(c))
76 {
77 score += 1;
78 } else {
79 feedback.push("Password should contain special characters".to_string());
80 }
81
82 let common_passwords = ["password", "123456", "password123", "admin", "letmein"];
84 if common_passwords.contains(&password.to_lowercase().as_str()) {
85 score = 0;
86 feedback.push("Password is too common".to_string());
87 }
88
89 let strength = match score {
90 0..=2 => PasswordStrengthLevel::Weak,
91 3..=4 => PasswordStrengthLevel::Medium,
92 5..=6 => PasswordStrengthLevel::Strong,
93 _ => PasswordStrengthLevel::VeryStrong,
94 };
95
96 PasswordStrength {
97 level: strength,
98 score,
99 feedback,
100 }
101 }
102
103 #[derive(Debug, Clone)]
105 pub struct PasswordStrength {
106 pub level: PasswordStrengthLevel,
107 pub score: u8,
108 pub feedback: Vec<String>,
109 }
110
111 #[derive(Debug, Clone, PartialEq)]
113 pub enum PasswordStrengthLevel {
114 Weak,
115 Medium,
116 Strong,
117 VeryStrong,
118 }
119}
120
121pub mod crypto {
123 use super::*;
124
125 pub fn generate_random_string(length: usize) -> String {
127 const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
128
129 let mut rng = rand::rng();
130 (0..length)
131 .map(|_| {
132 let idx = rng.random_range(0..CHARSET.len());
133 CHARSET[idx] as char
134 })
135 .collect()
136 }
137
138 pub fn generate_random_bytes(length: usize) -> Vec<u8> {
140 use rand::RngCore;
141 let mut bytes = vec![0u8; length];
142 rand::rng().fill_bytes(&mut bytes);
143 bytes
144 }
145
146 pub fn sha256(data: &[u8]) -> Vec<u8> {
148 let digest = digest::digest(&digest::SHA256, data);
149 digest.as_ref().to_vec()
150 }
151
152 pub fn sha256_hex(data: &[u8]) -> String {
154 hex::encode(sha256(data))
155 }
156
157 pub fn generate_token(length: usize) -> String {
159 use base64::Engine;
160 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(generate_random_bytes(length))
161 }
162
163 pub fn constant_time_eq(a: &str, b: &str) -> bool {
165 if a.len() != b.len() {
166 return false;
167 }
168
169 let mut result = 0u8;
171 for (byte_a, byte_b) in a.as_bytes().iter().zip(b.as_bytes().iter()) {
172 result |= byte_a ^ byte_b;
173 }
174 result == 0
175 }
176}
177
178pub mod time {
180 use super::*;
181 use std::time::Duration;
182
183 pub fn current_timestamp() -> u64 {
185 SystemTime::now()
186 .duration_since(UNIX_EPOCH)
187 .unwrap()
188 .as_secs()
189 }
190
191 pub fn current_timestamp_millis() -> u64 {
193 SystemTime::now()
194 .duration_since(UNIX_EPOCH)
195 .unwrap()
196 .as_millis() as u64
197 }
198
199 pub fn duration_to_seconds(duration: Duration) -> u64 {
201 duration.as_secs()
202 }
203
204 pub fn seconds_to_duration(seconds: u64) -> Duration {
206 Duration::from_secs(seconds)
207 }
208
209 pub fn is_expired(expires_at: u64) -> bool {
211 current_timestamp() > expires_at
212 }
213
214 pub fn time_until_expiry(expires_at: u64) -> Option<Duration> {
216 let now = current_timestamp();
217 if expires_at > now {
218 Some(Duration::from_secs(expires_at - now))
219 } else {
220 None
221 }
222 }
223}
224
225pub mod string {
227 pub fn mask_string(input: &str, visible_chars: usize) -> String {
229 if input.is_empty() {
230 return String::new();
231 }
232
233 if visible_chars >= input.len() {
234 return input.to_string();
235 }
236
237 if input.len() <= visible_chars * 2 {
238 "*".repeat(input.len().min(8))
239 } else {
240 format!(
241 "{}{}{}",
242 &input[..visible_chars],
243 "*".repeat(input.len() - visible_chars * 2),
244 &input[input.len() - visible_chars..]
245 )
246 }
247 }
248
249 pub fn truncate(input: &str, max_length: usize) -> String {
251 if input.len() <= max_length {
252 input.to_string()
253 } else {
254 format!("{}...", &input[..max_length.saturating_sub(3)])
255 }
256 }
257
258 pub fn is_valid_email(email: &str) -> bool {
260 if email.len() <= 5 || !email.contains('@') || !email.contains('.') {
261 return false;
262 }
263
264 if email.starts_with('@') || email.ends_with('@') {
266 return false;
267 }
268
269 if email.contains(' ') {
271 return false;
272 }
273
274 if email.matches('@').count() != 1 {
276 return false;
277 }
278
279 let parts: Vec<&str> = email.split('@').collect();
280 let local = parts[0];
281 let domain = parts[1];
282
283 if local.is_empty() {
285 return false;
286 }
287
288 if domain.is_empty() || !domain.contains('.') {
290 return false;
291 }
292
293 if domain.starts_with('.') || domain.ends_with('.') {
295 return false;
296 }
297
298 if domain.contains("..") {
300 return false;
301 }
302
303 true
304 }
305 pub fn normalize_email(email: &str) -> String {
307 email.trim().to_lowercase()
308 }
309
310 pub fn generate_id(prefix: Option<&str>) -> String {
312 let id = uuid::Uuid::new_v4().to_string();
313 match prefix {
314 Some(prefix) => format!("{prefix}_{id}"),
315 None => id,
316 }
317 }
318}
319
320pub mod validation {
322 use super::*;
323
324 pub fn validate_username(username: &str) -> Result<()> {
326 if username.is_empty() {
327 return Err(AuthError::validation("Username cannot be empty"));
328 }
329
330 if username.len() < 3 {
331 return Err(AuthError::validation(
332 "Username must be at least 3 characters long",
333 ));
334 }
335
336 if username.len() > 50 {
337 return Err(AuthError::validation(
338 "Username cannot be longer than 50 characters",
339 ));
340 }
341
342 if !username
344 .chars()
345 .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
346 {
347 return Err(AuthError::validation(
348 "Username can only contain letters, numbers, underscores, and hyphens",
349 ));
350 }
351
352 if username.starts_with('_')
354 || username.starts_with('-')
355 || username.ends_with('_')
356 || username.ends_with('-')
357 {
358 return Err(AuthError::validation(
359 "Username cannot start or end with underscore or hyphen",
360 ));
361 }
362
363 Ok(())
364 }
365
366 pub fn validate_email(email: &str) -> Result<()> {
368 use crate::security::secure_utils::SecureValidation;
369 SecureValidation::validate_email(email).map(|_| ())
370 }
371
372 pub fn validate_password(
374 password: &str,
375 min_length: usize,
376 require_complexity: bool,
377 ) -> Result<()> {
378 if password.is_empty() {
379 return Err(AuthError::validation("Password cannot be empty"));
380 }
381
382 if password.len() < min_length {
383 return Err(AuthError::validation(format!(
384 "Password must be at least {min_length} characters long"
385 )));
386 }
387
388 if require_complexity {
389 let strength = password::check_password_strength(password);
390 if matches!(strength.level, password::PasswordStrengthLevel::Weak) {
391 return Err(AuthError::validation(format!(
392 "Password is too weak: {}",
393 strength.feedback.join(", ")
394 )));
395 }
396 }
397
398 Ok(())
399 }
400
401 pub fn validate_api_key(api_key: &str, expected_prefix: Option<&str>) -> Result<()> {
403 if api_key.is_empty() {
404 return Err(AuthError::validation("API key cannot be empty"));
405 }
406
407 if let Some(prefix) = expected_prefix
408 && !api_key.starts_with(prefix)
409 {
410 return Err(AuthError::validation(format!(
411 "API key must start with '{prefix}'"
412 )));
413 }
414
415 if api_key.len() < 16 {
417 return Err(AuthError::validation("API key is too short"));
418 }
419
420 if api_key.len() > 128 {
421 return Err(AuthError::validation("API key is too long"));
422 }
423
424 Ok(())
425 }
426}
427
428pub mod rate_limit {
430 use dashmap::DashMap;
431
432 use std::sync::Arc;
433 use std::time::{Duration, Instant};
434
435 #[derive(Debug)]
437 pub struct RateLimiter {
438 buckets: Arc<DashMap<String, Bucket>>,
439 max_requests: u32,
440 window: Duration,
441 }
442
443 #[derive(Debug)]
444 struct Bucket {
445 count: u32,
446 window_start: Instant,
447 }
448
449 impl RateLimiter {
450 pub fn new(max_requests: u32, window: Duration) -> Self {
452 Self {
453 buckets: Arc::new(DashMap::new()),
454 max_requests,
455 window,
456 }
457 }
458
459 pub fn is_allowed(&self, key: &str) -> bool {
461 let now = Instant::now();
462
463 let mut bucket = self.buckets.entry(key.to_string()).or_insert(Bucket {
465 count: 0,
466 window_start: now,
467 });
468
469 if now.duration_since(bucket.window_start) >= self.window {
471 bucket.count = 0;
472 bucket.window_start = now;
473 }
474
475 if bucket.count < self.max_requests {
477 bucket.count += 1;
478 true
479 } else {
480 false
481 }
482 }
483
484 pub fn remaining_requests(&self, key: &str) -> u32 {
486 if let Some(bucket_ref) = self.buckets.get(key) {
487 let bucket = bucket_ref.value();
488 let now = Instant::now();
489 if now.duration_since(bucket.window_start) >= self.window {
490 self.max_requests
491 } else {
492 self.max_requests.saturating_sub(bucket.count)
493 }
494 } else {
495 self.max_requests
496 }
497 }
498
499 pub fn cleanup(&self) {
501 let now = Instant::now();
502 self.buckets
503 .retain(|_, bucket| now.duration_since(bucket.window_start) < self.window);
504 }
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 use super::*;
511
512 #[test]
513 fn test_password_hashing() {
514 let password = "test_password_123";
515 let hash = password::hash_password(password).unwrap();
516
517 assert!(password::verify_password(password, &hash).unwrap());
518 assert!(!password::verify_password("wrong_password", &hash).unwrap());
519 }
520
521 #[test]
522 fn test_password_strength() {
523 let weak = password::check_password_strength("123");
524 assert!(matches!(weak.level, password::PasswordStrengthLevel::Weak));
525
526 let strong = password::check_password_strength("MySecureP@ssw0rd!");
527 assert!(matches!(
528 strong.level,
529 password::PasswordStrengthLevel::Strong | password::PasswordStrengthLevel::VeryStrong
530 ));
531 }
532
533 #[test]
534 fn test_crypto_utils() {
535 let random_string = crypto::generate_random_string(16);
536 assert_eq!(random_string.len(), 16);
537
538 let data = b"test data";
539 let hash = crypto::sha256_hex(data);
540 assert_eq!(hash.len(), 64); }
542
543 #[test]
544 fn test_string_utils() {
545 let masked = string::mask_string("secret123456", 2);
546 assert!(masked.starts_with("se"));
547 assert!(masked.ends_with("56"));
548 assert!(masked.contains("*"));
549
550 assert!(string::is_valid_email("test@example.com"));
551 assert!(!string::is_valid_email("invalid_email"));
552 }
553
554 #[test]
555 fn test_validation() {
556 assert!(validation::validate_username("test_user").is_ok());
558
559 assert!(validation::validate_username("").is_err());
561 assert!(validation::validate_username("ab").is_err());
562 assert!(validation::validate_username("_invalid").is_err());
563 assert!(validation::validate_username("invalid@").is_err());
564
565 assert!(validation::validate_email("test@example.com").is_ok());
567
568 assert!(validation::validate_email("").is_err());
570 assert!(validation::validate_email("invalid").is_err());
571 }
572
573 #[test]
574 fn test_rate_limiter() {
575 let limiter = rate_limit::RateLimiter::new(3, std::time::Duration::from_secs(1));
576
577 assert!(limiter.is_allowed("user1"));
579 assert!(limiter.is_allowed("user1"));
580 assert!(limiter.is_allowed("user1"));
581
582 assert!(!limiter.is_allowed("user1"));
584
585 assert!(limiter.is_allowed("user2"));
587 }
588
589 #[test]
590 fn test_password_hashing_edge_cases() {
591 let long_password = "a".repeat(1000);
593 let hash = password::hash_password(&long_password).unwrap();
594 assert!(password::verify_password(&long_password, &hash).unwrap());
595
596 let special_password = "!@#$%^&*()_+-=[]{}|;:,.<>?";
598 let hash = password::hash_password(special_password).unwrap();
599 assert!(password::verify_password(special_password, &hash).unwrap());
600
601 let unicode_password = "пароль测试🔒";
603 let hash = password::hash_password(unicode_password).unwrap();
604 assert!(password::verify_password(unicode_password, &hash).unwrap());
605
606 let password1 = "password123";
608 let password2 = "password124";
609 let hash1 = password::hash_password(password1).unwrap();
610 let hash2 = password::hash_password(password2).unwrap();
611 assert_ne!(hash1, hash2);
612 }
613
614 #[test]
615 fn test_password_strength_comprehensive() {
616 let test_cases = vec![
617 ("", password::PasswordStrengthLevel::Weak),
618 ("a", password::PasswordStrengthLevel::Weak),
619 ("password", password::PasswordStrengthLevel::Weak),
620 ("password123", password::PasswordStrengthLevel::Weak), ("mypassword123", password::PasswordStrengthLevel::Medium), ("MyPassword123", password::PasswordStrengthLevel::Medium),
623 ("MyPassword123!", password::PasswordStrengthLevel::Strong),
624 (
625 "VerySecureP@ssw0rd2024!",
626 password::PasswordStrengthLevel::VeryStrong,
627 ),
628 ];
629
630 for (password, expected_min_level) in test_cases {
631 let strength = password::check_password_strength(password);
632 match expected_min_level {
634 password::PasswordStrengthLevel::Weak => {
635 }
637 password::PasswordStrengthLevel::Medium => {
638 assert!(
639 !matches!(strength.level, password::PasswordStrengthLevel::Weak),
640 "Password '{}' should be at least Medium strength",
641 password
642 );
643 }
644 password::PasswordStrengthLevel::Strong => {
645 assert!(
646 matches!(
647 strength.level,
648 password::PasswordStrengthLevel::Strong
649 | password::PasswordStrengthLevel::VeryStrong
650 ),
651 "Password '{}' should be at least Strong",
652 password
653 );
654 }
655 password::PasswordStrengthLevel::VeryStrong => {
656 assert!(
657 matches!(strength.level, password::PasswordStrengthLevel::VeryStrong),
658 "Password '{}' should be VeryStrong",
659 password
660 );
661 }
662 }
663 }
664 }
665
666 #[test]
667 fn test_crypto_utils_edge_cases() {
668 let lengths = vec![0, 1, 8, 16, 32, 64, 128];
670 for length in lengths {
671 let random_string = crypto::generate_random_string(length);
672 assert_eq!(
673 random_string.len(),
674 length,
675 "Generated string should have requested length"
676 );
677
678 if length > 0 {
679 let another_string = crypto::generate_random_string(length);
681 if length > 4 {
682 assert_ne!(
684 random_string, another_string,
685 "Random strings should be different"
686 );
687 }
688 }
689 }
690
691 let test_data = vec![
693 b"".as_slice(),
694 b"a",
695 b"hello world",
696 &[0u8; 1000], "unicode: 测试 🔒".as_bytes(),
698 ];
699
700 for data in test_data {
701 let hash = crypto::sha256_hex(data);
702 assert_eq!(hash.len(), 64, "SHA256 hex should always be 64 characters");
703
704 let hash2 = crypto::sha256_hex(data);
706 assert_eq!(hash, hash2, "Same input should produce same hash");
707 }
708 }
709
710 #[test]
711 fn test_string_utils_comprehensive() {
712 let masking_tests = vec![
714 ("", 0),
715 ("a", 1),
716 ("ab", 1),
717 ("secret", 2),
718 ("verylongsecret", 3),
719 ("short", 10), ];
721
722 for (input, reveal_chars) in masking_tests {
723 let masked = string::mask_string(input, reveal_chars);
724 if input.is_empty() {
725 assert_eq!(masked, "");
726 } else if reveal_chars >= input.len() {
727 assert_eq!(masked, input, "Should not mask if reveal_chars >= length");
728 } else if input.len() > reveal_chars * 2 {
729 assert!(
731 masked.starts_with(&input[..reveal_chars]),
732 "Should preserve first {} characters",
733 reveal_chars
734 );
735 assert!(masked.contains("*"), "Should contain masking characters");
736 } else {
737 assert!(
739 masked.contains("*"),
740 "Should contain masking characters for short strings"
741 );
742 }
743 }
744
745 let valid_emails = vec![
747 "user@example.com",
748 "user.name@example.com",
749 "user+tag@example.co.uk",
750 "user123@example-domain.com",
751 "a@b.co",
752 "test_email@domain.info",
753 ];
754
755 let invalid_emails = vec![
756 "",
757 "user",
758 "@example.com",
759 "user@",
760 "user@@example.com",
761 "user@example",
762 "user @example.com",
763 "user@exam ple.com",
764 "user@.example.com",
765 "user@example..com",
766 ];
767
768 for email in valid_emails {
769 assert!(
770 string::is_valid_email(email),
771 "Should accept valid email: {}",
772 email
773 );
774 }
775
776 for email in invalid_emails {
777 assert!(
778 !string::is_valid_email(email),
779 "Should reject invalid email: {}",
780 email
781 );
782 }
783 }
784
785 #[test]
786 fn test_validation_comprehensive() {
787 let valid_usernames = vec!["user", "user123", "user_name", "user-name", "abc"];
789
790 let invalid_usernames = vec![
791 "",
792 "us", "a", "user name", "user@domain", "user\0name", "_invalid", ];
799
800 for username in valid_usernames {
801 assert!(
802 validation::validate_username(username).is_ok(),
803 "Should accept valid username: {}",
804 username
805 );
806 }
807
808 for username in invalid_usernames {
809 assert!(
810 validation::validate_username(username).is_err(),
811 "Should reject invalid username: {}",
812 username
813 );
814 }
815
816 let valid_emails = vec![
818 "test@example.com",
819 "user.name@domain.co.uk",
820 "user+tag@example.org",
821 ];
822
823 let invalid_emails = vec!["", "invalid", "@example.com", "user@", "user@@example.com"];
824
825 for email in valid_emails {
826 assert!(
827 validation::validate_email(email).is_ok(),
828 "Should accept valid email: {}",
829 email
830 );
831 }
832
833 for email in invalid_emails {
834 assert!(
835 validation::validate_email(email).is_err(),
836 "Should reject invalid email: {}",
837 email
838 );
839 }
840 }
841
842 #[test]
843 fn test_rate_limiter_edge_cases() {
844 let zero_limiter = rate_limit::RateLimiter::new(0, std::time::Duration::from_secs(60));
846 assert!(!zero_limiter.is_allowed("user1")); let short_limiter = rate_limit::RateLimiter::new(1, std::time::Duration::from_millis(10));
850 assert!(short_limiter.is_allowed("user1"));
851 assert!(!short_limiter.is_allowed("user1")); std::thread::sleep(std::time::Duration::from_millis(20));
855 assert!(short_limiter.is_allowed("user1")); }
857
858 #[test]
859 fn test_rate_limiter_multiple_users() {
860 let limiter = rate_limit::RateLimiter::new(2, std::time::Duration::from_secs(60));
861
862 assert!(limiter.is_allowed("user1"));
864 assert!(limiter.is_allowed("user1"));
865 assert!(!limiter.is_allowed("user1")); assert!(limiter.is_allowed("user2"));
868 assert!(limiter.is_allowed("user2"));
869 assert!(!limiter.is_allowed("user2")); assert!(limiter.is_allowed("user3"));
873 assert!(limiter.is_allowed("user3"));
874 assert!(!limiter.is_allowed("user3")); }
876
877 #[test]
878 fn test_crypto_random_uniqueness() {
879 let mut strings = std::collections::HashSet::new();
881 for _ in 0..1000 {
882 let random_string = crypto::generate_random_string(16);
883 assert!(
884 !strings.contains(&random_string),
885 "Generated duplicate random string"
886 );
887 strings.insert(random_string);
888 }
889 }
890
891 #[test]
892 fn test_password_hash_uniqueness() {
893 let password = "test_password_123";
895 let mut hashes = std::collections::HashSet::new();
896
897 for _ in 0..10 {
898 let hash = password::hash_password(password).unwrap();
899 assert!(
900 !hashes.contains(&hash),
901 "Password hashes should be unique due to salt"
902 );
903 hashes.insert(hash.clone());
904
905 assert!(password::verify_password(password, &hash).unwrap());
907 }
908 }
909}