use super::error::{BackendError, BackendResult};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use hmac::{Hmac, Mac};
use sha2::{Digest, Sha256};
type HmacSha256 = Hmac<Sha256>;
pub fn md5_password_response(user: &str, password: &str, salt: &[u8; 4]) -> Vec<u8> {
let mut out = Vec::with_capacity(35 + 1);
let inner = md5_hex(format!("{}{}", password, user).as_bytes());
let mut salted = Vec::with_capacity(inner.len() + 4);
salted.extend_from_slice(inner.as_bytes());
salted.extend_from_slice(salt);
out.extend_from_slice(b"md5");
out.extend_from_slice(md5_hex(&salted).as_bytes());
out.push(0);
out
}
fn md5_hex(bytes: &[u8]) -> String {
let digest = md5::Md5::digest(bytes);
let mut s = String::with_capacity(digest.len() * 2);
for b in digest {
s.push_str(&format!("{:02x}", b));
}
s
}
pub struct Scram {
client_first_bare: String,
nonce: String,
server_key: [u8; 32],
auth_message: String,
finalised: bool,
}
#[derive(Debug)]
pub struct ScramMessage(pub Vec<u8>);
impl Scram {
pub fn client_first(nonce: impl Into<String>) -> (Self, ScramMessage) {
let nonce = nonce.into();
let client_first_bare = format!("n=,r={}", nonce);
let client_first = format!("n,,{}", client_first_bare);
let mech = b"SCRAM-SHA-256\0";
let mut out = Vec::with_capacity(mech.len() + 4 + client_first.len());
out.extend_from_slice(mech);
out.extend_from_slice(&(client_first.len() as u32).to_be_bytes());
out.extend_from_slice(client_first.as_bytes());
(
Self {
client_first_bare,
nonce,
server_key: [0u8; 32],
auth_message: String::new(),
finalised: false,
},
ScramMessage(out),
)
}
pub fn client_final(
&mut self,
server_first: &[u8],
password: &str,
) -> BackendResult<ScramMessage> {
let server_first_str = std::str::from_utf8(server_first).map_err(|e| {
BackendError::Auth(format!("server-first is not UTF-8: {}", e))
})?;
let mut server_nonce = None;
let mut salt_b64 = None;
let mut iterations: Option<u32> = None;
for field in server_first_str.split(',') {
if let Some(rest) = field.strip_prefix("r=") {
server_nonce = Some(rest);
} else if let Some(rest) = field.strip_prefix("s=") {
salt_b64 = Some(rest);
} else if let Some(rest) = field.strip_prefix("i=") {
iterations = rest.parse().ok();
}
}
let server_nonce = server_nonce
.ok_or_else(|| BackendError::Auth("missing r= in server-first".into()))?;
let salt_b64 = salt_b64
.ok_or_else(|| BackendError::Auth("missing s= in server-first".into()))?;
let iterations = iterations
.ok_or_else(|| BackendError::Auth("missing/invalid i= in server-first".into()))?;
if !server_nonce.starts_with(&self.nonce) {
return Err(BackendError::Auth(
"server nonce does not extend client nonce".into(),
));
}
if iterations < 1 {
return Err(BackendError::Auth("iteration count must be >= 1".into()));
}
let salt = BASE64
.decode(salt_b64)
.map_err(|e| BackendError::Auth(format!("bad salt base64: {}", e)))?;
let salted_password = pbkdf2_hmac_sha256(password.as_bytes(), &salt, iterations);
let client_key = hmac_sha256(&salted_password, b"Client Key");
let stored_key = sha256(&client_key);
self.server_key = hmac_sha256(&salted_password, b"Server Key");
let channel_binding = BASE64.encode(b"n,,");
let client_final_without_proof =
format!("c={},r={}", channel_binding, server_nonce);
self.auth_message = format!(
"{},{},{}",
self.client_first_bare, server_first_str, client_final_without_proof
);
let client_signature = hmac_sha256(&stored_key, self.auth_message.as_bytes());
let mut client_proof = [0u8; 32];
for i in 0..32 {
client_proof[i] = client_key[i] ^ client_signature[i];
}
let client_final = format!(
"{},p={}",
client_final_without_proof,
BASE64.encode(client_proof)
);
self.finalised = true;
Ok(ScramMessage(client_final.into_bytes()))
}
pub fn verify_server(&self, server_final: &[u8]) -> BackendResult<()> {
if !self.finalised {
return Err(BackendError::Auth(
"verify_server called before client_final".into(),
));
}
let s = std::str::from_utf8(server_final).map_err(|e| {
BackendError::Auth(format!("server-final is not UTF-8: {}", e))
})?;
if let Some(err) = s.strip_prefix("e=") {
return Err(BackendError::Auth(format!("server reported: {}", err)));
}
let sig_b64 = s
.strip_prefix("v=")
.ok_or_else(|| BackendError::Auth("missing v= in server-final".into()))?
.split(',')
.next()
.unwrap_or("");
let received = BASE64
.decode(sig_b64)
.map_err(|e| BackendError::Auth(format!("bad v= base64: {}", e)))?;
let expected = hmac_sha256(&self.server_key, self.auth_message.as_bytes());
if received == expected {
Ok(())
} else {
Err(BackendError::Auth("server signature mismatch".into()))
}
}
}
fn sha256(data: &[u8]) -> [u8; 32] {
let mut h = Sha256::new();
h.update(data);
h.finalize().into()
}
fn hmac_sha256(key: &[u8], data: &[u8]) -> [u8; 32] {
let mut mac =
HmacSha256::new_from_slice(key).expect("HMAC accepts any key length");
mac.update(data);
let tag = mac.finalize().into_bytes();
let mut out = [0u8; 32];
out.copy_from_slice(&tag);
out
}
fn pbkdf2_hmac_sha256(password: &[u8], salt: &[u8], iters: u32) -> [u8; 32] {
let mut mac = HmacSha256::new_from_slice(password)
.expect("HMAC accepts any key length");
mac.update(salt);
mac.update(&1u32.to_be_bytes());
let mut u: [u8; 32] = mac.finalize().into_bytes().into();
let mut out = u;
for _ in 1..iters {
let mut mac = HmacSha256::new_from_slice(password)
.expect("HMAC accepts any key length");
mac.update(&u);
u = mac.finalize().into_bytes().into();
for i in 0..32 {
out[i] ^= u[i];
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_md5_password_response_known_answer() {
let got = md5_password_response("alice", "secret", &[0x01, 0x02, 0x03, 0x04]);
assert_eq!(got.last().copied(), Some(0u8));
let body = std::str::from_utf8(&got[..got.len() - 1]).unwrap();
assert!(body.starts_with("md5"));
assert_eq!(body.len(), 3 + 32); let inner = md5_hex(b"secretalice");
let mut combined = inner.into_bytes();
combined.extend_from_slice(&[0x01, 0x02, 0x03, 0x04]);
let outer = md5_hex(&combined);
assert_eq!(&body[3..], outer);
}
#[test]
fn test_pbkdf2_hmac_sha256_rfc_vector() {
let got = pbkdf2_hmac_sha256(b"password", b"salt", 1);
let expected = [
0x12, 0x0f, 0xb6, 0xcf, 0xfc, 0xf8, 0xb3, 0x2c, 0x43, 0xe7, 0x22, 0x52,
0x56, 0xc4, 0xf8, 0x37, 0xa8, 0x65, 0x48, 0xc9, 0x2c, 0xcc, 0x35, 0x48,
0x08, 0x05, 0x98, 0x7c, 0xb7, 0x0b, 0xe1, 0x7b,
];
assert_eq!(got, expected);
}
#[test]
fn test_pbkdf2_hmac_sha256_high_iters() {
let got = pbkdf2_hmac_sha256(b"password", b"salt", 4096);
let expected = [
0xc5, 0xe4, 0x78, 0xd5, 0x92, 0x88, 0xc8, 0x41, 0xaa, 0x53, 0x0d, 0xb6,
0x84, 0x5c, 0x4c, 0x8d, 0x96, 0x28, 0x93, 0xa0, 0x01, 0xce, 0x4e, 0x11,
0xa4, 0x96, 0x38, 0x73, 0xaa, 0x98, 0x13, 0x4a,
];
assert_eq!(got, expected);
}
#[test]
fn test_scram_roundtrip_against_synthetic_server() {
let (mut scram, first) = Scram::client_first("fyko+d2lbbFgONRv9qkxdawL");
let msg = &first.0;
let mech_end = msg.iter().position(|&b| b == 0).unwrap();
assert_eq!(&msg[..mech_end], b"SCRAM-SHA-256");
let len =
u32::from_be_bytes(msg[mech_end + 1..mech_end + 5].try_into().unwrap())
as usize;
let cfirst = &msg[mech_end + 5..mech_end + 5 + len];
let cfirst_str = std::str::from_utf8(cfirst).unwrap();
assert!(cfirst_str.starts_with("n,,n=,r=fyko+d2lbbFgONRv9qkxdawL"));
let server_nonce_suffix = "3rfcNHYJY1ZVvWVs7j";
let combined_nonce =
format!("fyko+d2lbbFgONRv9qkxdawL{}", server_nonce_suffix);
let salt: [u8; 16] = [
0x41, 0x25, 0xc2, 0x47, 0xe4, 0x3a, 0xb1, 0xe9, 0x3c, 0x6d, 0xff, 0x76,
0xd1, 0x22, 0x3a, 0x10,
];
let iterations = 4096u32;
let salt_b64 = BASE64.encode(salt);
let server_first = format!(
"r={},s={},i={}",
combined_nonce, salt_b64, iterations
);
let password = "pencil";
let client_final = scram
.client_final(server_first.as_bytes(), password)
.expect("client_final");
let cfinal_str = std::str::from_utf8(&client_final.0).unwrap();
assert!(cfinal_str.starts_with("c=biws,r=")); assert!(cfinal_str.contains(&format!("r={}", combined_nonce)));
assert!(cfinal_str.contains(",p="));
let salted = pbkdf2_hmac_sha256(password.as_bytes(), &salt, iterations);
let server_key = hmac_sha256(&salted, b"Server Key");
let (cfinal_no_proof, _proof) = {
let idx = cfinal_str.rfind(",p=").unwrap();
(&cfinal_str[..idx], &cfinal_str[idx + 3..])
};
let auth_message = format!(
"n=,r=fyko+d2lbbFgONRv9qkxdawL,{},{}",
server_first, cfinal_no_proof
);
let server_sig = hmac_sha256(&server_key, auth_message.as_bytes());
let server_final = format!("v={}", BASE64.encode(server_sig));
scram
.verify_server(server_final.as_bytes())
.expect("verify_server");
}
#[test]
fn test_scram_rejects_nonce_mismatch() {
let (mut scram, _) = Scram::client_first("client-nonce");
let server_first = "r=OTHER-nonce,s=QUJD,i=4096";
let err = scram.client_final(server_first.as_bytes(), "pw").unwrap_err();
assert!(matches!(err, BackendError::Auth(_)));
}
#[test]
fn test_scram_rejects_bad_server_signature() {
let (mut scram, _) = Scram::client_first("abc");
let server_first = "r=abc-extension,s=QUJD,i=4096";
let _ = scram.client_final(server_first.as_bytes(), "pw").unwrap();
let bad_sig = BASE64.encode([0u8; 32]);
let server_final = format!("v={}", bad_sig);
assert!(scram.verify_server(server_final.as_bytes()).is_err());
}
#[test]
fn test_scram_rejects_server_error() {
let (mut scram, _) = Scram::client_first("abc");
let server_first = "r=abc-extension,s=QUJD,i=4096";
let _ = scram.client_final(server_first.as_bytes(), "pw").unwrap();
assert!(scram.verify_server(b"e=invalid-proof").is_err());
}
}