use hmac::{Hmac, KeyInit, Mac};
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
use crate::error::AuthError;
use crate::protocol::{base64_decode, base64_encode};
pub(crate) const SHA256_LEN: usize = 32;
pub(crate) const MIN_PBKDF2_ITERATIONS: u32 = 4096;
pub(crate) const MAX_PBKDF2_ITERATIONS: u32 = 600_000;
pub(crate) const CLIENT_NONCE_LEN: usize = 24;
pub(crate) fn generate_client_nonce() -> Result<String, AuthError> {
let mut bytes = [0u8; CLIENT_NONCE_LEN];
getrandom::getrandom(&mut bytes).map_err(|_| AuthError::Other("CSPRNG unavailable"))?;
Ok(base64_encode(&bytes))
}
pub(crate) fn build_client_first(username: &str, client_nonce: &str) -> String {
let escaped_user = escape_saslname(username);
format!("n,,n={escaped_user},r={client_nonce}")
}
#[derive(Debug)]
pub(crate) struct ServerFirst {
pub(crate) nonce: String,
pub(crate) salt: Vec<u8>,
pub(crate) iterations: u32,
}
pub(crate) fn parse_server_first(
server_first: &str,
client_nonce: &str,
) -> Result<ServerFirst, AuthError> {
let mut nonce: Option<&str> = None;
let mut salt_b64: Option<&str> = None;
let mut iterations: Option<u32> = None;
for attr in server_first.split(',') {
let (key, value) = attr
.split_once('=')
.ok_or(AuthError::Other("malformed server-first message"))?;
match key {
"r" => nonce = Some(value),
"s" => salt_b64 = Some(value),
"i" => {
iterations = Some(
value
.parse::<u32>()
.map_err(|_| AuthError::Other("server iteration count not a u32"))?,
);
}
"m" => {
return Err(AuthError::Other(
"server requested an unsupported SCRAM extension",
));
}
_ => {
}
}
}
let nonce = nonce.ok_or(AuthError::Other("server-first missing r="))?;
let salt_b64 = salt_b64.ok_or(AuthError::Other("server-first missing s="))?;
let iterations = iterations.ok_or(AuthError::Other("server-first missing i="))?;
if !nonce.starts_with(client_nonce) {
return Err(AuthError::Other(
"server nonce does not start with client nonce",
));
}
if !(MIN_PBKDF2_ITERATIONS..=MAX_PBKDF2_ITERATIONS).contains(&iterations) {
return Err(AuthError::Other(
"server iteration count outside acceptable range",
));
}
let salt = base64_decode(salt_b64)
.map_err(|_| AuthError::Other("server-first salt is not valid base64"))?;
Ok(ServerFirst {
nonce: nonce.to_string(),
salt,
iterations,
})
}
#[derive(Debug)]
pub(crate) struct ClientFinal {
pub(crate) message: String,
pub(crate) expected_server_signature: [u8; SHA256_LEN],
}
pub(crate) fn compute_client_final(
username: &str,
password: &str,
client_nonce: &str,
server_first: &ServerFirst,
server_first_raw: &str,
) -> ClientFinal {
let mut salted_password = [0u8; SHA256_LEN];
pbkdf2::pbkdf2::<Hmac<Sha256>>(
password.as_bytes(),
&server_first.salt,
server_first.iterations,
&mut salted_password,
)
.expect("PBKDF2 with valid output length never fails");
let client_key = hmac_sha256(&salted_password, b"Client Key");
let stored_key: [u8; SHA256_LEN] = Sha256::digest(client_key).into();
let server_key = hmac_sha256(&salted_password, b"Server Key");
let client_final_no_proof = format!("c=biws,r={}", server_first.nonce);
let client_first_bare = format!("n={},r={}", escape_saslname(username), client_nonce);
let auth_message = format!("{client_first_bare},{server_first_raw},{client_final_no_proof}");
let client_signature = hmac_sha256(&stored_key, auth_message.as_bytes());
let mut client_proof = [0u8; SHA256_LEN];
for i in 0..SHA256_LEN {
client_proof[i] = client_key[i] ^ client_signature[i];
}
let expected_server_signature = hmac_sha256(&server_key, auth_message.as_bytes());
let message = format!("{client_final_no_proof},p={}", base64_encode(&client_proof));
ClientFinal {
message,
expected_server_signature,
}
}
pub(crate) fn verify_server_final(
server_final: &str,
expected: &[u8; SHA256_LEN],
) -> Result<(), AuthError> {
for attr in server_final.split(',') {
if let Some(error) = attr.strip_prefix("e=") {
let _ = error;
return Err(AuthError::Other("server rejected SCRAM exchange"));
}
}
let v_b64 = server_final
.split(',')
.find_map(|attr| attr.strip_prefix("v="))
.ok_or(AuthError::Other("server-final missing v="))?;
let v = base64_decode(v_b64)
.map_err(|_| AuthError::Other("server-final v= is not valid base64"))?;
if v.len() != SHA256_LEN {
return Err(AuthError::Other("server-final v= is not 32 bytes"));
}
if v.ct_eq(expected.as_slice()).into() {
Ok(())
} else {
Err(AuthError::Other("server signature did not verify"))
}
}
fn hmac_sha256(key: &[u8], message: &[u8]) -> [u8; SHA256_LEN] {
let mut mac =
<Hmac<Sha256> as KeyInit>::new_from_slice(key).expect("HMAC accepts arbitrary key length");
mac.update(message);
mac.finalize().into_bytes().into()
}
fn escape_saslname(name: &str) -> String {
let mut out = String::with_capacity(name.len());
for ch in name.chars() {
match ch {
',' => out.push_str("=2C"),
'=' => out.push_str("=3D"),
other => out.push(other),
}
}
out
}