use std::collections::HashMap;
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use crate::backend::auth::{hmac_sha256, pbkdf2_hmac_sha256, sha256};
#[derive(Debug, Clone)]
pub struct ScramVerifier {
pub salt: Vec<u8>,
pub iterations: u32,
pub stored_key: [u8; 32],
pub server_key: [u8; 32],
}
impl ScramVerifier {
pub fn from_password(password: &str, salt: Vec<u8>, iterations: u32) -> Self {
let salted = pbkdf2_hmac_sha256(password.as_bytes(), &salt, iterations);
let client_key = hmac_sha256(&salted, b"Client Key");
let stored_key = sha256(&client_key);
let server_key = hmac_sha256(&salted, b"Server Key");
Self {
salt,
iterations,
stored_key,
server_key,
}
}
pub fn parse(s: &str) -> Option<Self> {
let rest = s.strip_prefix("SCRAM-SHA-256$")?;
let (params, keys) = rest.split_once('$')?;
let (iter_str, salt_b64) = params.split_once(':')?;
let (stored_b64, server_b64) = keys.split_once(':')?;
let iterations: u32 = iter_str.parse().ok()?;
let salt = BASE64.decode(salt_b64.trim()).ok()?;
let stored = BASE64.decode(stored_b64.trim()).ok()?;
let server = BASE64.decode(server_b64.trim()).ok()?;
if stored.len() != 32 || server.len() != 32 {
return None;
}
let mut stored_key = [0u8; 32];
stored_key.copy_from_slice(&stored);
let mut server_key = [0u8; 32];
server_key.copy_from_slice(&server);
Some(Self {
salt,
iterations,
stored_key,
server_key,
})
}
}
#[derive(Debug, Clone, Default)]
pub struct AuthFile {
users: HashMap<String, ScramVerifier>,
}
impl AuthFile {
pub fn load(path: &str) -> Result<Self, String> {
let data = std::fs::read_to_string(path)
.map_err(|e| format!("reading auth_file {}: {}", path, e))?;
Self::parse_str(&data, path)
}
pub fn parse_str(data: &str, path: &str) -> Result<Self, String> {
let mut users = HashMap::new();
for (lineno, raw) in data.lines().enumerate() {
let line = raw.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let (user, secret) = line
.split_once(':')
.ok_or_else(|| format!("{}:{}: expected `user:secret`", path, lineno + 1))?;
let user = unquote(user.trim());
let secret = unquote(secret.trim());
let verifier = if secret.starts_with("SCRAM-SHA-256$") {
ScramVerifier::parse(&secret)
.ok_or_else(|| format!("{}:{}: malformed SCRAM verifier", path, lineno + 1))?
} else {
let salt = sha256(user.as_bytes())[..16].to_vec();
ScramVerifier::from_password(&secret, salt, 4096)
};
users.insert(user, verifier);
}
Ok(Self { users })
}
pub fn get(&self, user: &str) -> Option<&ScramVerifier> {
self.users.get(user)
}
pub fn is_empty(&self) -> bool {
self.users.is_empty()
}
}
fn unquote(s: &str) -> String {
let t = s.trim();
if t.len() >= 2 && t.starts_with('"') && t.ends_with('"') {
t[1..t.len() - 1].to_string()
} else {
t.to_string()
}
}
pub struct ScramServer {
verifier: ScramVerifier,
combined_nonce: String,
client_first_bare: String,
server_first: String,
}
impl ScramServer {
pub fn start(
verifier: ScramVerifier,
client_first: &str,
server_nonce: &str,
) -> Result<(Self, String), String> {
let mut parts = client_first.splitn(3, ',');
let _gs2_cbind = parts.next();
let _gs2_authzid = parts.next();
let bare = parts
.next()
.ok_or_else(|| "malformed client-first (no bare part)".to_string())?;
let client_nonce = bare
.split(',')
.find_map(|f| f.strip_prefix("r="))
.ok_or_else(|| "client-first missing r=".to_string())?;
if client_nonce.is_empty() {
return Err("empty client nonce".to_string());
}
let combined_nonce = format!("{}{}", client_nonce, server_nonce);
let salt_b64 = BASE64.encode(&verifier.salt);
let server_first = format!(
"r={},s={},i={}",
combined_nonce, salt_b64, verifier.iterations
);
Ok((
Self {
verifier,
combined_nonce,
client_first_bare: bare.to_string(),
server_first: server_first.clone(),
},
server_first,
))
}
pub fn finish(&self, client_final: &str) -> Result<String, String> {
let proof_pos = client_final
.rfind(",p=")
.ok_or_else(|| "client-final missing p=".to_string())?;
let without_proof = &client_final[..proof_pos];
let proof_b64 = &client_final[proof_pos + 3..];
let echoed_nonce = without_proof
.split(',')
.find_map(|f| f.strip_prefix("r="))
.ok_or_else(|| "client-final missing r=".to_string())?;
if echoed_nonce != self.combined_nonce {
return Err("nonce mismatch".to_string());
}
let proof = BASE64
.decode(proof_b64.trim())
.map_err(|e| format!("bad proof base64: {}", e))?;
if proof.len() != 32 {
return Err("proof wrong length".to_string());
}
let auth_message = format!(
"{},{},{}",
self.client_first_bare, self.server_first, without_proof
);
let client_signature = hmac_sha256(&self.verifier.stored_key, auth_message.as_bytes());
let mut client_key = [0u8; 32];
for i in 0..32 {
client_key[i] = proof[i] ^ client_signature[i];
}
let derived_stored = sha256(&client_key);
if !constant_time_eq(&derived_stored, &self.verifier.stored_key) {
return Err("authentication failed (proof mismatch)".to_string());
}
let server_signature = hmac_sha256(&self.verifier.server_key, auth_message.as_bytes());
Ok(format!("v={}", BASE64.encode(server_signature)))
}
}
fn constant_time_eq(a: &[u8; 32], b: &[u8; 32]) -> bool {
let mut diff = 0u8;
for i in 0..32 {
diff |= a[i] ^ b[i];
}
diff == 0
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::auth::Scram;
#[test]
fn parse_pg_verifier_roundtrips_from_password() {
let v = ScramVerifier::from_password("s3cret", b"0123456789abcdef".to_vec(), 4096);
let s = format!(
"SCRAM-SHA-256${}:{}${}:{}",
v.iterations,
BASE64.encode(&v.salt),
BASE64.encode(v.stored_key),
BASE64.encode(v.server_key),
);
let p = ScramVerifier::parse(&s).expect("parses");
assert_eq!(p.iterations, v.iterations);
assert_eq!(p.salt, v.salt);
assert_eq!(p.stored_key, v.stored_key);
assert_eq!(p.server_key, v.server_key);
}
#[test]
fn full_scram_handshake_client_vs_server() {
let password = "correct horse battery staple";
let verifier = ScramVerifier::from_password(password, b"saltsaltsaltsalt".to_vec(), 4096);
let (mut client, init) = Scram::client_first("clientNONCE123");
let data = &init.0;
let mech_end = data.iter().position(|&b| b == 0).unwrap() + 1;
let client_first = std::str::from_utf8(&data[mech_end + 4..]).unwrap();
let (server, server_first) =
ScramServer::start(verifier.clone(), client_first, "serverNONCE456").unwrap();
let client_final = client
.client_final(server_first.as_bytes(), password)
.unwrap();
let server_final = server
.finish(std::str::from_utf8(&client_final.0).unwrap())
.unwrap();
client.verify_server(server_final.as_bytes()).unwrap();
}
#[test]
fn wrong_password_is_rejected() {
let verifier = ScramVerifier::from_password("rightpw", b"saltsaltsaltsalt".to_vec(), 4096);
let (mut client, init) = Scram::client_first("nonceAAA");
let data = &init.0;
let mech_end = data.iter().position(|&b| b == 0).unwrap() + 1;
let client_first = std::str::from_utf8(&data[mech_end + 4..]).unwrap();
let (server, server_first) =
ScramServer::start(verifier, client_first, "nonceBBB").unwrap();
let client_final = client
.client_final(server_first.as_bytes(), "wrongpw")
.unwrap();
let res = server.finish(std::str::from_utf8(&client_final.0).unwrap());
assert!(res.is_err(), "wrong password must be rejected");
}
#[test]
fn auth_file_parses_plaintext_and_verifier() {
let v = ScramVerifier::from_password("pw", b"0123456789abcdef".to_vec(), 4096);
let verifier_line = format!(
"carol:SCRAM-SHA-256${}:{}${}:{}",
v.iterations,
BASE64.encode(&v.salt),
BASE64.encode(v.stored_key),
BASE64.encode(v.server_key),
);
let body = format!(
"# comment\nalice:secret\n\nbob:\"quoted\"\n{}\n",
verifier_line
);
let af = AuthFile::parse_str(&body, "test").unwrap();
assert!(af.get("alice").is_some());
assert!(af.get("bob").is_some());
assert!(af.get("carol").is_some());
assert!(af.get("dave").is_none());
}
}