1use parking_lot::RwLock;
27use ring::rand::{SecureRandom, SystemRandom};
28use serde::{Deserialize, Serialize};
29use sha2::{Digest, Sha256};
30use std::collections::{HashMap, HashSet};
31use std::sync::Arc;
32use std::time::{Duration, Instant};
33use thiserror::Error;
34use tracing::{debug, warn};
35
36#[derive(Error, Debug, Clone)]
41pub enum AuthError {
42 #[error("Authentication failed")]
43 AuthenticationFailed,
44
45 #[error("Invalid credentials")]
46 InvalidCredentials,
47
48 #[error("Principal not found: {0}")]
49 PrincipalNotFound(String),
50
51 #[error("Principal already exists: {0}")]
52 PrincipalAlreadyExists(String),
53
54 #[error("Access denied: {0}")]
55 AccessDenied(String),
56
57 #[error("Permission denied: {principal} lacks {permission} on {resource}")]
58 PermissionDenied {
59 principal: String,
60 permission: String,
61 resource: String,
62 },
63
64 #[error("Role not found: {0}")]
65 RoleNotFound(String),
66
67 #[error("Invalid token")]
68 InvalidToken,
69
70 #[error("Token expired")]
71 TokenExpired,
72
73 #[error("Rate limited: too many authentication failures")]
74 RateLimited,
75
76 #[error("Internal error: {0}")]
77 Internal(String),
78}
79
80pub type AuthResult<T> = std::result::Result<T, AuthError>;
81
82#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
88pub enum ResourceType {
89 Cluster,
91 Topic(String),
93 TopicPattern(String),
95 ConsumerGroup(String),
97 Schema(String),
99 TransactionalId(String),
101}
102
103impl ResourceType {
104 pub fn matches(&self, other: &ResourceType) -> bool {
106 match (self, other) {
107 (a, b) if a == b => true,
109
110 (ResourceType::TopicPattern(pattern), ResourceType::Topic(name)) => {
112 Self::glob_match(pattern, name)
113 }
114 (ResourceType::Topic(name), ResourceType::TopicPattern(pattern)) => {
115 Self::glob_match(pattern, name)
116 }
117
118 _ => false,
119 }
120 }
121
122 fn glob_match(pattern: &str, text: &str) -> bool {
124 if pattern == "*" {
125 return true;
126 }
127
128 if let Some(prefix) = pattern.strip_suffix('*') {
129 return text.starts_with(prefix);
130 }
131
132 if let Some(suffix) = pattern.strip_prefix('*') {
133 return text.ends_with(suffix);
134 }
135
136 if let Some(idx) = pattern.find('*') {
138 let prefix = &pattern[..idx];
139 let suffix = &pattern[idx + 1..];
140 return text.starts_with(prefix)
141 && text.ends_with(suffix)
142 && text.len() >= prefix.len() + suffix.len();
143 }
144
145 pattern == text
146 }
147}
148
149#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
155pub enum Permission {
156 Read, Write, Create, Delete, Alter, Describe, GroupRead, GroupDelete, ClusterAction, IdempotentWrite, AlterConfigs, DescribeConfigs, All, }
179
180impl Permission {
181 pub fn implies(&self, other: &Permission) -> bool {
184 if self == other {
186 return true;
187 }
188
189 match self {
190 Permission::All => true, Permission::Alter | Permission::Write | Permission::Read => {
193 matches!(other, Permission::Describe)
194 }
195 _ => false,
196 }
197 }
198}
199
200#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
206pub enum PrincipalType {
207 User,
208 ServiceAccount,
209 Anonymous,
210}
211
212#[derive(Clone, Serialize, Deserialize)]
219pub struct Principal {
220 pub name: String,
222
223 pub principal_type: PrincipalType,
225
226 pub password_hash: PasswordHash,
228
229 pub roles: HashSet<String>,
231
232 pub enabled: bool,
234
235 pub metadata: HashMap<String, String>,
237
238 pub created_at: u64,
240}
241
242impl std::fmt::Debug for Principal {
243 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244 f.debug_struct("Principal")
245 .field("name", &self.name)
246 .field("principal_type", &self.principal_type)
247 .field("password_hash", &"[REDACTED]")
248 .field("roles", &self.roles)
249 .field("enabled", &self.enabled)
250 .field("metadata", &self.metadata)
251 .field("created_at", &self.created_at)
252 .finish()
253 }
254}
255
256#[derive(Clone, Serialize, Deserialize)]
263pub struct PasswordHash {
264 pub salt: Vec<u8>,
266 pub iterations: u32,
268 pub server_key: Vec<u8>,
270 pub stored_key: Vec<u8>,
272}
273
274impl std::fmt::Debug for PasswordHash {
275 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276 f.debug_struct("PasswordHash")
277 .field("salt", &"[REDACTED]")
278 .field("iterations", &self.iterations)
279 .field("server_key", &"[REDACTED]")
280 .field("stored_key", &"[REDACTED]")
281 .finish()
282 }
283}
284
285impl PasswordHash {
286 pub fn new(password: &str) -> Self {
288 let rng = SystemRandom::new();
289 let mut salt = vec![0u8; 32];
290 rng.fill(&mut salt).expect("Failed to generate salt");
291
292 Self::with_salt(password, &salt, 4096)
293 }
294
295 pub fn with_salt(password: &str, salt: &[u8], iterations: u32) -> Self {
297 let salted_password = Self::pbkdf2_sha256(password.as_bytes(), salt, iterations);
299
300 let client_key = Self::hmac_sha256(&salted_password, b"Client Key");
302 let server_key = Self::hmac_sha256(&salted_password, b"Server Key");
303
304 let stored_key = Sha256::digest(&client_key).to_vec();
306
307 PasswordHash {
308 salt: salt.to_vec(),
309 iterations,
310 server_key,
311 stored_key,
312 }
313 }
314
315 pub fn verify(&self, password: &str) -> bool {
317 let salted_password = Self::pbkdf2_sha256(password.as_bytes(), &self.salt, self.iterations);
318 let client_key = Self::hmac_sha256(&salted_password, b"Client Key");
319 let stored_key = Sha256::digest(&client_key);
320
321 Self::constant_time_compare(&stored_key, &self.stored_key)
323 }
324
325 pub fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
327 if a.len() != b.len() {
328 return false;
329 }
330
331 let mut result = 0u8;
333 for (x, y) in a.iter().zip(b.iter()) {
334 result |= x ^ y;
335 }
336 result == 0
337 }
338
339 fn pbkdf2_sha256(password: &[u8], salt: &[u8], iterations: u32) -> Vec<u8> {
341 use hmac::{Hmac, Mac};
342 type HmacSha256 = Hmac<Sha256>;
343
344 let mut result = vec![0u8; 32];
345
346 let mut mac = HmacSha256::new_from_slice(password).expect("HMAC accepts any key length");
348 mac.update(salt);
349 mac.update(&1u32.to_be_bytes());
350 let mut u = mac.finalize().into_bytes();
351 result.copy_from_slice(&u);
352
353 for _ in 1..iterations {
355 let mut mac =
356 HmacSha256::new_from_slice(password).expect("HMAC accepts any key length");
357 mac.update(&u);
358 u = mac.finalize().into_bytes();
359
360 for (r, ui) in result.iter_mut().zip(u.iter()) {
361 *r ^= ui;
362 }
363 }
364
365 result
366 }
367
368 pub fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
370 use hmac::{Hmac, Mac};
371 type HmacSha256 = Hmac<Sha256>;
372
373 let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key length");
374 mac.update(data);
375 mac.finalize().into_bytes().to_vec()
376 }
377}
378
379#[derive(Debug, Clone, Serialize, Deserialize)]
385pub struct Role {
386 pub name: String,
388
389 pub description: String,
391
392 pub permissions: HashSet<(ResourceType, Permission)>,
394
395 pub builtin: bool,
397}
398
399impl Role {
400 pub fn admin() -> Self {
402 let mut permissions = HashSet::new();
403 permissions.insert((ResourceType::Cluster, Permission::All));
404
405 Role {
406 name: "admin".to_string(),
407 description: "Full administrative access to all resources".to_string(),
408 permissions,
409 builtin: true,
410 }
411 }
412
413 pub fn producer() -> Self {
415 let mut permissions = HashSet::new();
416 permissions.insert((
417 ResourceType::TopicPattern("*".to_string()),
418 Permission::Write,
419 ));
420 permissions.insert((
421 ResourceType::TopicPattern("*".to_string()),
422 Permission::Describe,
423 ));
424 permissions.insert((ResourceType::Cluster, Permission::IdempotentWrite));
425
426 Role {
427 name: "producer".to_string(),
428 description: "Can produce to all topics".to_string(),
429 permissions,
430 builtin: true,
431 }
432 }
433
434 pub fn consumer() -> Self {
436 let mut permissions = HashSet::new();
437 permissions.insert((
438 ResourceType::TopicPattern("*".to_string()),
439 Permission::Read,
440 ));
441 permissions.insert((
442 ResourceType::TopicPattern("*".to_string()),
443 Permission::Describe,
444 ));
445 permissions.insert((
446 ResourceType::ConsumerGroup("*".to_string()),
447 Permission::GroupRead,
448 ));
449
450 Role {
451 name: "consumer".to_string(),
452 description: "Can consume from all topics".to_string(),
453 permissions,
454 builtin: true,
455 }
456 }
457
458 pub fn read_only() -> Self {
460 let mut permissions = HashSet::new();
461 permissions.insert((
462 ResourceType::TopicPattern("*".to_string()),
463 Permission::Read,
464 ));
465 permissions.insert((
466 ResourceType::TopicPattern("*".to_string()),
467 Permission::Describe,
468 ));
469
470 Role {
471 name: "read-only".to_string(),
472 description: "Read-only access to all topics".to_string(),
473 permissions,
474 builtin: true,
475 }
476 }
477}
478
479#[derive(Debug, Clone, Serialize, Deserialize)]
485pub struct AclEntry {
486 pub principal: String,
488
489 pub resource: ResourceType,
491
492 pub permission: Permission,
494
495 pub allow: bool,
497
498 pub host: String,
500}
501
502#[derive(Debug, Clone)]
508pub struct AuthSession {
509 pub id: String,
511
512 pub principal_name: String,
514
515 pub principal_type: PrincipalType,
517
518 pub permissions: HashSet<(ResourceType, Permission)>,
520
521 pub created_at: Instant,
523
524 pub expires_at: Instant,
526
527 pub client_ip: String,
529}
530
531impl AuthSession {
532 pub fn is_expired(&self) -> bool {
534 Instant::now() >= self.expires_at
535 }
536
537 pub fn has_permission(&self, resource: &ResourceType, permission: &Permission) -> bool {
539 if self
541 .permissions
542 .contains(&(ResourceType::Cluster, Permission::All))
543 {
544 return true;
545 }
546
547 if self.permissions.contains(&(resource.clone(), *permission)) {
549 return true;
550 }
551
552 for (res, perm) in &self.permissions {
554 let resource_matches = res.matches(resource);
557 let permission_implies = perm.implies(permission);
558 if resource_matches && permission_implies {
559 return true;
560 }
561 }
562
563 false
564 }
565}
566
567#[derive(Debug, Clone)]
573pub struct AuthConfig {
574 pub session_timeout: Duration,
576
577 pub max_failed_attempts: u32,
579
580 pub lockout_duration: Duration,
582
583 pub require_authentication: bool,
585
586 pub enable_acls: bool,
588
589 pub default_deny: bool,
591}
592
593impl Default for AuthConfig {
594 fn default() -> Self {
595 AuthConfig {
596 session_timeout: Duration::from_secs(3600), max_failed_attempts: 5,
598 lockout_duration: Duration::from_secs(300), require_authentication: false, enable_acls: false,
601 default_deny: true,
602 }
603 }
604}
605
606struct FailedAttemptTracker {
608 attempts: HashMap<String, Vec<Instant>>,
609 lockouts: HashMap<String, Instant>,
610}
611
612impl FailedAttemptTracker {
613 fn new() -> Self {
614 Self {
615 attempts: HashMap::new(),
616 lockouts: HashMap::new(),
617 }
618 }
619
620 fn is_locked_out(&self, identifier: &str, lockout_duration: Duration) -> bool {
622 if let Some(lockout_time) = self.lockouts.get(identifier) {
623 if lockout_time.elapsed() < lockout_duration {
624 return true;
625 }
626 }
627 false
628 }
629
630 fn record_failure(
632 &mut self,
633 identifier: &str,
634 max_attempts: u32,
635 lockout_duration: Duration,
636 ) -> bool {
637 let now = Instant::now();
638
639 self.lockouts.retain(|_, t| t.elapsed() < lockout_duration);
641
642 let attempts = self.attempts.entry(identifier.to_string()).or_default();
644
645 attempts.retain(|t| t.elapsed() < lockout_duration);
647
648 attempts.push(now);
650
651 if attempts.len() >= max_attempts as usize {
653 warn!(
654 "Principal '{}' locked out after {} failed attempts",
655 identifier, max_attempts
656 );
657 self.lockouts.insert(identifier.to_string(), now);
658 return true;
659 }
660
661 false
662 }
663
664 fn clear_failures(&mut self, identifier: &str) {
666 self.attempts.remove(identifier);
667 self.lockouts.remove(identifier);
668 }
669}
670
671pub struct AuthManager {
673 config: AuthConfig,
674
675 principals: RwLock<HashMap<String, Principal>>,
677
678 roles: RwLock<HashMap<String, Role>>,
680
681 acls: RwLock<Vec<AclEntry>>,
683
684 sessions: RwLock<HashMap<String, AuthSession>>,
686
687 failed_attempts: RwLock<FailedAttemptTracker>,
689
690 rng: SystemRandom,
692}
693
694impl AuthManager {
695 pub fn new(config: AuthConfig) -> Self {
697 let manager = Self {
698 config,
699 principals: RwLock::new(HashMap::new()),
700 roles: RwLock::new(HashMap::new()),
701 acls: RwLock::new(Vec::new()),
702 sessions: RwLock::new(HashMap::new()),
703 failed_attempts: RwLock::new(FailedAttemptTracker::new()),
704 rng: SystemRandom::new(),
705 };
706
707 manager.init_builtin_roles();
709
710 manager
711 }
712
713 pub fn new_default() -> Self {
715 Self::new(AuthConfig::default())
716 }
717
718 pub fn with_auth_enabled() -> Self {
720 Self::new(AuthConfig {
721 require_authentication: true,
722 enable_acls: true,
723 ..Default::default()
724 })
725 }
726
727 fn init_builtin_roles(&self) {
729 let mut roles = self.roles.write();
730 roles.insert("admin".to_string(), Role::admin());
731 roles.insert("producer".to_string(), Role::producer());
732 roles.insert("consumer".to_string(), Role::consumer());
733 roles.insert("read-only".to_string(), Role::read_only());
734 }
735
736 pub fn create_principal(
742 &self,
743 name: &str,
744 password: &str,
745 principal_type: PrincipalType,
746 roles: HashSet<String>,
747 ) -> AuthResult<()> {
748 if name.is_empty() || name.len() > 255 {
750 return Err(AuthError::Internal("Invalid principal name".to_string()));
751 }
752
753 if password.len() < 8 {
755 return Err(AuthError::Internal(
756 "Password must be at least 8 characters".to_string(),
757 ));
758 }
759
760 {
762 let role_map = self.roles.read();
763 for role in &roles {
764 if !role_map.contains_key(role) {
765 return Err(AuthError::RoleNotFound(role.clone()));
766 }
767 }
768 }
769
770 let mut principals = self.principals.write();
771
772 if principals.contains_key(name) {
773 return Err(AuthError::PrincipalAlreadyExists(name.to_string()));
774 }
775
776 let principal = Principal {
777 name: name.to_string(),
778 principal_type,
779 password_hash: PasswordHash::new(password),
780 roles,
781 enabled: true,
782 metadata: HashMap::new(),
783 created_at: std::time::SystemTime::now()
784 .duration_since(std::time::UNIX_EPOCH)
785 .unwrap_or_default()
786 .as_secs(),
787 };
788
789 principals.insert(name.to_string(), principal);
790 debug!("Created principal: {}", name);
791
792 Ok(())
793 }
794
795 pub fn delete_principal(&self, name: &str) -> AuthResult<()> {
797 let mut principals = self.principals.write();
798
799 if principals.remove(name).is_none() {
800 return Err(AuthError::PrincipalNotFound(name.to_string()));
801 }
802
803 let mut sessions = self.sessions.write();
805 sessions.retain(|_, s| s.principal_name != name);
806
807 debug!("Deleted principal: {}", name);
808 Ok(())
809 }
810
811 pub fn get_principal(&self, name: &str) -> Option<Principal> {
813 self.principals.read().get(name).cloned()
814 }
815
816 pub fn list_principals(&self) -> Vec<String> {
818 self.principals.read().keys().cloned().collect()
819 }
820
821 pub fn update_password(&self, name: &str, new_password: &str) -> AuthResult<()> {
823 if new_password.len() < 8 {
824 return Err(AuthError::Internal(
825 "Password must be at least 8 characters".to_string(),
826 ));
827 }
828
829 let mut principals = self.principals.write();
830
831 let principal = principals
832 .get_mut(name)
833 .ok_or_else(|| AuthError::PrincipalNotFound(name.to_string()))?;
834
835 principal.password_hash = PasswordHash::new(new_password);
836
837 let mut sessions = self.sessions.write();
839 sessions.retain(|_, s| s.principal_name != name);
840
841 debug!("Updated password for principal: {}", name);
842 Ok(())
843 }
844
845 pub fn add_role_to_principal(&self, principal_name: &str, role_name: &str) -> AuthResult<()> {
847 if !self.roles.read().contains_key(role_name) {
849 return Err(AuthError::RoleNotFound(role_name.to_string()));
850 }
851
852 let mut principals = self.principals.write();
853
854 let principal = principals
855 .get_mut(principal_name)
856 .ok_or_else(|| AuthError::PrincipalNotFound(principal_name.to_string()))?;
857
858 principal.roles.insert(role_name.to_string());
859
860 debug!(
861 "Added role '{}' to principal '{}'",
862 role_name, principal_name
863 );
864 Ok(())
865 }
866
867 pub fn remove_role_from_principal(
869 &self,
870 principal_name: &str,
871 role_name: &str,
872 ) -> AuthResult<()> {
873 let mut principals = self.principals.write();
874
875 let principal = principals
876 .get_mut(principal_name)
877 .ok_or_else(|| AuthError::PrincipalNotFound(principal_name.to_string()))?;
878
879 principal.roles.remove(role_name);
880
881 debug!(
882 "Removed role '{}' from principal '{}'",
883 role_name, principal_name
884 );
885 Ok(())
886 }
887
888 pub fn create_role(&self, role: Role) -> AuthResult<()> {
894 let mut roles = self.roles.write();
895
896 if roles.contains_key(&role.name) {
897 return Err(AuthError::Internal(format!(
898 "Role '{}' already exists",
899 role.name
900 )));
901 }
902
903 debug!("Created role: {}", role.name);
904 roles.insert(role.name.clone(), role);
905 Ok(())
906 }
907
908 pub fn delete_role(&self, name: &str) -> AuthResult<()> {
910 let mut roles = self.roles.write();
911
912 if let Some(role) = roles.get(name) {
913 if role.builtin {
914 return Err(AuthError::Internal(
915 "Cannot delete built-in role".to_string(),
916 ));
917 }
918 } else {
919 return Err(AuthError::RoleNotFound(name.to_string()));
920 }
921
922 roles.remove(name);
923 debug!("Deleted role: {}", name);
924 Ok(())
925 }
926
927 pub fn get_role(&self, name: &str) -> Option<Role> {
929 self.roles.read().get(name).cloned()
930 }
931
932 pub fn list_roles(&self) -> Vec<String> {
934 self.roles.read().keys().cloned().collect()
935 }
936
937 pub fn add_acl(&self, entry: AclEntry) {
943 let mut acls = self.acls.write();
944 acls.push(entry);
945 }
946
947 pub fn remove_acls(&self, principal: Option<&str>, resource: Option<&ResourceType>) {
949 let mut acls = self.acls.write();
950 acls.retain(|acl| {
951 let principal_match =
952 principal.is_none_or(|p| acl.principal == p || acl.principal == "*");
953 let resource_match = resource.is_none_or(|r| &acl.resource == r);
954 !(principal_match && resource_match)
955 });
956 }
957
958 pub fn list_acls(&self) -> Vec<AclEntry> {
960 self.acls.read().clone()
961 }
962
963 pub fn authenticate(
969 &self,
970 username: &str,
971 password: &str,
972 client_ip: &str,
973 ) -> AuthResult<AuthSession> {
974 {
976 let tracker = self.failed_attempts.read();
977 if tracker.is_locked_out(username, self.config.lockout_duration) {
978 warn!(
979 "Authentication attempt for locked-out principal: {}",
980 username
981 );
982 return Err(AuthError::RateLimited);
983 }
984 if tracker.is_locked_out(client_ip, self.config.lockout_duration) {
985 warn!("Authentication attempt from locked-out IP: {}", client_ip);
986 return Err(AuthError::RateLimited);
987 }
988 }
989
990 let principal = {
992 let principals = self.principals.read();
993 principals.get(username).cloned()
994 };
995
996 let principal = match principal {
997 Some(p) if p.enabled => p,
998 Some(_) => {
999 self.record_auth_failure(username, client_ip);
1001 return Err(AuthError::AuthenticationFailed);
1002 }
1003 None => {
1004 let dummy = PasswordHash::new("dummy");
1007 let _ = dummy.verify(password);
1008 self.record_auth_failure(username, client_ip);
1009 return Err(AuthError::AuthenticationFailed);
1010 }
1011 };
1012
1013 if !principal.password_hash.verify(password) {
1015 self.record_auth_failure(username, client_ip);
1016 return Err(AuthError::AuthenticationFailed);
1017 }
1018
1019 self.failed_attempts.write().clear_failures(username);
1021 self.failed_attempts.write().clear_failures(client_ip);
1022
1023 let permissions = self.resolve_permissions(&principal);
1025
1026 let mut session_id = vec![0u8; 32];
1028 self.rng
1029 .fill(&mut session_id)
1030 .map_err(|_| AuthError::Internal("RNG failed".to_string()))?;
1031 let session_id = hex::encode(&session_id);
1032
1033 let now = Instant::now();
1034 let session = AuthSession {
1035 id: session_id.clone(),
1036 principal_name: principal.name.clone(),
1037 principal_type: principal.principal_type.clone(),
1038 permissions,
1039 created_at: now,
1040 expires_at: now + self.config.session_timeout,
1041 client_ip: client_ip.to_string(),
1042 };
1043
1044 self.sessions.write().insert(session_id, session.clone());
1046
1047 debug!("Authenticated principal '{}' from {}", username, client_ip);
1048 Ok(session)
1049 }
1050
1051 fn record_auth_failure(&self, username: &str, client_ip: &str) {
1053 let mut tracker = self.failed_attempts.write();
1054 tracker.record_failure(
1055 username,
1056 self.config.max_failed_attempts,
1057 self.config.lockout_duration,
1058 );
1059 tracker.record_failure(
1060 client_ip,
1061 self.config.max_failed_attempts * 2,
1062 self.config.lockout_duration,
1063 );
1064 }
1065
1066 pub fn get_session(&self, session_id: &str) -> Option<AuthSession> {
1068 let sessions = self.sessions.read();
1069 sessions.get(session_id).and_then(|s| {
1070 if s.is_expired() {
1071 None
1072 } else {
1073 Some(s.clone())
1074 }
1075 })
1076 }
1077
1078 pub fn invalidate_session(&self, session_id: &str) {
1080 self.sessions.write().remove(session_id);
1081 }
1082
1083 pub fn invalidate_all_sessions(&self, principal_name: &str) {
1085 self.sessions
1086 .write()
1087 .retain(|_, s| s.principal_name != principal_name);
1088 }
1089
1090 pub fn cleanup_expired_sessions(&self) {
1092 self.sessions.write().retain(|_, s| !s.is_expired());
1093 }
1094
1095 pub fn create_session(&self, principal: &Principal) -> AuthSession {
1097 let permissions = self.resolve_permissions(principal);
1098
1099 let mut session_id = vec![0u8; 32];
1100 self.rng.fill(&mut session_id).expect("RNG failed");
1101 let session_id = hex::encode(&session_id);
1102
1103 let now = Instant::now();
1104 let session = AuthSession {
1105 id: session_id.clone(),
1106 principal_name: principal.name.clone(),
1107 principal_type: principal.principal_type.clone(),
1108 permissions,
1109 created_at: now,
1110 expires_at: now + self.config.session_timeout,
1111 client_ip: "scram".to_string(),
1112 };
1113
1114 self.sessions.write().insert(session_id, session.clone());
1115 session
1116 }
1117
1118 fn resolve_permissions(&self, principal: &Principal) -> HashSet<(ResourceType, Permission)> {
1124 let mut permissions = HashSet::new();
1125
1126 let roles = self.roles.read();
1127
1128 for role_name in &principal.roles {
1130 if let Some(role) = roles.get(role_name) {
1131 permissions.extend(role.permissions.iter().cloned());
1132 }
1133 }
1134
1135 permissions
1136 }
1137
1138 pub fn authorize(
1140 &self,
1141 session: &AuthSession,
1142 resource: &ResourceType,
1143 permission: Permission,
1144 client_ip: &str,
1145 ) -> AuthResult<()> {
1146 if !self.config.require_authentication && !self.config.enable_acls {
1148 return Ok(());
1149 }
1150
1151 if session.is_expired() {
1153 return Err(AuthError::TokenExpired);
1154 }
1155
1156 if session.has_permission(resource, &permission) {
1158 return Ok(());
1159 }
1160
1161 if self.config.enable_acls
1163 && self.check_acls(&session.principal_name, resource, permission, client_ip)
1164 {
1165 return Ok(());
1166 }
1167
1168 if self.config.default_deny {
1170 warn!(
1171 "Access denied: {} attempted {} on {:?} from {}",
1172 session.principal_name,
1173 format!("{:?}", permission),
1174 resource,
1175 client_ip
1176 );
1177 return Err(AuthError::PermissionDenied {
1178 principal: session.principal_name.clone(),
1179 permission: format!("{:?}", permission),
1180 resource: format!("{:?}", resource),
1181 });
1182 }
1183
1184 Ok(())
1185 }
1186
1187 fn check_acls(
1189 &self,
1190 principal: &str,
1191 resource: &ResourceType,
1192 permission: Permission,
1193 client_ip: &str,
1194 ) -> bool {
1195 let acls = self.acls.read();
1196
1197 for acl in acls.iter() {
1199 if !acl.allow
1200 && (acl.principal == principal || acl.principal == "*")
1201 && (acl.host == client_ip || acl.host == "*")
1202 && acl.resource.matches(resource)
1203 && (acl.permission == permission || acl.permission == Permission::All)
1204 {
1205 return false; }
1207 }
1208
1209 for acl in acls.iter() {
1211 if acl.allow
1212 && (acl.principal == principal || acl.principal == "*")
1213 && (acl.host == client_ip || acl.host == "*")
1214 && acl.resource.matches(resource)
1215 && (acl.permission == permission || acl.permission == Permission::All)
1216 {
1217 return true;
1218 }
1219 }
1220
1221 false
1222 }
1223
1224 #[allow(unused_variables)]
1226 pub fn authorize_anonymous(
1227 &self,
1228 resource: &ResourceType,
1229 permission: Permission,
1230 ) -> AuthResult<()> {
1231 if !self.config.require_authentication {
1232 return Ok(());
1233 }
1234
1235 Err(AuthError::AuthenticationFailed)
1236 }
1237}
1238
1239pub struct SaslPlainAuth {
1245 auth_manager: Arc<AuthManager>,
1246}
1247
1248impl SaslPlainAuth {
1249 pub fn new(auth_manager: Arc<AuthManager>) -> Self {
1250 Self { auth_manager }
1251 }
1252
1253 pub fn authenticate(&self, sasl_bytes: &[u8], client_ip: &str) -> AuthResult<AuthSession> {
1256 let parts: Vec<&[u8]> = sasl_bytes.split(|&b| b == 0).collect();
1258
1259 if parts.len() < 2 {
1260 return Err(AuthError::InvalidCredentials);
1261 }
1262
1263 let (username, password) = if parts.len() == 2 {
1265 (
1266 std::str::from_utf8(parts[0]).map_err(|_| AuthError::InvalidCredentials)?,
1267 std::str::from_utf8(parts[1]).map_err(|_| AuthError::InvalidCredentials)?,
1268 )
1269 } else {
1270 (
1272 std::str::from_utf8(parts[1]).map_err(|_| AuthError::InvalidCredentials)?,
1273 std::str::from_utf8(parts[2]).map_err(|_| AuthError::InvalidCredentials)?,
1274 )
1275 };
1276
1277 self.auth_manager
1278 .authenticate(username, password, client_ip)
1279 }
1280}
1281
1282#[derive(Debug, Clone)]
1295pub enum ScramState {
1296 Initial,
1298 ServerFirstSent {
1300 username: String,
1301 client_nonce: String,
1302 server_nonce: String,
1303 salt: Vec<u8>,
1304 iterations: u32,
1305 auth_message: String,
1306 },
1307 Complete,
1309}
1310
1311pub struct SaslScramAuth {
1313 auth_manager: Arc<AuthManager>,
1314}
1315
1316impl SaslScramAuth {
1317 pub fn new(auth_manager: Arc<AuthManager>) -> Self {
1318 Self { auth_manager }
1319 }
1320
1321 pub fn process_client_first(
1326 &self,
1327 client_first: &[u8],
1328 client_ip: &str,
1329 ) -> AuthResult<(ScramState, Vec<u8>)> {
1330 let client_first_str =
1331 std::str::from_utf8(client_first).map_err(|_| AuthError::InvalidCredentials)?;
1332
1333 let parts: Vec<&str> = client_first_str.splitn(3, ',').collect();
1339 if parts.len() < 3 {
1340 return Err(AuthError::InvalidCredentials);
1341 }
1342
1343 let client_first_bare = if parts[0] == "n" || parts[0] == "y" || parts[0] == "p" {
1345 &client_first_str[parts[0].len() + 1 + parts[1].len() + 1..]
1347 } else {
1348 client_first_str
1350 };
1351
1352 let mut username = None;
1354 let mut client_nonce = None;
1355
1356 for attr in client_first_bare.split(',') {
1357 if let Some(value) = attr.strip_prefix("n=") {
1358 username = Some(Self::unescape_username(value));
1359 } else if let Some(value) = attr.strip_prefix("r=") {
1360 client_nonce = Some(value.to_string());
1361 }
1362 }
1363
1364 let username = username.ok_or(AuthError::InvalidCredentials)?;
1365 let client_nonce = client_nonce.ok_or(AuthError::InvalidCredentials)?;
1366
1367 let (salt, iterations) = match self.auth_manager.get_principal(&username) {
1369 Some(principal) => (
1370 principal.password_hash.salt.clone(),
1371 principal.password_hash.iterations,
1372 ),
1373 None => {
1374 warn!(
1377 "SCRAM auth for unknown user '{}' from {}",
1378 username, client_ip
1379 );
1380 let rng = SystemRandom::new();
1381 let mut fake_salt = vec![0u8; 32];
1382 rng.fill(&mut fake_salt).expect("Failed to generate salt");
1383 (fake_salt, 4096)
1384 }
1385 };
1386
1387 let rng = SystemRandom::new();
1389 let mut server_nonce_bytes = vec![0u8; 24];
1390 rng.fill(&mut server_nonce_bytes)
1391 .expect("Failed to generate nonce");
1392 let server_nonce = base64_encode(&server_nonce_bytes);
1393 let combined_nonce = format!("{}{}", client_nonce, server_nonce);
1394
1395 let salt_b64 = base64_encode(&salt);
1397 let server_first = format!("r={},s={},i={}", combined_nonce, salt_b64, iterations);
1398
1399 let auth_message = format!(
1401 "{},{},c=biws,r={}",
1402 client_first_bare, server_first, combined_nonce
1403 );
1404
1405 let state = ScramState::ServerFirstSent {
1406 username,
1407 client_nonce,
1408 server_nonce,
1409 salt,
1410 iterations,
1411 auth_message,
1412 };
1413
1414 Ok((state, server_first.into_bytes()))
1415 }
1416
1417 pub fn process_client_final(
1422 &self,
1423 state: &ScramState,
1424 client_final: &[u8],
1425 client_ip: &str,
1426 ) -> AuthResult<(AuthSession, Vec<u8>)> {
1427 let ScramState::ServerFirstSent {
1428 username,
1429 client_nonce,
1430 server_nonce,
1431 salt: _, iterations: _, auth_message,
1434 } = state
1435 else {
1436 return Err(AuthError::Internal("Invalid SCRAM state".to_string()));
1437 };
1438
1439 let client_final_str =
1440 std::str::from_utf8(client_final).map_err(|_| AuthError::InvalidCredentials)?;
1441
1442 let mut channel_binding = None;
1444 let mut nonce = None;
1445 let mut proof = None;
1446
1447 for attr in client_final_str.split(',') {
1448 if let Some(value) = attr.strip_prefix("c=") {
1449 channel_binding = Some(value.to_string());
1450 } else if let Some(value) = attr.strip_prefix("r=") {
1451 nonce = Some(value.to_string());
1452 } else if let Some(value) = attr.strip_prefix("p=") {
1453 proof = Some(value.to_string());
1454 }
1455 }
1456
1457 let _channel_binding = channel_binding.ok_or(AuthError::InvalidCredentials)?;
1458 let nonce = nonce.ok_or(AuthError::InvalidCredentials)?;
1459 let proof_b64 = proof.ok_or(AuthError::InvalidCredentials)?;
1460
1461 let expected_nonce = format!("{}{}", client_nonce, server_nonce);
1463 if nonce != expected_nonce {
1464 warn!("SCRAM nonce mismatch for '{}' from {}", username, client_ip);
1465 return Err(AuthError::InvalidCredentials);
1466 }
1467
1468 let principal = self
1470 .auth_manager
1471 .get_principal(username)
1472 .ok_or(AuthError::AuthenticationFailed)?;
1473
1474 let client_proof = base64_decode(&proof_b64).map_err(|_| AuthError::InvalidCredentials)?;
1480
1481 let client_signature =
1483 PasswordHash::hmac_sha256(&principal.password_hash.stored_key, auth_message.as_bytes());
1484
1485 if client_proof.len() != client_signature.len() {
1487 return Err(AuthError::InvalidCredentials);
1488 }
1489
1490 let client_key: Vec<u8> = client_proof
1491 .iter()
1492 .zip(client_signature.iter())
1493 .map(|(p, s)| p ^ s)
1494 .collect();
1495
1496 let computed_stored_key = Sha256::digest(&client_key);
1498 if !PasswordHash::constant_time_compare(
1499 &computed_stored_key,
1500 &principal.password_hash.stored_key,
1501 ) {
1502 warn!(
1503 "SCRAM authentication failed for '{}' from {}",
1504 username, client_ip
1505 );
1506 return Err(AuthError::AuthenticationFailed);
1507 }
1508
1509 let server_signature =
1511 PasswordHash::hmac_sha256(&principal.password_hash.server_key, auth_message.as_bytes());
1512 let server_final = format!("v={}", base64_encode(&server_signature));
1513
1514 let session = self.auth_manager.create_session(&principal);
1516 debug!(
1517 "SCRAM authentication successful for '{}' from {}",
1518 username, client_ip
1519 );
1520
1521 Ok((session, server_final.into_bytes()))
1522 }
1523
1524 fn unescape_username(s: &str) -> String {
1526 s.replace("=2C", ",").replace("=3D", "=")
1527 }
1528}
1529
1530fn base64_encode(data: &[u8]) -> String {
1532 use base64::{engine::general_purpose::STANDARD, Engine as _};
1533 STANDARD.encode(data)
1534}
1535
1536fn base64_decode(s: &str) -> Result<Vec<u8>, base64::DecodeError> {
1538 use base64::{engine::general_purpose::STANDARD, Engine as _};
1539 STANDARD.decode(s)
1540}
1541
1542#[cfg(test)]
1547mod tests {
1548 use super::*;
1549
1550 #[test]
1551 fn test_password_hash_verify() {
1552 let hash = PasswordHash::new("test_password_123");
1553 assert!(hash.verify("test_password_123"));
1554 assert!(!hash.verify("wrong_password"));
1555 assert!(!hash.verify(""));
1556 assert!(!hash.verify("test_password_12")); }
1558
1559 #[test]
1560 fn test_password_hash_timing_attack_resistant() {
1561 let hash = PasswordHash::new("correct_password");
1564
1565 assert!(!hash.verify("wrong_password"));
1567
1568 assert!(!hash.verify("x"));
1570
1571 }
1573
1574 #[test]
1575 fn test_create_principal() {
1576 let auth = AuthManager::new_default();
1577
1578 let mut roles = HashSet::new();
1579 roles.insert("producer".to_string());
1580
1581 auth.create_principal(
1582 "alice",
1583 "secure_pass_123",
1584 PrincipalType::User,
1585 roles.clone(),
1586 )
1587 .expect("Failed to create principal");
1588
1589 assert!(auth
1591 .create_principal("alice", "other_pass", PrincipalType::User, roles.clone())
1592 .is_err());
1593
1594 let principal = auth.get_principal("alice").expect("Principal not found");
1596 assert_eq!(principal.name, "alice");
1597 assert!(principal.roles.contains("producer"));
1598 }
1599
1600 #[test]
1601 fn test_authentication_success() {
1602 let auth = AuthManager::new_default();
1603
1604 let mut roles = HashSet::new();
1605 roles.insert("producer".to_string());
1606
1607 auth.create_principal("bob", "bob_password", PrincipalType::User, roles)
1608 .unwrap();
1609
1610 let session = auth
1611 .authenticate("bob", "bob_password", "127.0.0.1")
1612 .expect("Authentication should succeed");
1613
1614 assert_eq!(session.principal_name, "bob");
1615 assert!(!session.is_expired());
1616 }
1617
1618 #[test]
1619 fn test_authentication_failure() {
1620 let auth = AuthManager::new_default();
1621
1622 let mut roles = HashSet::new();
1623 roles.insert("producer".to_string());
1624
1625 auth.create_principal("charlie", "correct_password", PrincipalType::User, roles)
1626 .unwrap();
1627
1628 let result = auth.authenticate("charlie", "wrong_password", "127.0.0.1");
1630 assert!(matches!(result, Err(AuthError::AuthenticationFailed)));
1631
1632 let result = auth.authenticate("unknown", "password", "127.0.0.1");
1634 assert!(matches!(result, Err(AuthError::AuthenticationFailed)));
1635 }
1636
1637 #[test]
1638 fn test_rate_limiting() {
1639 let config = AuthConfig {
1640 max_failed_attempts: 3,
1641 lockout_duration: Duration::from_secs(1),
1642 ..Default::default()
1643 };
1644 let auth = AuthManager::new(config);
1645
1646 let mut roles = HashSet::new();
1647 roles.insert("consumer".to_string());
1648 auth.create_principal("eve", "password", PrincipalType::User, roles)
1649 .unwrap();
1650
1651 for _ in 0..3 {
1653 let _ = auth.authenticate("eve", "wrong", "192.168.1.1");
1654 }
1655
1656 let result = auth.authenticate("eve", "password", "192.168.1.1");
1658 assert!(matches!(result, Err(AuthError::RateLimited)));
1659
1660 std::thread::sleep(Duration::from_millis(1100));
1662
1663 let result = auth.authenticate("eve", "password", "192.168.1.1");
1665 assert!(result.is_ok());
1666 }
1667
1668 #[test]
1669 fn test_role_permissions() {
1670 let auth = AuthManager::with_auth_enabled();
1671
1672 let mut roles = HashSet::new();
1673 roles.insert("producer".to_string());
1674 auth.create_principal("producer_user", "password", PrincipalType::User, roles)
1675 .unwrap();
1676
1677 let session = auth
1678 .authenticate("producer_user", "password", "127.0.0.1")
1679 .unwrap();
1680
1681 assert!(session.has_permission(
1683 &ResourceType::Topic("orders".to_string()),
1684 &Permission::Write
1685 ));
1686
1687 assert!(!session.has_permission(
1689 &ResourceType::Topic("orders".to_string()),
1690 &Permission::Delete
1691 ));
1692 }
1693
1694 #[test]
1695 fn test_admin_has_all_permissions() {
1696 let auth = AuthManager::with_auth_enabled();
1697
1698 let mut roles = HashSet::new();
1699 roles.insert("admin".to_string());
1700 auth.create_principal("admin_user", "admin_pass", PrincipalType::User, roles)
1701 .unwrap();
1702
1703 let session = auth
1704 .authenticate("admin_user", "admin_pass", "127.0.0.1")
1705 .unwrap();
1706
1707 assert!(session.has_permission(&ResourceType::Cluster, &Permission::All));
1709 assert!(session.has_permission(
1710 &ResourceType::Topic("any_topic".to_string()),
1711 &Permission::Delete
1712 ));
1713 }
1714
1715 #[test]
1716 fn test_resource_pattern_matching() {
1717 assert!(ResourceType::TopicPattern("*".to_string())
1718 .matches(&ResourceType::Topic("anything".to_string())));
1719
1720 assert!(ResourceType::TopicPattern("orders-*".to_string())
1721 .matches(&ResourceType::Topic("orders-us".to_string())));
1722
1723 assert!(ResourceType::TopicPattern("orders-*".to_string())
1724 .matches(&ResourceType::Topic("orders-eu".to_string())));
1725
1726 assert!(!ResourceType::TopicPattern("orders-*".to_string())
1727 .matches(&ResourceType::Topic("events-us".to_string())));
1728 }
1729
1730 #[test]
1731 fn test_acl_enforcement() {
1732 let auth = AuthManager::new(AuthConfig {
1733 require_authentication: true,
1734 enable_acls: true,
1735 default_deny: true,
1736 ..Default::default()
1737 });
1738
1739 let mut roles = HashSet::new();
1740 roles.insert("read-only".to_string());
1741 auth.create_principal("reader", "password", PrincipalType::User, roles)
1742 .unwrap();
1743
1744 auth.add_acl(AclEntry {
1746 principal: "reader".to_string(),
1747 resource: ResourceType::Topic("special-topic".to_string()),
1748 permission: Permission::Write,
1749 allow: true,
1750 host: "*".to_string(),
1751 });
1752
1753 let session = auth
1754 .authenticate("reader", "password", "127.0.0.1")
1755 .unwrap();
1756
1757 let result = auth.authorize(
1759 &session,
1760 &ResourceType::Topic("special-topic".to_string()),
1761 Permission::Write,
1762 "127.0.0.1",
1763 );
1764 assert!(result.is_ok());
1765
1766 let result = auth.authorize(
1768 &session,
1769 &ResourceType::Topic("other-topic".to_string()),
1770 Permission::Write,
1771 "127.0.0.1",
1772 );
1773 assert!(result.is_err());
1774 }
1775
1776 #[test]
1777 fn test_sasl_plain_authentication() {
1778 let auth = Arc::new(AuthManager::new_default());
1779
1780 let mut roles = HashSet::new();
1781 roles.insert("producer".to_string());
1782 auth.create_principal("sasl_user", "sasl_password", PrincipalType::User, roles)
1783 .unwrap();
1784
1785 let sasl = SaslPlainAuth::new(auth);
1786
1787 let two_part = b"sasl_user\0sasl_password";
1789 let result = sasl.authenticate(two_part, "127.0.0.1");
1790 assert!(result.is_ok());
1791
1792 let three_part = b"\0sasl_user\0sasl_password";
1794 let result = sasl.authenticate(three_part, "127.0.0.1");
1795 assert!(result.is_ok());
1796 }
1797
1798 #[test]
1799 fn test_session_expiration() {
1800 let config = AuthConfig {
1801 session_timeout: Duration::from_millis(100),
1802 ..Default::default()
1803 };
1804 let auth = AuthManager::new(config);
1805
1806 let mut roles = HashSet::new();
1807 roles.insert("producer".to_string());
1808 auth.create_principal("expiring", "password", PrincipalType::User, roles)
1809 .unwrap();
1810
1811 let session = auth
1812 .authenticate("expiring", "password", "127.0.0.1")
1813 .unwrap();
1814 assert!(!session.is_expired());
1815
1816 std::thread::sleep(Duration::from_millis(150));
1818
1819 let session = AuthSession {
1821 expires_at: session.expires_at,
1822 ..session
1823 };
1824 assert!(session.is_expired());
1825 }
1826
1827 #[test]
1828 fn test_delete_principal_invalidates_sessions() {
1829 let auth = AuthManager::new_default();
1830
1831 let mut roles = HashSet::new();
1832 roles.insert("producer".to_string());
1833 auth.create_principal("deleteme", "password", PrincipalType::User, roles)
1834 .unwrap();
1835
1836 let session = auth
1837 .authenticate("deleteme", "password", "127.0.0.1")
1838 .unwrap();
1839
1840 assert!(auth.get_session(&session.id).is_some());
1842
1843 auth.delete_principal("deleteme").unwrap();
1845
1846 assert!(auth.get_session(&session.id).is_none());
1848 }
1849
1850 #[test]
1851 fn test_disabled_principal_cannot_authenticate() {
1852 let auth = AuthManager::new_default();
1853
1854 let mut roles = HashSet::new();
1855 roles.insert("producer".to_string());
1856 auth.create_principal("disabled_user", "password", PrincipalType::User, roles)
1857 .unwrap();
1858
1859 {
1861 let mut principals = auth.principals.write();
1862 if let Some(p) = principals.get_mut("disabled_user") {
1863 p.enabled = false;
1864 }
1865 }
1866
1867 let result = auth.authenticate("disabled_user", "password", "127.0.0.1");
1869 assert!(matches!(result, Err(AuthError::AuthenticationFailed)));
1870 }
1871
1872 #[test]
1873 fn test_password_hash_debug_redacts_sensitive_data() {
1874 let hash = PasswordHash::new("super_secret_password");
1875 let debug_output = format!("{:?}", hash);
1876
1877 assert!(
1879 debug_output.contains("[REDACTED]"),
1880 "Debug output should contain [REDACTED]"
1881 );
1882
1883 assert!(
1886 !debug_output.contains("super_secret_password"),
1887 "Debug output should not contain password"
1888 );
1889
1890 assert!(
1892 debug_output.contains("iterations"),
1893 "Debug output should show iterations field"
1894 );
1895 }
1896
1897 #[test]
1898 fn test_principal_debug_redacts_password_hash() {
1899 let principal = Principal {
1900 name: "test_user".to_string(),
1901 principal_type: PrincipalType::User,
1902 password_hash: PasswordHash::new("secret_password"),
1903 roles: HashSet::from(["admin".to_string()]),
1904 enabled: true,
1905 metadata: HashMap::new(),
1906 created_at: 1234567890,
1907 };
1908
1909 let debug_output = format!("{:?}", principal);
1910
1911 assert!(
1913 debug_output.contains("[REDACTED]"),
1914 "Debug output should contain [REDACTED]: {}",
1915 debug_output
1916 );
1917
1918 assert!(
1920 debug_output.contains("test_user"),
1921 "Debug output should show name"
1922 );
1923 assert!(
1924 debug_output.contains("admin"),
1925 "Debug output should show roles"
1926 );
1927 }
1928
1929 #[test]
1934 fn test_scram_full_handshake() {
1935 use sha2::{Digest, Sha256};
1936
1937 let auth = Arc::new(AuthManager::new_default());
1938
1939 let mut roles = HashSet::new();
1941 roles.insert("producer".to_string());
1942 auth.create_principal("scram_user", "scram_password", PrincipalType::User, roles)
1943 .expect("Failed to create principal");
1944
1945 let scram = SaslScramAuth::new(auth.clone());
1946
1947 let client_nonce = "rOprNGfwEbeRWgbNEkqO";
1949 let client_first = format!("n,,n=scram_user,r={}", client_nonce);
1950
1951 let (state, server_first) = scram
1952 .process_client_first(client_first.as_bytes(), "127.0.0.1")
1953 .expect("client-first processing should succeed");
1954
1955 let server_first_str = std::str::from_utf8(&server_first).expect("valid UTF-8");
1957 assert!(server_first_str.starts_with(&format!("r={}", client_nonce)));
1958 assert!(server_first_str.contains(",s="));
1959 assert!(server_first_str.contains(",i="));
1960
1961 let ScramState::ServerFirstSent {
1963 username: _,
1964 client_nonce: _,
1965 server_nonce: _,
1966 salt,
1967 iterations,
1968 auth_message: _,
1969 } = &state
1970 else {
1971 panic!("Expected ServerFirstSent state");
1972 };
1973
1974 let salted_password = compute_salted_password("scram_password", salt, *iterations);
1977 let client_key = PasswordHash::hmac_sha256(&salted_password, b"Client Key");
1978 let stored_key = Sha256::digest(&client_key);
1979
1980 let client_first_bare = format!("n=scram_user,r={}", client_nonce);
1982 let combined_nonce: String = server_first_str
1983 .split(',')
1984 .find(|s| s.starts_with("r="))
1985 .map(|s| &s[2..])
1986 .unwrap()
1987 .to_string();
1988
1989 let auth_message = format!(
1990 "{},{},c=biws,r={}",
1991 client_first_bare, server_first_str, combined_nonce
1992 );
1993
1994 let client_signature = PasswordHash::hmac_sha256(&stored_key, auth_message.as_bytes());
1995 let client_proof: Vec<u8> = client_key
1996 .iter()
1997 .zip(client_signature.iter())
1998 .map(|(k, s)| k ^ s)
1999 .collect();
2000
2001 let client_final = format!(
2002 "c=biws,r={},p={}",
2003 combined_nonce,
2004 base64_encode(&client_proof)
2005 );
2006
2007 let (session, server_final) = scram
2009 .process_client_final(&state, client_final.as_bytes(), "127.0.0.1")
2010 .expect("client-final processing should succeed");
2011
2012 assert_eq!(session.principal_name, "scram_user");
2014 assert!(!session.is_expired());
2015
2016 let server_final_str = std::str::from_utf8(&server_final).expect("valid UTF-8");
2018 assert!(server_final_str.starts_with("v="));
2019 }
2020
2021 #[test]
2022 fn test_scram_wrong_password() {
2023 let auth = Arc::new(AuthManager::new_default());
2024
2025 let mut roles = HashSet::new();
2026 roles.insert("producer".to_string());
2027 auth.create_principal(
2028 "scram_user2",
2029 "correct_password",
2030 PrincipalType::User,
2031 roles,
2032 )
2033 .expect("Failed to create principal");
2034
2035 let scram = SaslScramAuth::new(auth.clone());
2036
2037 let client_nonce = "test_nonce_12345";
2039 let client_first = format!("n,,n=scram_user2,r={}", client_nonce);
2040
2041 let (state, server_first) = scram
2042 .process_client_first(client_first.as_bytes(), "127.0.0.1")
2043 .expect("client-first processing should succeed");
2044
2045 let server_first_str = std::str::from_utf8(&server_first).expect("valid UTF-8");
2047 let combined_nonce: String = server_first_str
2048 .split(',')
2049 .find(|s| s.starts_with("r="))
2050 .map(|s| &s[2..])
2051 .unwrap()
2052 .to_string();
2053
2054 let ScramState::ServerFirstSent {
2056 salt, iterations, ..
2057 } = &state
2058 else {
2059 panic!("Expected ServerFirstSent state");
2060 };
2061
2062 let salted_password = compute_salted_password("wrong_password", salt, *iterations);
2063 let client_key = PasswordHash::hmac_sha256(&salted_password, b"Client Key");
2064 let stored_key = sha2::Sha256::digest(&client_key);
2065
2066 let client_first_bare = format!("n=scram_user2,r={}", client_nonce);
2067 let auth_message = format!(
2068 "{},{},c=biws,r={}",
2069 client_first_bare, server_first_str, combined_nonce
2070 );
2071
2072 let client_signature = PasswordHash::hmac_sha256(&stored_key, auth_message.as_bytes());
2073 let client_proof: Vec<u8> = client_key
2074 .iter()
2075 .zip(client_signature.iter())
2076 .map(|(k, s)| k ^ s)
2077 .collect();
2078
2079 let client_final = format!(
2080 "c=biws,r={},p={}",
2081 combined_nonce,
2082 base64_encode(&client_proof)
2083 );
2084
2085 let result = scram.process_client_final(&state, client_final.as_bytes(), "127.0.0.1");
2087 assert!(result.is_err());
2088 assert!(matches!(result, Err(AuthError::AuthenticationFailed)));
2089 }
2090
2091 #[test]
2092 fn test_scram_nonexistent_user() {
2093 let auth = Arc::new(AuthManager::new_default());
2094 let scram = SaslScramAuth::new(auth.clone());
2095
2096 let client_first = "n,,n=nonexistent_user,r=test_nonce";
2098
2099 let result = scram.process_client_first(client_first.as_bytes(), "127.0.0.1");
2101 assert!(
2102 result.is_ok(),
2103 "Should return fake server-first to prevent enumeration"
2104 );
2105
2106 let (state, server_first) = result.unwrap();
2107 let server_first_str = std::str::from_utf8(&server_first).expect("valid UTF-8");
2108
2109 assert!(server_first_str.contains("r=test_nonce"));
2111 assert!(server_first_str.contains(",s="));
2112 assert!(server_first_str.contains(",i="));
2113
2114 let combined_nonce: String = server_first_str
2116 .split(',')
2117 .find(|s| s.starts_with("r="))
2118 .map(|s| &s[2..])
2119 .unwrap()
2120 .to_string();
2121
2122 let client_final = format!("c=biws,r={},p=dW5rbm93bg==", combined_nonce);
2123 let result = scram.process_client_final(&state, client_final.as_bytes(), "127.0.0.1");
2124 assert!(result.is_err());
2125 }
2126
2127 #[test]
2128 fn test_scram_nonce_mismatch() {
2129 let auth = Arc::new(AuthManager::new_default());
2130
2131 let mut roles = HashSet::new();
2132 roles.insert("producer".to_string());
2133 auth.create_principal("scram_user3", "password", PrincipalType::User, roles)
2134 .expect("Failed to create principal");
2135
2136 let scram = SaslScramAuth::new(auth.clone());
2137
2138 let client_first = "n,,n=scram_user3,r=original_nonce";
2139 let (state, _server_first) = scram
2140 .process_client_first(client_first.as_bytes(), "127.0.0.1")
2141 .expect("client-first should succeed");
2142
2143 let client_final = "c=biws,r=tampered_nonce_plus_server,p=dW5rbm93bg==";
2145 let result = scram.process_client_final(&state, client_final.as_bytes(), "127.0.0.1");
2146 assert!(result.is_err());
2147 assert!(matches!(result, Err(AuthError::InvalidCredentials)));
2148 }
2149
2150 fn compute_salted_password(password: &str, salt: &[u8], iterations: u32) -> Vec<u8> {
2152 use hmac::{Hmac, Mac};
2153 type HmacSha256 = Hmac<sha2::Sha256>;
2154
2155 let mut result = vec![0u8; 32];
2156
2157 let mut mac =
2158 HmacSha256::new_from_slice(password.as_bytes()).expect("HMAC accepts any key length");
2159 mac.update(salt);
2160 mac.update(&1u32.to_be_bytes());
2161 let mut u = mac.finalize().into_bytes();
2162 result.copy_from_slice(&u);
2163
2164 for _ in 1..iterations {
2165 let mut mac = HmacSha256::new_from_slice(password.as_bytes())
2166 .expect("HMAC accepts any key length");
2167 mac.update(&u);
2168 u = mac.finalize().into_bytes();
2169
2170 for (r, ui) in result.iter_mut().zip(u.iter()) {
2171 *r ^= ui;
2172 }
2173 }
2174
2175 result
2176 }
2177}