1use crate::errors::{AuthError, Result};
3use base64::Engine;
4use ring::rand::{SecureRandom, SystemRandom};
5use subtle::ConstantTimeEq;
6use zeroize::{Zeroize, ZeroizeOnDrop};
7
8#[derive(Debug, Clone, ZeroizeOnDrop)]
10pub struct SecureString {
11 data: String,
12}
13
14impl SecureString {
15 pub fn new(data: String) -> Self {
17 Self { data }
18 }
19
20 pub fn as_str(&self) -> &str {
22 &self.data
23 }
24
25 pub fn as_bytes(&self) -> &[u8] {
27 self.data.as_bytes()
28 }
29
30 pub fn len(&self) -> usize {
32 self.data.len()
33 }
34
35 pub fn is_empty(&self) -> bool {
37 self.data.is_empty()
38 }
39}
40
41impl From<String> for SecureString {
42 fn from(data: String) -> Self {
43 Self::new(data)
44 }
45}
46
47impl From<&str> for SecureString {
48 fn from(data: &str) -> Self {
49 Self::new(data.to_string())
50 }
51}
52
53pub struct SecureComparison;
55
56impl SecureComparison {
57 pub fn constant_time_eq(a: &str, b: &str) -> bool {
59 if a.len() != b.len() {
60 return false;
61 }
62 a.as_bytes().ct_eq(b.as_bytes()).into()
63 }
64
65 pub fn constant_time_eq_bytes(a: &[u8], b: &[u8]) -> bool {
67 if a.len() != b.len() {
68 return false;
69 }
70 a.ct_eq(b).into()
71 }
72
73 pub fn secure_string_compare(a: &str, b: &str) -> bool {
76 let max_len = a.len().max(b.len()).min(1024); let mut a_padded = vec![0u8; max_len];
80 let mut b_padded = vec![0u8; max_len];
81
82 let a_bytes = a.as_bytes();
84 let b_bytes = b.as_bytes();
85
86 a_padded[..a_bytes.len().min(max_len)]
87 .copy_from_slice(&a_bytes[..a_bytes.len().min(max_len)]);
88 b_padded[..b_bytes.len().min(max_len)]
89 .copy_from_slice(&b_bytes[..b_bytes.len().min(max_len)]);
90
91 let result = a_padded.ct_eq(&b_padded).into() && a.len() == b.len();
93
94 a_padded.zeroize();
96 b_padded.zeroize();
97
98 result
99 }
100
101 pub fn verify_token(token1: &str, token2: &str) -> bool {
103 Self::secure_string_compare(token1, token2)
104 }
105}
106
107pub struct SecureRandomGen;
109
110impl SecureRandomGen {
111 pub fn generate_bytes(len: usize) -> Result<Vec<u8>> {
113 let rng = SystemRandom::new();
114 let mut bytes = vec![0u8; len];
115 rng.fill(&mut bytes)
116 .map_err(|_| AuthError::crypto("Failed to generate random bytes".to_string()))?;
117 Ok(bytes)
118 }
119
120 pub fn generate_string(byte_len: usize) -> Result<String> {
122 let bytes = Self::generate_bytes(byte_len)?;
123 Ok(base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&bytes))
124 }
125
126 pub fn generate_token() -> Result<String> {
128 Self::generate_string(32) }
130
131 pub fn generate_session_id() -> Result<String> {
133 Self::generate_string(24) }
135
136 pub fn generate_challenge_id() -> Result<String> {
138 Self::generate_string(16) }
140}
141
142pub struct SecureValidation;
144
145impl SecureValidation {
146 pub fn validate_username(username: &str) -> Result<()> {
148 if username.is_empty() {
149 return Err(AuthError::validation(
150 "Username cannot be empty".to_string(),
151 ));
152 }
153
154 if username.len() > 320 {
155 return Err(AuthError::validation("Username too long".to_string()));
156 }
157
158 if username.contains('\0') || username.contains('\r') || username.contains('\n') {
160 return Err(AuthError::validation(
161 "Username contains invalid characters".to_string(),
162 ));
163 }
164
165 #[cfg(feature = "unicode-support")]
167 {
168 let normalized = unicode_normalization::UnicodeNormalization::nfc(username.chars())
169 .collect::<String>();
170 if normalized != username {
171 return Err(AuthError::validation(
172 "Username must be in NFC form".to_string(),
173 ));
174 }
175 }
176
177 #[cfg(not(feature = "unicode-support"))]
178 {
179 if username.chars().any(|c| c.is_control()) {
181 return Err(AuthError::validation(
182 "Username contains invalid control characters".to_string(),
183 ));
184 }
185 }
186
187 Ok(())
188 }
189
190 pub fn validate_password(password: &str) -> Result<()> {
192 if password.is_empty() {
193 return Err(AuthError::validation(
194 "Password cannot be empty".to_string(),
195 ));
196 }
197
198 if password.len() > 1000 {
199 return Err(AuthError::validation("Password too long".to_string()));
200 }
201
202 if password.contains('\0') {
204 return Err(AuthError::validation(
205 "Password contains null bytes".to_string(),
206 ));
207 }
208
209 Ok(())
210 }
211
212 pub fn sanitize_input(input: &str) -> String {
214 input
216 .chars()
217 .filter(|&c| !c.is_control() || c == '\n' || c == '\t' || c == ' ')
218 .collect()
219 }
220
221 pub fn validate_email(email: &str) -> Result<String> {
223 let sanitized = Self::sanitize_input(email);
224
225 if sanitized.is_empty() {
226 return Err(AuthError::validation("Email cannot be empty".to_string()));
227 }
228
229 if sanitized.len() > 320 {
230 return Err(AuthError::validation("Email too long".to_string()));
231 }
232
233 if !sanitized.contains('@') || sanitized.starts_with('@') || sanitized.ends_with('@') {
235 return Err(AuthError::validation("Invalid email format".to_string()));
236 }
237
238 if sanitized.matches('@').count() != 1 {
240 return Err(AuthError::validation("Invalid email format".to_string()));
241 }
242
243 let parts: Vec<&str> = sanitized.split('@').collect();
244 let local_part = parts[0];
245 let domain_part = parts[1];
246
247 if local_part.is_empty() || local_part.starts_with('.') || local_part.ends_with('.') {
249 return Err(AuthError::validation("Invalid email format".to_string()));
250 }
251
252 if domain_part.is_empty()
254 || domain_part.starts_with('.')
255 || domain_part.ends_with('.')
256 || domain_part.contains("..")
257 || !domain_part.contains('.')
258 {
259 return Err(AuthError::validation("Invalid email format".to_string()));
260 }
261
262 if sanitized.contains(' ') {
264 return Err(AuthError::validation("Invalid email format".to_string()));
265 }
266
267 Ok(sanitized)
268 }
269}
270
271pub fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
304 SecureComparison::constant_time_eq_bytes(a, b)
305}
306
307pub fn generate_secure_token(byte_length: usize) -> Result<String> {
338 SecureRandomGen::generate_string(byte_length)
339}
340
341pub fn hash_password(password: &str) -> Result<String> {
373 if password.is_empty() {
374 return Err(AuthError::validation(
375 "Password cannot be empty".to_string(),
376 ));
377 }
378
379 bcrypt::hash(password, bcrypt::DEFAULT_COST)
380 .map_err(|e| AuthError::crypto(format!("Password hashing failed: {}", e)))
381}
382
383pub fn verify_password(password: &str, hash: &str) -> Result<bool> {
418 bcrypt::verify(password, hash)
419 .map_err(|e| AuthError::crypto(format!("Password verification failed: {}", e)))
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425
426 #[test]
427 fn test_secure_string() {
428 let secret = SecureString::new("password123".to_string());
429 assert_eq!(secret.as_str(), "password123");
430 assert_eq!(secret.len(), 11);
431 }
433
434 #[test]
435 fn test_constant_time_comparison() {
436 assert!(SecureComparison::constant_time_eq("hello", "hello"));
437 assert!(!SecureComparison::constant_time_eq("hello", "world"));
438 assert!(!SecureComparison::constant_time_eq("hello", "hello world"));
439 }
440
441 #[test]
442 fn test_secure_string_compare() {
443 assert!(SecureComparison::secure_string_compare("test", "test"));
444 assert!(!SecureComparison::secure_string_compare(
445 "test",
446 "different"
447 ));
448 assert!(!SecureComparison::secure_string_compare("short", "longer"));
449 }
450
451 #[test]
452 fn test_token_verification() {
453 let token = "abc123def456";
454 assert!(SecureComparison::verify_token(token, token));
455 assert!(!SecureComparison::verify_token(token, "different"));
456 }
457
458 #[test]
459 fn test_secure_random_generation() {
460 let token1 = SecureRandomGen::generate_token().unwrap();
461 let token2 = SecureRandomGen::generate_token().unwrap();
462
463 assert_ne!(token1, token2);
464 assert!(!token1.is_empty());
465 assert!(!token2.is_empty());
466 }
467
468 #[test]
469 fn test_input_validation() {
470 assert!(SecureValidation::validate_username("user123").is_ok());
471 assert!(SecureValidation::validate_username("").is_err());
472 assert!(SecureValidation::validate_username("user\0name").is_err());
473 }
474
475 #[test]
476 fn test_email_validation() {
477 assert!(SecureValidation::validate_email("test@example.com").is_ok());
478 assert!(SecureValidation::validate_email("").is_err());
479 assert!(SecureValidation::validate_email("@example.com").is_err());
480 assert!(SecureValidation::validate_email("user@").is_err());
481 }
482
483 #[test]
484 fn test_input_sanitization() {
485 let dirty = "hello\0world\x01test";
486 let clean = SecureValidation::sanitize_input(dirty);
487 assert_eq!(clean, "helloworldtest");
488
489 let with_newlines = "line1\nline2\tline3";
490 let cleaned = SecureValidation::sanitize_input(with_newlines);
491 assert_eq!(cleaned, "line1\nline2\tline3");
492 }
493
494 #[test]
495 fn test_secure_string_zeroization() {
496 let secret = SecureString::new("sensitive_data".to_string());
497 let _ptr = secret.as_str().as_ptr();
498
499 assert_eq!(secret.as_str(), "sensitive_data");
501 drop(secret);
502
503 }
506
507 #[test]
508 fn test_constant_time_comparison_edge_cases() {
509 assert!(SecureComparison::constant_time_eq("", ""));
511 assert!(!SecureComparison::constant_time_eq("", "nonempty"));
512 assert!(!SecureComparison::constant_time_eq("nonempty", ""));
513
514 let long1 = "a".repeat(1000);
516 let long2 = "a".repeat(1000);
517 let long3 = "b".repeat(1000);
518
519 assert!(SecureComparison::constant_time_eq(&long1, &long2));
520 assert!(!SecureComparison::constant_time_eq(&long1, &long3));
521
522 let almost_same1 = "verylongstringtestX";
524 let almost_same2 = "verylongstringtestY";
525 assert!(!SecureComparison::constant_time_eq(
526 almost_same1,
527 almost_same2
528 ));
529 }
530
531 #[test]
532 fn test_secure_random_generation_uniqueness() {
533 let mut tokens = std::collections::HashSet::new();
534
535 for _ in 0..100 {
537 let token = SecureRandomGen::generate_token().unwrap();
538 assert!(!tokens.contains(&token), "Generated duplicate token");
539 tokens.insert(token);
540 }
541 }
542
543 #[test]
544 fn test_secure_random_generation_length() {
545 for byte_len in [8, 16, 32, 64] {
547 let token = SecureRandomGen::generate_string(byte_len).unwrap();
548 let expected_len = (byte_len * 4 + 2) / 3;
550 assert!(
551 token.len() >= expected_len - 2 && token.len() <= expected_len + 2,
552 "Token length {} not in expected range for {} bytes",
553 token.len(),
554 byte_len
555 );
556 }
557 }
558
559 #[test]
560 fn test_input_validation_edge_cases() {
561 let long_username = "a".repeat(320);
563 assert!(SecureValidation::validate_username(&long_username).is_ok());
564 let too_long_username = "a".repeat(321);
565 assert!(SecureValidation::validate_username(&too_long_username).is_err());
566
567 assert!(SecureValidation::validate_username("user\x01").is_err());
569 assert!(SecureValidation::validate_username("user\x1f").is_err());
570
571 assert!(SecureValidation::validate_username("user_ñ").is_ok());
573 }
574
575 #[test]
576 fn test_email_validation_comprehensive() {
577 let valid_emails = vec![
579 "user@example.com",
580 "user.name@example.com",
581 "user+tag@example.com",
582 "user123@example-domain.com",
583 "a@b.co",
584 "very.long.email.address@very.long.domain.name.com",
585 ];
586
587 for email in valid_emails {
588 assert!(
589 SecureValidation::validate_email(email).is_ok(),
590 "Should accept valid email: {}",
591 email
592 );
593 }
594
595 let invalid_emails = vec![
597 "",
598 "user",
599 "@example.com",
600 "user@",
601 "user@@example.com",
602 "user@example",
603 "user @example.com", "user@exam ple.com", "user@.example.com", "user@example..com", ".user@example.com", "user.@example.com", ];
610
611 for email in invalid_emails {
612 assert!(
613 SecureValidation::validate_email(email).is_err(),
614 "Should reject invalid email: {}",
615 email
616 );
617 }
618 }
619
620 #[test]
621 fn test_input_sanitization_comprehensive() {
622 let test_cases = vec![
624 ("hello\0world", "helloworld"), ("test\x01\x02\x03", "test"), ("normal text", "normal text"), ("\x7f", ""), ("mix\0ed\x01cont\x02rol", "mixedcontrol"), ("", ""), (" spaced ", " spaced "), ];
632
633 for (input, expected) in test_cases {
634 let result = SecureValidation::sanitize_input(input);
635 assert_eq!(result, expected, "Sanitization failed for: {:?}", input);
636 }
637 }
638
639 #[test]
640 fn test_password_hashing_security() {
641 let password = "test_password_123";
642
643 let hash1 = hash_password(password).unwrap();
645 let hash2 = hash_password(password).unwrap();
646
647 assert_ne!(
649 hash1, hash2,
650 "Password hashes should be different due to random salt"
651 );
652
653 assert!(verify_password(password, &hash1).unwrap());
655 assert!(verify_password(password, &hash2).unwrap());
656
657 assert!(!verify_password("wrong_password", &hash1).unwrap());
659 assert!(!verify_password("wrong_password", &hash2).unwrap());
660 }
661
662 #[test]
663 fn test_password_hashing_edge_cases() {
664 let result = hash_password("");
666 assert!(result.is_err(), "Should reject empty password");
667
668 let long_password = "a".repeat(100);
670 let hash = hash_password(&long_password).unwrap();
671 assert!(verify_password(&long_password, &hash).unwrap());
672
673 let special_password = "p@ssw0rd!#$%^&*()";
675 let hash = hash_password(special_password).unwrap();
676 assert!(verify_password(special_password, &hash).unwrap());
677
678 let unicode_password = "пароль123测试";
680 let hash = hash_password(unicode_password).unwrap();
681 assert!(verify_password(unicode_password, &hash).unwrap());
682 }
683
684 #[test]
685 fn test_secure_comparison_timing() {
686 let short_a = "a";
690 let short_b = "a";
691 let long_a = "a".repeat(1000);
692 let long_b = "a".repeat(1000);
693
694 assert!(SecureComparison::constant_time_eq(short_a, short_b));
695 assert!(SecureComparison::secure_string_compare(short_a, short_b));
696 assert!(SecureComparison::verify_token(short_a, short_b));
697
698 assert!(SecureComparison::constant_time_eq(&long_a, &long_b));
699 assert!(SecureComparison::secure_string_compare(&long_a, &long_b));
700 assert!(SecureComparison::verify_token(&long_a, &long_b));
701
702 let different_short_a = "a";
703 let different_short_b = "b";
704 let different_long_a = "a".repeat(1000);
705 let different_long_b = "b".repeat(1000);
706
707 assert!(!SecureComparison::constant_time_eq(
708 different_short_a,
709 different_short_b
710 ));
711 assert!(!SecureComparison::secure_string_compare(
712 different_short_a,
713 different_short_b
714 ));
715 assert!(!SecureComparison::verify_token(
716 different_short_a,
717 different_short_b
718 ));
719
720 assert!(!SecureComparison::constant_time_eq(
721 &different_long_a,
722 &different_long_b
723 ));
724 assert!(!SecureComparison::secure_string_compare(
725 &different_long_a,
726 &different_long_b
727 ));
728 assert!(!SecureComparison::verify_token(
729 &different_long_a,
730 &different_long_b
731 ));
732 }
733
734 #[test]
735 fn test_secure_string_multiple_operations() {
736 let secret1 = SecureString::new("password1".to_string());
737 let secret2 = SecureString::new("password2".to_string());
738
739 assert_ne!(secret1.as_str(), secret2.as_str());
740 assert!(SecureComparison::verify_token(
741 secret1.as_str(),
742 secret1.as_str()
743 ));
744 assert!(!SecureComparison::verify_token(
745 secret1.as_str(),
746 secret2.as_str()
747 ));
748
749 assert_eq!(secret1.len(), 9);
751 assert_eq!(secret2.len(), 9);
752 assert!(!secret1.is_empty());
753 assert!(!secret2.is_empty());
754 }
755
756 #[test]
757 fn test_token_verification_false_positives() {
758 let token = "secure_token_123";
759 let similar_token = "secure_token_124"; let prefix_token = "secure_token_12"; let longer_token = "secure_token_1234"; assert!(SecureComparison::verify_token(token, token));
764 assert!(!SecureComparison::verify_token(token, similar_token));
765 assert!(!SecureComparison::verify_token(token, prefix_token));
766 assert!(!SecureComparison::verify_token(token, longer_token));
767 }
768}