1use crate::errors::{AuthError, Result};
3use base64::Engine;
4use ring::rand::{SecureRandom, SystemRandom};
5use subtle::ConstantTimeEq;
6use zeroize::{Zeroize, ZeroizeOnDrop};
7
8#[derive(Debug, Clone, ZeroizeOnDrop)]
10pub struct SecureString {
11 data: String,
12}
13
14impl SecureString {
15 pub fn new(data: String) -> Self {
17 Self { data }
18 }
19
20 pub fn as_str(&self) -> &str {
22 &self.data
23 }
24
25 pub fn as_bytes(&self) -> &[u8] {
27 self.data.as_bytes()
28 }
29
30 pub fn len(&self) -> usize {
32 self.data.len()
33 }
34
35 pub fn is_empty(&self) -> bool {
37 self.data.is_empty()
38 }
39}
40
41impl From<String> for SecureString {
42 fn from(data: String) -> Self {
43 Self::new(data)
44 }
45}
46
47impl From<&str> for SecureString {
48 fn from(data: &str) -> Self {
49 Self::new(data.to_string())
50 }
51}
52
53pub struct SecureComparison;
55
56impl SecureComparison {
57 pub fn constant_time_eq(a: &str, b: &str) -> bool {
62 Self::secure_string_compare(a, b)
63 }
64
65 pub fn constant_time_eq_bytes(a: &[u8], b: &[u8]) -> bool {
69 let max_len = a.len().max(b.len()).min(1024);
70 let mut a_padded = vec![0u8; max_len];
71 let mut b_padded = vec![0u8; max_len];
72
73 a_padded[..a.len().min(max_len)].copy_from_slice(&a[..a.len().min(max_len)]);
74 b_padded[..b.len().min(max_len)].copy_from_slice(&b[..b.len().min(max_len)]);
75
76 let result = a_padded.ct_eq(&b_padded).into() && a.len() == b.len();
77
78 a_padded.zeroize();
79 b_padded.zeroize();
80
81 result
82 }
83
84 pub fn secure_string_compare(a: &str, b: &str) -> bool {
87 let max_len = a.len().max(b.len()).min(1024); let mut a_padded = vec![0u8; max_len];
91 let mut b_padded = vec![0u8; max_len];
92
93 let a_bytes = a.as_bytes();
95 let b_bytes = b.as_bytes();
96
97 a_padded[..a_bytes.len().min(max_len)]
98 .copy_from_slice(&a_bytes[..a_bytes.len().min(max_len)]);
99 b_padded[..b_bytes.len().min(max_len)]
100 .copy_from_slice(&b_bytes[..b_bytes.len().min(max_len)]);
101
102 let result = a_padded.ct_eq(&b_padded).into() && a.len() == b.len();
104
105 a_padded.zeroize();
107 b_padded.zeroize();
108
109 result
110 }
111
112 pub fn verify_token(token1: &str, token2: &str) -> bool {
114 Self::secure_string_compare(token1, token2)
115 }
116}
117
118pub struct SecureRandomGen;
120
121impl SecureRandomGen {
122 pub fn generate_bytes(len: usize) -> Result<Vec<u8>> {
124 let rng = SystemRandom::new();
125 let mut bytes = vec![0u8; len];
126 rng.fill(&mut bytes)
127 .map_err(|_| AuthError::crypto("Failed to generate random bytes".to_string()))?;
128 Ok(bytes)
129 }
130
131 pub fn generate_string(byte_len: usize) -> Result<String> {
133 let bytes = Self::generate_bytes(byte_len)?;
134 Ok(base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&bytes))
135 }
136
137 pub fn generate_token() -> Result<String> {
139 Self::generate_string(32) }
141
142 pub fn generate_session_id() -> Result<String> {
144 Self::generate_string(24) }
146
147 pub fn generate_challenge_id() -> Result<String> {
149 Self::generate_string(16) }
151}
152
153pub struct SecureValidation;
155
156impl SecureValidation {
157 pub fn validate_username(username: &str) -> Result<()> {
159 if username.is_empty() {
160 return Err(AuthError::validation(
161 "Username cannot be empty".to_string(),
162 ));
163 }
164
165 if username.len() > 320 {
166 return Err(AuthError::validation("Username too long".to_string()));
167 }
168
169 if username.contains('\0') || username.contains('\r') || username.contains('\n') {
171 return Err(AuthError::validation(
172 "Username contains invalid characters".to_string(),
173 ));
174 }
175
176 if username.chars().any(|c| c.is_control()) {
178 return Err(AuthError::validation(
179 "Username contains invalid control characters".to_string(),
180 ));
181 }
182
183 #[cfg(feature = "unicode-support")]
185 {
186 let normalized = unicode_normalization::UnicodeNormalization::nfc(username.chars())
187 .collect::<String>();
188 if normalized != username {
189 return Err(AuthError::validation(
190 "Username must be in NFC form".to_string(),
191 ));
192 }
193 }
194
195 #[cfg(not(feature = "unicode-support"))]
196 {
197 }
199
200 Ok(())
201 }
202
203 pub fn validate_password(password: &str) -> Result<()> {
205 if password.is_empty() {
206 return Err(AuthError::validation(
207 "Password cannot be empty".to_string(),
208 ));
209 }
210
211 if password.len() > 1000 {
212 return Err(AuthError::validation("Password too long".to_string()));
213 }
214
215 if password.contains('\0') {
217 return Err(AuthError::validation(
218 "Password contains null bytes".to_string(),
219 ));
220 }
221
222 Ok(())
223 }
224
225 pub fn sanitize_input(input: &str) -> String {
227 input
229 .chars()
230 .filter(|&c| !c.is_control() || c == '\n' || c == '\t' || c == ' ')
231 .collect()
232 }
233
234 pub fn validate_email(email: &str) -> Result<String> {
236 let sanitized = Self::sanitize_input(email);
237
238 if sanitized.is_empty() {
239 return Err(AuthError::validation("Email cannot be empty".to_string()));
240 }
241
242 if sanitized.len() > 320 {
243 return Err(AuthError::validation("Email too long".to_string()));
244 }
245
246 if !sanitized.contains('@') || sanitized.starts_with('@') || sanitized.ends_with('@') {
248 return Err(AuthError::validation("Invalid email format".to_string()));
249 }
250
251 if sanitized.matches('@').count() != 1 {
253 return Err(AuthError::validation("Invalid email format".to_string()));
254 }
255
256 let parts: Vec<&str> = sanitized.split('@').collect();
257 let local_part = parts[0];
258 let domain_part = parts[1];
259
260 if local_part.is_empty() || local_part.starts_with('.') || local_part.ends_with('.') {
262 return Err(AuthError::validation("Invalid email format".to_string()));
263 }
264
265 if domain_part.is_empty()
267 || domain_part.starts_with('.')
268 || domain_part.ends_with('.')
269 || domain_part.contains("..")
270 || !domain_part.contains('.')
271 {
272 return Err(AuthError::validation("Invalid email format".to_string()));
273 }
274
275 if sanitized.contains(' ') {
277 return Err(AuthError::validation("Invalid email format".to_string()));
278 }
279
280 Ok(sanitized)
281 }
282}
283
284pub fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
317 SecureComparison::constant_time_eq_bytes(a, b)
318}
319
320pub fn generate_secure_token(byte_length: usize) -> Result<String> {
351 SecureRandomGen::generate_string(byte_length)
352}
353
354pub fn hash_password(password: &str) -> Result<String> {
386 if password.is_empty() {
387 return Err(AuthError::validation(
388 "Password cannot be empty".to_string(),
389 ));
390 }
391
392 bcrypt::hash(password, bcrypt::DEFAULT_COST)
393 .map_err(|e| AuthError::crypto(format!("Password hashing failed: {}", e)))
394}
395
396pub fn verify_password(password: &str, hash: &str) -> Result<bool> {
431 bcrypt::verify(password, hash)
432 .map_err(|e| AuthError::crypto(format!("Password verification failed: {}", e)))
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438
439 #[test]
440 fn test_secure_string() {
441 let secret = SecureString::new("password123".to_string());
442 assert_eq!(secret.as_str(), "password123");
443 assert_eq!(secret.len(), 11);
444 }
446
447 #[test]
448 fn test_constant_time_comparison() {
449 assert!(SecureComparison::constant_time_eq("hello", "hello"));
450 assert!(!SecureComparison::constant_time_eq("hello", "world"));
451 assert!(!SecureComparison::constant_time_eq("hello", "hello world"));
452 }
453
454 #[test]
455 fn test_secure_string_compare() {
456 assert!(SecureComparison::secure_string_compare("test", "test"));
457 assert!(!SecureComparison::secure_string_compare(
458 "test",
459 "different"
460 ));
461 assert!(!SecureComparison::secure_string_compare("short", "longer"));
462 }
463
464 #[test]
465 fn test_token_verification() {
466 let token = "abc123def456";
467 assert!(SecureComparison::verify_token(token, token));
468 assert!(!SecureComparison::verify_token(token, "different"));
469 }
470
471 #[test]
472 fn test_secure_random_generation() {
473 let token1 = SecureRandomGen::generate_token().unwrap();
474 let token2 = SecureRandomGen::generate_token().unwrap();
475
476 assert_ne!(token1, token2);
477 assert!(!token1.is_empty());
478 assert!(!token2.is_empty());
479 }
480
481 #[test]
482 fn test_input_validation() {
483 assert!(SecureValidation::validate_username("user123").is_ok());
484 assert!(SecureValidation::validate_username("").is_err());
485 assert!(SecureValidation::validate_username("user\0name").is_err());
486 }
487
488 #[test]
489 fn test_email_validation() {
490 assert!(SecureValidation::validate_email("test@example.com").is_ok());
491 assert!(SecureValidation::validate_email("").is_err());
492 assert!(SecureValidation::validate_email("@example.com").is_err());
493 assert!(SecureValidation::validate_email("user@").is_err());
494 }
495
496 #[test]
497 fn test_input_sanitization() {
498 let dirty = "hello\0world\x01test";
499 let clean = SecureValidation::sanitize_input(dirty);
500 assert_eq!(clean, "helloworldtest");
501
502 let with_newlines = "line1\nline2\tline3";
503 let cleaned = SecureValidation::sanitize_input(with_newlines);
504 assert_eq!(cleaned, "line1\nline2\tline3");
505 }
506
507 #[test]
508 fn test_secure_string_zeroization() {
509 let secret = SecureString::new("sensitive_data".to_string());
510 let _ptr = secret.as_str().as_ptr();
511
512 assert_eq!(secret.as_str(), "sensitive_data");
514 drop(secret);
515
516 }
519
520 #[test]
521 fn test_constant_time_comparison_edge_cases() {
522 assert!(SecureComparison::constant_time_eq("", ""));
524 assert!(!SecureComparison::constant_time_eq("", "nonempty"));
525 assert!(!SecureComparison::constant_time_eq("nonempty", ""));
526
527 let long1 = "a".repeat(1000);
529 let long2 = "a".repeat(1000);
530 let long3 = "b".repeat(1000);
531
532 assert!(SecureComparison::constant_time_eq(&long1, &long2));
533 assert!(!SecureComparison::constant_time_eq(&long1, &long3));
534
535 let almost_same1 = "verylongstringtestX";
537 let almost_same2 = "verylongstringtestY";
538 assert!(!SecureComparison::constant_time_eq(
539 almost_same1,
540 almost_same2
541 ));
542 }
543
544 #[test]
545 fn test_secure_random_generation_uniqueness() {
546 let mut tokens = std::collections::HashSet::new();
547
548 for _ in 0..100 {
550 let token = SecureRandomGen::generate_token().unwrap();
551 assert!(!tokens.contains(&token), "Generated duplicate token");
552 tokens.insert(token);
553 }
554 }
555
556 #[test]
557 fn test_secure_random_generation_length() {
558 for byte_len in [8, 16, 32, 64] {
560 let token = SecureRandomGen::generate_string(byte_len).unwrap();
561 let expected_len = (byte_len * 4).div_ceil(3);
563 assert!(
564 token.len() >= expected_len - 2 && token.len() <= expected_len + 2,
565 "Token length {} not in expected range for {} bytes",
566 token.len(),
567 byte_len
568 );
569 }
570 }
571
572 #[test]
573 fn test_input_validation_edge_cases() {
574 let long_username = "a".repeat(320);
576 assert!(SecureValidation::validate_username(&long_username).is_ok());
577 let too_long_username = "a".repeat(321);
578 assert!(SecureValidation::validate_username(&too_long_username).is_err());
579
580 assert!(SecureValidation::validate_username("user\x01").is_err());
582 assert!(SecureValidation::validate_username("user\x1f").is_err());
583
584 assert!(SecureValidation::validate_username("user_ñ").is_ok());
586 }
587
588 #[test]
589 fn test_email_validation_comprehensive() {
590 let valid_emails = vec![
592 "user@example.com",
593 "user.name@example.com",
594 "user+tag@example.com",
595 "user123@example-domain.com",
596 "a@b.co",
597 "very.long.email.address@very.long.domain.name.com",
598 ];
599
600 for email in valid_emails {
601 assert!(
602 SecureValidation::validate_email(email).is_ok(),
603 "Should accept valid email: {}",
604 email
605 );
606 }
607
608 let invalid_emails = vec![
610 "",
611 "user",
612 "@example.com",
613 "user@",
614 "user@@example.com",
615 "user@example",
616 "user @example.com", "user@exam ple.com", "user@.example.com", "user@example..com", ".user@example.com", "user.@example.com", ];
623
624 for email in invalid_emails {
625 assert!(
626 SecureValidation::validate_email(email).is_err(),
627 "Should reject invalid email: {}",
628 email
629 );
630 }
631 }
632
633 #[test]
634 fn test_input_sanitization_comprehensive() {
635 let test_cases = vec![
637 ("hello\0world", "helloworld"), ("test\x01\x02\x03", "test"), ("normal text", "normal text"), ("\x7f", ""), ("mix\0ed\x01cont\x02rol", "mixedcontrol"), ("", ""), (" spaced ", " spaced "), ];
645
646 for (input, expected) in test_cases {
647 let result = SecureValidation::sanitize_input(input);
648 assert_eq!(result, expected, "Sanitization failed for: {:?}", input);
649 }
650 }
651
652 #[test]
653 fn test_password_hashing_security() {
654 let password = "test_password_123";
655
656 let hash1 = hash_password(password).unwrap();
658 let hash2 = hash_password(password).unwrap();
659
660 assert_ne!(
662 hash1, hash2,
663 "Password hashes should be different due to random salt"
664 );
665
666 assert!(verify_password(password, &hash1).unwrap());
668 assert!(verify_password(password, &hash2).unwrap());
669
670 assert!(!verify_password("wrong_password", &hash1).unwrap());
672 assert!(!verify_password("wrong_password", &hash2).unwrap());
673 }
674
675 #[test]
676 fn test_password_hashing_edge_cases() {
677 let result = hash_password("");
679 assert!(result.is_err(), "Should reject empty password");
680
681 let long_password = "a".repeat(100);
683 let hash = hash_password(&long_password).unwrap();
684 assert!(verify_password(&long_password, &hash).unwrap());
685
686 let special_password = "p@ssw0rd!#$%^&*()";
688 let hash = hash_password(special_password).unwrap();
689 assert!(verify_password(special_password, &hash).unwrap());
690
691 let unicode_password = "пароль123测试";
693 let hash = hash_password(unicode_password).unwrap();
694 assert!(verify_password(unicode_password, &hash).unwrap());
695 }
696
697 #[test]
698 fn test_secure_comparison_timing() {
699 let short_a = "a";
703 let short_b = "a";
704 let long_a = "a".repeat(1000);
705 let long_b = "a".repeat(1000);
706
707 assert!(SecureComparison::constant_time_eq(short_a, short_b));
708 assert!(SecureComparison::secure_string_compare(short_a, short_b));
709 assert!(SecureComparison::verify_token(short_a, short_b));
710
711 assert!(SecureComparison::constant_time_eq(&long_a, &long_b));
712 assert!(SecureComparison::secure_string_compare(&long_a, &long_b));
713 assert!(SecureComparison::verify_token(&long_a, &long_b));
714
715 let different_short_a = "a";
716 let different_short_b = "b";
717 let different_long_a = "a".repeat(1000);
718 let different_long_b = "b".repeat(1000);
719
720 assert!(!SecureComparison::constant_time_eq(
721 different_short_a,
722 different_short_b
723 ));
724 assert!(!SecureComparison::secure_string_compare(
725 different_short_a,
726 different_short_b
727 ));
728 assert!(!SecureComparison::verify_token(
729 different_short_a,
730 different_short_b
731 ));
732
733 assert!(!SecureComparison::constant_time_eq(
734 &different_long_a,
735 &different_long_b
736 ));
737 assert!(!SecureComparison::secure_string_compare(
738 &different_long_a,
739 &different_long_b
740 ));
741 assert!(!SecureComparison::verify_token(
742 &different_long_a,
743 &different_long_b
744 ));
745 }
746
747 #[test]
748 fn test_secure_string_multiple_operations() {
749 let secret1 = SecureString::new("password1".to_string());
750 let secret2 = SecureString::new("password2".to_string());
751
752 assert_ne!(secret1.as_str(), secret2.as_str());
753 assert!(SecureComparison::verify_token(
754 secret1.as_str(),
755 secret1.as_str()
756 ));
757 assert!(!SecureComparison::verify_token(
758 secret1.as_str(),
759 secret2.as_str()
760 ));
761
762 assert_eq!(secret1.len(), 9);
764 assert_eq!(secret2.len(), 9);
765 assert!(!secret1.is_empty());
766 assert!(!secret2.is_empty());
767 }
768
769 #[test]
770 fn test_token_verification_false_positives() {
771 let token = "secure_token_123";
772 let similar_token = "secure_token_124"; let prefix_token = "secure_token_12"; let longer_token = "secure_token_1234"; assert!(SecureComparison::verify_token(token, token));
777 assert!(!SecureComparison::verify_token(token, similar_token));
778 assert!(!SecureComparison::verify_token(token, prefix_token));
779 assert!(!SecureComparison::verify_token(token, longer_token));
780 }
781}