use base64::{Engine, engine::general_purpose::STANDARD as B64};
use hmac::{Hmac, Mac};
use md5::Md5;
use sha2::{Digest, Sha256};
use crate::DriverError;
type HmacSha256 = Hmac<Sha256>;
pub fn md5_password(user: &str, password: &str, salt: &[u8; 4]) -> [u8; 36] {
let mut hasher = Md5::new();
hasher.update(password.as_bytes());
hasher.update(user.as_bytes());
let inner = hex_encode_fixed(&hasher.finalize());
let mut hasher = Md5::new();
hasher.update(inner);
hasher.update(salt);
let outer = hex_encode_fixed(&hasher.finalize());
let mut result = [0u8; 36];
result[0] = b'm';
result[1] = b'd';
result[2] = b'5';
result[3..35].copy_from_slice(&outer);
result[35] = 0;
result
}
pub struct ScramClient {
password: String,
nonce: String,
client_first_bare: String,
server_first: String,
salted_password: [u8; 32],
auth_message: String,
}
impl ScramClient {
pub fn new(user: &str, password: &str) -> Result<Self, DriverError> {
let nonce = generate_nonce()?;
let client_first_bare = format!("n={user},r={nonce}");
Ok(Self {
password: password.to_owned(),
nonce,
client_first_bare,
server_first: String::new(),
salted_password: [0u8; 32],
auth_message: String::new(),
})
}
pub fn client_first_message(&self) -> Vec<u8> {
format!("n,,{}", self.client_first_bare).into_bytes()
}
pub fn process_server_first(&mut self, server_first: &[u8]) -> Result<(), DriverError> {
let server_first_str = std::str::from_utf8(server_first)
.map_err(|_| DriverError::Auth("server-first is not valid UTF-8".into()))?;
self.server_first = server_first_str.to_owned();
let mut server_nonce = None;
let mut salt_b64 = None;
let mut iterations = None;
for part in server_first_str.split(',') {
if let Some(val) = part.strip_prefix("r=") {
server_nonce = Some(val);
} else if let Some(val) = part.strip_prefix("s=") {
salt_b64 = Some(val);
} else if let Some(val) = part.strip_prefix("i=") {
iterations = val.parse::<u32>().ok();
}
}
let server_nonce = server_nonce
.ok_or_else(|| DriverError::Auth("missing nonce in server-first".into()))?;
let salt_b64 =
salt_b64.ok_or_else(|| DriverError::Auth("missing salt in server-first".into()))?;
let iterations = iterations
.ok_or_else(|| DriverError::Auth("missing iterations in server-first".into()))?;
if !server_nonce.starts_with(&self.nonce) {
return Err(DriverError::Auth(
"server nonce does not start with client nonce".into(),
));
}
let salt = B64
.decode(salt_b64)
.map_err(|_| DriverError::Auth("invalid base64 salt".into()))?;
pbkdf2::pbkdf2_hmac::<Sha256>(
self.password.as_bytes(),
&salt,
iterations,
&mut self.salted_password,
);
self.password.clear();
self.password.shrink_to(0);
let client_final_without_proof = format!("c=biws,r={server_nonce}");
self.auth_message = format!(
"{},{},{}",
self.client_first_bare, self.server_first, client_final_without_proof
);
Ok(())
}
pub fn client_final_message(&self) -> Result<Vec<u8>, DriverError> {
let client_key = hmac_sha256(&self.salted_password, b"Client Key")?;
let stored_key = Sha256::digest(client_key);
let client_signature = hmac_sha256(&stored_key, self.auth_message.as_bytes())?;
let mut proof = client_key;
for (p, s) in proof.iter_mut().zip(client_signature.iter()) {
*p ^= s;
}
let proof_b64 = B64.encode(proof);
let server_nonce = self
.server_first
.split(',')
.find_map(|p| p.strip_prefix("r="))
.ok_or_else(|| DriverError::Auth("missing nonce for final message".into()))?;
let msg = format!("c=biws,r={server_nonce},p={proof_b64}");
Ok(msg.into_bytes())
}
pub fn verify_server_final(&self, server_final: &[u8]) -> Result<(), DriverError> {
let server_final_str = std::str::from_utf8(server_final)
.map_err(|_| DriverError::Auth("server-final is not valid UTF-8".into()))?;
let server_sig_b64 = server_final_str
.strip_prefix("v=")
.ok_or_else(|| DriverError::Auth("server-final missing 'v=' prefix".into()))?;
let server_sig = B64
.decode(server_sig_b64)
.map_err(|_| DriverError::Auth("invalid base64 in server signature".into()))?;
let server_key = hmac_sha256(&self.salted_password, b"Server Key")?;
let expected = hmac_sha256(&server_key, self.auth_message.as_bytes())?;
if !constant_time_eq(&server_sig, &expected) {
return Err(DriverError::Auth("server signature mismatch".into()));
}
Ok(())
}
}
fn hmac_sha256(key: &[u8], data: &[u8]) -> Result<[u8; 32], DriverError> {
let mut mac = HmacSha256::new_from_slice(key)
.map_err(|_| DriverError::Auth("HMAC computation failed".into()))?;
mac.update(data);
Ok(mac.finalize().into_bytes().into())
}
fn generate_nonce() -> Result<String, DriverError> {
use rand::TryRngCore;
let mut bytes = [0u8; 24];
rand::rngs::OsRng
.try_fill_bytes(&mut bytes)
.map_err(|e| DriverError::Auth(format!("OS RNG failed: {e}")))?;
Ok(B64.encode(bytes))
}
#[inline(never)]
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
let max_len = a.len().max(b.len());
let mut diff: u32 = 0;
diff |= (a.len() ^ b.len()) as u32;
for i in 0..max_len {
let x = if i < a.len() { a[i] } else { 0 };
let y = if i < b.len() { b[i] } else { 0 };
diff |= (x ^ y) as u32;
}
diff == 0
}
fn hex_encode_fixed(bytes: &[u8]) -> [u8; 32] {
const HEX: &[u8; 16] = b"0123456789abcdef";
let mut out = [0u8; 32];
for (i, &b) in bytes.iter().enumerate() {
out[i * 2] = HEX[(b >> 4) as usize];
out[i * 2 + 1] = HEX[(b & 0x0f) as usize];
}
out
}
#[cfg(test)]
fn hex_encode(bytes: &[u8]) -> String {
const HEX: &[u8; 16] = b"0123456789abcdef";
let mut out = String::with_capacity(bytes.len() * 2);
for &b in bytes {
out.push(HEX[(b >> 4) as usize] as char);
out.push(HEX[(b & 0x0f) as usize] as char);
}
out
}
pub fn parse_sasl_mechanisms(data: &[u8]) -> smallvec::SmallVec<[&str; 2]> {
let mut mechanisms = smallvec::SmallVec::new();
let mut pos = 0;
while pos < data.len() {
if data[pos] == 0 {
break;
}
if let Some(end) = data[pos..].iter().position(|&b| b == 0) {
if let Ok(s) = std::str::from_utf8(&data[pos..pos + end]) {
if !s.is_empty() {
mechanisms.push(s);
}
}
pos += end + 1;
} else {
break;
}
}
mechanisms
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn md5_password_known_value() {
let result = md5_password("testuser", "testpass", &[0x01, 0x02, 0x03, 0x04]);
assert!(result.starts_with(b"md5"));
assert_eq!(result[35], 0); }
#[test]
fn md5_password_format() {
let result = md5_password("user", "pass", &[0xAA, 0xBB, 0xCC, 0xDD]);
let s = std::str::from_utf8(&result[..35]).unwrap();
assert!(s.starts_with("md5"));
assert!(s[3..].chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn scram_client_first_message_format() {
let client = ScramClient::new("testuser", "testpass").unwrap();
let msg = client.client_first_message();
let s = std::str::from_utf8(&msg).unwrap();
assert!(s.starts_with("n,,n=testuser,r="));
}
#[test]
fn scram_nonce_is_unique() {
let n1 = generate_nonce().unwrap();
let n2 = generate_nonce().unwrap();
assert_ne!(n1, n2);
}
#[test]
fn constant_time_eq_works() {
assert!(constant_time_eq(b"hello", b"hello"));
assert!(!constant_time_eq(b"hello", b"world"));
assert!(!constant_time_eq(b"hello", b"hell"));
}
#[test]
fn hex_encode_works() {
assert_eq!(hex_encode(&[0xDE, 0xAD, 0xBE, 0xEF]), "deadbeef");
assert_eq!(hex_encode(&[0x00, 0xFF]), "00ff");
}
#[test]
fn parse_sasl_mechanisms_works() {
let data = b"SCRAM-SHA-256\0SCRAM-SHA-256-PLUS\0\0";
let mechs = parse_sasl_mechanisms(data);
assert_eq!(mechs.as_slice(), &["SCRAM-SHA-256", "SCRAM-SHA-256-PLUS"]);
}
#[test]
fn parse_sasl_mechanisms_empty() {
let data = b"\0";
let mechs = parse_sasl_mechanisms(data);
assert!(mechs.is_empty());
}
#[test]
fn scram_roundtrip() {
let mut client = ScramClient::new("user", "pencil").unwrap();
let _first = client.client_first_message();
let server_nonce = format!("{}serverpart", client.nonce);
let salt = B64.encode(b"salt1234salt5678");
let server_first = format!("r={server_nonce},s={salt},i=4096");
client
.process_server_first(server_first.as_bytes())
.unwrap();
let final_msg = client.client_final_message().unwrap();
let s = std::str::from_utf8(&final_msg).unwrap();
assert!(s.starts_with("c=biws,r="));
assert!(s.contains(",p="));
}
#[test]
fn scram_rejects_bad_nonce() {
let mut client = ScramClient::new("user", "pass").unwrap();
let _first = client.client_first_message();
let result = client.process_server_first(b"r=wrongnonce,s=c2FsdA==,i=4096");
assert!(result.is_err());
}
#[test]
fn constant_time_eq_different_lengths() {
assert!(!constant_time_eq(b"ab", b"abc"));
assert!(!constant_time_eq(b"abc", b"ab"));
assert!(!constant_time_eq(b"", b"a"));
assert!(!constant_time_eq(b"a", b""));
assert!(constant_time_eq(b"", b""));
}
#[test]
fn constant_time_eq_sha256_length() {
let a = [0xAAu8; 32];
let b = [0xAAu8; 32];
let c = [0xBBu8; 32];
assert!(constant_time_eq(&a, &b));
assert!(!constant_time_eq(&a, &c));
}
#[test]
fn scram_missing_salt_error() {
let mut client = ScramClient::new("user", "pass").unwrap();
let _first = client.client_first_message();
let server_nonce = format!("{}serverpart", client.nonce);
let server_first = format!("r={server_nonce},i=4096"); let result = client.process_server_first(server_first.as_bytes());
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("salt"), "should mention salt: {err}");
}
#[test]
fn scram_missing_iterations_error() {
let mut client = ScramClient::new("user", "pass").unwrap();
let _first = client.client_first_message();
let server_nonce = format!("{}serverpart", client.nonce);
let salt = B64.encode(b"salt1234");
let server_first = format!("r={server_nonce},s={salt}"); let result = client.process_server_first(server_first.as_bytes());
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("iterations"),
"should mention iterations: {err}"
);
}
#[test]
fn scram_non_numeric_iterations_error() {
let mut client = ScramClient::new("user", "pass").unwrap();
let _first = client.client_first_message();
let server_nonce = format!("{}serverpart", client.nonce);
let salt = B64.encode(b"salt1234");
let server_first = format!("r={server_nonce},s={salt},i=notanumber");
let result = client.process_server_first(server_first.as_bytes());
assert!(result.is_err());
}
#[test]
fn scram_invalid_base64_salt_error() {
let mut client = ScramClient::new("user", "pass").unwrap();
let _first = client.client_first_message();
let server_nonce = format!("{}serverpart", client.nonce);
let server_first = format!("r={server_nonce},s=!@#$not_base64,i=4096");
let result = client.process_server_first(server_first.as_bytes());
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("base64") || err.contains("salt"),
"should mention base64 or salt: {err}"
);
}
#[test]
fn scram_verify_server_final_mismatch() {
let mut client = ScramClient::new("user", "pencil").unwrap();
let _first = client.client_first_message();
let server_nonce = format!("{}serverpart", client.nonce);
let salt = B64.encode(b"salt1234salt5678");
let server_first = format!("r={server_nonce},s={salt},i=4096");
client
.process_server_first(server_first.as_bytes())
.unwrap();
let _final_msg = client.client_final_message().unwrap();
let wrong_sig = B64.encode(b"wrongwrongwrongwrongwrongwrongww"); let server_final = format!("v={wrong_sig}");
let result = client.verify_server_final(server_final.as_bytes());
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("mismatch"), "should mention mismatch: {err}");
}
#[test]
fn scram_verify_server_final_missing_prefix() {
let mut client = ScramClient::new("user", "pencil").unwrap();
let _first = client.client_first_message();
let server_nonce = format!("{}serverpart", client.nonce);
let salt = B64.encode(b"salt1234salt5678");
let server_first = format!("r={server_nonce},s={salt},i=4096");
client
.process_server_first(server_first.as_bytes())
.unwrap();
let result = client.verify_server_final(b"no_v_prefix_here");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("v="),
"should mention missing v= prefix: {err}"
);
}
#[test]
fn constant_time_eq_both_empty_true() {
assert!(constant_time_eq(b"", b""));
}
#[test]
fn constant_time_eq_diff_lengths_false() {
assert!(!constant_time_eq(b"a", b"ab"));
assert!(!constant_time_eq(b"ab", b"a"));
assert!(!constant_time_eq(b"", b"x"));
}
#[test]
fn parse_sasl_mechanisms_unsupported_only() {
let data = b"SCRAM-SHA-512\0SCRAM-SHA-256-PLUS\0\0";
let mechs = parse_sasl_mechanisms(data);
assert_eq!(mechs.len(), 2);
assert!(!mechs.contains(&"SCRAM-SHA-256"));
}
}