use parking_lot::RwLock;
use ring::rand::{SecureRandom, SystemRandom};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::{Duration, Instant};
use thiserror::Error;
use tracing::{debug, warn};
#[derive(Error, Debug, Clone)]
pub enum AuthError {
#[error("Authentication failed")]
AuthenticationFailed,
#[error("Invalid credentials")]
InvalidCredentials,
#[error("Principal not found: {0}")]
PrincipalNotFound(String),
#[error("Principal already exists: {0}")]
PrincipalAlreadyExists(String),
#[error("Access denied: {0}")]
AccessDenied(String),
#[error("Permission denied: {principal} lacks {permission} on {resource}")]
PermissionDenied {
principal: String,
permission: String,
resource: String,
},
#[error("Role not found: {0}")]
RoleNotFound(String),
#[error("Invalid token")]
InvalidToken,
#[error("Token expired")]
TokenExpired,
#[error("Rate limited: too many authentication failures")]
RateLimited,
#[error("Internal error: {0}")]
Internal(String),
}
pub type AuthResult<T> = std::result::Result<T, AuthError>;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ResourceType {
Cluster,
Topic(String),
TopicPattern(String),
ConsumerGroup(String),
Schema(String),
TransactionalId(String),
}
impl ResourceType {
pub fn matches(&self, other: &ResourceType) -> bool {
match (self, other) {
(a, b) if a == b => true,
(ResourceType::TopicPattern(pattern), ResourceType::Topic(name)) => {
Self::glob_match(pattern, name)
}
(ResourceType::Topic(name), ResourceType::TopicPattern(pattern)) => {
Self::glob_match(pattern, name)
}
_ => false,
}
}
fn glob_match(pattern: &str, text: &str) -> bool {
if pattern == "*" {
return true;
}
if let Some(prefix) = pattern.strip_suffix('*') {
return text.starts_with(prefix);
}
if let Some(suffix) = pattern.strip_prefix('*') {
return text.ends_with(suffix);
}
if let Some(idx) = pattern.find('*') {
let prefix = &pattern[..idx];
let suffix = &pattern[idx + 1..];
return text.starts_with(prefix)
&& text.ends_with(suffix)
&& text.len() >= prefix.len() + suffix.len();
}
pattern == text
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Permission {
Read, Write, Create, Delete, Alter, Describe,
GroupRead, GroupDelete,
ClusterAction, IdempotentWrite,
AlterConfigs, DescribeConfigs,
All, }
impl Permission {
pub fn implies(&self, other: &Permission) -> bool {
if self == other {
return true;
}
match self {
Permission::All => true, Permission::Alter | Permission::Write | Permission::Read => {
matches!(other, Permission::Describe)
}
_ => false,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum PrincipalType {
User,
ServiceAccount,
Anonymous,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct Principal {
pub name: String,
pub principal_type: PrincipalType,
pub password_hash: PasswordHash,
pub roles: HashSet<String>,
pub enabled: bool,
pub metadata: HashMap<String, String>,
pub created_at: u64,
}
impl std::fmt::Debug for Principal {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Principal")
.field("name", &self.name)
.field("principal_type", &self.principal_type)
.field("password_hash", &"[REDACTED]")
.field("roles", &self.roles)
.field("enabled", &self.enabled)
.field("metadata", &self.metadata)
.field("created_at", &self.created_at)
.finish()
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct PasswordHash {
pub salt: Vec<u8>,
pub iterations: u32,
pub server_key: Vec<u8>,
pub stored_key: Vec<u8>,
}
impl std::fmt::Debug for PasswordHash {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PasswordHash")
.field("salt", &"[REDACTED]")
.field("iterations", &self.iterations)
.field("server_key", &"[REDACTED]")
.field("stored_key", &"[REDACTED]")
.finish()
}
}
impl PasswordHash {
pub fn new(password: &str) -> Self {
let rng = SystemRandom::new();
let mut salt = vec![0u8; 32];
rng.fill(&mut salt).expect("Failed to generate salt");
Self::with_salt(password, &salt, 600_000)
}
pub fn with_salt(password: &str, salt: &[u8], iterations: u32) -> Self {
let salted_password = Self::pbkdf2_sha256(password.as_bytes(), salt, iterations);
let client_key = Self::hmac_sha256(&salted_password, b"Client Key");
let server_key = Self::hmac_sha256(&salted_password, b"Server Key");
let stored_key = Sha256::digest(&client_key).to_vec();
PasswordHash {
salt: salt.to_vec(),
iterations,
server_key,
stored_key,
}
}
pub fn verify(&self, password: &str) -> bool {
let salted_password = Self::pbkdf2_sha256(password.as_bytes(), &self.salt, self.iterations);
let client_key = Self::hmac_sha256(&salted_password, b"Client Key");
let stored_key = Sha256::digest(&client_key);
Self::constant_time_compare(&stored_key, &self.stored_key)
}
pub async fn new_async(password: &str) -> Self {
let password = password.to_string();
tokio::task::spawn_blocking(move || Self::new(&password))
.await
.expect("PasswordHash::new_async: spawn_blocking panicked")
}
pub async fn verify_async(&self, password: &str) -> bool {
let hash = self.clone();
let password = password.to_string();
tokio::task::spawn_blocking(move || hash.verify(&password))
.await
.expect("PasswordHash::verify_async: spawn_blocking panicked")
}
pub fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut result = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
result |= x ^ y;
}
result == 0
}
fn pbkdf2_sha256(password: &[u8], salt: &[u8], iterations: u32) -> Vec<u8> {
use hmac::{Hmac, Mac};
type HmacSha256 = Hmac<Sha256>;
let mut result = vec![0u8; 32];
let mut mac = HmacSha256::new_from_slice(password).expect("HMAC accepts any key length");
mac.update(salt);
mac.update(&1u32.to_be_bytes());
let mut u = mac.finalize().into_bytes();
result.copy_from_slice(&u);
for _ in 1..iterations {
let mut mac =
HmacSha256::new_from_slice(password).expect("HMAC accepts any key length");
mac.update(&u);
u = mac.finalize().into_bytes();
for (r, ui) in result.iter_mut().zip(u.iter()) {
*r ^= ui;
}
}
result
}
pub fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
use hmac::{Hmac, Mac};
type HmacSha256 = Hmac<Sha256>;
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key length");
mac.update(data);
mac.finalize().into_bytes().to_vec()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Role {
pub name: String,
pub description: String,
pub permissions: HashSet<(ResourceType, Permission)>,
pub builtin: bool,
}
impl Role {
pub fn admin() -> Self {
let mut permissions = HashSet::new();
permissions.insert((ResourceType::Cluster, Permission::All));
Role {
name: "admin".to_string(),
description: "Full administrative access to all resources".to_string(),
permissions,
builtin: true,
}
}
pub fn producer() -> Self {
let mut permissions = HashSet::new();
permissions.insert((
ResourceType::TopicPattern("*".to_string()),
Permission::Write,
));
permissions.insert((
ResourceType::TopicPattern("*".to_string()),
Permission::Describe,
));
permissions.insert((ResourceType::Cluster, Permission::IdempotentWrite));
Role {
name: "producer".to_string(),
description: "Can produce to all topics".to_string(),
permissions,
builtin: true,
}
}
pub fn consumer() -> Self {
let mut permissions = HashSet::new();
permissions.insert((
ResourceType::TopicPattern("*".to_string()),
Permission::Read,
));
permissions.insert((
ResourceType::TopicPattern("*".to_string()),
Permission::Describe,
));
permissions.insert((
ResourceType::ConsumerGroup("*".to_string()),
Permission::GroupRead,
));
Role {
name: "consumer".to_string(),
description: "Can consume from all topics".to_string(),
permissions,
builtin: true,
}
}
pub fn read_only() -> Self {
let mut permissions = HashSet::new();
permissions.insert((
ResourceType::TopicPattern("*".to_string()),
Permission::Read,
));
permissions.insert((
ResourceType::TopicPattern("*".to_string()),
Permission::Describe,
));
Role {
name: "read-only".to_string(),
description: "Read-only access to all topics".to_string(),
permissions,
builtin: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AclEntry {
pub principal: String,
pub resource: ResourceType,
pub permission: Permission,
pub allow: bool,
pub host: String,
}
#[derive(Debug, Default)]
struct AclIndex {
by_principal: HashMap<String, Vec<AclEntry>>,
wildcard: Vec<AclEntry>,
}
impl AclIndex {
fn new() -> Self {
Self::default()
}
fn push(&mut self, entry: AclEntry) {
if entry.principal == "*" {
self.wildcard.push(entry);
} else {
self.by_principal
.entry(entry.principal.clone())
.or_default()
.push(entry);
}
}
fn retain<F: Fn(&AclEntry) -> bool>(&mut self, predicate: F) {
self.wildcard.retain(&predicate);
self.by_principal.retain(|_, entries| {
entries.retain(&predicate);
!entries.is_empty()
});
}
fn lookup(&self, principal: &str) -> impl Iterator<Item = &AclEntry> {
let specific = self
.by_principal
.get(principal)
.map(|v| v.as_slice())
.unwrap_or(&[]);
specific.iter().chain(self.wildcard.iter())
}
fn to_vec(&self) -> Vec<AclEntry> {
self.wildcard
.iter()
.chain(self.by_principal.values().flatten())
.cloned()
.collect()
}
}
pub trait Session: Send + Sync + std::fmt::Debug {
fn session_id(&self) -> &str;
fn principal(&self) -> &str;
fn is_expired(&self) -> bool;
fn client_ip(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct AuthSession {
pub id: String,
pub principal_name: String,
pub principal_type: PrincipalType,
pub permissions: HashSet<(ResourceType, Permission)>,
pub created_at: Instant,
pub expires_at: Instant,
pub client_ip: String,
}
impl Session for AuthSession {
fn session_id(&self) -> &str {
&self.id
}
fn principal(&self) -> &str {
&self.principal_name
}
fn is_expired(&self) -> bool {
Instant::now() >= self.expires_at
}
fn client_ip(&self) -> &str {
&self.client_ip
}
}
impl AuthSession {
pub fn is_expired(&self) -> bool {
<Self as Session>::is_expired(self)
}
pub fn has_permission(&self, resource: &ResourceType, permission: &Permission) -> bool {
if self
.permissions
.contains(&(ResourceType::Cluster, Permission::All))
{
return true;
}
if self.permissions.contains(&(resource.clone(), *permission)) {
return true;
}
for (res, perm) in &self.permissions {
let resource_matches = res.matches(resource);
let permission_implies = perm.implies(permission);
if resource_matches && permission_implies {
return true;
}
}
false
}
}
#[derive(Debug, Clone)]
pub struct AuthConfig {
pub session_timeout: Duration,
pub max_failed_attempts: u32,
pub lockout_duration: Duration,
pub require_authentication: bool,
pub enable_acls: bool,
pub default_deny: bool,
}
impl Default for AuthConfig {
fn default() -> Self {
#[cfg(not(test))]
tracing::info!("AuthConfig::default() — authentication and ACLs ENABLED by default.");
AuthConfig {
session_timeout: Duration::from_secs(3600), max_failed_attempts: 5,
lockout_duration: Duration::from_secs(300), require_authentication: true, enable_acls: true, default_deny: true,
}
}
}
struct FailedAttemptTracker {
attempts: HashMap<String, Vec<Instant>>,
lockouts: HashMap<String, Instant>,
}
const MAX_TRACKED_IDENTIFIERS: usize = 10_000;
impl FailedAttemptTracker {
fn new() -> Self {
Self {
attempts: HashMap::new(),
lockouts: HashMap::new(),
}
}
fn is_locked_out(&self, identifier: &str, lockout_duration: Duration) -> bool {
if let Some(lockout_time) = self.lockouts.get(identifier) {
if lockout_time.elapsed() < lockout_duration {
return true;
}
}
false
}
fn record_failure(
&mut self,
identifier: &str,
max_attempts: u32,
lockout_duration: Duration,
) -> bool {
let now = Instant::now();
self.lockouts.retain(|_, t| t.elapsed() < lockout_duration);
if self.attempts.len() >= MAX_TRACKED_IDENTIFIERS && !self.attempts.contains_key(identifier)
{
self.attempts.retain(|_, v| !v.is_empty());
if self.attempts.len() >= MAX_TRACKED_IDENTIFIERS {
let mut entries: Vec<(String, Instant)> = self
.attempts
.iter()
.filter_map(|(k, v)| v.last().map(|t| (k.clone(), *t)))
.collect();
entries.sort_by_key(|(_, t)| *t);
let to_remove = entries.len() / 10;
for (key, _) in entries.into_iter().take(to_remove.max(1)) {
self.attempts.remove(&key);
self.lockouts.remove(&key);
}
}
}
let attempts = self.attempts.entry(identifier.to_string()).or_default();
attempts.retain(|t| t.elapsed() < lockout_duration);
attempts.push(now);
let exceeded = attempts.len() >= max_attempts as usize;
if exceeded {
warn!(
"Principal '{}' locked out after {} failed attempts",
identifier, max_attempts
);
self.lockouts.insert(identifier.to_string(), now);
}
if self.attempts.len() > 10_000 {
self.attempts.retain(|_, v| !v.is_empty());
}
exceeded
}
fn clear_failures(&mut self, identifier: &str) {
self.attempts.remove(identifier);
self.lockouts.remove(identifier);
}
}
fn validate_password_strength(password: &str) -> AuthResult<()> {
if password.len() < 8 {
return Err(AuthError::Internal(
"Password must be at least 8 characters".to_string(),
));
}
let has_uppercase = password.chars().any(|c| c.is_ascii_uppercase());
let has_lowercase = password.chars().any(|c| c.is_ascii_lowercase());
let has_digit = password.chars().any(|c| c.is_ascii_digit());
let has_special = password.chars().any(|c| !c.is_alphanumeric());
if !has_uppercase {
return Err(AuthError::Internal(
"Password must contain at least one uppercase letter".to_string(),
));
}
if !has_lowercase {
return Err(AuthError::Internal(
"Password must contain at least one lowercase letter".to_string(),
));
}
if !has_digit {
return Err(AuthError::Internal(
"Password must contain at least one digit".to_string(),
));
}
if !has_special {
return Err(AuthError::Internal(
"Password must contain at least one special character".to_string(),
));
}
Ok(())
}
pub struct AuthManager {
config: AuthConfig,
principals: RwLock<HashMap<String, Principal>>,
roles: RwLock<HashMap<String, Role>>,
acls: RwLock<AclIndex>,
sessions: RwLock<HashMap<String, AuthSession>>,
failed_attempts: RwLock<FailedAttemptTracker>,
rng: SystemRandom,
}
impl AuthManager {
pub fn new(config: AuthConfig) -> Self {
let manager = Self {
config,
principals: RwLock::new(HashMap::new()),
roles: RwLock::new(HashMap::new()),
acls: RwLock::new(AclIndex::new()),
sessions: RwLock::new(HashMap::new()),
failed_attempts: RwLock::new(FailedAttemptTracker::new()),
rng: SystemRandom::new(),
};
manager.init_builtin_roles();
manager
}
pub fn new_default() -> Self {
Self::new(AuthConfig::default())
}
pub fn with_auth_enabled() -> Self {
Self::new(AuthConfig {
require_authentication: true,
enable_acls: true,
..Default::default()
})
}
fn init_builtin_roles(&self) {
let mut roles = self.roles.write();
roles.insert("admin".to_string(), Role::admin());
roles.insert("producer".to_string(), Role::producer());
roles.insert("consumer".to_string(), Role::consumer());
roles.insert("read-only".to_string(), Role::read_only());
}
pub fn create_principal(
&self,
name: &str,
password: &str,
principal_type: PrincipalType,
roles: HashSet<String>,
) -> AuthResult<()> {
if name.is_empty() || name.len() > 255 {
return Err(AuthError::Internal("Invalid principal name".to_string()));
}
validate_password_strength(password)?;
{
let role_map = self.roles.read();
for role in &roles {
if !role_map.contains_key(role) {
return Err(AuthError::RoleNotFound(role.clone()));
}
}
}
let mut principals = self.principals.write();
if principals.contains_key(name) {
return Err(AuthError::PrincipalAlreadyExists(name.to_string()));
}
let principal = Principal {
name: name.to_string(),
principal_type,
password_hash: PasswordHash::new(password),
roles,
enabled: true,
metadata: HashMap::new(),
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
principals.insert(name.to_string(), principal);
debug!("Created principal: {}", name);
Ok(())
}
pub fn delete_principal(&self, name: &str) -> AuthResult<()> {
let mut principals = self.principals.write();
if principals.remove(name).is_none() {
return Err(AuthError::PrincipalNotFound(name.to_string()));
}
let mut sessions = self.sessions.write();
sessions.retain(|_, s| s.principal_name != name);
debug!("Deleted principal: {}", name);
Ok(())
}
pub fn get_principal(&self, name: &str) -> Option<Principal> {
self.principals.read().get(name).cloned()
}
pub fn list_principals(&self) -> Vec<String> {
self.principals.read().keys().cloned().collect()
}
pub fn update_password(&self, name: &str, new_password: &str) -> AuthResult<()> {
validate_password_strength(new_password)?;
let mut principals = self.principals.write();
let principal = principals
.get_mut(name)
.ok_or_else(|| AuthError::PrincipalNotFound(name.to_string()))?;
principal.password_hash = PasswordHash::new(new_password);
let mut sessions = self.sessions.write();
sessions.retain(|_, s| s.principal_name != name);
debug!("Updated password for principal: {}", name);
Ok(())
}
pub async fn create_principal_async(
&self,
name: &str,
password: &str,
principal_type: PrincipalType,
roles: HashSet<String>,
) -> AuthResult<()> {
if name.is_empty() || name.len() > 255 {
return Err(AuthError::Internal("Invalid principal name".to_string()));
}
validate_password_strength(password)?;
{
let role_map = self.roles.read();
for role in &roles {
if !role_map.contains_key(role) {
return Err(AuthError::RoleNotFound(role.clone()));
}
}
}
if self.principals.read().contains_key(name) {
return Err(AuthError::PrincipalAlreadyExists(name.to_string()));
}
let password_hash = PasswordHash::new_async(password).await;
let mut principals = self.principals.write();
if principals.contains_key(name) {
return Err(AuthError::PrincipalAlreadyExists(name.to_string()));
}
let principal = Principal {
name: name.to_string(),
principal_type,
password_hash,
roles,
enabled: true,
metadata: HashMap::new(),
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
principals.insert(name.to_string(), principal);
debug!("Created principal: {}", name);
Ok(())
}
pub async fn update_password_async(&self, name: &str, new_password: &str) -> AuthResult<()> {
validate_password_strength(new_password)?;
if !self.principals.read().contains_key(name) {
return Err(AuthError::PrincipalNotFound(name.to_string()));
}
let password_hash = PasswordHash::new_async(new_password).await;
let mut principals = self.principals.write();
let principal = principals
.get_mut(name)
.ok_or_else(|| AuthError::PrincipalNotFound(name.to_string()))?;
principal.password_hash = password_hash;
let mut sessions = self.sessions.write();
sessions.retain(|_, s| s.principal_name != name);
debug!("Updated password for principal: {}", name);
Ok(())
}
pub fn add_role_to_principal(&self, principal_name: &str, role_name: &str) -> AuthResult<()> {
if !self.roles.read().contains_key(role_name) {
return Err(AuthError::RoleNotFound(role_name.to_string()));
}
let mut principals = self.principals.write();
let principal = principals
.get_mut(principal_name)
.ok_or_else(|| AuthError::PrincipalNotFound(principal_name.to_string()))?;
principal.roles.insert(role_name.to_string());
debug!(
"Added role '{}' to principal '{}'",
role_name, principal_name
);
Ok(())
}
pub fn remove_role_from_principal(
&self,
principal_name: &str,
role_name: &str,
) -> AuthResult<()> {
let mut principals = self.principals.write();
let principal = principals
.get_mut(principal_name)
.ok_or_else(|| AuthError::PrincipalNotFound(principal_name.to_string()))?;
principal.roles.remove(role_name);
debug!(
"Removed role '{}' from principal '{}'",
role_name, principal_name
);
Ok(())
}
pub fn create_role(&self, role: Role) -> AuthResult<()> {
let mut roles = self.roles.write();
if roles.contains_key(&role.name) {
return Err(AuthError::Internal(format!(
"Role '{}' already exists",
role.name
)));
}
debug!("Created role: {}", role.name);
roles.insert(role.name.clone(), role);
Ok(())
}
pub fn delete_role(&self, name: &str) -> AuthResult<()> {
let mut roles = self.roles.write();
if let Some(role) = roles.get(name) {
if role.builtin {
return Err(AuthError::Internal(
"Cannot delete built-in role".to_string(),
));
}
} else {
return Err(AuthError::RoleNotFound(name.to_string()));
}
roles.remove(name);
debug!("Deleted role: {}", name);
Ok(())
}
pub fn get_role(&self, name: &str) -> Option<Role> {
self.roles.read().get(name).cloned()
}
pub fn list_roles(&self) -> Vec<String> {
self.roles.read().keys().cloned().collect()
}
pub fn add_acl(&self, entry: AclEntry) {
let mut acls = self.acls.write();
acls.push(entry);
}
pub fn remove_acls(&self, principal: Option<&str>, resource: Option<&ResourceType>) {
let mut acls = self.acls.write();
acls.retain(|acl| {
let principal_match =
principal.is_none_or(|p| acl.principal == p || acl.principal == "*");
let resource_match = resource.is_none_or(|r| &acl.resource == r);
!(principal_match && resource_match)
});
}
pub fn list_acls(&self) -> Vec<AclEntry> {
self.acls.read().to_vec()
}
pub fn authenticate(
&self,
username: &str,
password: &str,
client_ip: &str,
) -> AuthResult<AuthSession> {
{
let tracker = self.failed_attempts.read();
if tracker.is_locked_out(username, self.config.lockout_duration) {
warn!(
"Authentication attempt for locked-out principal: {}",
username
);
return Err(AuthError::RateLimited);
}
if tracker.is_locked_out(client_ip, self.config.lockout_duration) {
warn!("Authentication attempt from locked-out IP: {}", client_ip);
return Err(AuthError::RateLimited);
}
}
let principal = {
let principals = self.principals.read();
principals.get(username).cloned()
};
let principal = match principal {
Some(p) if p.enabled => p,
Some(_) => {
self.record_auth_failure(username, client_ip);
return Err(AuthError::AuthenticationFailed);
}
None => {
let dummy = PasswordHash::new("dummy");
let _ = dummy.verify(password);
self.record_auth_failure(username, client_ip);
return Err(AuthError::AuthenticationFailed);
}
};
if !principal.password_hash.verify(password) {
self.record_auth_failure(username, client_ip);
return Err(AuthError::AuthenticationFailed);
}
self.failed_attempts.write().clear_failures(username);
self.failed_attempts.write().clear_failures(client_ip);
let permissions = self.resolve_permissions(&principal);
let mut session_id = vec![0u8; 32];
self.rng
.fill(&mut session_id)
.map_err(|_| AuthError::Internal("RNG failed".to_string()))?;
let session_id = hex::encode(&session_id);
let now = Instant::now();
let session = AuthSession {
id: session_id.clone(),
principal_name: principal.name.clone(),
principal_type: principal.principal_type.clone(),
permissions,
created_at: now,
expires_at: now + self.config.session_timeout,
client_ip: client_ip.to_string(),
};
self.sessions.write().insert(session_id, session.clone());
debug!("Authenticated principal '{}' from {}", username, client_ip);
Ok(session)
}
fn record_auth_failure(&self, username: &str, client_ip: &str) {
let mut tracker = self.failed_attempts.write();
tracker.record_failure(
username,
self.config.max_failed_attempts,
self.config.lockout_duration,
);
tracker.record_failure(
client_ip,
self.config.max_failed_attempts * 2,
self.config.lockout_duration,
);
}
pub fn get_session(&self, session_id: &str) -> Option<AuthSession> {
let sessions = self.sessions.read();
sessions.get(session_id).and_then(|s| {
if s.is_expired() {
None
} else {
Some(s.clone())
}
})
}
pub fn invalidate_session(&self, session_id: &str) {
self.sessions.write().remove(session_id);
}
pub fn invalidate_all_sessions(&self, principal_name: &str) {
self.sessions
.write()
.retain(|_, s| s.principal_name != principal_name);
}
pub fn cleanup_expired_sessions(&self) {
self.sessions.write().retain(|_, s| !s.is_expired());
}
pub fn create_session(&self, principal: &Principal) -> AuthSession {
let permissions = self.resolve_permissions(principal);
let mut session_id = vec![0u8; 32];
self.rng.fill(&mut session_id).expect("RNG failed");
let session_id = hex::encode(&session_id);
let now = Instant::now();
let session = AuthSession {
id: session_id.clone(),
principal_name: principal.name.clone(),
principal_type: principal.principal_type.clone(),
permissions,
created_at: now,
expires_at: now + self.config.session_timeout,
client_ip: "scram".to_string(),
};
self.sessions.write().insert(session_id, session.clone());
session
}
pub fn create_api_key_session(&self, principal_name: &str, roles: &[String]) -> AuthSession {
let mut permissions = HashSet::new();
{
let roles_map = self.roles.read();
for role_name in roles {
if let Some(role) = roles_map.get(role_name) {
permissions.extend(role.permissions.iter().cloned());
}
}
}
let mut session_id = vec![0u8; 32];
self.rng.fill(&mut session_id).expect("RNG failed");
let session_id = hex::encode(&session_id);
let now = Instant::now();
let session = AuthSession {
id: session_id.clone(),
principal_name: principal_name.to_string(),
principal_type: PrincipalType::ServiceAccount,
permissions,
created_at: now,
expires_at: now + self.config.session_timeout,
client_ip: "api-key".to_string(),
};
self.sessions.write().insert(session_id, session.clone());
debug!(principal = %principal_name, "Created API key session");
session
}
pub fn create_jwt_session(&self, principal_name: &str, groups: &[String]) -> AuthSession {
let mut permissions = HashSet::new();
{
let roles_map = self.roles.read();
for group in groups {
if let Some(role) = roles_map.get(group) {
permissions.extend(role.permissions.iter().cloned());
}
}
}
let mut session_id = vec![0u8; 32];
self.rng.fill(&mut session_id).expect("RNG failed");
let session_id = hex::encode(&session_id);
let now = Instant::now();
let session = AuthSession {
id: session_id.clone(),
principal_name: principal_name.to_string(),
principal_type: PrincipalType::User,
permissions,
created_at: now,
expires_at: now + self.config.session_timeout,
client_ip: "jwt".to_string(),
};
self.sessions.write().insert(session_id, session.clone());
debug!(principal = %principal_name, groups = ?groups, "Created JWT session");
session
}
pub fn get_session_by_principal(&self, principal_name: &str) -> Option<AuthSession> {
let sessions = self.sessions.read();
sessions
.values()
.find(|s| s.principal_name == principal_name && !s.is_expired())
.cloned()
}
fn resolve_permissions(&self, principal: &Principal) -> HashSet<(ResourceType, Permission)> {
let mut permissions = HashSet::new();
let roles = self.roles.read();
for role_name in &principal.roles {
if let Some(role) = roles.get(role_name) {
permissions.extend(role.permissions.iter().cloned());
}
}
permissions
}
pub fn authorize(
&self,
session: &AuthSession,
resource: &ResourceType,
permission: Permission,
client_ip: &str,
) -> AuthResult<()> {
if !self.config.require_authentication && !self.config.enable_acls {
return Ok(());
}
if session.is_expired() {
return Err(AuthError::TokenExpired);
}
if session.has_permission(resource, &permission) {
return Ok(());
}
if self.config.enable_acls
&& self.check_acls(&session.principal_name, resource, permission, client_ip)
{
return Ok(());
}
if self.config.default_deny {
warn!(
"Access denied: {} attempted {} on {:?} from {}",
session.principal_name,
format!("{:?}", permission),
resource,
client_ip
);
return Err(AuthError::PermissionDenied {
principal: session.principal_name.clone(),
permission: format!("{:?}", permission),
resource: format!("{:?}", resource),
});
}
Ok(())
}
fn check_acls(
&self,
principal: &str,
resource: &ResourceType,
permission: Permission,
client_ip: &str,
) -> bool {
let acls = self.acls.read();
for acl in acls.lookup(principal) {
if !acl.allow
&& (acl.host == client_ip || acl.host == "*")
&& acl.resource.matches(resource)
&& (acl.permission == permission || acl.permission == Permission::All)
{
return false; }
}
for acl in acls.lookup(principal) {
if acl.allow
&& (acl.host == client_ip || acl.host == "*")
&& acl.resource.matches(resource)
&& (acl.permission == permission || acl.permission == Permission::All)
{
return true;
}
}
false
}
#[allow(unused_variables)]
pub fn authorize_anonymous(
&self,
resource: &ResourceType,
permission: Permission,
) -> AuthResult<()> {
if !self.config.require_authentication {
return Ok(());
}
Err(AuthError::AuthenticationFailed)
}
}
pub struct SaslPlainAuth {
auth_manager: Arc<AuthManager>,
}
impl SaslPlainAuth {
pub fn new(auth_manager: Arc<AuthManager>) -> Self {
Self { auth_manager }
}
pub fn authenticate(&self, sasl_bytes: &[u8], client_ip: &str) -> AuthResult<AuthSession> {
let parts: Vec<&[u8]> = sasl_bytes.split(|&b| b == 0).collect();
if parts.len() < 2 {
return Err(AuthError::InvalidCredentials);
}
let (username, password) = if parts.len() == 2 {
(
std::str::from_utf8(parts[0]).map_err(|_| AuthError::InvalidCredentials)?,
std::str::from_utf8(parts[1]).map_err(|_| AuthError::InvalidCredentials)?,
)
} else {
(
std::str::from_utf8(parts[1]).map_err(|_| AuthError::InvalidCredentials)?,
std::str::from_utf8(parts[2]).map_err(|_| AuthError::InvalidCredentials)?,
)
};
self.auth_manager
.authenticate(username, password, client_ip)
}
}
#[derive(Debug, Clone)]
pub enum ScramState {
Initial,
ServerFirstSent {
username: String,
client_nonce: String,
server_nonce: String,
salt: Vec<u8>,
iterations: u32,
auth_message: String,
},
Complete,
}
pub struct SaslScramAuth {
auth_manager: Arc<AuthManager>,
}
impl SaslScramAuth {
pub fn new(auth_manager: Arc<AuthManager>) -> Self {
Self { auth_manager }
}
pub fn process_client_first(
&self,
client_first: &[u8],
client_ip: &str,
) -> AuthResult<(ScramState, Vec<u8>)> {
let client_first_str =
std::str::from_utf8(client_first).map_err(|_| AuthError::InvalidCredentials)?;
let parts: Vec<&str> = client_first_str.splitn(3, ',').collect();
if parts.len() < 3 {
return Err(AuthError::InvalidCredentials);
}
let client_first_bare = if parts[0] == "n" || parts[0] == "y" || parts[0] == "p" {
&client_first_str[parts[0].len() + 1 + parts[1].len() + 1..]
} else {
client_first_str
};
let mut username = None;
let mut client_nonce = None;
for attr in client_first_bare.split(',') {
if let Some(value) = attr.strip_prefix("n=") {
username = Some(Self::unescape_username(value));
} else if let Some(value) = attr.strip_prefix("r=") {
client_nonce = Some(value.to_string());
}
}
let username = username.ok_or(AuthError::InvalidCredentials)?;
let client_nonce = client_nonce.ok_or(AuthError::InvalidCredentials)?;
let (salt, iterations) = match self.auth_manager.get_principal(&username) {
Some(principal) => (
principal.password_hash.salt.clone(),
principal.password_hash.iterations,
),
None => {
warn!(
"SCRAM auth for unknown user '{}' from {}",
username, client_ip
);
let rng = SystemRandom::new();
let mut fake_salt = vec![0u8; 32];
rng.fill(&mut fake_salt).expect("Failed to generate salt");
(fake_salt, 600_000)
}
};
let rng = SystemRandom::new();
let mut server_nonce_bytes = vec![0u8; 24];
rng.fill(&mut server_nonce_bytes)
.expect("Failed to generate nonce");
let server_nonce = base64_encode(&server_nonce_bytes);
let combined_nonce = format!("{}{}", client_nonce, server_nonce);
let salt_b64 = base64_encode(&salt);
let server_first = format!("r={},s={},i={}", combined_nonce, salt_b64, iterations);
let auth_message = format!(
"{},{},c=biws,r={}",
client_first_bare, server_first, combined_nonce
);
let state = ScramState::ServerFirstSent {
username,
client_nonce,
server_nonce,
salt,
iterations,
auth_message,
};
Ok((state, server_first.into_bytes()))
}
pub fn process_client_final(
&self,
state: &ScramState,
client_final: &[u8],
client_ip: &str,
) -> AuthResult<(AuthSession, Vec<u8>)> {
let ScramState::ServerFirstSent {
username,
client_nonce,
server_nonce,
salt: _, iterations: _, auth_message,
} = state
else {
return Err(AuthError::Internal("Invalid SCRAM state".to_string()));
};
let client_final_str =
std::str::from_utf8(client_final).map_err(|_| AuthError::InvalidCredentials)?;
let mut channel_binding = None;
let mut nonce = None;
let mut proof = None;
for attr in client_final_str.split(',') {
if let Some(value) = attr.strip_prefix("c=") {
channel_binding = Some(value.to_string());
} else if let Some(value) = attr.strip_prefix("r=") {
nonce = Some(value.to_string());
} else if let Some(value) = attr.strip_prefix("p=") {
proof = Some(value.to_string());
}
}
let _channel_binding = channel_binding.ok_or(AuthError::InvalidCredentials)?;
let nonce = nonce.ok_or(AuthError::InvalidCredentials)?;
let proof_b64 = proof.ok_or(AuthError::InvalidCredentials)?;
let expected_nonce = format!("{}{}", client_nonce, server_nonce);
if nonce != expected_nonce {
warn!("SCRAM nonce mismatch for '{}' from {}", username, client_ip);
return Err(AuthError::InvalidCredentials);
}
let principal = self
.auth_manager
.get_principal(username)
.ok_or(AuthError::AuthenticationFailed)?;
let client_proof = base64_decode(&proof_b64).map_err(|_| AuthError::InvalidCredentials)?;
let client_signature =
PasswordHash::hmac_sha256(&principal.password_hash.stored_key, auth_message.as_bytes());
if client_proof.len() != client_signature.len() {
return Err(AuthError::InvalidCredentials);
}
let client_key: Vec<u8> = client_proof
.iter()
.zip(client_signature.iter())
.map(|(p, s)| p ^ s)
.collect();
let computed_stored_key = Sha256::digest(&client_key);
if !PasswordHash::constant_time_compare(
&computed_stored_key,
&principal.password_hash.stored_key,
) {
warn!(
"SCRAM authentication failed for '{}' from {}",
username, client_ip
);
return Err(AuthError::AuthenticationFailed);
}
let server_signature =
PasswordHash::hmac_sha256(&principal.password_hash.server_key, auth_message.as_bytes());
let server_final = format!("v={}", base64_encode(&server_signature));
let session = self.auth_manager.create_session(&principal);
debug!(
"SCRAM authentication successful for '{}' from {}",
username, client_ip
);
Ok((session, server_final.into_bytes()))
}
fn unescape_username(s: &str) -> String {
s.replace("=2C", ",").replace("=3D", "=")
}
}
fn base64_encode(data: &[u8]) -> String {
use base64::{engine::general_purpose::STANDARD, Engine as _};
STANDARD.encode(data)
}
fn base64_decode(s: &str) -> Result<Vec<u8>, base64::DecodeError> {
use base64::{engine::general_purpose::STANDARD, Engine as _};
STANDARD.decode(s)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_password_hash_verify() {
let hash = PasswordHash::new("Test@Pass123");
assert!(hash.verify("Test@Pass123"));
assert!(!hash.verify("Wrong@Pass1"));
assert!(!hash.verify(""));
assert!(!hash.verify("Test@Pass12")); }
#[test]
fn test_password_hash_timing_attack_resistant() {
let hash = PasswordHash::new("Correct@Pass1");
assert!(!hash.verify("Wrong@Pass1"));
assert!(!hash.verify("x"));
}
#[test]
fn test_create_principal() {
let auth = AuthManager::new_default();
let mut roles = HashSet::new();
roles.insert("producer".to_string());
auth.create_principal(
"alice",
"Secure@Pass123",
PrincipalType::User,
roles.clone(),
)
.expect("Failed to create principal");
assert!(auth
.create_principal("alice", "Other@Pass1", PrincipalType::User, roles.clone())
.is_err());
let principal = auth.get_principal("alice").expect("Principal not found");
assert_eq!(principal.name, "alice");
assert!(principal.roles.contains("producer"));
}
#[test]
fn test_authentication_success() {
let auth = AuthManager::new_default();
let mut roles = HashSet::new();
roles.insert("producer".to_string());
auth.create_principal("bob", "Bob@Pass123", PrincipalType::User, roles)
.unwrap();
let session = auth
.authenticate("bob", "Bob@Pass123", "127.0.0.1")
.expect("Authentication should succeed");
assert_eq!(session.principal_name, "bob");
assert!(!session.is_expired());
}
#[test]
fn test_authentication_failure() {
let auth = AuthManager::new_default();
let mut roles = HashSet::new();
roles.insert("producer".to_string());
auth.create_principal("charlie", "Correct@Pass1", PrincipalType::User, roles)
.unwrap();
let result = auth.authenticate("charlie", "Wrong@Pass1", "127.0.0.1");
assert!(matches!(result, Err(AuthError::AuthenticationFailed)));
let result = auth.authenticate("unknown", "password", "127.0.0.1");
assert!(matches!(result, Err(AuthError::AuthenticationFailed)));
}
#[test]
fn test_rate_limiting() {
let config = AuthConfig {
max_failed_attempts: 3,
lockout_duration: Duration::from_secs(120),
..Default::default()
};
let auth = AuthManager::new(config);
let mut roles = HashSet::new();
roles.insert("consumer".to_string());
auth.create_principal("eve", "Eve@Pass123", PrincipalType::User, roles)
.unwrap();
for _ in 0..3 {
let _ = auth.authenticate("eve", "wrong", "192.168.1.1");
}
let result = auth.authenticate("eve", "Eve@Pass123", "192.168.1.1");
assert!(matches!(result, Err(AuthError::RateLimited)));
{
let mut tracker = auth.failed_attempts.write();
tracker.clear_failures("eve");
tracker.clear_failures("192.168.1.1");
}
let result = auth.authenticate("eve", "Eve@Pass123", "192.168.1.1");
assert!(result.is_ok());
}
#[test]
fn test_role_permissions() {
let auth = AuthManager::with_auth_enabled();
let mut roles = HashSet::new();
roles.insert("producer".to_string());
auth.create_principal("producer_user", "Prod@Pass123", PrincipalType::User, roles)
.unwrap();
let session = auth
.authenticate("producer_user", "Prod@Pass123", "127.0.0.1")
.unwrap();
assert!(session.has_permission(
&ResourceType::Topic("orders".to_string()),
&Permission::Write
));
assert!(!session.has_permission(
&ResourceType::Topic("orders".to_string()),
&Permission::Delete
));
}
#[test]
fn test_admin_has_all_permissions() {
let auth = AuthManager::with_auth_enabled();
let mut roles = HashSet::new();
roles.insert("admin".to_string());
auth.create_principal("admin_user", "Admin@Pass1", PrincipalType::User, roles)
.unwrap();
let session = auth
.authenticate("admin_user", "Admin@Pass1", "127.0.0.1")
.unwrap();
assert!(session.has_permission(&ResourceType::Cluster, &Permission::All));
assert!(session.has_permission(
&ResourceType::Topic("any_topic".to_string()),
&Permission::Delete
));
}
#[test]
fn test_resource_pattern_matching() {
assert!(ResourceType::TopicPattern("*".to_string())
.matches(&ResourceType::Topic("anything".to_string())));
assert!(ResourceType::TopicPattern("orders-*".to_string())
.matches(&ResourceType::Topic("orders-us".to_string())));
assert!(ResourceType::TopicPattern("orders-*".to_string())
.matches(&ResourceType::Topic("orders-eu".to_string())));
assert!(!ResourceType::TopicPattern("orders-*".to_string())
.matches(&ResourceType::Topic("events-us".to_string())));
}
#[test]
fn test_acl_enforcement() {
let auth = AuthManager::new(AuthConfig {
require_authentication: true,
enable_acls: true,
default_deny: true,
..Default::default()
});
let mut roles = HashSet::new();
roles.insert("read-only".to_string());
auth.create_principal("reader", "Read@Pass123", PrincipalType::User, roles)
.unwrap();
auth.add_acl(AclEntry {
principal: "reader".to_string(),
resource: ResourceType::Topic("special-topic".to_string()),
permission: Permission::Write,
allow: true,
host: "*".to_string(),
});
let session = auth
.authenticate("reader", "Read@Pass123", "127.0.0.1")
.unwrap();
let result = auth.authorize(
&session,
&ResourceType::Topic("special-topic".to_string()),
Permission::Write,
"127.0.0.1",
);
assert!(result.is_ok());
let result = auth.authorize(
&session,
&ResourceType::Topic("other-topic".to_string()),
Permission::Write,
"127.0.0.1",
);
assert!(result.is_err());
}
#[test]
fn test_sasl_plain_authentication() {
let auth = Arc::new(AuthManager::new_default());
let mut roles = HashSet::new();
roles.insert("producer".to_string());
auth.create_principal("sasl_user", "Sasl@Pass123", PrincipalType::User, roles)
.unwrap();
let sasl = SaslPlainAuth::new(auth);
let two_part = b"sasl_user\0Sasl@Pass123";
let result = sasl.authenticate(two_part, "127.0.0.1");
assert!(result.is_ok());
let three_part = b"\0sasl_user\0Sasl@Pass123";
let result = sasl.authenticate(three_part, "127.0.0.1");
assert!(result.is_ok());
}
#[test]
fn test_session_expiration() {
let config = AuthConfig {
session_timeout: Duration::from_millis(100),
..Default::default()
};
let auth = AuthManager::new(config);
let mut roles = HashSet::new();
roles.insert("producer".to_string());
auth.create_principal("expiring", "Expiry@Pass1", PrincipalType::User, roles)
.unwrap();
let session = auth
.authenticate("expiring", "Expiry@Pass1", "127.0.0.1")
.unwrap();
assert!(!session.is_expired());
std::thread::sleep(Duration::from_millis(150));
let session = AuthSession {
expires_at: session.expires_at,
..session
};
assert!(session.is_expired());
}
#[test]
fn test_delete_principal_invalidates_sessions() {
let auth = AuthManager::new_default();
let mut roles = HashSet::new();
roles.insert("producer".to_string());
auth.create_principal("deleteme", "Delete@Pass1", PrincipalType::User, roles)
.unwrap();
let session = auth
.authenticate("deleteme", "Delete@Pass1", "127.0.0.1")
.unwrap();
assert!(auth.get_session(&session.id).is_some());
auth.delete_principal("deleteme").unwrap();
assert!(auth.get_session(&session.id).is_none());
}
#[test]
fn test_disabled_principal_cannot_authenticate() {
let auth = AuthManager::new_default();
let mut roles = HashSet::new();
roles.insert("producer".to_string());
auth.create_principal("disabled_user", "Disable@Pass1", PrincipalType::User, roles)
.unwrap();
{
let mut principals = auth.principals.write();
if let Some(p) = principals.get_mut("disabled_user") {
p.enabled = false;
}
}
let result = auth.authenticate("disabled_user", "Disable@Pass1", "127.0.0.1");
assert!(matches!(result, Err(AuthError::AuthenticationFailed)));
}
#[test]
fn test_password_hash_debug_redacts_sensitive_data() {
let hash = PasswordHash::new("super_secret_password");
let debug_output = format!("{:?}", hash);
assert!(
debug_output.contains("[REDACTED]"),
"Debug output should contain [REDACTED]"
);
assert!(
!debug_output.contains("super_secret_password"),
"Debug output should not contain password"
);
assert!(
debug_output.contains("iterations"),
"Debug output should show iterations field"
);
}
#[test]
fn test_principal_debug_redacts_password_hash() {
let principal = Principal {
name: "test_user".to_string(),
principal_type: PrincipalType::User,
password_hash: PasswordHash::new("secret_password"),
roles: HashSet::from(["admin".to_string()]),
enabled: true,
metadata: HashMap::new(),
created_at: 1234567890,
};
let debug_output = format!("{:?}", principal);
assert!(
debug_output.contains("[REDACTED]"),
"Debug output should contain [REDACTED]: {}",
debug_output
);
assert!(
debug_output.contains("test_user"),
"Debug output should show name"
);
assert!(
debug_output.contains("admin"),
"Debug output should show roles"
);
}
#[test]
fn test_scram_full_handshake() {
use sha2::{Digest, Sha256};
let auth = Arc::new(AuthManager::new_default());
let mut roles = HashSet::new();
roles.insert("producer".to_string());
auth.create_principal("scram_user", "Scram@Pass123", PrincipalType::User, roles)
.expect("Failed to create principal");
let scram = SaslScramAuth::new(auth.clone());
let client_nonce = "rOprNGfwEbeRWgbNEkqO";
let client_first = format!("n,,n=scram_user,r={}", client_nonce);
let (state, server_first) = scram
.process_client_first(client_first.as_bytes(), "127.0.0.1")
.expect("client-first processing should succeed");
let server_first_str = std::str::from_utf8(&server_first).expect("valid UTF-8");
assert!(server_first_str.starts_with(&format!("r={}", client_nonce)));
assert!(server_first_str.contains(",s="));
assert!(server_first_str.contains(",i="));
let ScramState::ServerFirstSent {
username: _,
client_nonce: _,
server_nonce: _,
salt,
iterations,
auth_message: _,
} = &state
else {
panic!("Expected ServerFirstSent state");
};
let salted_password = compute_salted_password("Scram@Pass123", salt, *iterations);
let client_key = PasswordHash::hmac_sha256(&salted_password, b"Client Key");
let stored_key = Sha256::digest(&client_key);
let client_first_bare = format!("n=scram_user,r={}", client_nonce);
let combined_nonce: String = server_first_str
.split(',')
.find(|s| s.starts_with("r="))
.map(|s| &s[2..])
.unwrap()
.to_string();
let auth_message = format!(
"{},{},c=biws,r={}",
client_first_bare, server_first_str, combined_nonce
);
let client_signature = PasswordHash::hmac_sha256(&stored_key, auth_message.as_bytes());
let client_proof: Vec<u8> = client_key
.iter()
.zip(client_signature.iter())
.map(|(k, s)| k ^ s)
.collect();
let client_final = format!(
"c=biws,r={},p={}",
combined_nonce,
base64_encode(&client_proof)
);
let (session, server_final) = scram
.process_client_final(&state, client_final.as_bytes(), "127.0.0.1")
.expect("client-final processing should succeed");
assert_eq!(session.principal_name, "scram_user");
assert!(!session.is_expired());
let server_final_str = std::str::from_utf8(&server_final).expect("valid UTF-8");
assert!(server_final_str.starts_with("v="));
}
#[test]
fn test_scram_wrong_password() {
let auth = Arc::new(AuthManager::new_default());
let mut roles = HashSet::new();
roles.insert("producer".to_string());
auth.create_principal("scram_user2", "Correct@Pass1", PrincipalType::User, roles)
.expect("Failed to create principal");
let scram = SaslScramAuth::new(auth.clone());
let client_nonce = "test_nonce_12345";
let client_first = format!("n,,n=scram_user2,r={}", client_nonce);
let (state, server_first) = scram
.process_client_first(client_first.as_bytes(), "127.0.0.1")
.expect("client-first processing should succeed");
let server_first_str = std::str::from_utf8(&server_first).expect("valid UTF-8");
let combined_nonce: String = server_first_str
.split(',')
.find(|s| s.starts_with("r="))
.map(|s| &s[2..])
.unwrap()
.to_string();
let ScramState::ServerFirstSent {
salt, iterations, ..
} = &state
else {
panic!("Expected ServerFirstSent state");
};
let salted_password = compute_salted_password("Wrong@Pass1", salt, *iterations);
let client_key = PasswordHash::hmac_sha256(&salted_password, b"Client Key");
let stored_key = sha2::Sha256::digest(&client_key);
let client_first_bare = format!("n=scram_user2,r={}", client_nonce);
let auth_message = format!(
"{},{},c=biws,r={}",
client_first_bare, server_first_str, combined_nonce
);
let client_signature = PasswordHash::hmac_sha256(&stored_key, auth_message.as_bytes());
let client_proof: Vec<u8> = client_key
.iter()
.zip(client_signature.iter())
.map(|(k, s)| k ^ s)
.collect();
let client_final = format!(
"c=biws,r={},p={}",
combined_nonce,
base64_encode(&client_proof)
);
let result = scram.process_client_final(&state, client_final.as_bytes(), "127.0.0.1");
assert!(result.is_err());
assert!(matches!(result, Err(AuthError::AuthenticationFailed)));
}
#[test]
fn test_scram_nonexistent_user() {
let auth = Arc::new(AuthManager::new_default());
let scram = SaslScramAuth::new(auth.clone());
let client_first = "n,,n=nonexistent_user,r=test_nonce";
let result = scram.process_client_first(client_first.as_bytes(), "127.0.0.1");
assert!(
result.is_ok(),
"Should return fake server-first to prevent enumeration"
);
let (state, server_first) = result.unwrap();
let server_first_str = std::str::from_utf8(&server_first).expect("valid UTF-8");
assert!(server_first_str.contains("r=test_nonce"));
assert!(server_first_str.contains(",s="));
assert!(server_first_str.contains(",i="));
let combined_nonce: String = server_first_str
.split(',')
.find(|s| s.starts_with("r="))
.map(|s| &s[2..])
.unwrap()
.to_string();
let client_final = format!("c=biws,r={},p=dW5rbm93bg==", combined_nonce);
let result = scram.process_client_final(&state, client_final.as_bytes(), "127.0.0.1");
assert!(result.is_err());
}
#[test]
fn test_scram_nonce_mismatch() {
let auth = Arc::new(AuthManager::new_default());
let mut roles = HashSet::new();
roles.insert("producer".to_string());
auth.create_principal("scram_user3", "Scram3@Pass1", PrincipalType::User, roles)
.expect("Failed to create principal");
let scram = SaslScramAuth::new(auth.clone());
let client_first = "n,,n=scram_user3,r=original_nonce";
let (state, _server_first) = scram
.process_client_first(client_first.as_bytes(), "127.0.0.1")
.expect("client-first should succeed");
let client_final = "c=biws,r=tampered_nonce_plus_server,p=dW5rbm93bg==";
let result = scram.process_client_final(&state, client_final.as_bytes(), "127.0.0.1");
assert!(result.is_err());
assert!(matches!(result, Err(AuthError::InvalidCredentials)));
}
fn compute_salted_password(password: &str, salt: &[u8], iterations: u32) -> Vec<u8> {
use hmac::{Hmac, Mac};
type HmacSha256 = Hmac<sha2::Sha256>;
let mut result = vec![0u8; 32];
let mut mac =
HmacSha256::new_from_slice(password.as_bytes()).expect("HMAC accepts any key length");
mac.update(salt);
mac.update(&1u32.to_be_bytes());
let mut u = mac.finalize().into_bytes();
result.copy_from_slice(&u);
for _ in 1..iterations {
let mut mac = HmacSha256::new_from_slice(password.as_bytes())
.expect("HMAC accepts any key length");
mac.update(&u);
u = mac.finalize().into_bytes();
for (r, ui) in result.iter_mut().zip(u.iter()) {
*r ^= ui;
}
}
result
}
}