Skip to main content

rivven_core/
auth.rs

1//! Authentication and Authorization (RBAC/ACL) for Rivven
2//!
3//! This module provides production-grade security for Rivven including:
4//! - SASL/PLAIN authentication (compatible with Kafka clients)
5//! - SCRAM-SHA-256 authentication (more secure, salted)
6//! - Role-Based Access Control (RBAC)
7//! - Topic/Schema-level Access Control Lists (ACLs)
8//! - Principal management (users, service accounts)
9//!
10//! ## Security Model
11//!
12//! Rivven uses a principal-based security model:
13//! - **Principal**: An authenticated identity (user, service account)
14//! - **Role**: A set of permissions (admin, producer, consumer)
15//! - **ACL**: Fine-grained access rules for specific resources
16//!
17//! ## Threat Model
18//!
19//! This implementation defends against:
20//! - Credential stuffing (rate limiting on auth failures)
21//! - Timing attacks (constant-time password comparison)
22//! - Replay attacks (nonce-based challenge-response)
23//! - Privilege escalation (strict role hierarchy)
24//! - Resource enumeration (deny by default)
25
26use parking_lot::RwLock;
27use ring::rand::{SecureRandom, SystemRandom};
28use serde::{Deserialize, Serialize};
29use sha2::{Digest, Sha256};
30use std::collections::{HashMap, HashSet};
31use std::sync::Arc;
32use std::time::{Duration, Instant};
33use thiserror::Error;
34use tracing::{debug, warn};
35
36// ============================================================================
37// Error Types
38// ============================================================================
39
40#[derive(Error, Debug, Clone)]
41pub enum AuthError {
42    #[error("Authentication failed")]
43    AuthenticationFailed,
44
45    #[error("Invalid credentials")]
46    InvalidCredentials,
47
48    #[error("Principal not found: {0}")]
49    PrincipalNotFound(String),
50
51    #[error("Principal already exists: {0}")]
52    PrincipalAlreadyExists(String),
53
54    #[error("Access denied: {0}")]
55    AccessDenied(String),
56
57    #[error("Permission denied: {principal} lacks {permission} on {resource}")]
58    PermissionDenied {
59        principal: String,
60        permission: String,
61        resource: String,
62    },
63
64    #[error("Role not found: {0}")]
65    RoleNotFound(String),
66
67    #[error("Invalid token")]
68    InvalidToken,
69
70    #[error("Token expired")]
71    TokenExpired,
72
73    #[error("Rate limited: too many authentication failures")]
74    RateLimited,
75
76    #[error("Internal error: {0}")]
77    Internal(String),
78}
79
80pub type AuthResult<T> = std::result::Result<T, AuthError>;
81
82// ============================================================================
83// Resource Types for ACLs
84// ============================================================================
85
86/// Types of resources that can have ACLs
87#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
88pub enum ResourceType {
89    /// All cluster operations
90    Cluster,
91    /// Specific topic
92    Topic(String),
93    /// Pattern-matched topics (e.g., "orders-*")
94    TopicPattern(String),
95    /// Consumer group
96    ConsumerGroup(String),
97    /// Schema subject
98    Schema(String),
99    /// Transactional ID
100    TransactionalId(String),
101}
102
103impl ResourceType {
104    /// Check if this resource matches another (for pattern matching)
105    pub fn matches(&self, other: &ResourceType) -> bool {
106        match (self, other) {
107            // Exact match
108            (a, b) if a == b => true,
109
110            // Topic pattern matching
111            (ResourceType::TopicPattern(pattern), ResourceType::Topic(name)) => {
112                Self::glob_match(pattern, name)
113            }
114            (ResourceType::Topic(name), ResourceType::TopicPattern(pattern)) => {
115                Self::glob_match(pattern, name)
116            }
117
118            _ => false,
119        }
120    }
121
122    /// Simple glob matching for patterns like "orders-*"
123    fn glob_match(pattern: &str, text: &str) -> bool {
124        if pattern == "*" {
125            return true;
126        }
127
128        if let Some(prefix) = pattern.strip_suffix('*') {
129            return text.starts_with(prefix);
130        }
131
132        if let Some(suffix) = pattern.strip_prefix('*') {
133            return text.ends_with(suffix);
134        }
135
136        // Check for middle wildcard (e.g., "pre*suf")
137        if let Some(idx) = pattern.find('*') {
138            let prefix = &pattern[..idx];
139            let suffix = &pattern[idx + 1..];
140            return text.starts_with(prefix)
141                && text.ends_with(suffix)
142                && text.len() >= prefix.len() + suffix.len();
143        }
144
145        pattern == text
146    }
147}
148
149// ============================================================================
150// Permissions
151// ============================================================================
152
153/// Operations that can be performed on resources
154#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
155pub enum Permission {
156    // Topic operations
157    Read,     // Consume from topic
158    Write,    // Produce to topic
159    Create,   // Create topic
160    Delete,   // Delete topic
161    Alter,    // Modify topic config
162    Describe, // View topic metadata
163
164    // Consumer group operations
165    GroupRead,   // Read group state
166    GroupDelete, // Delete consumer group
167
168    // Cluster operations
169    ClusterAction,   // Cluster-wide actions (rebalance, etc.)
170    IdempotentWrite, // Idempotent producer
171
172    // Admin operations
173    AlterConfigs,    // Modify broker configs
174    DescribeConfigs, // View broker configs
175
176    // Full access
177    All, // All permissions (super admin)
178}
179
180impl Permission {
181    /// Check if this permission implies another permission
182    /// A permission implies itself + any subordinate permissions
183    pub fn implies(&self, other: &Permission) -> bool {
184        // Same permission always implies itself
185        if self == other {
186            return true;
187        }
188
189        match self {
190            Permission::All => true, // All implies everything
191            // These permissions imply Describe
192            Permission::Alter | Permission::Write | Permission::Read => {
193                matches!(other, Permission::Describe)
194            }
195            _ => false,
196        }
197    }
198}
199
200// ============================================================================
201// Principal Types
202// ============================================================================
203
204/// Type of principal (for audit logging and quotas)
205#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
206pub enum PrincipalType {
207    User,
208    ServiceAccount,
209    Anonymous,
210}
211
212/// A security principal (identity)
213///
214/// # Security Note
215///
216/// This struct implements a custom Debug that redacts the password_hash field
217/// to prevent accidental leakage to logs.
218#[derive(Clone, Serialize, Deserialize)]
219pub struct Principal {
220    /// Principal name (unique identifier)
221    pub name: String,
222
223    /// Type of principal
224    pub principal_type: PrincipalType,
225
226    /// Hashed password (SCRAM-SHA-256 format)
227    pub password_hash: PasswordHash,
228
229    /// Roles assigned to this principal
230    pub roles: HashSet<String>,
231
232    /// Whether the principal is enabled
233    pub enabled: bool,
234
235    /// Optional metadata (tags, labels)
236    pub metadata: HashMap<String, String>,
237
238    /// Creation timestamp
239    pub created_at: u64,
240}
241
242impl std::fmt::Debug for Principal {
243    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244        f.debug_struct("Principal")
245            .field("name", &self.name)
246            .field("principal_type", &self.principal_type)
247            .field("password_hash", &"[REDACTED]")
248            .field("roles", &self.roles)
249            .field("enabled", &self.enabled)
250            .field("metadata", &self.metadata)
251            .field("created_at", &self.created_at)
252            .finish()
253    }
254}
255
256/// SCRAM-SHA-256 password hash with salt and iterations
257///
258/// # Security Note
259///
260/// This struct implements a custom Debug that redacts sensitive key material
261/// to prevent accidental leakage to logs.
262#[derive(Clone, Serialize, Deserialize)]
263pub struct PasswordHash {
264    /// Salt (32 bytes, base64 encoded for storage)
265    pub salt: Vec<u8>,
266    /// Number of iterations
267    pub iterations: u32,
268    /// Server key
269    pub server_key: Vec<u8>,
270    /// Stored key
271    pub stored_key: Vec<u8>,
272}
273
274impl std::fmt::Debug for PasswordHash {
275    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276        f.debug_struct("PasswordHash")
277            .field("salt", &"[REDACTED]")
278            .field("iterations", &self.iterations)
279            .field("server_key", &"[REDACTED]")
280            .field("stored_key", &"[REDACTED]")
281            .finish()
282    }
283}
284
285impl PasswordHash {
286    /// Create a new password hash from plaintext
287    pub fn new(password: &str) -> Self {
288        let rng = SystemRandom::new();
289        let mut salt = vec![0u8; 32];
290        rng.fill(&mut salt).expect("Failed to generate salt");
291
292        Self::with_salt(password, &salt, 4096)
293    }
294
295    /// Create a password hash with a specific salt (for testing/migration)
296    pub fn with_salt(password: &str, salt: &[u8], iterations: u32) -> Self {
297        // PBKDF2-HMAC-SHA256 derivation
298        let salted_password = Self::pbkdf2_sha256(password.as_bytes(), salt, iterations);
299
300        // Derive client and server keys
301        let client_key = Self::hmac_sha256(&salted_password, b"Client Key");
302        let server_key = Self::hmac_sha256(&salted_password, b"Server Key");
303
304        // Stored key = H(client_key)
305        let stored_key = Sha256::digest(&client_key).to_vec();
306
307        PasswordHash {
308            salt: salt.to_vec(),
309            iterations,
310            server_key,
311            stored_key,
312        }
313    }
314
315    /// Verify a password against this hash (constant-time comparison)
316    pub fn verify(&self, password: &str) -> bool {
317        let salted_password = Self::pbkdf2_sha256(password.as_bytes(), &self.salt, self.iterations);
318        let client_key = Self::hmac_sha256(&salted_password, b"Client Key");
319        let stored_key = Sha256::digest(&client_key);
320
321        // Constant-time comparison to prevent timing attacks
322        Self::constant_time_compare(&stored_key, &self.stored_key)
323    }
324
325    /// Constant-time comparison to prevent timing attacks
326    pub fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
327        if a.len() != b.len() {
328            return false;
329        }
330
331        // XOR all bytes and accumulate - timing is constant regardless of where mismatch occurs
332        let mut result = 0u8;
333        for (x, y) in a.iter().zip(b.iter()) {
334            result |= x ^ y;
335        }
336        result == 0
337    }
338
339    /// PBKDF2-HMAC-SHA256 key derivation
340    fn pbkdf2_sha256(password: &[u8], salt: &[u8], iterations: u32) -> Vec<u8> {
341        use hmac::{Hmac, Mac};
342        type HmacSha256 = Hmac<Sha256>;
343
344        let mut result = vec![0u8; 32];
345
346        // U1 = PRF(Password, Salt || INT(1))
347        let mut mac = HmacSha256::new_from_slice(password).expect("HMAC accepts any key length");
348        mac.update(salt);
349        mac.update(&1u32.to_be_bytes());
350        let mut u = mac.finalize().into_bytes();
351        result.copy_from_slice(&u);
352
353        // Ui = PRF(Password, Ui-1)
354        for _ in 1..iterations {
355            let mut mac =
356                HmacSha256::new_from_slice(password).expect("HMAC accepts any key length");
357            mac.update(&u);
358            u = mac.finalize().into_bytes();
359
360            for (r, ui) in result.iter_mut().zip(u.iter()) {
361                *r ^= ui;
362            }
363        }
364
365        result
366    }
367
368    /// HMAC-SHA256
369    pub fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
370        use hmac::{Hmac, Mac};
371        type HmacSha256 = Hmac<Sha256>;
372
373        let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key length");
374        mac.update(data);
375        mac.finalize().into_bytes().to_vec()
376    }
377}
378
379// ============================================================================
380// Roles
381// ============================================================================
382
383/// A role with a set of permissions
384#[derive(Debug, Clone, Serialize, Deserialize)]
385pub struct Role {
386    /// Role name
387    pub name: String,
388
389    /// Description
390    pub description: String,
391
392    /// Permissions granted by this role
393    pub permissions: HashSet<(ResourceType, Permission)>,
394
395    /// Whether this is a built-in role
396    pub builtin: bool,
397}
398
399impl Role {
400    /// Create a built-in admin role
401    pub fn admin() -> Self {
402        let mut permissions = HashSet::new();
403        permissions.insert((ResourceType::Cluster, Permission::All));
404
405        Role {
406            name: "admin".to_string(),
407            description: "Full administrative access to all resources".to_string(),
408            permissions,
409            builtin: true,
410        }
411    }
412
413    /// Create a built-in producer role
414    pub fn producer() -> Self {
415        let mut permissions = HashSet::new();
416        permissions.insert((
417            ResourceType::TopicPattern("*".to_string()),
418            Permission::Write,
419        ));
420        permissions.insert((
421            ResourceType::TopicPattern("*".to_string()),
422            Permission::Describe,
423        ));
424        permissions.insert((ResourceType::Cluster, Permission::IdempotentWrite));
425
426        Role {
427            name: "producer".to_string(),
428            description: "Can produce to all topics".to_string(),
429            permissions,
430            builtin: true,
431        }
432    }
433
434    /// Create a built-in consumer role
435    pub fn consumer() -> Self {
436        let mut permissions = HashSet::new();
437        permissions.insert((
438            ResourceType::TopicPattern("*".to_string()),
439            Permission::Read,
440        ));
441        permissions.insert((
442            ResourceType::TopicPattern("*".to_string()),
443            Permission::Describe,
444        ));
445        permissions.insert((
446            ResourceType::ConsumerGroup("*".to_string()),
447            Permission::GroupRead,
448        ));
449
450        Role {
451            name: "consumer".to_string(),
452            description: "Can consume from all topics".to_string(),
453            permissions,
454            builtin: true,
455        }
456    }
457
458    /// Create a built-in read-only role
459    pub fn read_only() -> Self {
460        let mut permissions = HashSet::new();
461        permissions.insert((
462            ResourceType::TopicPattern("*".to_string()),
463            Permission::Read,
464        ));
465        permissions.insert((
466            ResourceType::TopicPattern("*".to_string()),
467            Permission::Describe,
468        ));
469
470        Role {
471            name: "read-only".to_string(),
472            description: "Read-only access to all topics".to_string(),
473            permissions,
474            builtin: true,
475        }
476    }
477}
478
479// ============================================================================
480// Access Control List (ACL)
481// ============================================================================
482
483/// An ACL entry
484#[derive(Debug, Clone, Serialize, Deserialize)]
485pub struct AclEntry {
486    /// Principal this ACL applies to (use "*" for all)
487    pub principal: String,
488
489    /// Resource this ACL applies to
490    pub resource: ResourceType,
491
492    /// Permission granted or denied
493    pub permission: Permission,
494
495    /// Whether this is an allow or deny rule
496    pub allow: bool,
497
498    /// Host pattern (IP or hostname, "*" for all)
499    pub host: String,
500}
501
502// ============================================================================
503// Session and Token
504// ============================================================================
505
506/// An authenticated session
507#[derive(Debug, Clone)]
508pub struct AuthSession {
509    /// Session ID
510    pub id: String,
511
512    /// Authenticated principal name
513    pub principal_name: String,
514
515    /// Principal type
516    pub principal_type: PrincipalType,
517
518    /// Resolved permissions (cached for performance)
519    pub permissions: HashSet<(ResourceType, Permission)>,
520
521    /// Session creation time
522    pub created_at: Instant,
523
524    /// Session expiration time
525    pub expires_at: Instant,
526
527    /// Client IP address
528    pub client_ip: String,
529}
530
531impl AuthSession {
532    /// Check if the session has expired
533    pub fn is_expired(&self) -> bool {
534        Instant::now() >= self.expires_at
535    }
536
537    /// Check if this session has a specific permission on a resource
538    pub fn has_permission(&self, resource: &ResourceType, permission: &Permission) -> bool {
539        // Admin has all permissions
540        if self
541            .permissions
542            .contains(&(ResourceType::Cluster, Permission::All))
543        {
544            return true;
545        }
546
547        // Check direct permission
548        if self.permissions.contains(&(resource.clone(), *permission)) {
549            return true;
550        }
551
552        // Check pattern matches and permission implies
553        for (res, perm) in &self.permissions {
554            // Check if the granted resource (res) matches the requested resource
555            // AND the granted permission (perm) implies the requested permission
556            let resource_matches = res.matches(resource);
557            let permission_implies = perm.implies(permission);
558            if resource_matches && permission_implies {
559                return true;
560            }
561        }
562
563        false
564    }
565}
566
567// ============================================================================
568// Auth Manager
569// ============================================================================
570
571/// Configuration for the authentication manager
572#[derive(Debug, Clone)]
573pub struct AuthConfig {
574    /// Session timeout
575    pub session_timeout: Duration,
576
577    /// Maximum failed auth attempts before lockout
578    pub max_failed_attempts: u32,
579
580    /// Lockout duration after max failed attempts
581    pub lockout_duration: Duration,
582
583    /// Whether to require authentication (false = anonymous access allowed)
584    pub require_authentication: bool,
585
586    /// Whether to enable ACL enforcement
587    pub enable_acls: bool,
588
589    /// Default deny (true = deny unless explicitly allowed)
590    pub default_deny: bool,
591}
592
593impl Default for AuthConfig {
594    fn default() -> Self {
595        AuthConfig {
596            session_timeout: Duration::from_secs(3600), // 1 hour
597            max_failed_attempts: 5,
598            lockout_duration: Duration::from_secs(300), // 5 minutes
599            require_authentication: false,              // Default to open for dev
600            enable_acls: false,
601            default_deny: true,
602        }
603    }
604}
605
606/// Tracks failed authentication attempts for rate limiting
607struct FailedAttemptTracker {
608    attempts: HashMap<String, Vec<Instant>>,
609    lockouts: HashMap<String, Instant>,
610}
611
612impl FailedAttemptTracker {
613    fn new() -> Self {
614        Self {
615            attempts: HashMap::new(),
616            lockouts: HashMap::new(),
617        }
618    }
619
620    /// Check if an identifier is currently locked out
621    fn is_locked_out(&self, identifier: &str, lockout_duration: Duration) -> bool {
622        if let Some(lockout_time) = self.lockouts.get(identifier) {
623            if lockout_time.elapsed() < lockout_duration {
624                return true;
625            }
626        }
627        false
628    }
629
630    /// Record a failed attempt
631    fn record_failure(
632        &mut self,
633        identifier: &str,
634        max_attempts: u32,
635        lockout_duration: Duration,
636    ) -> bool {
637        let now = Instant::now();
638
639        // Clean up old lockouts
640        self.lockouts.retain(|_, t| t.elapsed() < lockout_duration);
641
642        // Get or create attempt list
643        let attempts = self.attempts.entry(identifier.to_string()).or_default();
644
645        // Remove attempts older than lockout duration
646        attempts.retain(|t| t.elapsed() < lockout_duration);
647
648        // Add this attempt
649        attempts.push(now);
650
651        // Check if we've exceeded max attempts
652        if attempts.len() >= max_attempts as usize {
653            warn!(
654                "Principal '{}' locked out after {} failed attempts",
655                identifier, max_attempts
656            );
657            self.lockouts.insert(identifier.to_string(), now);
658            return true;
659        }
660
661        false
662    }
663
664    /// Clear failures for an identifier (on successful auth)
665    fn clear_failures(&mut self, identifier: &str) {
666        self.attempts.remove(identifier);
667        self.lockouts.remove(identifier);
668    }
669}
670
671/// The main authentication and authorization manager
672pub struct AuthManager {
673    config: AuthConfig,
674
675    /// Principals (users/service accounts)
676    principals: RwLock<HashMap<String, Principal>>,
677
678    /// Roles
679    roles: RwLock<HashMap<String, Role>>,
680
681    /// ACL entries
682    acls: RwLock<Vec<AclEntry>>,
683
684    /// Active sessions
685    sessions: RwLock<HashMap<String, AuthSession>>,
686
687    /// Failed attempt tracking
688    failed_attempts: RwLock<FailedAttemptTracker>,
689
690    /// Random number generator for session IDs
691    rng: SystemRandom,
692}
693
694impl AuthManager {
695    /// Create a new authentication manager
696    pub fn new(config: AuthConfig) -> Self {
697        let manager = Self {
698            config,
699            principals: RwLock::new(HashMap::new()),
700            roles: RwLock::new(HashMap::new()),
701            acls: RwLock::new(Vec::new()),
702            sessions: RwLock::new(HashMap::new()),
703            failed_attempts: RwLock::new(FailedAttemptTracker::new()),
704            rng: SystemRandom::new(),
705        };
706
707        // Initialize built-in roles
708        manager.init_builtin_roles();
709
710        manager
711    }
712
713    /// Create with default config
714    pub fn new_default() -> Self {
715        Self::new(AuthConfig::default())
716    }
717
718    /// Create an auth manager with authentication enabled
719    pub fn with_auth_enabled() -> Self {
720        Self::new(AuthConfig {
721            require_authentication: true,
722            enable_acls: true,
723            ..Default::default()
724        })
725    }
726
727    /// Initialize built-in roles
728    fn init_builtin_roles(&self) {
729        let mut roles = self.roles.write();
730        roles.insert("admin".to_string(), Role::admin());
731        roles.insert("producer".to_string(), Role::producer());
732        roles.insert("consumer".to_string(), Role::consumer());
733        roles.insert("read-only".to_string(), Role::read_only());
734    }
735
736    // ========================================================================
737    // Principal Management
738    // ========================================================================
739
740    /// Create a new principal (user or service account)
741    pub fn create_principal(
742        &self,
743        name: &str,
744        password: &str,
745        principal_type: PrincipalType,
746        roles: HashSet<String>,
747    ) -> AuthResult<()> {
748        // Validate principal name
749        if name.is_empty() || name.len() > 255 {
750            return Err(AuthError::Internal("Invalid principal name".to_string()));
751        }
752
753        // Validate password strength (minimum requirements)
754        if password.len() < 8 {
755            return Err(AuthError::Internal(
756                "Password must be at least 8 characters".to_string(),
757            ));
758        }
759
760        // Validate roles exist
761        {
762            let role_map = self.roles.read();
763            for role in &roles {
764                if !role_map.contains_key(role) {
765                    return Err(AuthError::RoleNotFound(role.clone()));
766                }
767            }
768        }
769
770        let mut principals = self.principals.write();
771
772        if principals.contains_key(name) {
773            return Err(AuthError::PrincipalAlreadyExists(name.to_string()));
774        }
775
776        let principal = Principal {
777            name: name.to_string(),
778            principal_type,
779            password_hash: PasswordHash::new(password),
780            roles,
781            enabled: true,
782            metadata: HashMap::new(),
783            created_at: std::time::SystemTime::now()
784                .duration_since(std::time::UNIX_EPOCH)
785                .unwrap_or_default()
786                .as_secs(),
787        };
788
789        principals.insert(name.to_string(), principal);
790        debug!("Created principal: {}", name);
791
792        Ok(())
793    }
794
795    /// Delete a principal
796    pub fn delete_principal(&self, name: &str) -> AuthResult<()> {
797        let mut principals = self.principals.write();
798
799        if principals.remove(name).is_none() {
800            return Err(AuthError::PrincipalNotFound(name.to_string()));
801        }
802
803        // Also invalidate any active sessions for this principal
804        let mut sessions = self.sessions.write();
805        sessions.retain(|_, s| s.principal_name != name);
806
807        debug!("Deleted principal: {}", name);
808        Ok(())
809    }
810
811    /// Get a principal by name
812    pub fn get_principal(&self, name: &str) -> Option<Principal> {
813        self.principals.read().get(name).cloned()
814    }
815
816    /// List all principals
817    pub fn list_principals(&self) -> Vec<String> {
818        self.principals.read().keys().cloned().collect()
819    }
820
821    /// Update principal password
822    pub fn update_password(&self, name: &str, new_password: &str) -> AuthResult<()> {
823        if new_password.len() < 8 {
824            return Err(AuthError::Internal(
825                "Password must be at least 8 characters".to_string(),
826            ));
827        }
828
829        let mut principals = self.principals.write();
830
831        let principal = principals
832            .get_mut(name)
833            .ok_or_else(|| AuthError::PrincipalNotFound(name.to_string()))?;
834
835        principal.password_hash = PasswordHash::new(new_password);
836
837        // Invalidate sessions
838        let mut sessions = self.sessions.write();
839        sessions.retain(|_, s| s.principal_name != name);
840
841        debug!("Updated password for principal: {}", name);
842        Ok(())
843    }
844
845    /// Add a role to a principal
846    pub fn add_role_to_principal(&self, principal_name: &str, role_name: &str) -> AuthResult<()> {
847        // Validate role exists
848        if !self.roles.read().contains_key(role_name) {
849            return Err(AuthError::RoleNotFound(role_name.to_string()));
850        }
851
852        let mut principals = self.principals.write();
853
854        let principal = principals
855            .get_mut(principal_name)
856            .ok_or_else(|| AuthError::PrincipalNotFound(principal_name.to_string()))?;
857
858        principal.roles.insert(role_name.to_string());
859
860        debug!(
861            "Added role '{}' to principal '{}'",
862            role_name, principal_name
863        );
864        Ok(())
865    }
866
867    /// Remove a role from a principal
868    pub fn remove_role_from_principal(
869        &self,
870        principal_name: &str,
871        role_name: &str,
872    ) -> AuthResult<()> {
873        let mut principals = self.principals.write();
874
875        let principal = principals
876            .get_mut(principal_name)
877            .ok_or_else(|| AuthError::PrincipalNotFound(principal_name.to_string()))?;
878
879        principal.roles.remove(role_name);
880
881        debug!(
882            "Removed role '{}' from principal '{}'",
883            role_name, principal_name
884        );
885        Ok(())
886    }
887
888    // ========================================================================
889    // Role Management
890    // ========================================================================
891
892    /// Create a custom role
893    pub fn create_role(&self, role: Role) -> AuthResult<()> {
894        let mut roles = self.roles.write();
895
896        if roles.contains_key(&role.name) {
897            return Err(AuthError::Internal(format!(
898                "Role '{}' already exists",
899                role.name
900            )));
901        }
902
903        debug!("Created role: {}", role.name);
904        roles.insert(role.name.clone(), role);
905        Ok(())
906    }
907
908    /// Delete a custom role
909    pub fn delete_role(&self, name: &str) -> AuthResult<()> {
910        let mut roles = self.roles.write();
911
912        if let Some(role) = roles.get(name) {
913            if role.builtin {
914                return Err(AuthError::Internal(
915                    "Cannot delete built-in role".to_string(),
916                ));
917            }
918        } else {
919            return Err(AuthError::RoleNotFound(name.to_string()));
920        }
921
922        roles.remove(name);
923        debug!("Deleted role: {}", name);
924        Ok(())
925    }
926
927    /// Get a role by name
928    pub fn get_role(&self, name: &str) -> Option<Role> {
929        self.roles.read().get(name).cloned()
930    }
931
932    /// List all roles
933    pub fn list_roles(&self) -> Vec<String> {
934        self.roles.read().keys().cloned().collect()
935    }
936
937    // ========================================================================
938    // ACL Management
939    // ========================================================================
940
941    /// Add an ACL entry
942    pub fn add_acl(&self, entry: AclEntry) {
943        let mut acls = self.acls.write();
944        acls.push(entry);
945    }
946
947    /// Remove ACL entries matching criteria
948    pub fn remove_acls(&self, principal: Option<&str>, resource: Option<&ResourceType>) {
949        let mut acls = self.acls.write();
950        acls.retain(|acl| {
951            let principal_match =
952                principal.is_none_or(|p| acl.principal == p || acl.principal == "*");
953            let resource_match = resource.is_none_or(|r| &acl.resource == r);
954            !(principal_match && resource_match)
955        });
956    }
957
958    /// List ACL entries
959    pub fn list_acls(&self) -> Vec<AclEntry> {
960        self.acls.read().clone()
961    }
962
963    // ========================================================================
964    // Authentication
965    // ========================================================================
966
967    /// Authenticate a principal and create a session
968    pub fn authenticate(
969        &self,
970        username: &str,
971        password: &str,
972        client_ip: &str,
973    ) -> AuthResult<AuthSession> {
974        // Check rate limiting
975        {
976            let tracker = self.failed_attempts.read();
977            if tracker.is_locked_out(username, self.config.lockout_duration) {
978                warn!(
979                    "Authentication attempt for locked-out principal: {}",
980                    username
981                );
982                return Err(AuthError::RateLimited);
983            }
984            if tracker.is_locked_out(client_ip, self.config.lockout_duration) {
985                warn!("Authentication attempt from locked-out IP: {}", client_ip);
986                return Err(AuthError::RateLimited);
987            }
988        }
989
990        // Look up principal
991        let principal = {
992            let principals = self.principals.read();
993            principals.get(username).cloned()
994        };
995
996        let principal = match principal {
997            Some(p) if p.enabled => p,
998            Some(_) => {
999                // Disabled account - don't leak this info
1000                self.record_auth_failure(username, client_ip);
1001                return Err(AuthError::AuthenticationFailed);
1002            }
1003            None => {
1004                // Unknown principal - still do constant-time password check
1005                // to prevent timing attacks that enumerate users
1006                let dummy = PasswordHash::new("dummy");
1007                let _ = dummy.verify(password);
1008                self.record_auth_failure(username, client_ip);
1009                return Err(AuthError::AuthenticationFailed);
1010            }
1011        };
1012
1013        // Verify password (constant-time comparison)
1014        if !principal.password_hash.verify(password) {
1015            self.record_auth_failure(username, client_ip);
1016            return Err(AuthError::AuthenticationFailed);
1017        }
1018
1019        // Clear any failed attempt tracking
1020        self.failed_attempts.write().clear_failures(username);
1021        self.failed_attempts.write().clear_failures(client_ip);
1022
1023        // Build session with resolved permissions
1024        let permissions = self.resolve_permissions(&principal);
1025
1026        // Generate session ID
1027        let mut session_id = vec![0u8; 32];
1028        self.rng
1029            .fill(&mut session_id)
1030            .map_err(|_| AuthError::Internal("RNG failed".to_string()))?;
1031        let session_id = hex::encode(&session_id);
1032
1033        let now = Instant::now();
1034        let session = AuthSession {
1035            id: session_id.clone(),
1036            principal_name: principal.name.clone(),
1037            principal_type: principal.principal_type.clone(),
1038            permissions,
1039            created_at: now,
1040            expires_at: now + self.config.session_timeout,
1041            client_ip: client_ip.to_string(),
1042        };
1043
1044        // Store session
1045        self.sessions.write().insert(session_id, session.clone());
1046
1047        debug!("Authenticated principal '{}' from {}", username, client_ip);
1048        Ok(session)
1049    }
1050
1051    /// Record a failed authentication attempt
1052    fn record_auth_failure(&self, username: &str, client_ip: &str) {
1053        let mut tracker = self.failed_attempts.write();
1054        tracker.record_failure(
1055            username,
1056            self.config.max_failed_attempts,
1057            self.config.lockout_duration,
1058        );
1059        tracker.record_failure(
1060            client_ip,
1061            self.config.max_failed_attempts * 2,
1062            self.config.lockout_duration,
1063        );
1064    }
1065
1066    /// Get an active session by ID
1067    pub fn get_session(&self, session_id: &str) -> Option<AuthSession> {
1068        let sessions = self.sessions.read();
1069        sessions.get(session_id).and_then(|s| {
1070            if s.is_expired() {
1071                None
1072            } else {
1073                Some(s.clone())
1074            }
1075        })
1076    }
1077
1078    /// Invalidate a session (logout)
1079    pub fn invalidate_session(&self, session_id: &str) {
1080        self.sessions.write().remove(session_id);
1081    }
1082
1083    /// Invalidate all sessions for a principal
1084    pub fn invalidate_all_sessions(&self, principal_name: &str) {
1085        self.sessions
1086            .write()
1087            .retain(|_, s| s.principal_name != principal_name);
1088    }
1089
1090    /// Clean up expired sessions
1091    pub fn cleanup_expired_sessions(&self) {
1092        self.sessions.write().retain(|_, s| !s.is_expired());
1093    }
1094
1095    /// Create a session for a principal (used by SCRAM after successful auth)
1096    pub fn create_session(&self, principal: &Principal) -> AuthSession {
1097        let permissions = self.resolve_permissions(principal);
1098
1099        let mut session_id = vec![0u8; 32];
1100        self.rng.fill(&mut session_id).expect("RNG failed");
1101        let session_id = hex::encode(&session_id);
1102
1103        let now = Instant::now();
1104        let session = AuthSession {
1105            id: session_id.clone(),
1106            principal_name: principal.name.clone(),
1107            principal_type: principal.principal_type.clone(),
1108            permissions,
1109            created_at: now,
1110            expires_at: now + self.config.session_timeout,
1111            client_ip: "scram".to_string(),
1112        };
1113
1114        self.sessions.write().insert(session_id, session.clone());
1115        session
1116    }
1117
1118    // ========================================================================
1119    // Authorization
1120    // ========================================================================
1121
1122    /// Resolve all permissions for a principal
1123    fn resolve_permissions(&self, principal: &Principal) -> HashSet<(ResourceType, Permission)> {
1124        let mut permissions = HashSet::new();
1125
1126        let roles = self.roles.read();
1127
1128        // Collect permissions from all roles
1129        for role_name in &principal.roles {
1130            if let Some(role) = roles.get(role_name) {
1131                permissions.extend(role.permissions.iter().cloned());
1132            }
1133        }
1134
1135        permissions
1136    }
1137
1138    /// Check if a session/principal has permission on a resource
1139    pub fn authorize(
1140        &self,
1141        session: &AuthSession,
1142        resource: &ResourceType,
1143        permission: Permission,
1144        client_ip: &str,
1145    ) -> AuthResult<()> {
1146        // If auth is not required and no ACLs, allow everything
1147        if !self.config.require_authentication && !self.config.enable_acls {
1148            return Ok(());
1149        }
1150
1151        // Check session expiration
1152        if session.is_expired() {
1153            return Err(AuthError::TokenExpired);
1154        }
1155
1156        // Check role-based permissions
1157        if session.has_permission(resource, &permission) {
1158            return Ok(());
1159        }
1160
1161        // Check ACL entries
1162        if self.config.enable_acls
1163            && self.check_acls(&session.principal_name, resource, permission, client_ip)
1164        {
1165            return Ok(());
1166        }
1167
1168        // Default deny
1169        if self.config.default_deny {
1170            warn!(
1171                "Access denied: {} attempted {} on {:?} from {}",
1172                session.principal_name,
1173                format!("{:?}", permission),
1174                resource,
1175                client_ip
1176            );
1177            return Err(AuthError::PermissionDenied {
1178                principal: session.principal_name.clone(),
1179                permission: format!("{:?}", permission),
1180                resource: format!("{:?}", resource),
1181            });
1182        }
1183
1184        Ok(())
1185    }
1186
1187    /// Check ACL entries for authorization
1188    fn check_acls(
1189        &self,
1190        principal: &str,
1191        resource: &ResourceType,
1192        permission: Permission,
1193        client_ip: &str,
1194    ) -> bool {
1195        let acls = self.acls.read();
1196
1197        // Check deny rules first (deny takes precedence)
1198        for acl in acls.iter() {
1199            if !acl.allow
1200                && (acl.principal == principal || acl.principal == "*")
1201                && (acl.host == client_ip || acl.host == "*")
1202                && acl.resource.matches(resource)
1203                && (acl.permission == permission || acl.permission == Permission::All)
1204            {
1205                return false; // Explicit deny
1206            }
1207        }
1208
1209        // Check allow rules
1210        for acl in acls.iter() {
1211            if acl.allow
1212                && (acl.principal == principal || acl.principal == "*")
1213                && (acl.host == client_ip || acl.host == "*")
1214                && acl.resource.matches(resource)
1215                && (acl.permission == permission || acl.permission == Permission::All)
1216            {
1217                return true;
1218            }
1219        }
1220
1221        false
1222    }
1223
1224    /// Simple authorization check without session (for internal use)
1225    #[allow(unused_variables)]
1226    pub fn authorize_anonymous(
1227        &self,
1228        resource: &ResourceType,
1229        permission: Permission,
1230    ) -> AuthResult<()> {
1231        if !self.config.require_authentication {
1232            return Ok(());
1233        }
1234
1235        Err(AuthError::AuthenticationFailed)
1236    }
1237}
1238
1239// ============================================================================
1240// SASL/PLAIN Support (Kafka-compatible)
1241// ============================================================================
1242
1243/// SASL/PLAIN authentication handler
1244pub struct SaslPlainAuth {
1245    auth_manager: Arc<AuthManager>,
1246}
1247
1248impl SaslPlainAuth {
1249    pub fn new(auth_manager: Arc<AuthManager>) -> Self {
1250        Self { auth_manager }
1251    }
1252
1253    /// Parse and authenticate a SASL/PLAIN request
1254    /// Format: \[authzid\] NUL authcid NUL passwd
1255    pub fn authenticate(&self, sasl_bytes: &[u8], client_ip: &str) -> AuthResult<AuthSession> {
1256        // Parse SASL/PLAIN format: [authzid] \0 authcid \0 password
1257        let parts: Vec<&[u8]> = sasl_bytes.split(|&b| b == 0).collect();
1258
1259        if parts.len() < 2 {
1260            return Err(AuthError::InvalidCredentials);
1261        }
1262
1263        // Handle both 2-part (authcid, passwd) and 3-part (authzid, authcid, passwd)
1264        let (username, password) = if parts.len() == 2 {
1265            (
1266                std::str::from_utf8(parts[0]).map_err(|_| AuthError::InvalidCredentials)?,
1267                std::str::from_utf8(parts[1]).map_err(|_| AuthError::InvalidCredentials)?,
1268            )
1269        } else {
1270            // 3-part format - authzid is ignored, use authcid
1271            (
1272                std::str::from_utf8(parts[1]).map_err(|_| AuthError::InvalidCredentials)?,
1273                std::str::from_utf8(parts[2]).map_err(|_| AuthError::InvalidCredentials)?,
1274            )
1275        };
1276
1277        self.auth_manager
1278            .authenticate(username, password, client_ip)
1279    }
1280}
1281
1282// ============================================================================
1283// SCRAM-SHA-256 Support (RFC 5802 / RFC 7677)
1284// ============================================================================
1285
1286/// SCRAM-SHA-256 authentication state machine
1287///
1288/// Implements the full SCRAM protocol for secure password-based authentication.
1289/// This is significantly more secure than PLAIN because:
1290/// 1. The password is never sent over the wire (even encrypted)
1291/// 2. The server stores derived keys, not the password
1292/// 3. Mutual authentication (server proves it knows the password too)
1293/// 4. Protection against replay attacks via nonces
1294#[derive(Debug, Clone)]
1295pub enum ScramState {
1296    /// Waiting for client-first-message
1297    Initial,
1298    /// Waiting for client-final-message
1299    ServerFirstSent {
1300        username: String,
1301        client_nonce: String,
1302        server_nonce: String,
1303        salt: Vec<u8>,
1304        iterations: u32,
1305        auth_message: String,
1306    },
1307    /// Authentication complete (success or failure pending verification)
1308    Complete,
1309}
1310
1311/// SCRAM-SHA-256 authentication handler
1312pub struct SaslScramAuth {
1313    auth_manager: Arc<AuthManager>,
1314}
1315
1316impl SaslScramAuth {
1317    pub fn new(auth_manager: Arc<AuthManager>) -> Self {
1318        Self { auth_manager }
1319    }
1320
1321    /// Process client-first-message and return server-first-message
1322    ///
1323    /// Client-first-message format: `n,,n=<username>,r=<client-nonce>`
1324    /// Server-first-message format: `r=<combined-nonce>,s=<salt>,i=<iterations>`
1325    pub fn process_client_first(
1326        &self,
1327        client_first: &[u8],
1328        client_ip: &str,
1329    ) -> AuthResult<(ScramState, Vec<u8>)> {
1330        let client_first_str =
1331            std::str::from_utf8(client_first).map_err(|_| AuthError::InvalidCredentials)?;
1332
1333        // Parse client-first-message
1334        // Format: gs2-header,client-first-message-bare
1335        // gs2-header: n,, (no channel binding)
1336        // client-first-message-bare: n=<user>,r=<nonce>
1337
1338        let parts: Vec<&str> = client_first_str.splitn(3, ',').collect();
1339        if parts.len() < 3 {
1340            return Err(AuthError::InvalidCredentials);
1341        }
1342
1343        // Skip gs2-header (parts[0] and parts[1])
1344        let client_first_bare = if parts[0] == "n" || parts[0] == "y" || parts[0] == "p" {
1345            // gs2-header present, skip first two parts
1346            &client_first_str[parts[0].len() + 1 + parts[1].len() + 1..]
1347        } else {
1348            // No gs2-header, message is just client-first-message-bare
1349            client_first_str
1350        };
1351
1352        // Parse client-first-message-bare
1353        let mut username = None;
1354        let mut client_nonce = None;
1355
1356        for attr in client_first_bare.split(',') {
1357            if let Some(value) = attr.strip_prefix("n=") {
1358                username = Some(Self::unescape_username(value));
1359            } else if let Some(value) = attr.strip_prefix("r=") {
1360                client_nonce = Some(value.to_string());
1361            }
1362        }
1363
1364        let username = username.ok_or(AuthError::InvalidCredentials)?;
1365        let client_nonce = client_nonce.ok_or(AuthError::InvalidCredentials)?;
1366
1367        // Look up principal to get salt and iterations
1368        let (salt, iterations) = match self.auth_manager.get_principal(&username) {
1369            Some(principal) => (
1370                principal.password_hash.salt.clone(),
1371                principal.password_hash.iterations,
1372            ),
1373            None => {
1374                // User not found - generate fake salt to prevent enumeration
1375                // Still continue with the protocol to not leak timing info
1376                warn!(
1377                    "SCRAM auth for unknown user '{}' from {}",
1378                    username, client_ip
1379                );
1380                let rng = SystemRandom::new();
1381                let mut fake_salt = vec![0u8; 32];
1382                rng.fill(&mut fake_salt).expect("Failed to generate salt");
1383                (fake_salt, 4096)
1384            }
1385        };
1386
1387        // Generate server nonce (random bytes, base64 encoded)
1388        let rng = SystemRandom::new();
1389        let mut server_nonce_bytes = vec![0u8; 24];
1390        rng.fill(&mut server_nonce_bytes)
1391            .expect("Failed to generate nonce");
1392        let server_nonce = base64_encode(&server_nonce_bytes);
1393        let combined_nonce = format!("{}{}", client_nonce, server_nonce);
1394
1395        // Build server-first-message
1396        let salt_b64 = base64_encode(&salt);
1397        let server_first = format!("r={},s={},i={}", combined_nonce, salt_b64, iterations);
1398
1399        // Store auth message for later verification
1400        let auth_message = format!(
1401            "{},{},c=biws,r={}",
1402            client_first_bare, server_first, combined_nonce
1403        );
1404
1405        let state = ScramState::ServerFirstSent {
1406            username,
1407            client_nonce,
1408            server_nonce,
1409            salt,
1410            iterations,
1411            auth_message,
1412        };
1413
1414        Ok((state, server_first.into_bytes()))
1415    }
1416
1417    /// Process client-final-message and return server-final-message
1418    ///
1419    /// Client-final-message format: `c=<channel-binding>,r=<nonce>,p=<proof>`
1420    /// Server-final-message format: `v=<verifier>` (on success) or `e=<error>`
1421    pub fn process_client_final(
1422        &self,
1423        state: &ScramState,
1424        client_final: &[u8],
1425        client_ip: &str,
1426    ) -> AuthResult<(AuthSession, Vec<u8>)> {
1427        let ScramState::ServerFirstSent {
1428            username,
1429            client_nonce,
1430            server_nonce,
1431            salt: _,       // Not needed for verification, stored in principal
1432            iterations: _, // Not needed for verification, stored in principal
1433            auth_message,
1434        } = state
1435        else {
1436            return Err(AuthError::Internal("Invalid SCRAM state".to_string()));
1437        };
1438
1439        let client_final_str =
1440            std::str::from_utf8(client_final).map_err(|_| AuthError::InvalidCredentials)?;
1441
1442        // Parse client-final-message
1443        let mut channel_binding = None;
1444        let mut nonce = None;
1445        let mut proof = None;
1446
1447        for attr in client_final_str.split(',') {
1448            if let Some(value) = attr.strip_prefix("c=") {
1449                channel_binding = Some(value.to_string());
1450            } else if let Some(value) = attr.strip_prefix("r=") {
1451                nonce = Some(value.to_string());
1452            } else if let Some(value) = attr.strip_prefix("p=") {
1453                proof = Some(value.to_string());
1454            }
1455        }
1456
1457        let _channel_binding = channel_binding.ok_or(AuthError::InvalidCredentials)?;
1458        let nonce = nonce.ok_or(AuthError::InvalidCredentials)?;
1459        let proof_b64 = proof.ok_or(AuthError::InvalidCredentials)?;
1460
1461        // Verify nonce
1462        let expected_nonce = format!("{}{}", client_nonce, server_nonce);
1463        if nonce != expected_nonce {
1464            warn!("SCRAM nonce mismatch for '{}' from {}", username, client_ip);
1465            return Err(AuthError::InvalidCredentials);
1466        }
1467
1468        // Get principal
1469        let principal = self
1470            .auth_manager
1471            .get_principal(username)
1472            .ok_or(AuthError::AuthenticationFailed)?;
1473
1474        // Verify client proof
1475        // ClientProof = ClientKey XOR ClientSignature
1476        // ClientSignature = HMAC(StoredKey, AuthMessage)
1477        // We need to verify: H(ClientKey) == StoredKey
1478
1479        let client_proof = base64_decode(&proof_b64).map_err(|_| AuthError::InvalidCredentials)?;
1480
1481        // Compute expected client signature
1482        let client_signature =
1483            PasswordHash::hmac_sha256(&principal.password_hash.stored_key, auth_message.as_bytes());
1484
1485        // Recover ClientKey = ClientProof XOR ClientSignature
1486        if client_proof.len() != client_signature.len() {
1487            return Err(AuthError::InvalidCredentials);
1488        }
1489
1490        let client_key: Vec<u8> = client_proof
1491            .iter()
1492            .zip(client_signature.iter())
1493            .map(|(p, s)| p ^ s)
1494            .collect();
1495
1496        // Verify: H(ClientKey) == StoredKey (constant-time comparison)
1497        let computed_stored_key = Sha256::digest(&client_key);
1498        if !PasswordHash::constant_time_compare(
1499            &computed_stored_key,
1500            &principal.password_hash.stored_key,
1501        ) {
1502            warn!(
1503                "SCRAM authentication failed for '{}' from {}",
1504                username, client_ip
1505            );
1506            return Err(AuthError::AuthenticationFailed);
1507        }
1508
1509        // Compute server signature for mutual authentication
1510        let server_signature =
1511            PasswordHash::hmac_sha256(&principal.password_hash.server_key, auth_message.as_bytes());
1512        let server_final = format!("v={}", base64_encode(&server_signature));
1513
1514        // Create session
1515        let session = self.auth_manager.create_session(&principal);
1516        debug!(
1517            "SCRAM authentication successful for '{}' from {}",
1518            username, client_ip
1519        );
1520
1521        Ok((session, server_final.into_bytes()))
1522    }
1523
1524    /// Unescape SCRAM username (=2C -> , and =3D -> =)
1525    fn unescape_username(s: &str) -> String {
1526        s.replace("=2C", ",").replace("=3D", "=")
1527    }
1528}
1529
1530/// Base64 encode (standard alphabet)
1531fn base64_encode(data: &[u8]) -> String {
1532    use base64::{engine::general_purpose::STANDARD, Engine as _};
1533    STANDARD.encode(data)
1534}
1535
1536/// Base64 decode (standard alphabet)
1537fn base64_decode(s: &str) -> Result<Vec<u8>, base64::DecodeError> {
1538    use base64::{engine::general_purpose::STANDARD, Engine as _};
1539    STANDARD.decode(s)
1540}
1541
1542// ============================================================================
1543// Tests
1544// ============================================================================
1545
1546#[cfg(test)]
1547mod tests {
1548    use super::*;
1549
1550    #[test]
1551    fn test_password_hash_verify() {
1552        let hash = PasswordHash::new("test_password_123");
1553        assert!(hash.verify("test_password_123"));
1554        assert!(!hash.verify("wrong_password"));
1555        assert!(!hash.verify(""));
1556        assert!(!hash.verify("test_password_12")); // Off by one
1557    }
1558
1559    #[test]
1560    fn test_password_hash_timing_attack_resistant() {
1561        // Both wrong passwords should take similar time
1562        // (This is more of a design assertion than a precise timing test)
1563        let hash = PasswordHash::new("correct_password");
1564
1565        // Wrong but similar length
1566        assert!(!hash.verify("wrong_password"));
1567
1568        // Wrong and very different
1569        assert!(!hash.verify("x"));
1570
1571        // Both should still return false (constant-time)
1572    }
1573
1574    #[test]
1575    fn test_create_principal() {
1576        let auth = AuthManager::new_default();
1577
1578        let mut roles = HashSet::new();
1579        roles.insert("producer".to_string());
1580
1581        auth.create_principal(
1582            "alice",
1583            "secure_pass_123",
1584            PrincipalType::User,
1585            roles.clone(),
1586        )
1587        .expect("Failed to create principal");
1588
1589        // Duplicate should fail
1590        assert!(auth
1591            .create_principal("alice", "other_pass", PrincipalType::User, roles.clone())
1592            .is_err());
1593
1594        // Verify principal exists
1595        let principal = auth.get_principal("alice").expect("Principal not found");
1596        assert_eq!(principal.name, "alice");
1597        assert!(principal.roles.contains("producer"));
1598    }
1599
1600    #[test]
1601    fn test_authentication_success() {
1602        let auth = AuthManager::new_default();
1603
1604        let mut roles = HashSet::new();
1605        roles.insert("producer".to_string());
1606
1607        auth.create_principal("bob", "bob_password", PrincipalType::User, roles)
1608            .unwrap();
1609
1610        let session = auth
1611            .authenticate("bob", "bob_password", "127.0.0.1")
1612            .expect("Authentication should succeed");
1613
1614        assert_eq!(session.principal_name, "bob");
1615        assert!(!session.is_expired());
1616    }
1617
1618    #[test]
1619    fn test_authentication_failure() {
1620        let auth = AuthManager::new_default();
1621
1622        let mut roles = HashSet::new();
1623        roles.insert("producer".to_string());
1624
1625        auth.create_principal("charlie", "correct_password", PrincipalType::User, roles)
1626            .unwrap();
1627
1628        // Wrong password
1629        let result = auth.authenticate("charlie", "wrong_password", "127.0.0.1");
1630        assert!(matches!(result, Err(AuthError::AuthenticationFailed)));
1631
1632        // Unknown user
1633        let result = auth.authenticate("unknown", "password", "127.0.0.1");
1634        assert!(matches!(result, Err(AuthError::AuthenticationFailed)));
1635    }
1636
1637    #[test]
1638    fn test_rate_limiting() {
1639        let config = AuthConfig {
1640            max_failed_attempts: 3,
1641            lockout_duration: Duration::from_secs(1),
1642            ..Default::default()
1643        };
1644        let auth = AuthManager::new(config);
1645
1646        let mut roles = HashSet::new();
1647        roles.insert("consumer".to_string());
1648        auth.create_principal("eve", "password", PrincipalType::User, roles)
1649            .unwrap();
1650
1651        // Fail 3 times
1652        for _ in 0..3 {
1653            let _ = auth.authenticate("eve", "wrong", "192.168.1.1");
1654        }
1655
1656        // Now should be rate limited
1657        let result = auth.authenticate("eve", "password", "192.168.1.1");
1658        assert!(matches!(result, Err(AuthError::RateLimited)));
1659
1660        // Wait for lockout to expire
1661        std::thread::sleep(Duration::from_millis(1100));
1662
1663        // Should work now
1664        let result = auth.authenticate("eve", "password", "192.168.1.1");
1665        assert!(result.is_ok());
1666    }
1667
1668    #[test]
1669    fn test_role_permissions() {
1670        let auth = AuthManager::with_auth_enabled();
1671
1672        let mut roles = HashSet::new();
1673        roles.insert("producer".to_string());
1674        auth.create_principal("producer_user", "password", PrincipalType::User, roles)
1675            .unwrap();
1676
1677        let session = auth
1678            .authenticate("producer_user", "password", "127.0.0.1")
1679            .unwrap();
1680
1681        // Producer should have write permission on topics
1682        assert!(session.has_permission(
1683            &ResourceType::Topic("orders".to_string()),
1684            &Permission::Write
1685        ));
1686
1687        // Producer should not have delete permission
1688        assert!(!session.has_permission(
1689            &ResourceType::Topic("orders".to_string()),
1690            &Permission::Delete
1691        ));
1692    }
1693
1694    #[test]
1695    fn test_admin_has_all_permissions() {
1696        let auth = AuthManager::with_auth_enabled();
1697
1698        let mut roles = HashSet::new();
1699        roles.insert("admin".to_string());
1700        auth.create_principal("admin_user", "admin_pass", PrincipalType::User, roles)
1701            .unwrap();
1702
1703        let session = auth
1704            .authenticate("admin_user", "admin_pass", "127.0.0.1")
1705            .unwrap();
1706
1707        // Admin should have all permissions
1708        assert!(session.has_permission(&ResourceType::Cluster, &Permission::All));
1709        assert!(session.has_permission(
1710            &ResourceType::Topic("any_topic".to_string()),
1711            &Permission::Delete
1712        ));
1713    }
1714
1715    #[test]
1716    fn test_resource_pattern_matching() {
1717        assert!(ResourceType::TopicPattern("*".to_string())
1718            .matches(&ResourceType::Topic("anything".to_string())));
1719
1720        assert!(ResourceType::TopicPattern("orders-*".to_string())
1721            .matches(&ResourceType::Topic("orders-us".to_string())));
1722
1723        assert!(ResourceType::TopicPattern("orders-*".to_string())
1724            .matches(&ResourceType::Topic("orders-eu".to_string())));
1725
1726        assert!(!ResourceType::TopicPattern("orders-*".to_string())
1727            .matches(&ResourceType::Topic("events-us".to_string())));
1728    }
1729
1730    #[test]
1731    fn test_acl_enforcement() {
1732        let auth = AuthManager::new(AuthConfig {
1733            require_authentication: true,
1734            enable_acls: true,
1735            default_deny: true,
1736            ..Default::default()
1737        });
1738
1739        let mut roles = HashSet::new();
1740        roles.insert("read-only".to_string());
1741        auth.create_principal("reader", "password", PrincipalType::User, roles)
1742            .unwrap();
1743
1744        // Add ACL allowing write to specific topic
1745        auth.add_acl(AclEntry {
1746            principal: "reader".to_string(),
1747            resource: ResourceType::Topic("special-topic".to_string()),
1748            permission: Permission::Write,
1749            allow: true,
1750            host: "*".to_string(),
1751        });
1752
1753        let session = auth
1754            .authenticate("reader", "password", "127.0.0.1")
1755            .unwrap();
1756
1757        // Should be able to write to special-topic via ACL
1758        let result = auth.authorize(
1759            &session,
1760            &ResourceType::Topic("special-topic".to_string()),
1761            Permission::Write,
1762            "127.0.0.1",
1763        );
1764        assert!(result.is_ok());
1765
1766        // Should NOT be able to write to other topics
1767        let result = auth.authorize(
1768            &session,
1769            &ResourceType::Topic("other-topic".to_string()),
1770            Permission::Write,
1771            "127.0.0.1",
1772        );
1773        assert!(result.is_err());
1774    }
1775
1776    #[test]
1777    fn test_sasl_plain_authentication() {
1778        let auth = Arc::new(AuthManager::new_default());
1779
1780        let mut roles = HashSet::new();
1781        roles.insert("producer".to_string());
1782        auth.create_principal("sasl_user", "sasl_password", PrincipalType::User, roles)
1783            .unwrap();
1784
1785        let sasl = SaslPlainAuth::new(auth);
1786
1787        // Test 2-part format: username\0password
1788        let two_part = b"sasl_user\0sasl_password";
1789        let result = sasl.authenticate(two_part, "127.0.0.1");
1790        assert!(result.is_ok());
1791
1792        // Test 3-part format: authzid\0username\0password
1793        let three_part = b"\0sasl_user\0sasl_password";
1794        let result = sasl.authenticate(three_part, "127.0.0.1");
1795        assert!(result.is_ok());
1796    }
1797
1798    #[test]
1799    fn test_session_expiration() {
1800        let config = AuthConfig {
1801            session_timeout: Duration::from_millis(100),
1802            ..Default::default()
1803        };
1804        let auth = AuthManager::new(config);
1805
1806        let mut roles = HashSet::new();
1807        roles.insert("producer".to_string());
1808        auth.create_principal("expiring", "password", PrincipalType::User, roles)
1809            .unwrap();
1810
1811        let session = auth
1812            .authenticate("expiring", "password", "127.0.0.1")
1813            .unwrap();
1814        assert!(!session.is_expired());
1815
1816        // Wait for session to expire
1817        std::thread::sleep(Duration::from_millis(150));
1818
1819        // Session should be expired
1820        let session = AuthSession {
1821            expires_at: session.expires_at,
1822            ..session
1823        };
1824        assert!(session.is_expired());
1825    }
1826
1827    #[test]
1828    fn test_delete_principal_invalidates_sessions() {
1829        let auth = AuthManager::new_default();
1830
1831        let mut roles = HashSet::new();
1832        roles.insert("producer".to_string());
1833        auth.create_principal("deleteme", "password", PrincipalType::User, roles)
1834            .unwrap();
1835
1836        let session = auth
1837            .authenticate("deleteme", "password", "127.0.0.1")
1838            .unwrap();
1839
1840        // Session should exist
1841        assert!(auth.get_session(&session.id).is_some());
1842
1843        // Delete principal
1844        auth.delete_principal("deleteme").unwrap();
1845
1846        // Session should be gone
1847        assert!(auth.get_session(&session.id).is_none());
1848    }
1849
1850    #[test]
1851    fn test_disabled_principal_cannot_authenticate() {
1852        let auth = AuthManager::new_default();
1853
1854        let mut roles = HashSet::new();
1855        roles.insert("producer".to_string());
1856        auth.create_principal("disabled_user", "password", PrincipalType::User, roles)
1857            .unwrap();
1858
1859        // Disable the principal
1860        {
1861            let mut principals = auth.principals.write();
1862            if let Some(p) = principals.get_mut("disabled_user") {
1863                p.enabled = false;
1864            }
1865        }
1866
1867        // Should fail to authenticate
1868        let result = auth.authenticate("disabled_user", "password", "127.0.0.1");
1869        assert!(matches!(result, Err(AuthError::AuthenticationFailed)));
1870    }
1871
1872    #[test]
1873    fn test_password_hash_debug_redacts_sensitive_data() {
1874        let hash = PasswordHash::new("super_secret_password");
1875        let debug_output = format!("{:?}", hash);
1876
1877        // Should contain REDACTED markers
1878        assert!(
1879            debug_output.contains("[REDACTED]"),
1880            "Debug output should contain [REDACTED]"
1881        );
1882
1883        // Should NOT contain actual salt or key material
1884        // Salt and keys are binary data, but let's ensure no suspicious patterns
1885        assert!(
1886            !debug_output.contains("super_secret_password"),
1887            "Debug output should not contain password"
1888        );
1889
1890        // Should show iterations (not sensitive)
1891        assert!(
1892            debug_output.contains("iterations"),
1893            "Debug output should show iterations field"
1894        );
1895    }
1896
1897    #[test]
1898    fn test_principal_debug_redacts_password_hash() {
1899        let principal = Principal {
1900            name: "test_user".to_string(),
1901            principal_type: PrincipalType::User,
1902            password_hash: PasswordHash::new("secret_password"),
1903            roles: HashSet::from(["admin".to_string()]),
1904            enabled: true,
1905            metadata: HashMap::new(),
1906            created_at: 1234567890,
1907        };
1908
1909        let debug_output = format!("{:?}", principal);
1910
1911        // Should contain REDACTED for password_hash
1912        assert!(
1913            debug_output.contains("[REDACTED]"),
1914            "Debug output should contain [REDACTED]: {}",
1915            debug_output
1916        );
1917
1918        // Should still show non-sensitive fields
1919        assert!(
1920            debug_output.contains("test_user"),
1921            "Debug output should show name"
1922        );
1923        assert!(
1924            debug_output.contains("admin"),
1925            "Debug output should show roles"
1926        );
1927    }
1928
1929    // ========================================================================
1930    // SCRAM-SHA-256 Tests
1931    // ========================================================================
1932
1933    #[test]
1934    fn test_scram_full_handshake() {
1935        use sha2::{Digest, Sha256};
1936
1937        let auth = Arc::new(AuthManager::new_default());
1938
1939        // Create a user
1940        let mut roles = HashSet::new();
1941        roles.insert("producer".to_string());
1942        auth.create_principal("scram_user", "scram_password", PrincipalType::User, roles)
1943            .expect("Failed to create principal");
1944
1945        let scram = SaslScramAuth::new(auth.clone());
1946
1947        // Step 1: Client sends client-first-message
1948        let client_nonce = "rOprNGfwEbeRWgbNEkqO";
1949        let client_first = format!("n,,n=scram_user,r={}", client_nonce);
1950
1951        let (state, server_first) = scram
1952            .process_client_first(client_first.as_bytes(), "127.0.0.1")
1953            .expect("client-first processing should succeed");
1954
1955        // Verify server-first-message format
1956        let server_first_str = std::str::from_utf8(&server_first).expect("valid UTF-8");
1957        assert!(server_first_str.starts_with(&format!("r={}", client_nonce)));
1958        assert!(server_first_str.contains(",s="));
1959        assert!(server_first_str.contains(",i="));
1960
1961        // Parse server-first-message to build client-final
1962        let ScramState::ServerFirstSent {
1963            username: _,
1964            client_nonce: _,
1965            server_nonce: _,
1966            salt,
1967            iterations,
1968            auth_message: _,
1969        } = &state
1970        else {
1971            panic!("Expected ServerFirstSent state");
1972        };
1973
1974        // Step 2: Client computes proof and sends client-final-message
1975        // ClientProof = ClientKey XOR ClientSignature
1976        let salted_password = compute_salted_password("scram_password", salt, *iterations);
1977        let client_key = PasswordHash::hmac_sha256(&salted_password, b"Client Key");
1978        let stored_key = Sha256::digest(&client_key);
1979
1980        // Build auth message
1981        let client_first_bare = format!("n=scram_user,r={}", client_nonce);
1982        let combined_nonce: String = server_first_str
1983            .split(',')
1984            .find(|s| s.starts_with("r="))
1985            .map(|s| &s[2..])
1986            .unwrap()
1987            .to_string();
1988
1989        let auth_message = format!(
1990            "{},{},c=biws,r={}",
1991            client_first_bare, server_first_str, combined_nonce
1992        );
1993
1994        let client_signature = PasswordHash::hmac_sha256(&stored_key, auth_message.as_bytes());
1995        let client_proof: Vec<u8> = client_key
1996            .iter()
1997            .zip(client_signature.iter())
1998            .map(|(k, s)| k ^ s)
1999            .collect();
2000
2001        let client_final = format!(
2002            "c=biws,r={},p={}",
2003            combined_nonce,
2004            base64_encode(&client_proof)
2005        );
2006
2007        // Step 3: Server verifies and responds
2008        let (session, server_final) = scram
2009            .process_client_final(&state, client_final.as_bytes(), "127.0.0.1")
2010            .expect("client-final processing should succeed");
2011
2012        // Verify session was created
2013        assert_eq!(session.principal_name, "scram_user");
2014        assert!(!session.is_expired());
2015
2016        // Verify server-final-message (mutual authentication)
2017        let server_final_str = std::str::from_utf8(&server_final).expect("valid UTF-8");
2018        assert!(server_final_str.starts_with("v="));
2019    }
2020
2021    #[test]
2022    fn test_scram_wrong_password() {
2023        let auth = Arc::new(AuthManager::new_default());
2024
2025        let mut roles = HashSet::new();
2026        roles.insert("producer".to_string());
2027        auth.create_principal(
2028            "scram_user2",
2029            "correct_password",
2030            PrincipalType::User,
2031            roles,
2032        )
2033        .expect("Failed to create principal");
2034
2035        let scram = SaslScramAuth::new(auth.clone());
2036
2037        // Client-first with correct username
2038        let client_nonce = "test_nonce_12345";
2039        let client_first = format!("n,,n=scram_user2,r={}", client_nonce);
2040
2041        let (state, server_first) = scram
2042            .process_client_first(client_first.as_bytes(), "127.0.0.1")
2043            .expect("client-first processing should succeed");
2044
2045        // Parse server response
2046        let server_first_str = std::str::from_utf8(&server_first).expect("valid UTF-8");
2047        let combined_nonce: String = server_first_str
2048            .split(',')
2049            .find(|s| s.starts_with("r="))
2050            .map(|s| &s[2..])
2051            .unwrap()
2052            .to_string();
2053
2054        // Compute proof with WRONG password
2055        let ScramState::ServerFirstSent {
2056            salt, iterations, ..
2057        } = &state
2058        else {
2059            panic!("Expected ServerFirstSent state");
2060        };
2061
2062        let salted_password = compute_salted_password("wrong_password", salt, *iterations);
2063        let client_key = PasswordHash::hmac_sha256(&salted_password, b"Client Key");
2064        let stored_key = sha2::Sha256::digest(&client_key);
2065
2066        let client_first_bare = format!("n=scram_user2,r={}", client_nonce);
2067        let auth_message = format!(
2068            "{},{},c=biws,r={}",
2069            client_first_bare, server_first_str, combined_nonce
2070        );
2071
2072        let client_signature = PasswordHash::hmac_sha256(&stored_key, auth_message.as_bytes());
2073        let client_proof: Vec<u8> = client_key
2074            .iter()
2075            .zip(client_signature.iter())
2076            .map(|(k, s)| k ^ s)
2077            .collect();
2078
2079        let client_final = format!(
2080            "c=biws,r={},p={}",
2081            combined_nonce,
2082            base64_encode(&client_proof)
2083        );
2084
2085        // Should fail
2086        let result = scram.process_client_final(&state, client_final.as_bytes(), "127.0.0.1");
2087        assert!(result.is_err());
2088        assert!(matches!(result, Err(AuthError::AuthenticationFailed)));
2089    }
2090
2091    #[test]
2092    fn test_scram_nonexistent_user() {
2093        let auth = Arc::new(AuthManager::new_default());
2094        let scram = SaslScramAuth::new(auth.clone());
2095
2096        // Client-first for nonexistent user
2097        let client_first = "n,,n=nonexistent_user,r=test_nonce";
2098
2099        // Should still return a server-first (to prevent enumeration)
2100        let result = scram.process_client_first(client_first.as_bytes(), "127.0.0.1");
2101        assert!(
2102            result.is_ok(),
2103            "Should return fake server-first to prevent enumeration"
2104        );
2105
2106        let (state, server_first) = result.unwrap();
2107        let server_first_str = std::str::from_utf8(&server_first).expect("valid UTF-8");
2108
2109        // Should have valid format (fake salt/iterations)
2110        assert!(server_first_str.contains("r=test_nonce"));
2111        assert!(server_first_str.contains(",s="));
2112        assert!(server_first_str.contains(",i="));
2113
2114        // Final step should fail
2115        let combined_nonce: String = server_first_str
2116            .split(',')
2117            .find(|s| s.starts_with("r="))
2118            .map(|s| &s[2..])
2119            .unwrap()
2120            .to_string();
2121
2122        let client_final = format!("c=biws,r={},p=dW5rbm93bg==", combined_nonce);
2123        let result = scram.process_client_final(&state, client_final.as_bytes(), "127.0.0.1");
2124        assert!(result.is_err());
2125    }
2126
2127    #[test]
2128    fn test_scram_nonce_mismatch() {
2129        let auth = Arc::new(AuthManager::new_default());
2130
2131        let mut roles = HashSet::new();
2132        roles.insert("producer".to_string());
2133        auth.create_principal("scram_user3", "password", PrincipalType::User, roles)
2134            .expect("Failed to create principal");
2135
2136        let scram = SaslScramAuth::new(auth.clone());
2137
2138        let client_first = "n,,n=scram_user3,r=original_nonce";
2139        let (state, _server_first) = scram
2140            .process_client_first(client_first.as_bytes(), "127.0.0.1")
2141            .expect("client-first should succeed");
2142
2143        // Client-final with different nonce prefix (attack attempt)
2144        let client_final = "c=biws,r=tampered_nonce_plus_server,p=dW5rbm93bg==";
2145        let result = scram.process_client_final(&state, client_final.as_bytes(), "127.0.0.1");
2146        assert!(result.is_err());
2147        assert!(matches!(result, Err(AuthError::InvalidCredentials)));
2148    }
2149
2150    /// Helper: Compute salted password (PBKDF2)
2151    fn compute_salted_password(password: &str, salt: &[u8], iterations: u32) -> Vec<u8> {
2152        use hmac::{Hmac, Mac};
2153        type HmacSha256 = Hmac<sha2::Sha256>;
2154
2155        let mut result = vec![0u8; 32];
2156
2157        let mut mac =
2158            HmacSha256::new_from_slice(password.as_bytes()).expect("HMAC accepts any key length");
2159        mac.update(salt);
2160        mac.update(&1u32.to_be_bytes());
2161        let mut u = mac.finalize().into_bytes();
2162        result.copy_from_slice(&u);
2163
2164        for _ in 1..iterations {
2165            let mut mac = HmacSha256::new_from_slice(password.as_bytes())
2166                .expect("HMAC accepts any key length");
2167            mac.update(&u);
2168            u = mac.finalize().into_bytes();
2169
2170            for (r, ui) in result.iter_mut().zip(u.iter()) {
2171                *r ^= ui;
2172            }
2173        }
2174
2175        result
2176    }
2177}