use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use hmac::{Hmac, Mac};
use pbkdf2::pbkdf2;
use rand::{rngs::OsRng, Rng};
use sha2::{Digest, Sha256};
use std::fmt;
type HmacSha256 = Hmac<Sha256>;
const MAX_SCRAM_ITERATIONS: u32 = 1_000_000;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum ScramError {
InvalidServerProof(String),
InvalidServerMessage(String),
Utf8Error(String),
Base64Error(String),
}
impl fmt::Display for ScramError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ScramError::InvalidServerProof(msg) => write!(f, "invalid server proof: {}", msg),
ScramError::InvalidServerMessage(msg) => write!(f, "invalid server message: {}", msg),
ScramError::Utf8Error(msg) => write!(f, "UTF-8 error: {}", msg),
ScramError::Base64Error(msg) => write!(f, "Base64 error: {}", msg),
}
}
}
impl std::error::Error for ScramError {}
#[derive(Clone, Debug)]
pub struct ScramState {
auth_message: Vec<u8>,
server_key: Vec<u8>,
}
pub struct ScramClient {
username: String,
password: String,
nonce: String,
}
impl ScramClient {
pub fn new(username: String, password: String) -> Self {
let mut rng = OsRng;
let nonce_bytes: Vec<u8> = (0..24).map(|_| rng.gen()).collect();
let nonce = BASE64.encode(&nonce_bytes);
Self {
username,
password,
nonce,
}
}
pub fn client_first(&self) -> String {
let escaped_username = self.username.replace('=', "=3D").replace(',', "=2C");
format!("n,,n={},r={}", escaped_username, self.nonce)
}
pub fn client_final(&mut self, server_first: &str) -> Result<(String, ScramState), ScramError> {
let (server_nonce, salt, iterations) = parse_server_first(server_first)?;
if !server_nonce.starts_with(&self.nonce) {
return Err(ScramError::InvalidServerMessage(
"server nonce doesn't contain client nonce".to_string(),
));
}
let salt_bytes = BASE64
.decode(&salt)
.map_err(|_| ScramError::Base64Error("invalid salt encoding".to_string()))?;
let iterations = iterations
.parse::<u32>()
.map_err(|_| ScramError::InvalidServerMessage("invalid iteration count".to_string()))?;
if iterations > MAX_SCRAM_ITERATIONS {
return Err(ScramError::InvalidServerMessage(format!(
"server iteration count {iterations} exceeds maximum of {MAX_SCRAM_ITERATIONS}"
)));
}
let channel_binding = BASE64.encode(b"n,,");
let client_final_without_proof = format!("c={},r={}", channel_binding, server_nonce);
let escaped_username = self.username.replace('=', "=3D").replace(',', "=2C");
let client_first_bare = format!("n={},r={}", escaped_username, self.nonce);
let auth_message = format!(
"{},{},{}",
client_first_bare, server_first, client_final_without_proof
);
let proof = calculate_client_proof(
&self.password,
&salt_bytes,
iterations,
auth_message.as_bytes(),
)?;
let server_key = calculate_server_key(&self.password, &salt_bytes, iterations)?;
let client_final = format!("{},p={}", client_final_without_proof, BASE64.encode(&proof));
let state = ScramState {
auth_message: auth_message.into_bytes(),
server_key,
};
Ok((client_final, state))
}
pub fn verify_server_final(
&self,
server_final: &str,
state: &ScramState,
) -> Result<(), ScramError> {
let server_sig_encoded = server_final
.strip_prefix("v=")
.ok_or_else(|| ScramError::InvalidServerMessage("missing 'v=' prefix".to_string()))?;
let server_signature = BASE64.decode(server_sig_encoded).map_err(|_| {
ScramError::Base64Error("invalid server signature encoding".to_string())
})?;
let expected_signature =
calculate_server_signature(&state.server_key, &state.auth_message)?;
if constant_time_compare(&server_signature, &expected_signature) {
Ok(())
} else {
Err(ScramError::InvalidServerProof(
"server signature verification failed".to_string(),
))
}
}
}
fn parse_server_first(msg: &str) -> Result<(String, String, String), ScramError> {
let mut nonce = String::new();
let mut salt = String::new();
let mut iterations = String::new();
for part in msg.split(',') {
if let Some(value) = part.strip_prefix("r=") {
nonce = value.to_string();
} else if let Some(value) = part.strip_prefix("s=") {
salt = value.to_string();
} else if let Some(value) = part.strip_prefix("i=") {
iterations = value.to_string();
}
}
if nonce.is_empty() || salt.is_empty() || iterations.is_empty() {
return Err(ScramError::InvalidServerMessage(
"missing required fields in server first message".to_string(),
));
}
Ok((nonce, salt, iterations))
}
fn calculate_client_proof(
password: &str,
salt: &[u8],
iterations: u32,
auth_message: &[u8],
) -> Result<Vec<u8>, ScramError> {
let password_bytes = password.as_bytes();
let mut salted_password = vec![0u8; 32]; let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
let mut client_key_hmac = HmacSha256::new_from_slice(&salted_password)
.map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
client_key_hmac.update(b"Client Key");
let client_key = client_key_hmac.finalize().into_bytes();
let stored_key = Sha256::digest(client_key.to_vec().as_slice());
let mut client_sig_hmac = HmacSha256::new_from_slice(&stored_key)
.map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
client_sig_hmac.update(auth_message);
let client_signature = client_sig_hmac.finalize().into_bytes();
let mut proof = client_key.to_vec();
for (proof_byte, sig_byte) in proof.iter_mut().zip(client_signature.iter()) {
*proof_byte ^= sig_byte;
}
Ok(proof.clone())
}
fn calculate_server_key(
password: &str,
salt: &[u8],
iterations: u32,
) -> Result<Vec<u8>, ScramError> {
let password_bytes = password.as_bytes();
let mut salted_password = vec![0u8; 32];
let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
let mut server_key_hmac = HmacSha256::new_from_slice(&salted_password)
.map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
server_key_hmac.update(b"Server Key");
Ok(server_key_hmac.finalize().into_bytes().to_vec())
}
fn calculate_server_signature(
server_key: &[u8],
auth_message: &[u8],
) -> Result<Vec<u8>, ScramError> {
let mut hmac = HmacSha256::new_from_slice(server_key)
.map_err(|_| ScramError::Utf8Error("invalid HMAC key for server signature".to_string()))?;
hmac.update(auth_message);
Ok(hmac.finalize().into_bytes().to_vec())
}
fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
use subtle::ConstantTimeEq;
a.ct_eq(b).into()
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)] use super::*;
#[test]
fn test_scram_client_creation() {
let client = ScramClient::new("user".to_string(), "password".to_string());
assert_eq!(client.username, "user");
assert_eq!(client.password, "password");
assert!(!client.nonce.is_empty());
}
#[test]
fn test_client_first_message_format() {
let client = ScramClient::new("alice".to_string(), "secret".to_string());
let first = client.client_first();
assert!(first.starts_with("n,,n=alice,r="));
assert!(first.len() > 20);
}
#[test]
fn test_parse_server_first_valid() {
let server_first = "r=client_nonce_server_nonce,s=aW1hZ2luYXJ5c2FsdA==,i=4096";
let (nonce, salt, iterations) = parse_server_first(server_first).unwrap();
assert_eq!(nonce, "client_nonce_server_nonce");
assert_eq!(salt, "aW1hZ2luYXJ5c2FsdA==");
assert_eq!(iterations, "4096");
}
#[test]
fn test_parse_server_first_invalid() {
let server_first = "r=nonce,s=salt"; let result = parse_server_first(server_first);
assert!(
matches!(result, Err(ScramError::InvalidServerMessage(_))),
"expected InvalidServerMessage error, got: {result:?}"
);
}
#[test]
fn test_constant_time_compare_equal() {
let a = b"test_value";
let b_arr = b"test_value";
assert!(constant_time_compare(a, b_arr));
}
#[test]
fn test_constant_time_compare_different() {
let a = b"test_value";
let b_arr = b"test_wrong";
assert!(!constant_time_compare(a, b_arr));
}
#[test]
fn test_constant_time_compare_different_length() {
let a = b"test";
let b_arr = b"test_longer";
assert!(!constant_time_compare(a, b_arr));
}
#[test]
fn test_scram_client_final_flow() {
let mut client = ScramClient::new("user".to_string(), "password".to_string());
let _client_first = client.client_first();
let server_nonce = format!("{}server_nonce_part", client.nonce);
let server_first = format!("r={},s={},i=4096", server_nonce, BASE64.encode(b"salty"));
let result = client.client_final(&server_first);
let (client_final, state) = result.unwrap_or_else(|e| {
panic!("expected Ok for client_final with valid server message: {e}")
});
assert!(client_final.starts_with("c="));
assert!(!state.auth_message.is_empty());
}
#[test]
fn test_scram_iteration_count_too_high_is_rejected() {
let mut client = ScramClient::new("user".to_string(), "password".to_string());
let _client_first = client.client_first();
let server_nonce = format!("{}server_nonce_part", client.nonce);
let excessive_iterations = MAX_SCRAM_ITERATIONS + 1;
let server_first = format!(
"r={},s={},i={}",
server_nonce,
BASE64.encode(b"salty"),
excessive_iterations
);
let result = client.client_final(&server_first);
assert!(
matches!(result, Err(ScramError::InvalidServerMessage(_))),
"expected InvalidServerMessage for excessive iterations, got: {result:?}"
);
}
#[test]
fn test_scram_iteration_count_at_limit_is_accepted() {
let mut client = ScramClient::new("user".to_string(), "password".to_string());
let _client_first = client.client_first();
let server_nonce = format!("{}server_nonce_part", client.nonce);
let server_first = format!(
"r={},s={},i={}",
server_nonce,
BASE64.encode(b"salty"),
MAX_SCRAM_ITERATIONS
);
let result = client.client_final(&server_first);
if let Err(ScramError::InvalidServerMessage(msg)) = &result {
assert!(
!msg.contains("iteration count"),
"unexpected iteration-count rejection at limit: {msg}"
);
}
}
#[test]
fn test_scram_username_escaping_in_auth_message() {
let mut client = ScramClient::new("user,admin=evil".to_string(), "password".to_string());
let client_first = client.client_first();
assert!(
client_first.contains("user=2Cadmin=3Devil"),
"client_first should escape ',' and '=' in username, got: {client_first}"
);
let server_nonce = format!("{}server_nonce_part", client.nonce);
let server_first = format!("r={},s={},i=4096", server_nonce, BASE64.encode(b"salty"));
let result = client.client_final(&server_first);
let (_client_final, state) =
result.unwrap_or_else(|e| panic!("expected Ok for escaped-username client_final: {e}"));
let auth_message = String::from_utf8(state.auth_message).unwrap();
assert!(
auth_message.contains("user=2Cadmin=3Devil"),
"auth_message should contain escaped username, got: {auth_message}"
);
assert!(
!auth_message.contains("user,admin=evil"),
"auth_message must NOT contain raw (unescaped) username, got: {auth_message}"
);
}
}