mod client;
mod server;
pub use client::ScramClientExchange;
pub use server::{ScramServerExchange, StepResult};
use crate::SaslMechanism;
use hmac::{Hmac, KeyInit, Mac};
use ring::rand::{SecureRandom, SystemRandom};
use sha2::{Digest, Sha256, Sha512};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ScramCredential {
pub mechanism: SaslMechanism,
pub salt: Vec<u8>,
pub stored_key: Vec<u8>,
pub server_key: Vec<u8>,
pub iterations: u32,
}
#[must_use]
pub fn scram_hash_len(mechanism: SaslMechanism) -> usize {
match mechanism {
SaslMechanism::ScramSha256 => 32,
SaslMechanism::ScramSha512 => 64,
SaslMechanism::Plain | SaslMechanism::OAuthBearer | SaslMechanism::Gssapi => {
panic!("scram_hash_len called with non-SCRAM mechanism {mechanism:?}")
}
}
}
#[must_use]
pub fn hash_scram_password(
password: &[u8],
mechanism: SaslMechanism,
iterations: u32,
) -> ScramCredential {
assert!(iterations > 0, "iterations must be > 0");
let mut salt = vec![0u8; 16];
SystemRandom::new()
.fill(&mut salt)
.expect("system RNG must succeed");
hash_scram_password_with_salt(password, mechanism, iterations, salt)
}
#[must_use]
pub fn hash_scram_password_with_salt(
password: &[u8],
mechanism: SaslMechanism,
iterations: u32,
salt: Vec<u8>,
) -> ScramCredential {
let (stored_key, server_key) = match mechanism {
SaslMechanism::ScramSha512 => {
let salted: [u8; 64] =
pbkdf2::pbkdf2_hmac_array::<Sha512, 64>(password, &salt, iterations);
derive_keys_sha512(&salted)
}
SaslMechanism::ScramSha256 => {
let salted: [u8; 32] =
pbkdf2::pbkdf2_hmac_array::<Sha256, 32>(password, &salt, iterations);
derive_keys_sha256(&salted)
}
SaslMechanism::Plain | SaslMechanism::OAuthBearer | SaslMechanism::Gssapi => {
panic!("hash_scram_password called with non-SCRAM mechanism {mechanism:?}");
}
};
ScramCredential {
mechanism,
salt,
stored_key,
server_key,
iterations,
}
}
#[must_use]
pub fn pbkdf2_salted(
password: &[u8],
mechanism: SaslMechanism,
iterations: u32,
salt: &[u8],
) -> Vec<u8> {
assert!(iterations > 0, "iterations must be > 0");
match mechanism {
SaslMechanism::ScramSha512 => {
let arr: [u8; 64] = pbkdf2::pbkdf2_hmac_array::<Sha512, 64>(password, salt, iterations);
arr.to_vec()
}
SaslMechanism::ScramSha256 => {
let arr: [u8; 32] = pbkdf2::pbkdf2_hmac_array::<Sha256, 32>(password, salt, iterations);
arr.to_vec()
}
SaslMechanism::Plain | SaslMechanism::OAuthBearer | SaslMechanism::Gssapi => {
panic!("pbkdf2_salted called with non-SCRAM mechanism {mechanism:?}");
}
}
}
#[must_use]
pub fn derive_keys_from_salted(mechanism: SaslMechanism, salted: &[u8]) -> (Vec<u8>, Vec<u8>) {
match mechanism {
SaslMechanism::ScramSha512 => derive_keys_sha512(salted),
SaslMechanism::ScramSha256 => derive_keys_sha256(salted),
SaslMechanism::Plain | SaslMechanism::OAuthBearer | SaslMechanism::Gssapi => {
panic!("derive_keys_from_salted called with non-SCRAM mechanism {mechanism:?}");
}
}
}
fn derive_keys_sha512(salted: &[u8]) -> (Vec<u8>, Vec<u8>) {
let mut ck_mac = <Hmac<Sha512>>::new_from_slice(salted).expect("hmac accepts any key length");
ck_mac.update(b"Client Key");
let client_key = ck_mac.finalize().into_bytes();
let stored_key = Sha512::digest(client_key).to_vec();
let mut sk_mac = <Hmac<Sha512>>::new_from_slice(salted).expect("hmac accepts any key length");
sk_mac.update(b"Server Key");
let server_key = sk_mac.finalize().into_bytes().to_vec();
(stored_key, server_key)
}
fn derive_keys_sha256(salted: &[u8]) -> (Vec<u8>, Vec<u8>) {
let mut ck_mac = <Hmac<Sha256>>::new_from_slice(salted).expect("hmac accepts any key length");
ck_mac.update(b"Client Key");
let client_key = ck_mac.finalize().into_bytes();
let stored_key = Sha256::digest(client_key).to_vec();
let mut sk_mac = <Hmac<Sha256>>::new_from_slice(salted).expect("hmac accepts any key length");
sk_mac.update(b"Server Key");
let server_key = sk_mac.finalize().into_bytes().to_vec();
(stored_key, server_key)
}
#[cfg(test)]
mod tests {
use super::*;
use assert2::assert;
use base64::Engine;
use base64::engine::general_purpose::STANDARD as B64;
use sha2::{Digest, Sha512};
#[test]
fn hash_scram_password_produces_expected_keys() {
let password = b"pencil";
let cred = hash_scram_password(password, SaslMechanism::ScramSha512, 4096);
assert!(cred.mechanism == SaslMechanism::ScramSha512);
assert!(cred.salt.len() == 16, "salt must be 16 bytes");
assert!(cred.stored_key.len() == 64, "SHA-512 output is 64 bytes");
assert!(cred.server_key.len() == 64);
assert!(cred.iterations == 4096);
let salted =
pbkdf2::pbkdf2_hmac_array::<sha2::Sha512, 64>(password, &cred.salt, cred.iterations);
let client_key = {
use hmac::{Hmac, KeyInit, Mac};
let mut m = <Hmac<Sha512>>::new_from_slice(&salted).unwrap();
m.update(b"Client Key");
m.finalize().into_bytes()
};
let expected_stored = Sha512::digest(client_key);
assert!(cred.stored_key == expected_stored.as_slice());
}
#[test]
fn hash_scram_password_sha256_produces_expected_keys() {
let password = b"pencil";
let cred = hash_scram_password(password, SaslMechanism::ScramSha256, 4096);
assert!(cred.mechanism == SaslMechanism::ScramSha256);
assert!(cred.salt.len() == 16);
assert!(cred.stored_key.len() == 32, "SHA-256 output is 32 bytes");
assert!(cred.server_key.len() == 32);
let salted =
pbkdf2::pbkdf2_hmac_array::<sha2::Sha256, 32>(password, &cred.salt, cred.iterations);
let client_key = {
use hmac::{Hmac, KeyInit, Mac};
let mut m = <Hmac<sha2::Sha256>>::new_from_slice(&salted).unwrap();
m.update(b"Client Key");
m.finalize().into_bytes()
};
let expected_stored = sha2::Sha256::digest(client_key);
assert!(cred.stored_key == expected_stored.as_slice());
}
#[test]
fn hash_scram_password_is_deterministic_given_salt() {
let a = hash_scram_password(b"x", SaslMechanism::ScramSha512, 4096);
let b = hash_scram_password(b"x", SaslMechanism::ScramSha512, 4096);
assert!(a.salt != b.salt, "fresh salt each call");
}
use crate::scram::client::ScramClientExchange;
use crate::scram::server::{ScramServerExchange, StepResult};
#[test]
fn scram_server_and_client_round_trip() {
let password = b"hunter2";
let cred = hash_scram_password_with_salt(
password,
SaslMechanism::ScramSha512,
4096,
(0..16).collect::<Vec<u8>>(),
);
let mut server = ScramServerExchange::new("alice".to_string(), cred);
let mut client = ScramClientExchange::new(
"alice".to_string(),
password.to_vec(),
SaslMechanism::ScramSha512,
);
let c1 = client.client_first().expect("client first");
let s1 = match server.step(&c1) {
StepResult::Continue(b) => b,
other => panic!("server step 1 must continue, got {other:?}"),
};
let c2 = client.step(&s1).expect("client final");
let (principal, s2) = match server.step(&c2) {
StepResult::Done(p, b) => (p, b),
other => panic!("server step 2 must Done, got {other:?}"),
};
assert!(principal.name == "alice");
assert!(principal.auth_method == crate::AuthMethod::SaslScramSha512);
let final_check = client.verify_server_final(&s2);
assert!(final_check.is_ok(), "server signature must verify");
}
#[test]
fn scram_server_and_client_round_trip_sha256() {
let password = b"hunter2";
let cred = hash_scram_password_with_salt(
password,
SaslMechanism::ScramSha256,
4096,
(0..16).collect::<Vec<u8>>(),
);
let mut server = ScramServerExchange::new("alice".to_string(), cred);
let mut client = ScramClientExchange::new(
"alice".to_string(),
password.to_vec(),
SaslMechanism::ScramSha256,
);
let c1 = client.client_first().expect("client first");
let s1 = match server.step(&c1) {
StepResult::Continue(b) => b,
other => panic!("server step 1 must continue, got {other:?}"),
};
let c2 = client.step(&s1).expect("client final");
let (principal, s2) = match server.step(&c2) {
StepResult::Done(p, b) => (p, b),
other => panic!("server step 2 must Done, got {other:?}"),
};
assert!(principal.name == "alice");
assert!(principal.auth_method == crate::AuthMethod::SaslScramSha256);
let final_check = client.verify_server_final(&s2);
assert!(final_check.is_ok(), "server signature must verify");
}
#[test]
fn pbkdf2_salted_matches_hash_scram_password_intermediate_sha512() {
let password = b"pencil";
let salt: Vec<u8> = (0..16).collect();
let cred =
hash_scram_password_with_salt(password, SaslMechanism::ScramSha512, 4096, salt.clone());
let salted = pbkdf2_salted(password, SaslMechanism::ScramSha512, 4096, &salt);
assert!(salted.len() == 64);
let (stored_key, server_key) = derive_keys_from_salted(SaslMechanism::ScramSha512, &salted);
assert!(stored_key == cred.stored_key);
assert!(server_key == cred.server_key);
}
#[test]
fn pbkdf2_salted_matches_hash_scram_password_intermediate_sha256() {
let password = b"pencil";
let salt: Vec<u8> = (0..16).collect();
let cred =
hash_scram_password_with_salt(password, SaslMechanism::ScramSha256, 4096, salt.clone());
let salted = pbkdf2_salted(password, SaslMechanism::ScramSha256, 4096, &salt);
assert!(salted.len() == 32);
let (stored_key, server_key) = derive_keys_from_salted(SaslMechanism::ScramSha256, &salted);
assert!(stored_key == cred.stored_key);
assert!(server_key == cred.server_key);
}
#[test]
fn derive_keys_from_salted_matches_hash_scram_password_sha512() {
let password = b"hunter2";
let salt: Vec<u8> = (0..16).collect();
let cred =
hash_scram_password_with_salt(password, SaslMechanism::ScramSha512, 4096, salt.clone());
let salted: [u8; 64] = pbkdf2::pbkdf2_hmac_array::<sha2::Sha512, 64>(password, &salt, 4096);
let (stored_key, server_key) = derive_keys_from_salted(SaslMechanism::ScramSha512, &salted);
assert!(stored_key == cred.stored_key);
assert!(server_key == cred.server_key);
assert!(stored_key.len() == 64);
assert!(server_key.len() == 64);
}
#[test]
fn derive_keys_from_salted_matches_hash_scram_password_sha256() {
let password = b"hunter2";
let salt: Vec<u8> = (0..16).collect();
let cred =
hash_scram_password_with_salt(password, SaslMechanism::ScramSha256, 4096, salt.clone());
let salted: [u8; 32] = pbkdf2::pbkdf2_hmac_array::<sha2::Sha256, 32>(password, &salt, 4096);
let (stored_key, server_key) = derive_keys_from_salted(SaslMechanism::ScramSha256, &salted);
assert!(stored_key == cred.stored_key);
assert!(server_key == cred.server_key);
assert!(stored_key.len() == 32);
assert!(server_key.len() == 32);
}
#[test]
fn scram_server_with_principal_override_yields_override_on_done() {
let password = b"hunter2";
let cred = hash_scram_password_with_salt(
password,
SaslMechanism::ScramSha256,
4096,
(0..16).collect::<Vec<u8>>(),
);
let override_principal = crate::Principal {
name: "alice".to_string(),
auth_method: crate::AuthMethod::SaslScramSha256,
groups: vec![],
};
let mut server = ScramServerExchange::new_with_principal(
"tok-uuid".to_string(),
cred,
override_principal.clone(),
);
let mut client = ScramClientExchange::new(
"tok-uuid".to_string(),
password.to_vec(),
SaslMechanism::ScramSha256,
);
let c1 = client.client_first().expect("client first");
let s1 = match server.step(&c1) {
StepResult::Continue(b) => b,
other => panic!("server step 1 must continue, got {other:?}"),
};
let c2 = client.step(&s1).expect("client final");
let (principal, _s2) = match server.step(&c2) {
StepResult::Done(p, b) => (p, b),
other => panic!("server step 2 must Done, got {other:?}"),
};
assert!(principal == override_principal);
assert!(principal.name == "alice");
}
#[test]
fn scram_server_rejects_bad_proof() {
let cred = hash_scram_password_with_salt(
b"correct",
SaslMechanism::ScramSha512,
4096,
vec![0u8; 16],
);
let mut server = ScramServerExchange::new("alice".to_string(), cred);
let mut client = ScramClientExchange::new(
"alice".to_string(),
b"wrong".to_vec(),
SaslMechanism::ScramSha512,
);
let c1 = client.client_first().unwrap();
let StepResult::Continue(s1) = server.step(&c1) else {
panic!();
};
let c2 = client.step(&s1).unwrap();
match server.step(&c2) {
StepResult::Failed(crate::AuthError::BadProof) => {}
other => panic!("expected BadProof, got {other:?}"),
}
}
#[test]
fn scram_server_rejects_wrong_nonce() {
let password: Vec<u8> = (b'A'..=b'Z').collect();
let cred = hash_scram_password_with_salt(
&password,
SaslMechanism::ScramSha256,
4096,
(0..16).collect::<Vec<u8>>(),
);
let mut server = ScramServerExchange::new("alice".to_string(), cred);
let mut client = ScramClientExchange::new(
"alice".to_string(),
password.clone(),
SaslMechanism::ScramSha256,
);
let c1 = client.client_first().unwrap();
let StepResult::Continue(s1) = server.step(&c1) else {
panic!("server step 1 must continue");
};
let c2 = client.step(&s1).unwrap();
let c2_str = String::from_utf8(c2).unwrap();
let combined = c2_str
.split(',')
.find_map(|a| a.strip_prefix("r="))
.expect("client-final has r=");
let tampered = c2_str.replacen(
&format!("r={combined}"),
&format!("r={combined}deadbeef"),
1,
);
match server.step(tampered.as_bytes()) {
StepResult::Failed(crate::AuthError::MalformedMessage) => {}
other => panic!("expected MalformedMessage for wrong nonce, got {other:?}"),
}
}
#[test]
fn scram_server_rejects_wrong_channel_binding() {
let password: Vec<u8> = (b'A'..=b'Z').collect();
let cred = hash_scram_password_with_salt(
&password,
SaslMechanism::ScramSha256,
4096,
(0..16).collect::<Vec<u8>>(),
);
let mut server = ScramServerExchange::new("alice".to_string(), cred);
let mut client = ScramClientExchange::new(
"alice".to_string(),
password.clone(),
SaslMechanism::ScramSha256,
);
let c1 = client.client_first().unwrap();
let StepResult::Continue(s1) = server.step(&c1) else {
panic!("server step 1 must continue");
};
let c2 = client.step(&s1).unwrap();
let c2_str = String::from_utf8(c2).unwrap();
assert!(c2_str.starts_with("c=biws,"), "client emits c=biws");
let wrong_cb = B64.encode(b"y,,");
let tampered = c2_str.replacen("c=biws", &format!("c={wrong_cb}"), 1);
match server.step(tampered.as_bytes()) {
StepResult::Failed(crate::AuthError::MalformedMessage) => {}
other => panic!("expected MalformedMessage for wrong channel binding, got {other:?}"),
}
}
}