use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
use hmac::{Hmac, Mac};
use pbkdf2::pbkdf2_hmac;
use rand::Rng;
use sha2::{Digest, Sha256, Sha512};
use std::fmt;
use subtle::ConstantTimeEq;
use zeroize::Zeroize;
use crate::error::{KrafkaError, Result};
pub const MIN_PBKDF2_ITERATIONS: u32 = 4096;
pub const MAX_PBKDF2_ITERATIONS: u32 = 1_000_000;
#[derive(Debug, Clone)]
pub enum ChannelBinding {
None,
TlsServerEndPoint(Vec<u8>),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ScramMechanism {
Sha256,
Sha512,
}
impl ScramMechanism {
#[inline]
pub fn mechanism_name(&self) -> &'static str {
match self {
ScramMechanism::Sha256 => "SCRAM-SHA-256",
ScramMechanism::Sha512 => "SCRAM-SHA-512",
}
}
#[inline]
pub fn hash_length(&self) -> usize {
match self {
ScramMechanism::Sha256 => 32,
ScramMechanism::Sha512 => 64,
}
}
#[inline]
pub fn to_wire_byte(self) -> i8 {
match self {
ScramMechanism::Sha256 => 1,
ScramMechanism::Sha512 => 2,
}
}
#[inline]
pub fn from_wire_byte(b: i8) -> Result<Self> {
match b {
1 => Ok(ScramMechanism::Sha256),
2 => Ok(ScramMechanism::Sha512),
other => Err(KrafkaError::protocol(format!(
"unknown SCRAM mechanism code: {other}"
))),
}
}
}
impl fmt::Display for ScramMechanism {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.mechanism_name())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ScramState {
Initial,
WaitingServerFirst,
WaitingClientFinal,
WaitingServerFinal,
Complete,
Failed,
}
pub struct ScramClient {
username: String,
password: String,
mechanism: ScramMechanism,
channel_binding: ChannelBinding,
client_nonce: String,
state: ScramState,
client_first_bare: String,
server_nonce: Option<String>,
salt: Option<Vec<u8>>,
iteration_count: Option<u32>,
salted_password: Option<Vec<u8>>,
server_signature: Option<Vec<u8>>,
}
impl Drop for ScramClient {
fn drop(&mut self) {
self.password.zeroize();
if let Some(ref mut salted) = self.salted_password {
salted.zeroize();
}
if let Some(ref mut sig) = self.server_signature {
sig.zeroize();
}
self.client_first_bare.zeroize();
}
}
impl ScramClient {
pub fn new(
username: &str,
password: &str,
mechanism: ScramMechanism,
channel_binding: ChannelBinding,
) -> Self {
let client_nonce = generate_nonce();
Self {
username: username.to_string(),
password: password.to_string(),
mechanism,
channel_binding,
client_nonce,
state: ScramState::Initial,
client_first_bare: String::new(),
server_nonce: None,
salt: None,
iteration_count: None,
salted_password: None,
server_signature: None,
}
}
#[inline]
pub fn state(&self) -> &ScramState {
&self.state
}
#[inline]
pub fn mechanism(&self) -> ScramMechanism {
self.mechanism
}
pub fn client_first_message(&mut self) -> Vec<u8> {
let gs2_header = match &self.channel_binding {
ChannelBinding::None => "n,,".to_string(),
ChannelBinding::TlsServerEndPoint(_) => "p=tls-server-end-point,,".to_string(),
};
let escaped_username = escape_username(&self.username);
self.client_first_bare = format!("n={},r={}", escaped_username, self.client_nonce);
let message = format!("{}{}", gs2_header, self.client_first_bare);
self.state = ScramState::WaitingServerFirst;
message.into_bytes()
}
pub fn process_server_first(&mut self, server_first: &[u8]) -> Result<Vec<u8>> {
if self.state != ScramState::WaitingServerFirst {
self.state = ScramState::Failed;
return Err(KrafkaError::auth(
"Invalid SCRAM state: expected WaitingServerFirst",
));
}
let server_first_str = std::str::from_utf8(server_first)
.map_err(|_| KrafkaError::auth("Invalid UTF-8 in server-first message"))?;
let mut server_nonce = None;
let mut salt = None;
let mut iteration_count = None;
for part in server_first_str.split(',') {
if let Some(value) = part.strip_prefix("r=") {
server_nonce = Some(value.to_string());
} else if let Some(value) = part.strip_prefix("s=") {
salt = Some(
BASE64
.decode(value)
.map_err(|_| KrafkaError::auth("Invalid base64 salt in server-first"))?,
);
} else if let Some(value) = part.strip_prefix("i=") {
iteration_count =
Some(value.parse::<u32>().map_err(|_| {
KrafkaError::auth("Invalid iteration count in server-first")
})?);
}
}
let server_nonce =
server_nonce.ok_or_else(|| KrafkaError::auth("Missing nonce in server-first"))?;
let salt = salt.ok_or_else(|| KrafkaError::auth("Missing salt in server-first"))?;
let iteration_count = iteration_count
.ok_or_else(|| KrafkaError::auth("Missing iteration count in server-first"))?;
if iteration_count < MIN_PBKDF2_ITERATIONS {
self.state = ScramState::Failed;
return Err(KrafkaError::auth(format!(
"PBKDF2 iteration count {iteration_count} is below minimum {MIN_PBKDF2_ITERATIONS}"
)));
}
if iteration_count > MAX_PBKDF2_ITERATIONS {
self.state = ScramState::Failed;
return Err(KrafkaError::auth(format!(
"PBKDF2 iteration count {iteration_count} exceeds maximum {MAX_PBKDF2_ITERATIONS}"
)));
}
if !server_nonce.starts_with(&self.client_nonce) {
self.state = ScramState::Failed;
return Err(KrafkaError::auth(
"Server nonce doesn't contain client nonce",
));
}
self.server_nonce = Some(server_nonce.clone());
self.salt = Some(salt.clone());
self.iteration_count = Some(iteration_count);
let salted_password = self.compute_salted_password(&salt, iteration_count);
self.salted_password = Some(salted_password.clone());
let client_key = self.compute_client_key(&salted_password);
let stored_key = self.hash(&client_key);
let channel_binding = match &self.channel_binding {
ChannelBinding::None => BASE64.encode("n,,"),
ChannelBinding::TlsServerEndPoint(cb_data) => {
let mut buf = b"p=tls-server-end-point,,".to_vec();
buf.extend_from_slice(cb_data);
BASE64.encode(&buf)
}
};
let client_final_without_proof = format!("c={},r={}", channel_binding, server_nonce);
let auth_message = format!(
"{},{},{}",
self.client_first_bare, server_first_str, client_final_without_proof
);
let client_signature = self.compute_hmac(&stored_key, auth_message.as_bytes());
let client_proof = xor_bytes(&client_key, &client_signature);
let server_key = self.compute_server_key(&salted_password);
self.server_signature = Some(self.compute_hmac(&server_key, auth_message.as_bytes()));
let client_final = format!(
"{},p={}",
client_final_without_proof,
BASE64.encode(&client_proof)
);
self.state = ScramState::WaitingServerFinal;
Ok(client_final.into_bytes())
}
pub fn verify_server_final(&mut self, server_final: &[u8]) -> Result<()> {
if self.state != ScramState::WaitingServerFinal {
self.state = ScramState::Failed;
return Err(KrafkaError::auth(
"Invalid SCRAM state: expected WaitingServerFinal",
));
}
let server_final_str = std::str::from_utf8(server_final)
.map_err(|_| KrafkaError::auth("Invalid UTF-8 in server-final message"))?;
if let Some(error) = server_final_str.strip_prefix("e=") {
self.state = ScramState::Failed;
return Err(KrafkaError::auth(format!("SCRAM server error: {error}")));
}
let server_sig_b64 = server_final_str
.strip_prefix("v=")
.ok_or_else(|| KrafkaError::auth("Missing verifier in server-final"))?;
let server_signature = BASE64
.decode(server_sig_b64)
.map_err(|_| KrafkaError::auth("Invalid base64 in server-final verifier"))?;
let expected = self
.server_signature
.as_ref()
.ok_or_else(|| KrafkaError::auth("Server signature not computed"))?;
if !constant_time_compare(&server_signature, expected) {
self.state = ScramState::Failed;
return Err(KrafkaError::auth("Server signature verification failed"));
}
self.state = ScramState::Complete;
Ok(())
}
#[inline]
pub fn is_complete(&self) -> bool {
self.state == ScramState::Complete
}
fn compute_salted_password(&self, salt: &[u8], iterations: u32) -> Vec<u8> {
let mut output = vec![0u8; self.mechanism.hash_length()];
match self.mechanism {
ScramMechanism::Sha256 => {
pbkdf2_hmac::<Sha256>(self.password.as_bytes(), salt, iterations, &mut output);
}
ScramMechanism::Sha512 => {
pbkdf2_hmac::<Sha512>(self.password.as_bytes(), salt, iterations, &mut output);
}
}
output
}
fn compute_hmac(&self, key: &[u8], data: &[u8]) -> Vec<u8> {
match self.mechanism {
ScramMechanism::Sha256 => {
let Ok(mut mac) = Hmac::<Sha256>::new_from_slice(key) else {
unreachable!("HMAC accepts any key length per RFC 2104");
};
mac.update(data);
mac.finalize().into_bytes().to_vec()
}
ScramMechanism::Sha512 => {
let Ok(mut mac) = Hmac::<Sha512>::new_from_slice(key) else {
unreachable!("HMAC accepts any key length per RFC 2104");
};
mac.update(data);
mac.finalize().into_bytes().to_vec()
}
}
}
fn hash(&self, data: &[u8]) -> Vec<u8> {
match self.mechanism {
ScramMechanism::Sha256 => Sha256::digest(data).to_vec(),
ScramMechanism::Sha512 => Sha512::digest(data).to_vec(),
}
}
fn compute_client_key(&self, salted_password: &[u8]) -> Vec<u8> {
self.compute_hmac(salted_password, b"Client Key")
}
fn compute_server_key(&self, salted_password: &[u8]) -> Vec<u8> {
self.compute_hmac(salted_password, b"Server Key")
}
}
impl fmt::Debug for ScramClient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ScramClient")
.field("username", &self.username)
.field("password", &"[REDACTED]")
.field("mechanism", &self.mechanism)
.field("state", &self.state)
.finish()
}
}
fn generate_nonce() -> String {
let mut rng = rand::rng();
let bytes: [u8; 24] = rng.random();
BASE64.encode(bytes)
}
fn escape_username(username: &str) -> String {
username.replace('=', "=3D").replace(',', "=2C")
}
fn xor_bytes(a: &[u8], b: &[u8]) -> Vec<u8> {
a.iter().zip(b.iter()).map(|(x, y)| x ^ y).collect()
}
fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
a.ct_eq(b).into()
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_scram_mechanism_name() {
assert_eq!(ScramMechanism::Sha256.mechanism_name(), "SCRAM-SHA-256");
assert_eq!(ScramMechanism::Sha512.mechanism_name(), "SCRAM-SHA-512");
}
#[test]
fn test_scram_mechanism_hash_length() {
assert_eq!(ScramMechanism::Sha256.hash_length(), 32);
assert_eq!(ScramMechanism::Sha512.hash_length(), 64);
}
#[test]
fn test_escape_username() {
assert_eq!(escape_username("user"), "user");
assert_eq!(escape_username("user=name"), "user=3Dname");
assert_eq!(escape_username("user,name"), "user=2Cname");
assert_eq!(escape_username("a=b,c"), "a=3Db=2Cc");
}
#[test]
fn test_xor_bytes() {
let a = vec![0x01, 0x02, 0x03];
let b = vec![0x01, 0x00, 0x01];
let result = xor_bytes(&a, &b);
assert_eq!(result, vec![0x00, 0x02, 0x02]);
}
#[test]
fn test_constant_time_compare() {
assert!(constant_time_compare(b"hello", b"hello"));
assert!(!constant_time_compare(b"hello", b"world"));
assert!(!constant_time_compare(b"hello", b"hell"));
}
#[test]
fn test_scram_client_initial_state() {
let client = ScramClient::new(
"user",
"password",
ScramMechanism::Sha256,
ChannelBinding::None,
);
assert_eq!(client.state(), &ScramState::Initial);
assert_eq!(client.mechanism(), ScramMechanism::Sha256);
}
#[test]
fn test_scram_client_first_message() {
let mut client = ScramClient::new(
"user",
"password",
ScramMechanism::Sha256,
ChannelBinding::None,
);
let msg = client.client_first_message();
let msg_str = String::from_utf8(msg).unwrap();
assert!(msg_str.starts_with("n,,n=user,r="));
assert_eq!(client.state(), &ScramState::WaitingServerFirst);
}
#[test]
fn test_scram_client_first_message_escaped() {
let mut client = ScramClient::new(
"user=name",
"password",
ScramMechanism::Sha256,
ChannelBinding::None,
);
let msg = client.client_first_message();
let msg_str = String::from_utf8(msg).unwrap();
assert!(msg_str.contains("n=user=3Dname"));
}
#[test]
fn test_scram_client_invalid_server_first() {
let mut client = ScramClient::new(
"user",
"password",
ScramMechanism::Sha256,
ChannelBinding::None,
);
client.client_first_message();
let result = client.process_server_first(b"invalid");
assert!(result.is_err());
}
#[test]
fn test_scram_client_wrong_nonce() {
let mut client = ScramClient::new(
"user",
"password",
ScramMechanism::Sha256,
ChannelBinding::None,
);
client.client_first_message();
let server_first = "r=wrongnonce,s=c2FsdA==,i=4096";
let result = client.process_server_first(server_first.as_bytes());
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("client nonce"));
}
#[test]
fn test_generate_nonce() {
let n1 = generate_nonce();
let n2 = generate_nonce();
assert_ne!(n1, n2);
assert_eq!(n1.len(), 32);
}
#[test]
fn test_scram_sha256_full_flow() {
let mut client = ScramClient::new(
"user",
"pencil",
ScramMechanism::Sha256,
ChannelBinding::None,
);
client.client_nonce = "rOprNGfwEbeRWgbNEkqO".to_string();
let first = client.client_first_message();
let first_str = String::from_utf8(first).unwrap();
assert!(first_str.starts_with("n,,n=user,r=rOprNGfwEbeRWgbNEkqO"));
}
#[test]
fn test_scram_sha512_client() {
let mut client = ScramClient::new(
"user",
"password",
ScramMechanism::Sha512,
ChannelBinding::None,
);
let first = client.client_first_message();
let first_str = String::from_utf8(first).unwrap();
assert!(first_str.starts_with("n,,n=user,r="));
assert_eq!(client.mechanism().hash_length(), 64);
}
#[test]
fn test_pbkdf2_iteration_too_low() {
let mut client = ScramClient::new(
"user",
"password",
ScramMechanism::Sha256,
ChannelBinding::None,
);
client.client_first_message();
let server_first = format!("r={}extra,s=c2FsdA==,i=100", client.client_nonce);
let result = client.process_server_first(server_first.as_bytes());
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("below minimum"),
"Expected 'below minimum' in: {}",
err
);
}
#[test]
fn test_pbkdf2_iteration_too_high() {
let mut client = ScramClient::new(
"user",
"password",
ScramMechanism::Sha256,
ChannelBinding::None,
);
client.client_first_message();
let server_first = format!("r={}extra,s=c2FsdA==,i=2000000", client.client_nonce);
let result = client.process_server_first(server_first.as_bytes());
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("exceeds maximum"),
"Expected 'exceeds maximum' in: {}",
err
);
}
#[test]
fn test_pbkdf2_iteration_at_boundaries() {
let mut client = ScramClient::new(
"user",
"password",
ScramMechanism::Sha256,
ChannelBinding::None,
);
client.client_first_message();
let server_first = format!("r={}extra,s=c2FsdA==,i=4096", client.client_nonce);
let result = client.process_server_first(server_first.as_bytes());
assert!(result.is_ok());
let mut client = ScramClient::new(
"user",
"password",
ScramMechanism::Sha256,
ChannelBinding::None,
);
client.client_first_message();
let server_first = format!("r={}extra,s=c2FsdA==,i=1000000", client.client_nonce);
let result = client.process_server_first(server_first.as_bytes());
assert!(result.is_ok());
}
#[test]
fn test_scram_debug_redacts_password() {
let client = ScramClient::new(
"user",
"secret_password",
ScramMechanism::Sha256,
ChannelBinding::None,
);
let debug_output = format!("{:?}", client);
assert!(
!debug_output.contains("secret_password"),
"Password leaked in Debug output"
);
assert!(debug_output.contains("[REDACTED]"));
}
#[test]
fn test_scram_zeroize_on_drop() {
let mut client = ScramClient::new(
"user",
"password",
ScramMechanism::Sha256,
ChannelBinding::None,
);
client.client_first_message();
let server_first = format!("r={}extra,s=c2FsdA==,i=4096", client.client_nonce);
let _ = client.process_server_first(server_first.as_bytes());
drop(client);
}
#[test]
fn test_channel_binding_none_gs2_header() {
let mut client = ScramClient::new(
"user",
"password",
ScramMechanism::Sha256,
ChannelBinding::None,
);
let msg = client.client_first_message();
let msg_str = String::from_utf8(msg).unwrap();
assert!(
msg_str.starts_with("n,,"),
"Expected 'n,,' GS2 header, got: {msg_str}"
);
}
#[test]
fn test_channel_binding_tls_server_end_point_gs2_header() {
let cb_data = vec![0xDE, 0xAD, 0xBE, 0xEF];
let mut client = ScramClient::new(
"user",
"password",
ScramMechanism::Sha256,
ChannelBinding::TlsServerEndPoint(cb_data),
);
let msg = client.client_first_message();
let msg_str = String::from_utf8(msg).unwrap();
assert!(
msg_str.starts_with("p=tls-server-end-point,,"),
"Expected 'p=tls-server-end-point,,' GS2 header, got: {msg_str}"
);
}
#[test]
fn test_channel_binding_tls_server_end_point_c_field() {
let cb_data = vec![0x01, 0x02, 0x03, 0x04];
let mut client = ScramClient::new(
"user",
"password",
ScramMechanism::Sha256,
ChannelBinding::TlsServerEndPoint(cb_data.clone()),
);
client.client_first_message();
let server_first = format!("r={}extra,s=c2FsdA==,i=4096", client.client_nonce);
let client_final = client
.process_server_first(server_first.as_bytes())
.unwrap();
let client_final_str = String::from_utf8(client_final).unwrap();
let c_value = client_final_str
.split(',')
.find(|p| p.starts_with("c="))
.unwrap()
.strip_prefix("c=")
.unwrap();
let decoded = BASE64.decode(c_value).unwrap();
let expected_prefix = b"p=tls-server-end-point,,";
assert!(
decoded.starts_with(expected_prefix),
"c= field should start with GS2 header"
);
assert_eq!(
&decoded[expected_prefix.len()..],
&cb_data,
"c= field should end with channel binding data"
);
}
#[test]
fn test_channel_binding_none_c_field() {
let mut client = ScramClient::new(
"user",
"password",
ScramMechanism::Sha256,
ChannelBinding::None,
);
client.client_first_message();
let server_first = format!("r={}extra,s=c2FsdA==,i=4096", client.client_nonce);
let client_final = client
.process_server_first(server_first.as_bytes())
.unwrap();
let client_final_str = String::from_utf8(client_final).unwrap();
let c_value = client_final_str
.split(',')
.find(|p| p.starts_with("c="))
.unwrap()
.strip_prefix("c=")
.unwrap();
let decoded = BASE64.decode(c_value).unwrap();
assert_eq!(decoded, b"n,,");
}
}