use anyhow::{Context, Result, anyhow, bail};
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64;
use hmac::{Hmac, Mac};
use num_bigint::BigUint;
use sha2::{Digest, Sha256, Sha512};
use subtle::ConstantTimeEq;
use crate::config::SaslMechanism;
const GS2_HEADER: &str = "n,,";
pub const MIN_ITERATIONS: i32 = 4096;
pub(crate) const MAX_ITERATIONS: i32 = 16384;
#[derive(Debug, Clone)]
pub(crate) struct ScramClient {
mechanism: SaslMechanism,
password: String,
client_nonce: String,
client_first_message_bare: String,
server_first_message: Option<ServerFirstMessage>,
client_final_message_without_proof: Option<String>,
salted_password: Option<Vec<u8>>,
}
#[derive(Debug, Clone)]
struct ServerFirstMessage {
message: String,
nonce: String,
salt: Vec<u8>,
iterations: i32,
}
enum ServerFinalMessage {
Verifier(Vec<u8>),
Error(String),
}
impl ScramClient {
pub(crate) fn new(
mechanism: SaslMechanism,
username: String,
password: String,
) -> Result<Self> {
if !mechanism.is_scram() {
bail!("SCRAM client requires a SCRAM mechanism");
}
let client_nonce = secure_random_string()?;
let client_first_message_bare = format!("n={},r={client_nonce}", sasl_name(&username));
Ok(Self {
mechanism,
password,
client_nonce,
client_first_message_bare,
server_first_message: None,
client_final_message_without_proof: None,
salted_password: None,
})
}
pub(crate) fn client_first_message(&self) -> Vec<u8> {
format!("{GS2_HEADER}{}", self.client_first_message_bare).into_bytes()
}
pub(crate) fn handle_server_first_message(&mut self, challenge: &[u8]) -> Result<Vec<u8>> {
let server_first = ServerFirstMessage::parse(challenge)?;
if !server_first.nonce.starts_with(&self.client_nonce) {
bail!("invalid SCRAM server nonce: does not start with client nonce");
}
if server_first.iterations < MIN_ITERATIONS {
bail!(
"requested SCRAM iterations {} is less than the minimum {} for {}",
server_first.iterations,
MIN_ITERATIONS,
self.mechanism.as_str()
);
}
let salted_password = salted_password(
self.mechanism,
self.password.as_bytes(),
&server_first.salt,
server_first.iterations,
)?;
let client_final_message_without_proof =
format!("c={},r={}", BASE64.encode(GS2_HEADER), server_first.nonce);
let auth_message = auth_message(
&self.client_first_message_bare,
&server_first.message,
&client_final_message_without_proof,
);
let client_proof = client_proof(self.mechanism, &salted_password, auth_message.as_bytes())?;
let client_final_message = format!(
"{client_final_message_without_proof},p={}",
BASE64.encode(client_proof)
);
self.server_first_message = Some(server_first);
self.client_final_message_without_proof = Some(client_final_message_without_proof);
self.salted_password = Some(salted_password);
Ok(client_final_message.into_bytes())
}
pub(crate) fn handle_server_final_message(&self, challenge: &[u8]) -> Result<()> {
match ServerFinalMessage::parse(challenge)? {
ServerFinalMessage::Error(error) => {
bail!(
"SASL authentication using {} failed with error: {error}",
self.mechanism.as_str()
)
}
ServerFinalMessage::Verifier(signature) => {
let server_first = self
.server_first_message
.as_ref()
.context("SCRAM server-first message was not processed")?;
let client_final_message_without_proof = self
.client_final_message_without_proof
.as_ref()
.context("SCRAM client-final message without proof was not generated")?;
let salted_password = self
.salted_password
.as_ref()
.context("SCRAM salted password was not generated")?;
let server_key = server_key(self.mechanism, salted_password)?;
let expected = server_signature(
self.mechanism,
&server_key,
auth_message(
&self.client_first_message_bare,
&server_first.message,
client_final_message_without_proof,
)
.as_bytes(),
)?;
if signature.ct_eq(&expected).unwrap_u8() != 1 {
bail!("invalid SCRAM server signature in server final message");
}
Ok(())
}
}
}
}
impl ServerFirstMessage {
fn parse(challenge: &[u8]) -> Result<Self> {
let message = std::str::from_utf8(challenge)
.context("SCRAM server-first message is not UTF-8")?
.to_owned();
let mut nonce = None;
let mut salt = None;
let mut iterations = None;
for part in message.split(',') {
if let Some(value) = part.strip_prefix("r=") {
nonce = Some(value.to_owned());
} else if let Some(value) = part.strip_prefix("s=") {
salt = Some(
BASE64
.decode(value)
.context("invalid SCRAM salt encoding")?,
);
} else if let Some(value) = part.strip_prefix("i=") {
let parsed = value
.parse::<i32>()
.context("invalid SCRAM iteration count")?;
if parsed <= 0 {
bail!("invalid SCRAM iteration count {parsed}");
}
iterations = Some(parsed);
}
}
Ok(Self {
message,
nonce: nonce.context("SCRAM server-first message did not include nonce")?,
salt: salt.context("SCRAM server-first message did not include salt")?,
iterations: iterations
.context("SCRAM server-first message did not include iteration count")?,
})
}
}
impl ServerFinalMessage {
fn parse(challenge: &[u8]) -> Result<Self> {
let message =
std::str::from_utf8(challenge).context("SCRAM server-final message is not UTF-8")?;
if let Some(error) = message.strip_prefix("e=") {
let error = error.split(',').next().unwrap_or(error).to_owned();
return Ok(Self::Error(error));
}
for part in message.split(',') {
if let Some(signature) = part.strip_prefix("v=") {
return Ok(Self::Verifier(
BASE64
.decode(signature)
.context("invalid SCRAM server signature encoding")?,
));
}
}
Err(anyhow!(
"SCRAM server-final message did not include verifier"
))
}
}
fn sasl_name(username: &str) -> String {
username.replace('=', "=3D").replace(',', "=2C")
}
pub fn secure_random_bytes() -> Result<Vec<u8>> {
Ok(secure_random_string()?.into_bytes())
}
fn secure_random_string() -> Result<String> {
let mut bytes = [0_u8; 17];
getrandom::fill(&mut bytes)
.map_err(|error| anyhow!("failed to generate SCRAM nonce: {error}"))?;
bytes[0] &= 0x03;
Ok(BigUint::from_bytes_be(&bytes).to_str_radix(36))
}
fn auth_message(
client_first_message_bare: &str,
server_first_message: &str,
client_final_message_without_proof: &str,
) -> String {
format!(
"{client_first_message_bare},{server_first_message},{client_final_message_without_proof}"
)
}
pub fn salted_password(
mechanism: SaslMechanism,
password: &[u8],
salt: &[u8],
iterations: i32,
) -> Result<Vec<u8>> {
if !(MIN_ITERATIONS..=MAX_ITERATIONS).contains(&iterations) {
bail!(
"SCRAM iterations {iterations} outside supported range {MIN_ITERATIONS}..={MAX_ITERATIONS}"
);
}
match mechanism {
SaslMechanism::ScramSha256 => hi_with_mac::<Hmac<Sha256>>(password, salt, iterations),
SaslMechanism::ScramSha512 => hi_with_mac::<Hmac<Sha512>>(password, salt, iterations),
SaslMechanism::Plain => bail!("SASL/PLAIN does not use SCRAM hi()"),
}
}
fn hi_with_mac<M>(password: &[u8], salt: &[u8], iterations: i32) -> Result<Vec<u8>>
where
M: Mac + hmac::digest::KeyInit + Clone,
{
let mut mac = <M as hmac::digest::KeyInit>::new_from_slice(password)
.context("failed to initialize SCRAM HMAC")?;
mac.update(salt);
mac.update(&[0, 0, 0, 1]);
let mut previous = mac.finalize().into_bytes().to_vec();
let mut result = previous.clone();
for _ in 2..=iterations {
let mut mac = <M as hmac::digest::KeyInit>::new_from_slice(password)
.context("failed to initialize SCRAM HMAC")?;
mac.update(&previous);
previous = mac.finalize().into_bytes().to_vec();
xor_in_place(&mut result, &previous)?;
}
Ok(result)
}
fn client_proof(
mechanism: SaslMechanism,
salted_password: &[u8],
auth_message: &[u8],
) -> Result<Vec<u8>> {
let client_key = hmac(mechanism, salted_password, b"Client Key")?;
let stored_key = hash(mechanism, &client_key);
let client_signature = hmac(mechanism, &stored_key, auth_message)?;
xor(&client_key, &client_signature)
}
fn server_key(mechanism: SaslMechanism, salted_password: &[u8]) -> Result<Vec<u8>> {
hmac(mechanism, salted_password, b"Server Key")
}
fn server_signature(
mechanism: SaslMechanism,
server_key: &[u8],
auth_message: &[u8],
) -> Result<Vec<u8>> {
hmac(mechanism, server_key, auth_message)
}
fn hmac(mechanism: SaslMechanism, key: &[u8], bytes: &[u8]) -> Result<Vec<u8>> {
match mechanism {
SaslMechanism::ScramSha256 => hmac_with_mac::<Hmac<Sha256>>(key, bytes),
SaslMechanism::ScramSha512 => hmac_with_mac::<Hmac<Sha512>>(key, bytes),
SaslMechanism::Plain => bail!("SASL/PLAIN does not use SCRAM HMAC"),
}
}
fn hmac_with_mac<M>(key: &[u8], bytes: &[u8]) -> Result<Vec<u8>>
where
M: Mac + hmac::digest::KeyInit,
{
let mut mac = <M as hmac::digest::KeyInit>::new_from_slice(key)
.context("failed to initialize SCRAM HMAC")?;
mac.update(bytes);
Ok(mac.finalize().into_bytes().to_vec())
}
fn hash(mechanism: SaslMechanism, bytes: &[u8]) -> Vec<u8> {
match mechanism {
SaslMechanism::ScramSha256 => Sha256::digest(bytes).to_vec(),
SaslMechanism::ScramSha512 => Sha512::digest(bytes).to_vec(),
SaslMechanism::Plain => unreachable!("SASL/PLAIN does not use SCRAM hash"),
}
}
fn xor(first: &[u8], second: &[u8]) -> Result<Vec<u8>> {
if first.len() != second.len() {
bail!("SCRAM XOR inputs must be the same length");
}
Ok(first
.iter()
.zip(second.iter())
.map(|(left, right)| left ^ right)
.collect())
}
fn xor_in_place(target: &mut [u8], bytes: &[u8]) -> Result<()> {
if target.len() != bytes.len() {
bail!("SCRAM XOR inputs must be the same length");
}
for (target, byte) in target.iter_mut().zip(bytes.iter()) {
*target ^= byte;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sasl_name_escapes_comma_and_equals_like_kafka() {
assert_eq!(sasl_name("a=b,c"), "a=3Db=2Cc");
}
#[test]
fn server_first_parses_required_fields_and_extensions() {
let parsed = ServerFirstMessage::parse(b"r=abcxyz,s=QSBTYWx0,i=4096,extra=value").unwrap();
assert_eq!(parsed.nonce, "abcxyz");
assert_eq!(parsed.salt, b"A Salt");
assert_eq!(parsed.iterations, 4096);
assert_eq!(parsed.message, "r=abcxyz,s=QSBTYWx0,i=4096,extra=value");
}
#[test]
fn server_first_rejects_malformed_messages() {
assert!(ServerFirstMessage::parse(b"r=abc,s=bad*,i=4096").is_err());
assert!(ServerFirstMessage::parse(b"r=abc,s=YQ==,i=0").is_err());
assert!(ServerFirstMessage::parse(b"r=abc,s=YQ==,i=not-number").is_err());
assert!(ServerFirstMessage::parse(b"s=YQ==,i=4096").is_err());
assert!(ServerFirstMessage::parse(b"r=abc,i=4096").is_err());
assert!(ServerFirstMessage::parse(b"r=abc,s=YQ==").is_err());
assert!(ServerFirstMessage::parse(&[0xff]).is_err());
}
#[test]
fn scram_client_rejects_invalid_state_and_server_messages() {
assert!(
ScramClient::new(SaslMechanism::Plain, "user".to_owned(), "pw".to_owned()).is_err()
);
let mut client = ScramClient::new(
SaslMechanism::ScramSha256,
"user".to_owned(),
"pw".to_owned(),
)
.unwrap();
assert!(client.handle_server_final_message(b"v=ZmFrZQ==").is_err());
assert!(
client
.handle_server_first_message(b"r=server,s=YQ==,i=4096")
.is_err()
);
assert!(
client
.handle_server_first_message(b"r=server,s=YQ==,i=4095")
.is_err()
);
}
#[test]
fn server_final_rejects_error_and_invalid_signature() {
let mut client = ScramClient::new(
SaslMechanism::ScramSha256,
"user".to_owned(),
"password".to_owned(),
)
.unwrap();
let client_nonce = client.client_nonce.clone();
client
.handle_server_first_message(format!("r={client_nonce}server,s=YQ==,i=4096").as_bytes())
.unwrap();
assert!(
client
.handle_server_final_message(b"e=invalid-proof")
.is_err()
);
assert!(
client
.handle_server_final_message(b"v=not-base64*")
.is_err()
);
assert!(
client
.handle_server_final_message(b"v=ZmFrZS1zaWduYXR1cmU=")
.is_err()
);
assert!(client.handle_server_final_message(b"").is_err());
}
#[test]
fn salted_password_validates_mechanism_and_iteration_range() {
assert!(salted_password(SaslMechanism::Plain, b"pw", b"salt", MIN_ITERATIONS).is_err());
assert!(
salted_password(
SaslMechanism::ScramSha256,
b"pw",
b"salt",
MIN_ITERATIONS - 1
)
.is_err()
);
assert!(
salted_password(
SaslMechanism::ScramSha256,
b"pw",
b"salt",
MAX_ITERATIONS + 1
)
.is_err()
);
assert_eq!(
salted_password(SaslMechanism::ScramSha256, b"pw", b"salt", MIN_ITERATIONS)
.unwrap()
.len(),
32
);
assert_eq!(
salted_password(SaslMechanism::ScramSha512, b"pw", b"salt", MIN_ITERATIONS)
.unwrap()
.len(),
64
);
}
}