use base64::Engine;
use hmac::{Hmac, Mac};
use sha2::{Digest, Sha256};
type HmacSha256 = Hmac<Sha256>;
const B64: base64::engine::general_purpose::GeneralPurpose =
base64::engine::general_purpose::STANDARD;
#[derive(Clone)]
#[allow(dead_code)]
pub enum ChannelBinding {
None,
TlsServerEndPoint(Vec<u8>),
}
pub struct ScramClient {
password: String,
nonce: String,
client_first_bare: String,
channel_binding: ChannelBinding,
}
impl ScramClient {
pub fn new(password: &str, channel_binding: ChannelBinding) -> (Self, Vec<u8>) {
let nonce = generate_nonce();
let client_first_bare = format!("n=,r={nonce}");
let gs2_header = match &channel_binding {
ChannelBinding::None => "n,,".to_string(),
ChannelBinding::TlsServerEndPoint(_) => "p=tls-server-end-point,,".to_string(),
};
let client_first_msg = format!("{gs2_header}{client_first_bare}");
(
ScramClient {
password: password.to_string(),
nonce,
client_first_bare,
channel_binding,
},
client_first_msg.into_bytes(),
)
}
pub fn process_server_first(&self, server_first: &[u8]) -> Result<Vec<u8>, String> {
let server_first = std::str::from_utf8(server_first)
.map_err(|e| format!("Invalid UTF-8 in server-first: {e}"))?;
let mut server_nonce = "";
let mut salt_b64 = "";
let mut iterations = 0u32;
for part in server_first.split(',') {
if let Some(v) = part.strip_prefix("r=") {
server_nonce = v;
} else if let Some(v) = part.strip_prefix("s=") {
salt_b64 = v;
} else if let Some(v) = part.strip_prefix("i=") {
iterations = v.parse().map_err(|_| "Bad iteration count")?;
}
}
if !server_nonce.starts_with(&self.nonce) {
return Err("Server nonce doesn't start with client nonce".into());
}
let salt = B64
.decode(salt_b64)
.map_err(|e| format!("Bad salt base64: {e}"))?;
let salted_password = hi(&self.password, &salt, iterations)?;
let client_key = hmac_sha256(&salted_password, b"Client Key")?;
let stored_key = sha256(&client_key);
let channel_binding_data = match &self.channel_binding {
ChannelBinding::None => B64.encode(b"n,,"),
ChannelBinding::TlsServerEndPoint(cert_hash) => {
let mut cb = b"p=tls-server-end-point,,".to_vec();
cb.extend_from_slice(cert_hash);
B64.encode(&cb)
}
};
let client_final_without_proof = format!("c={channel_binding_data},r={server_nonce}");
let auth_message = format!(
"{},{server_first},{client_final_without_proof}",
self.client_first_bare
);
let client_signature = hmac_sha256(&stored_key, auth_message.as_bytes())?;
let proof: Vec<u8> = client_key
.iter()
.zip(client_signature.iter())
.map(|(a, b)| a ^ b)
.collect();
let proof_b64 = B64.encode(&proof);
let client_final = format!("{client_final_without_proof},p={proof_b64}");
Ok(client_final.into_bytes())
}
}
fn hi(password: &str, salt: &[u8], iterations: u32) -> Result<Vec<u8>, String> {
let mut mac = HmacSha256::new_from_slice(password.as_bytes())
.map_err(|e| format!("HMAC key error: {e}"))?;
mac.update(salt);
mac.update(&1u32.to_be_bytes());
let mut u = mac.finalize().into_bytes().to_vec();
let mut result = u.clone();
for _ in 1..iterations {
let mut mac = HmacSha256::new_from_slice(password.as_bytes())
.map_err(|e| format!("HMAC key error: {e}"))?;
mac.update(&u);
u = mac.finalize().into_bytes().to_vec();
for (r, x) in result.iter_mut().zip(u.iter()) {
*r ^= x;
}
}
Ok(result)
}
fn hmac_sha256(key: &[u8], data: &[u8]) -> Result<Vec<u8>, String> {
let mut mac = HmacSha256::new_from_slice(key).map_err(|e| format!("HMAC key error: {e}"))?;
mac.update(data);
Ok(mac.finalize().into_bytes().to_vec())
}
fn sha256(data: &[u8]) -> Vec<u8> {
Sha256::digest(data).to_vec()
}
fn generate_nonce() -> String {
use rand::Rng;
let mut rng = rand::rng();
let bytes: Vec<u8> = (0..24).map(|_| rng.random::<u8>()).collect();
B64.encode(&bytes)
}