#![allow(clippy::doc_markdown, clippy::uninlined_format_args)]
use spg_crypto::{base64, hmac, sha256};
use spg_engine::ScramSecrets;
#[derive(Debug)]
#[allow(dead_code)] pub enum ScramError {
BadInitial(String),
BadFinal(String),
NonceMismatch,
ProofMismatch,
}
impl core::fmt::Display for ScramError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::BadInitial(s) => write!(f, "SCRAM: bad client-first ({s})"),
Self::BadFinal(s) => write!(f, "SCRAM: bad client-final ({s})"),
Self::NonceMismatch => f.write_str("SCRAM: server nonce mismatch"),
Self::ProofMismatch => f.write_str("SCRAM: invalid client proof"),
}
}
}
#[derive(Debug)]
pub struct ClientFirst {
pub bare: String,
pub client_nonce: String,
}
pub fn parse_client_first(msg: &str) -> Result<ClientFirst, ScramError> {
let stripped = msg
.strip_prefix("n,,")
.ok_or_else(|| ScramError::BadInitial("only 'n,,' GS2 header supported".into()))?;
let bare = stripped.to_string();
let mut client_nonce = None;
for attr in bare.split(',') {
if let Some(rest) = attr.strip_prefix("r=") {
client_nonce = Some(rest.to_string());
}
}
let client_nonce = client_nonce
.ok_or_else(|| ScramError::BadInitial("missing r= (client nonce) attribute".into()))?;
Ok(ClientFirst { bare, client_nonce })
}
pub fn build_server_first(combined_nonce: &str, secrets: &ScramSecrets) -> String {
let salt_b64 = base64::encode(&secrets.salt);
format!("r={combined_nonce},s={salt_b64},i={}", secrets.iters)
}
#[derive(Debug)]
pub struct ClientFinal {
pub without_proof: String,
pub combined_nonce: String,
pub client_proof: [u8; sha256::OUT_LEN],
}
pub fn parse_client_final(msg: &str) -> Result<ClientFinal, ScramError> {
let p_idx = msg
.rfind(",p=")
.ok_or_else(|| ScramError::BadFinal("missing p= (proof) attribute".into()))?;
let without_proof = msg[..p_idx].to_string();
let proof_b64 = &msg[p_idx + 3..];
let decoded = base64::decode(proof_b64)
.map_err(|_| ScramError::BadFinal("proof not valid base64".into()))?;
if decoded.len() != sha256::OUT_LEN {
return Err(ScramError::BadFinal(format!(
"proof length {} ≠ expected {}",
decoded.len(),
sha256::OUT_LEN
)));
}
let mut client_proof = [0u8; sha256::OUT_LEN];
client_proof.copy_from_slice(&decoded);
let mut combined_nonce = None;
for attr in without_proof.split(',') {
if let Some(rest) = attr.strip_prefix("r=") {
combined_nonce = Some(rest.to_string());
}
}
let combined_nonce = combined_nonce
.ok_or_else(|| ScramError::BadFinal("missing r= attribute in without-proof".into()))?;
Ok(ClientFinal {
without_proof,
combined_nonce,
client_proof,
})
}
pub fn verify_and_sign(
secrets: &ScramSecrets,
client_first_bare: &str,
server_first: &str,
client_final_without_proof: &str,
client_proof: &[u8; sha256::OUT_LEN],
) -> Result<String, ScramError> {
let auth_message = format!("{client_first_bare},{server_first},{client_final_without_proof}");
let client_signature = hmac::hmac_sha256(&secrets.stored_key, auth_message.as_bytes());
let mut client_key = [0u8; sha256::OUT_LEN];
for i in 0..sha256::OUT_LEN {
client_key[i] = client_proof[i] ^ client_signature[i];
}
let computed_stored = sha256::hash(&client_key);
if !constant_time_eq(&computed_stored, &secrets.stored_key) {
return Err(ScramError::ProofMismatch);
}
let server_signature = hmac::hmac_sha256(&secrets.server_key, auth_message.as_bytes());
Ok(format!("v={}", base64::encode(&server_signature)))
}
fn constant_time_eq(a: &[u8; 32], b: &[u8; 32]) -> bool {
let mut diff: u8 = 0;
for i in 0..32 {
diff |= a[i] ^ b[i];
}
diff == 0
}
#[cfg(test)]
mod tests {
use super::*;
use spg_engine::users::compute_scram_secrets;
#[test]
fn parse_client_first_extracts_bare_and_nonce() {
let cf = parse_client_first("n,,n=alice,r=clientnonce123").unwrap();
assert_eq!(cf.bare, "n=alice,r=clientnonce123");
assert_eq!(cf.client_nonce, "clientnonce123");
}
#[test]
fn parse_client_first_rejects_channel_binding() {
let err = parse_client_first("p=tls-unique,,n=alice,r=nonce").unwrap_err();
assert!(matches!(err, ScramError::BadInitial(_)));
}
#[test]
fn parse_client_final_round_trip() {
let proof = [7u8; 32];
let proof_b64 = base64::encode(&proof);
let msg = format!("c=biws,r=combined,p={proof_b64}");
let cf = parse_client_final(&msg).unwrap();
assert_eq!(cf.without_proof, "c=biws,r=combined");
assert_eq!(cf.combined_nonce, "combined");
assert_eq!(cf.client_proof, proof);
}
#[test]
fn full_exchange_round_trip() {
let salt = [11u8; 16];
let secrets = compute_scram_secrets("hunter2", salt, 4096);
let client_nonce = "client-noncenoncenonce";
let server_nonce = "server-noncenoncenonce";
let combined_nonce = format!("{client_nonce}{server_nonce}");
let client_first_bare = format!("n=alice,r={client_nonce}");
let server_first = build_server_first(&combined_nonce, &secrets);
let client_final_without_proof = format!("c=biws,r={combined_nonce}");
let auth_message =
format!("{client_first_bare},{server_first},{client_final_without_proof}");
let salted = spg_crypto::pbkdf2::pbkdf2_sha256_32(b"hunter2", &salt, 4096);
let client_key = hmac::hmac_sha256(&salted, b"Client Key");
let stored_key = sha256::hash(&client_key);
let client_signature = hmac::hmac_sha256(&stored_key, auth_message.as_bytes());
let mut client_proof = [0u8; 32];
for i in 0..32 {
client_proof[i] = client_key[i] ^ client_signature[i];
}
let server_signature = verify_and_sign(
&secrets,
&client_first_bare,
&server_first,
&client_final_without_proof,
&client_proof,
)
.expect("verify must succeed for a real proof");
assert!(server_signature.starts_with("v="));
}
#[test]
fn wrong_password_fails_verify() {
let salt = [5u8; 16];
let secrets = compute_scram_secrets("correct", salt, 4096);
let salted = spg_crypto::pbkdf2::pbkdf2_sha256_32(b"wrong", &salt, 4096);
let client_key = hmac::hmac_sha256(&salted, b"Client Key");
let auth_message = "n=u,r=x,r=x,c=biws,r=x".to_string();
let client_signature =
hmac::hmac_sha256(&sha256::hash(&client_key), auth_message.as_bytes());
let mut client_proof = [0u8; 32];
for i in 0..32 {
client_proof[i] = client_key[i] ^ client_signature[i];
}
let result = verify_and_sign(&secrets, "n=u,r=x", "r=x", "c=biws,r=x", &client_proof);
assert!(matches!(result, Err(ScramError::ProofMismatch)));
}
}