use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use hmac::{Hmac, Mac};
use md5::{Digest, Md5};
use pbkdf2::pbkdf2_hmac;
use rand::RngExt;
use sha2::Sha256;
use zeroize::{Zeroize, Zeroizing};
use super::error::{Error, Result};
#[must_use]
pub fn compute_md5_password(user: &str, password: &str, salt: &[u8]) -> String {
let mut hasher = Md5::new();
hasher.update(password.as_bytes());
hasher.update(user.as_bytes());
let first_hash = hasher.finalize();
let first_hex = hex_encode(&first_hash);
let mut hasher = Md5::new();
hasher.update(first_hex.as_bytes());
hasher.update(salt);
let second_hash = hasher.finalize();
format!("md5{}", hex_encode(&second_hash))
}
#[expect(
clippy::format_collect,
reason = "readable hex/string formatting loop; refactoring to fold! obscures intent"
)]
fn hex_encode(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{b:02x}")).collect()
}
#[derive(Debug)]
pub struct AuthState {
password: Zeroizing<String>,
client_nonce: String,
client_first_bare: String,
#[allow(
dead_code,
reason = "retained for future re-authentication flows that replay the SCRAM exchange"
)]
server_first: Option<String>,
auth_message: Option<String>,
server_key: Option<Zeroizing<Vec<u8>>>,
}
pub fn scram_client_first(password: &str) -> Result<(AuthState, Vec<u8>)> {
let client_nonce = generate_nonce();
let client_first_bare = format!("n=,r={client_nonce}");
let client_first = format!("n,,{client_first_bare}");
let state = AuthState {
password: Zeroizing::new(password.to_string()),
client_nonce,
client_first_bare,
server_first: None,
auth_message: None,
server_key: None,
};
Ok((state, client_first.into_bytes()))
}
pub fn scram_client_final(
mut state: AuthState,
server_first: &[u8],
) -> Result<(AuthState, Vec<u8>)> {
let server_first_str = std::str::from_utf8(server_first)
.map_err(|_| Error::authentication("invalid UTF-8 in server-first message"))?;
let mut server_nonce = None;
let mut salt_b64 = None;
let mut iterations = None;
for part in server_first_str.split(',') {
if let Some(value) = part.strip_prefix("r=") {
server_nonce = Some(value);
} else if let Some(value) = part.strip_prefix("s=") {
salt_b64 = Some(value);
} else if let Some(value) = part.strip_prefix("i=") {
iterations = Some(value.parse::<u32>().map_err(|_| {
Error::authentication("invalid iteration count in server-first message")
})?);
}
}
let server_nonce = server_nonce
.ok_or_else(|| Error::authentication("missing nonce in server-first message"))?;
let salt_b64 =
salt_b64.ok_or_else(|| Error::authentication("missing salt in server-first message"))?;
let iterations = iterations
.ok_or_else(|| Error::authentication("missing iterations in server-first message"))?;
let client_nonce_bytes = state.client_nonce.as_bytes();
let server_nonce_bytes = server_nonce.as_bytes();
let prefix_match = server_nonce_bytes.len() >= client_nonce_bytes.len() && {
let mut diff: u8 = 0;
for (a, b) in server_nonce_bytes
.iter()
.zip(client_nonce_bytes.iter())
.take(client_nonce_bytes.len())
{
diff |= a ^ b;
}
diff == 0
};
if !prefix_match {
return Err(Error::authentication(
"server nonce doesn't match client nonce",
));
}
let salt = BASE64
.decode(salt_b64)
.map_err(|_| Error::authentication("invalid base64 in salt"))?;
let salted_password: Zeroizing<Vec<u8>> =
Zeroizing::new(pbkdf2_sha256(&state.password, &salt, iterations));
let client_key: Zeroizing<Vec<u8>> =
Zeroizing::new(hmac_sha256(&salted_password, b"Client Key"));
let server_key: Zeroizing<Vec<u8>> =
Zeroizing::new(hmac_sha256(&salted_password, b"Server Key"));
let stored_key: Zeroizing<Vec<u8>> = Zeroizing::new(sha256(&client_key));
let channel_binding_b64 = BASE64.encode(b"n,,");
let client_final_without_proof = format!("c={channel_binding_b64},r={server_nonce}");
let auth_message = format!(
"{},{},{}",
state.client_first_bare, server_first_str, client_final_without_proof
);
let client_signature: Zeroizing<Vec<u8>> =
Zeroizing::new(hmac_sha256(&stored_key, auth_message.as_bytes()));
let mut client_proof: Zeroizing<Vec<u8>> = Zeroizing::new(
client_key
.iter()
.zip(client_signature.iter())
.map(|(k, s)| k ^ s)
.collect(),
);
let client_final = format!(
"{},p={}",
client_final_without_proof,
BASE64.encode(client_proof.as_slice())
);
client_proof.zeroize();
state.server_first = Some(server_first_str.to_string());
state.auth_message = Some(auth_message);
state.server_key = Some(server_key);
Ok((state, client_final.into_bytes()))
}
pub fn scram_verify_server(state: AuthState, server_final: &[u8]) -> Result<()> {
let server_final_str = std::str::from_utf8(server_final)
.map_err(|_| Error::authentication("invalid UTF-8 in server-final message"))?;
let server_sig_b64 = server_final_str
.strip_prefix("v=")
.ok_or_else(|| Error::authentication("invalid server-final message format"))?;
let server_sig: Zeroizing<Vec<u8>> = Zeroizing::new(
BASE64
.decode(server_sig_b64)
.map_err(|_| Error::authentication("invalid base64 in server signature"))?,
);
let server_key = state
.server_key
.ok_or_else(|| Error::authentication("missing server key in auth state"))?;
let auth_message = state
.auth_message
.ok_or_else(|| Error::authentication("missing auth message in auth state"))?;
let expected_sig: Zeroizing<Vec<u8>> =
Zeroizing::new(hmac_sha256(&server_key, auth_message.as_bytes()));
if server_sig.as_slice() != expected_sig.as_slice() {
return Err(Error::authentication(
"server signature verification failed",
));
}
Ok(())
}
fn generate_nonce() -> String {
let mut rng = rand::rng();
let bytes: [u8; 18] = rng.random();
BASE64.encode(bytes)
}
fn pbkdf2_sha256(password: &str, salt: &[u8], iterations: u32) -> Vec<u8> {
let mut result = [0u8; 32];
pbkdf2_hmac::<Sha256>(password.as_bytes(), salt, iterations, &mut result);
result.to_vec()
}
fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
type HmacSha256 = Hmac<Sha256>;
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC can take key of any size");
mac.update(data);
mac.finalize().into_bytes().to_vec()
}
fn sha256(data: &[u8]) -> Vec<u8> {
let mut hasher = Sha256::new();
hasher.update(data);
hasher.finalize().to_vec()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_md5_password() {
let result = compute_md5_password("user", "password", &[0x01, 0x02, 0x03, 0x04]);
assert!(result.starts_with("md5"));
assert_eq!(result.len(), 35); }
#[test]
fn test_hex_encode() {
assert_eq!(hex_encode(&[0x00, 0xff, 0x12, 0xab]), "00ff12ab");
}
#[test]
fn test_generate_nonce() {
let nonce1 = generate_nonce();
let nonce2 = generate_nonce();
assert_ne!(nonce1, nonce2);
assert!(!nonce1.is_empty());
}
}