#[cfg(feature = "scram")]
use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
#[cfg(feature = "scram")]
use hmac::{Hmac, Mac};
#[cfg(feature = "scram")]
use rand::RngCore;
#[cfg(feature = "scram")]
use sha2::{Digest, Sha256};
use crate::error::{PgWireError, Result};
#[cfg(feature = "scram")]
type HmacSha256 = Hmac<Sha256>;
#[cfg(feature = "scram")]
#[derive(Debug, Clone)]
pub struct ScramClient {
pub client_nonce_b64: String,
pub client_first_bare: String,
pub client_first: String,
}
#[cfg(feature = "scram")]
impl ScramClient {
pub fn new(username: &str) -> ScramClient {
let mut nonce = [0u8; 18];
rand::rng().fill_bytes(&mut nonce);
let nonce_b64 = B64.encode(nonce);
let user = sasl_escape_username(username);
let client_first_bare = format!("n={user},r={nonce_b64}");
let client_first = format!("n,,{client_first_bare}");
ScramClient {
client_nonce_b64: nonce_b64,
client_first_bare,
client_first,
}
}
#[cfg(test)]
pub(crate) fn with_nonce(username: &str, nonce_b64: &str) -> ScramClient {
let user = sasl_escape_username(username);
let client_first_bare = format!("n={user},r={nonce_b64}");
let client_first = format!("n,,{client_first_bare}");
ScramClient {
client_nonce_b64: nonce_b64.to_string(),
client_first_bare,
client_first,
}
}
pub fn parse_server_first(server_first: &str) -> Result<(String, String, u32)> {
let mut r = None;
let mut s = None;
let mut i = None;
for part in server_first.split(',') {
if let Some(v) = part.strip_prefix("r=") {
r = Some(v.to_string());
} else if let Some(v) = part.strip_prefix("s=") {
s = Some(v.to_string());
} else if let Some(v) = part.strip_prefix("i=") {
i = v.parse::<u32>().ok();
}
}
Ok((
r.ok_or_else(|| PgWireError::Auth("SCRAM server-first missing nonce (r=)".into()))?,
s.ok_or_else(|| PgWireError::Auth("SCRAM server-first missing salt (s=)".into()))?,
i.ok_or_else(|| {
PgWireError::Auth(
"SCRAM server-first missing or invalid iteration count (i=)".into(),
)
})?,
))
}
pub fn client_final(
&self,
password: &str,
server_first: &str,
) -> Result<(String, String, Vec<u8>)> {
let (rnonce, salt_b64, iters) = Self::parse_server_first(server_first)?;
if !rnonce.starts_with(&self.client_nonce_b64) {
return Err(PgWireError::Auth(
"SCRAM nonce mismatch: server nonce doesn't include client nonce".into(),
));
}
let salt = B64
.decode(salt_b64.as_bytes())
.map_err(|e| PgWireError::Auth(format!("SCRAM invalid salt base64: {e}")))?;
let channel_binding = "biws";
let client_final_wo_proof = format!("c={channel_binding},r={rnonce}");
let auth_message = format!(
"{},{},{}",
self.client_first_bare, server_first, client_final_wo_proof
);
let salted_password = hi_sha256(password.as_bytes(), &salt, iters);
let client_key = hmac_sha256(&salted_password, b"Client Key");
let stored_key = Sha256::digest(&client_key);
let client_sig = hmac_sha256(stored_key.as_slice(), auth_message.as_bytes());
let proof = xor_bytes(&client_key, &client_sig);
let proof_b64 = B64.encode(proof);
let client_final = format!("{client_final_wo_proof},p={proof_b64}");
Ok((client_final, auth_message, salted_password))
}
pub fn verify_server_final(
server_final: &str,
salted_password: &[u8],
auth_message: &str,
) -> Result<()> {
if let Some(err) = server_final.split(',').find_map(|p| p.strip_prefix("e=")) {
return Err(PgWireError::Auth(format!("SCRAM server error: {err}")));
}
let v = server_final
.split(',')
.find_map(|p| p.strip_prefix("v="))
.ok_or_else(|| PgWireError::Auth("SCRAM server-final missing signature (v=)".into()))?;
let server_sig = B64.decode(v.trim().as_bytes()).map_err(|e| {
PgWireError::Auth(format!("SCRAM invalid server signature base64: {e}"))
})?;
let server_key = hmac_sha256(salted_password, b"Server Key");
let expected = hmac_sha256(&server_key, auth_message.as_bytes());
if !constant_time_eq(&server_sig, &expected) {
return Err(PgWireError::Auth(
"SCRAM server signature mismatch: server may not know the password".into(),
));
}
Ok(())
}
}
#[cfg(feature = "scram")]
fn sasl_escape_username(u: &str) -> String {
u.replace('=', "=3D").replace(',', "=2C")
}
#[cfg(feature = "scram")]
fn hi_sha256(password: &[u8], salt: &[u8], iters: u32) -> Vec<u8> {
let mut s1 = Vec::with_capacity(salt.len() + 4);
s1.extend_from_slice(salt);
s1.extend_from_slice(&1u32.to_be_bytes());
let mut u = hmac_sha256(password, &s1);
let mut out = u.clone();
for _ in 1..iters {
u = hmac_sha256(password, &u);
for (o, ui) in out.iter_mut().zip(u.iter()) {
*o ^= *ui;
}
}
out
}
#[cfg(feature = "scram")]
fn hmac_sha256(key: &[u8], msg: &[u8]) -> Vec<u8> {
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC key length is always valid");
mac.update(msg);
mac.finalize().into_bytes().to_vec()
}
#[cfg(feature = "scram")]
fn xor_bytes(a: &[u8], b: &[u8]) -> Vec<u8> {
debug_assert_eq!(a.len(), b.len(), "XOR operands must have equal length");
a.iter().zip(b.iter()).map(|(x, y)| x ^ y).collect()
}
#[cfg(feature = "scram")]
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let result = a
.iter()
.zip(b.iter())
.fold(0u8, |acc, (x, y)| acc | (x ^ y));
result == 0
}
#[cfg(test)]
#[cfg(feature = "scram")]
mod tests {
use super::*;
#[test]
fn scram_builds_first_message() {
let c = ScramClient::new("user");
assert!(c.client_first.starts_with("n,,n=user,r="));
assert!(c.client_first_bare.starts_with("n=user,r="));
assert!(!c.client_nonce_b64.is_empty());
}
#[test]
fn scram_escapes_special_chars_in_username() {
let c = ScramClient::new("user=name,test");
assert!(c.client_first.contains("n=user=3Dname=2Ctest,r="));
}
#[test]
fn scram_unique_nonces() {
let c1 = ScramClient::new("user");
let c2 = ScramClient::new("user");
assert_ne!(c1.client_nonce_b64, c2.client_nonce_b64);
}
#[test]
fn parse_server_first_valid() {
let (r, s, i) = ScramClient::parse_server_first("r=abc123,s=c2FsdA==,i=4096").unwrap();
assert_eq!(r, "abc123");
assert_eq!(s, "c2FsdA==");
assert_eq!(i, 4096);
}
#[test]
fn parse_server_first_different_order() {
let (r, s, i) = ScramClient::parse_server_first("i=1000,s=Zm9v,r=xyz").unwrap();
assert_eq!(r, "xyz");
assert_eq!(s, "Zm9v");
assert_eq!(i, 1000);
}
#[test]
fn parse_server_first_with_extensions() {
let (r, s, i) =
ScramClient::parse_server_first("r=nonce,s=c2FsdA==,i=4096,x=unknown").unwrap();
assert_eq!(r, "nonce");
assert_eq!(i, 4096);
let _ = s; }
#[test]
fn parse_server_first_missing_nonce() {
let err = ScramClient::parse_server_first("s=c2FsdA==,i=4096").unwrap_err();
assert!(err.to_string().contains("nonce"));
}
#[test]
fn parse_server_first_missing_salt() {
let err = ScramClient::parse_server_first("r=abc,i=4096").unwrap_err();
assert!(err.to_string().contains("salt"));
}
#[test]
fn parse_server_first_missing_iterations() {
let err = ScramClient::parse_server_first("r=abc,s=c2FsdA==").unwrap_err();
assert!(err.to_string().contains("iteration"));
}
#[test]
fn parse_server_first_invalid_iterations() {
let err = ScramClient::parse_server_first("r=abc,s=c2FsdA==,i=notanumber").unwrap_err();
assert!(err.to_string().contains("iteration"));
}
#[test]
fn client_final_computes_proof() {
let client = ScramClient::with_nonce("user", "rOprNGfwEbeRWgbNEkqO");
let server_first = "r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096";
let (client_final, auth_message, salted_password) =
client.client_final("pencil", server_first).unwrap();
assert!(client_final.starts_with("c=biws,r="));
assert!(client_final.contains(",p="));
assert!(auth_message.contains(&client.client_first_bare));
assert!(auth_message.contains(server_first));
assert_eq!(salted_password.len(), 32);
}
#[test]
fn client_final_rejects_nonce_mismatch() {
let client = ScramClient::with_nonce("user", "clientnonce");
let server_first = "r=differentnonce,s=c2FsdA==,i=4096";
let err = client.client_final("password", server_first).unwrap_err();
assert!(err.to_string().contains("nonce mismatch"));
}
#[test]
fn client_final_rejects_invalid_salt_base64() {
let client = ScramClient::with_nonce("user", "abc");
let server_first = "r=abcdef,s=!!!invalid!!!,i=4096";
let err = client.client_final("password", server_first).unwrap_err();
assert!(err.to_string().contains("base64"));
}
#[test]
fn verify_server_final_accepts_valid_signature() {
let client = ScramClient::with_nonce("user", "fyko+d2lbbFgONRv9qkxdawL");
let server_first = "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096";
let (_, auth_message, salted_password) =
client.client_final("pencil", server_first).unwrap();
let server_key = hmac_sha256(&salted_password, b"Server Key");
let server_sig = hmac_sha256(&server_key, auth_message.as_bytes());
let server_final = format!("v={}", B64.encode(&server_sig));
ScramClient::verify_server_final(&server_final, &salted_password, &auth_message).unwrap();
}
#[test]
fn verify_server_final_rejects_wrong_signature() {
let salted_password = vec![0u8; 32];
let auth_message = "test";
let server_final = "v=AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=";
let err = ScramClient::verify_server_final(server_final, &salted_password, auth_message)
.unwrap_err();
assert!(err.to_string().contains("signature mismatch"));
}
#[test]
fn verify_server_final_rejects_missing_signature() {
let err = ScramClient::verify_server_final("", &[], "").unwrap_err();
assert!(err.to_string().contains("missing signature"));
}
#[test]
fn verify_server_final_handles_server_error() {
let err = ScramClient::verify_server_final("e=invalid-proof", &[], "").unwrap_err();
assert!(err.to_string().contains("server error"));
assert!(err.to_string().contains("invalid-proof"));
}
#[test]
fn verify_server_final_rejects_invalid_base64() {
let err = ScramClient::verify_server_final("v=!!!invalid!!!", &[], "").unwrap_err();
assert!(err.to_string().contains("base64"));
}
#[test]
fn sasl_escape_username_escapes_equals() {
assert_eq!(sasl_escape_username("a=b"), "a=3Db");
}
#[test]
fn sasl_escape_username_escapes_comma() {
assert_eq!(sasl_escape_username("a,b"), "a=2Cb");
}
#[test]
fn sasl_escape_username_escapes_both() {
assert_eq!(sasl_escape_username("a=b,c"), "a=3Db=2Cc");
}
#[test]
fn sasl_escape_username_preserves_normal() {
assert_eq!(sasl_escape_username("normal_user123"), "normal_user123");
}
#[test]
fn hi_sha256_single_iteration() {
let result = hi_sha256(b"password", b"salt", 1);
assert_eq!(result.len(), 32);
}
#[test]
fn hi_sha256_multiple_iterations() {
let result = hi_sha256(b"password", b"salt", 4096);
assert_eq!(result.len(), 32);
let result2 = hi_sha256(b"password", b"salt", 1000);
assert_ne!(result, result2);
}
#[test]
fn hmac_sha256_produces_correct_length() {
let result = hmac_sha256(b"key", b"message");
assert_eq!(result.len(), 32);
}
#[test]
fn xor_bytes_works() {
assert_eq!(xor_bytes(&[0xFF, 0x00], &[0x0F, 0xF0]), vec![0xF0, 0xF0]);
assert_eq!(xor_bytes(&[0x00], &[0x00]), vec![0x00]);
}
#[test]
fn constant_time_eq_equal() {
assert!(constant_time_eq(&[1, 2, 3], &[1, 2, 3]));
assert!(constant_time_eq(&[], &[]));
}
#[test]
fn constant_time_eq_not_equal() {
assert!(!constant_time_eq(&[1, 2, 3], &[1, 2, 4]));
assert!(!constant_time_eq(&[1, 2, 3], &[1, 2]));
}
#[test]
fn constant_time_eq_different_lengths() {
assert!(!constant_time_eq(&[1, 2, 3], &[1, 2, 3, 4]));
}
}