use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
use hmac::{Hmac, Mac};
use rand::Rng;
use sha2::{Digest, Sha256};
type HmacSha256 = Hmac<Sha256>;
const GS2_HEADER_NO_CHANNEL_BINDING: &str = "n,,";
const GS2_HEADER_TLS_SERVER_END_POINT: &str = "p=tls-server-end-point,,";
pub struct ScramClient {
username: String,
password: String,
client_nonce: String,
combined_nonce: Option<String>,
salt: Option<Vec<u8>>,
iterations: Option<u32>,
auth_message: Option<String>,
salted_password: Option<Vec<u8>>,
channel_binding_data: Option<Vec<u8>>,
gs2_header: &'static str,
}
impl ScramClient {
pub fn new(username: &str, password: &str) -> Self {
Self::new_inner(username, password, None)
}
pub fn new_with_tls_server_end_point(
username: &str,
password: &str,
channel_binding_data: Vec<u8>,
) -> Self {
Self::new_inner(username, password, Some(channel_binding_data))
}
fn new_inner(username: &str, password: &str, channel_binding_data: Option<Vec<u8>>) -> Self {
let mut rng = rand::rng();
let nonce: String = (0..24)
.map(|_| {
let idx = rng.random_range(0..62);
let chars = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
chars[idx] as char
})
.collect();
Self {
username: username.to_string(),
password: password.to_string(),
client_nonce: nonce,
combined_nonce: None,
salt: None,
iterations: None,
auth_message: None,
salted_password: None,
gs2_header: if channel_binding_data.is_some() {
GS2_HEADER_TLS_SERVER_END_POINT
} else {
GS2_HEADER_NO_CHANNEL_BINDING
},
channel_binding_data,
}
}
pub fn client_first_message(&self) -> Vec<u8> {
let msg = format!(
"{}n={},r={}",
self.gs2_header, self.username, self.client_nonce
);
msg.into_bytes()
}
fn client_first_message_bare(&self) -> String {
format!("n={},r={}", self.username, self.client_nonce)
}
fn channel_binding_input(&self) -> Vec<u8> {
let binding_len = self.channel_binding_data.as_ref().map_or(0, Vec::len);
let mut input = Vec::with_capacity(self.gs2_header.len() + binding_len);
input.extend_from_slice(self.gs2_header.as_bytes());
if let Some(data) = &self.channel_binding_data {
input.extend_from_slice(data);
}
input
}
pub fn process_server_first(&mut self, server_msg: &[u8]) -> Result<Vec<u8>, String> {
let server_str =
std::str::from_utf8(server_msg).map_err(|_| "Invalid UTF-8 in server message")?;
let mut nonce = None;
let mut salt = None;
let mut iterations = None;
for part in server_str.split(',') {
if let Some(value) = part.strip_prefix("r=") {
nonce = Some(value.to_string());
} else if let Some(value) = part.strip_prefix("s=") {
salt = Some(BASE64.decode(value).map_err(|_| "Invalid salt base64")?);
} else if let Some(value) = part.strip_prefix("i=") {
iterations = Some(
value
.parse::<u32>()
.map_err(|_| "Invalid iteration count")?,
);
}
}
let nonce = nonce.ok_or("Missing nonce in server message")?;
let salt = salt.ok_or("Missing salt in server message")?;
let iterations = iterations.ok_or("Missing iterations in server message")?;
if iterations < 4096 {
return Err(format!(
"SCRAM iteration count too low: {} (minimum 4096)",
iterations,
));
}
if iterations > 100_000 {
return Err(format!(
"SCRAM iteration count too high: {} (maximum 100000)",
iterations,
));
}
if !nonce.starts_with(&self.client_nonce) {
return Err("Server nonce doesn't contain client nonce".to_string());
}
self.combined_nonce = Some(nonce.clone());
self.salt = Some(salt.clone());
self.iterations = Some(iterations);
let salted_password = self.derive_salted_password(&salt, iterations);
self.salted_password = Some(salted_password.clone());
let client_key = self.hmac(&salted_password, b"Client Key")?;
let stored_key = Self::sha256(&client_key);
let client_first_bare = self.client_first_message_bare();
let channel_binding_b64 = BASE64.encode(self.channel_binding_input());
let client_final_without_proof = format!("c={},r={}", channel_binding_b64, nonce);
let auth_message = format!(
"{},{},{}",
client_first_bare, server_str, client_final_without_proof
);
self.auth_message = Some(auth_message.clone());
let client_signature = self.hmac(&stored_key, auth_message.as_bytes())?;
let client_proof: Vec<u8> = client_key
.iter()
.zip(client_signature.iter())
.map(|(a, b)| a ^ b)
.collect();
let proof_b64 = BASE64.encode(&client_proof);
let client_final = format!("{},p={}", client_final_without_proof, proof_b64);
Ok(client_final.into_bytes())
}
pub fn verify_server_final(&self, server_msg: &[u8]) -> Result<(), String> {
let server_str =
std::str::from_utf8(server_msg).map_err(|_| "Invalid UTF-8 in server final message")?;
let verifier = server_str
.strip_prefix("v=")
.ok_or("Missing verifier in server final message")?;
let expected_signature = BASE64
.decode(verifier)
.map_err(|_| "Invalid base64 in server signature")?;
let salted_password = self
.salted_password
.as_ref()
.ok_or("Missing salted password")?;
let auth_message = self.auth_message.as_ref().ok_or("Missing auth message")?;
let server_key = self.hmac(salted_password, b"Server Key")?;
let computed_signature = self.hmac(&server_key, auth_message.as_bytes())?;
if computed_signature != expected_signature {
return Err("Server signature verification failed".to_string());
}
Ok(())
}
fn derive_salted_password(&self, salt: &[u8], iterations: u32) -> Vec<u8> {
let mut output = [0u8; 32];
pbkdf2::pbkdf2_hmac::<Sha256>(self.password.as_bytes(), salt, iterations, &mut output);
output.to_vec()
}
fn hmac(&self, key: &[u8], data: &[u8]) -> Result<Vec<u8>, String> {
let mut mac = HmacSha256::new_from_slice(key)
.map_err(|_| "HMAC init failed for SCRAM-SHA-256".to_string())?;
mac.update(data);
Ok(mac.finalize().into_bytes().to_vec())
}
fn sha256(data: &[u8]) -> Vec<u8> {
let mut hasher = Sha256::new();
hasher.update(data);
hasher.finalize().to_vec()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_first_message() {
let client = ScramClient::new("user", "password");
let msg = client.client_first_message();
let msg_str = String::from_utf8(msg).unwrap();
assert!(msg_str.starts_with("n,,n=user,r="));
}
#[test]
fn test_scram_flow() {
let mut client = ScramClient::new("testuser", "testpass");
let first = client.client_first_message();
assert!(String::from_utf8(first).unwrap().contains("n=testuser"));
let server_nonce = format!("{}ServerPart", client.client_nonce);
let salt_b64 = BASE64.encode(b"randomsalt");
let server_first = format!("r={},s={},i=4096", server_nonce, salt_b64);
let final_msg = client
.process_server_first(server_first.as_bytes())
.unwrap();
let final_str = String::from_utf8(final_msg).unwrap();
assert!(final_str.starts_with("c=biws,r="));
assert!(final_str.contains(",p="));
}
#[test]
fn test_client_first_message_plus() {
let client =
ScramClient::new_with_tls_server_end_point("user", "password", vec![1, 2, 3, 4]);
let msg = String::from_utf8(client.client_first_message()).unwrap();
assert!(msg.starts_with("p=tls-server-end-point,,n=user,r="));
}
#[test]
fn test_scram_plus_final_channel_binding_payload() {
let cb_data = vec![0xde, 0xad, 0xbe, 0xef];
let mut client =
ScramClient::new_with_tls_server_end_point("testuser", "testpass", cb_data.clone());
let server_nonce = format!("{}ServerPart", client.client_nonce);
let salt_b64 = BASE64.encode(b"randomsalt");
let server_first = format!("r={},s={},i=4096", server_nonce, salt_b64);
let final_msg = client
.process_server_first(server_first.as_bytes())
.unwrap();
let final_str = String::from_utf8(final_msg).unwrap();
let encoded_cb = final_str
.split(',')
.find_map(|part| part.strip_prefix("c="))
.unwrap()
.to_string();
let decoded = BASE64.decode(encoded_cb).unwrap();
let mut expected = GS2_HEADER_TLS_SERVER_END_POINT.as_bytes().to_vec();
expected.extend_from_slice(&cb_data);
assert_eq!(decoded, expected);
}
}