use aws_lc_rs::{digest, hmac, pbkdf2, rand};
use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ScramError {
InvalidServerNonce,
InvalidServerFirst(String),
InvalidServerFinal(String),
ServerSignatureMismatch,
Base64Error(String),
InvalidIterationCount,
}
impl std::fmt::Display for ScramError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ScramError::InvalidServerNonce => {
write!(f, "server nonce does not contain client nonce prefix")
}
ScramError::InvalidServerFirst(msg) => {
write!(f, "failed to parse server-first message: {msg}")
}
ScramError::InvalidServerFinal(msg) => {
write!(f, "failed to parse server-final message: {msg}")
}
ScramError::ServerSignatureMismatch => write!(f, "server signature mismatch"),
ScramError::Base64Error(msg) => write!(f, "base64 error: {msg}"),
ScramError::InvalidIterationCount => write!(f, "invalid iteration count"),
}
}
}
impl std::error::Error for ScramError {}
pub struct ScramClient {
username: String,
password: String,
client_nonce: String,
client_first_bare: String,
server_first: Option<String>,
auth_message: Option<String>,
server_key: Option<[u8; 32]>,
}
impl ScramClient {
pub fn new(username: &str, password: &str) -> Self {
let client_nonce = generate_nonce();
Self {
username: username.to_string(),
password: password.to_string(),
client_nonce,
client_first_bare: String::new(),
server_first: None,
auth_message: None,
server_key: None,
}
}
#[cfg(test)]
pub fn with_nonce(username: &str, password: &str, nonce: &str) -> Self {
Self {
username: username.to_string(),
password: password.to_string(),
client_nonce: nonce.to_string(),
client_first_bare: String::new(),
server_first: None,
auth_message: None,
server_key: None,
}
}
pub fn mechanism() -> &'static str {
"SCRAM-SHA-256"
}
pub fn client_first(&mut self) -> String {
let gs2_header = "n,,";
let escaped_username = escape_username(&self.username);
self.client_first_bare = format!("n={},r={}", escaped_username, self.client_nonce);
format!("{}{}", gs2_header, self.client_first_bare)
}
pub fn client_final(&mut self, server_first: &str) -> Result<String, ScramError> {
self.server_first = Some(server_first.to_string());
let (server_nonce, salt, iterations) = parse_server_first(server_first)?;
if !server_nonce.starts_with(&self.client_nonce) {
return Err(ScramError::InvalidServerNonce);
}
let salt_bytes = BASE64
.decode(&salt)
.map_err(|e| ScramError::Base64Error(e.to_string()))?;
let salted_password = pbkdf2_sha256(self.password.as_bytes(), &salt_bytes, iterations);
let client_key = hmac_sha256(&salted_password, b"Client Key");
let server_key = hmac_sha256(&salted_password, b"Server Key");
self.server_key = Some(server_key);
let stored_key = sha256(&client_key);
let client_final_without_proof = format!("c=biws,r={}", server_nonce);
let auth_message = format!(
"{},{},{}",
self.client_first_bare, server_first, client_final_without_proof
);
self.auth_message = Some(auth_message.clone());
let client_signature = hmac_sha256(&stored_key, auth_message.as_bytes());
let client_proof = xor_bytes(&client_key, &client_signature);
let client_proof_b64 = BASE64.encode(client_proof);
Ok(format!(
"{},p={}",
client_final_without_proof, client_proof_b64
))
}
pub fn verify_server(&self, server_final: &str) -> Result<(), ScramError> {
let server_key = self
.server_key
.as_ref()
.ok_or(ScramError::InvalidServerFinal(
"client_final not called".into(),
))?;
let auth_message = self
.auth_message
.as_ref()
.ok_or(ScramError::InvalidServerFinal(
"client_final not called".into(),
))?;
let server_signature_b64 = parse_server_final(server_final)?;
let server_signature = BASE64
.decode(&server_signature_b64)
.map_err(|e| ScramError::Base64Error(e.to_string()))?;
let expected_signature = hmac_sha256(server_key, auth_message.as_bytes());
if server_signature != expected_signature {
return Err(ScramError::ServerSignatureMismatch);
}
Ok(())
}
}
fn generate_nonce() -> String {
let mut bytes = [0u8; 18];
rand::fill(&mut bytes).expect("random generation failed");
BASE64.encode(bytes)
}
fn escape_username(username: &str) -> String {
username.replace('=', "=3D").replace(',', "=2C")
}
fn parse_server_first(msg: &str) -> Result<(String, String, u32), ScramError> {
let mut nonce = None;
let mut salt = None;
let mut iterations = None;
for part in msg.split(',') {
if let Some(value) = part.strip_prefix("r=") {
nonce = Some(value.to_string());
} else if let Some(value) = part.strip_prefix("s=") {
salt = Some(value.to_string());
} else if let Some(value) = part.strip_prefix("i=") {
iterations = Some(
value
.parse::<u32>()
.map_err(|_| ScramError::InvalidIterationCount)?,
);
}
}
match (nonce, salt, iterations) {
(Some(n), Some(s), Some(i)) => Ok((n, s, i)),
_ => Err(ScramError::InvalidServerFirst(
"missing required fields".into(),
)),
}
}
fn parse_server_final(msg: &str) -> Result<String, ScramError> {
for part in msg.split(',') {
if let Some(value) = part.strip_prefix("v=") {
return Ok(value.to_string());
}
}
Err(ScramError::InvalidServerFinal(
"missing verifier field".into(),
))
}
fn pbkdf2_sha256(password: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
let mut result = [0u8; 32];
pbkdf2::derive(
pbkdf2::PBKDF2_HMAC_SHA256,
iterations.try_into().expect("iteration count too large"),
salt,
password,
&mut result,
);
result
}
fn hmac_sha256(key: &[u8], data: &[u8]) -> [u8; 32] {
let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
let tag = hmac::sign(&hmac_key, data);
tag.as_ref().try_into().expect("HMAC-SHA256 is 32 bytes")
}
fn sha256(data: &[u8]) -> [u8; 32] {
let digest = digest::digest(&digest::SHA256, data);
digest.as_ref().try_into().expect("SHA256 is 32 bytes")
}
fn xor_bytes(a: &[u8; 32], b: &[u8; 32]) -> [u8; 32] {
let mut result = [0u8; 32];
for i in 0..32 {
result[i] = a[i] ^ b[i];
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_escape_username() {
assert_eq!(escape_username("user"), "user");
assert_eq!(escape_username("user=test"), "user=3Dtest");
assert_eq!(escape_username("user,test"), "user=2Ctest");
assert_eq!(escape_username("a=b,c"), "a=3Db=2Cc");
}
#[test]
fn test_parse_server_first() {
let msg = "r=clientnonce+servernonce,s=c2FsdA==,i=4096";
let (nonce, salt, iterations) = parse_server_first(msg).unwrap();
assert_eq!(nonce, "clientnonce+servernonce");
assert_eq!(salt, "c2FsdA==");
assert_eq!(iterations, 4096);
}
#[test]
fn test_parse_server_final() {
let msg = "v=cm1GM3pydXVYNWhKNDZlcm5yL2RLbTdrSzg0cXdqRS8=";
let verifier = parse_server_final(msg).unwrap();
assert_eq!(verifier, "cm1GM3pydXVYNWhKNDZlcm5yL2RLbTdrSzg0cXdqRS8=");
}
#[test]
fn test_client_first() {
let mut client = ScramClient::with_nonce("user", "password", "rOprNGfwEbeRWgbNEkqO");
let client_first = client.client_first();
assert!(client_first.starts_with("n,,"));
assert!(client_first.contains("n=user"));
assert!(client_first.contains("r=rOprNGfwEbeRWgbNEkqO"));
}
#[test]
fn test_scram_full_flow() {
let mut client = ScramClient::with_nonce("user", "pencil", "rOprNGfwEbeRWgbNEkqO");
let client_first = client.client_first();
assert_eq!(client_first, "n,,n=user,r=rOprNGfwEbeRWgbNEkqO");
let server_first = "r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096";
let client_final = client.client_final(server_first).unwrap();
assert!(client_final.starts_with("c=biws,r="));
assert!(client_final.contains(",p="));
}
#[test]
fn test_invalid_server_nonce() {
let mut client = ScramClient::with_nonce("user", "password", "clientnonce");
client.client_first();
let server_first = "r=differentnonce,s=c2FsdA==,i=4096";
let result = client.client_final(server_first);
assert!(matches!(result, Err(ScramError::InvalidServerNonce)));
}
}