use sha2::{Digest, Sha256};
use crate::storage::encryption::pbkdf2::pbkdf2_sha256;
use crate::storage::encryption::pbkdf2::Pbkdf2Params;
pub const DEFAULT_ITER: u32 = 16_384;
pub const MIN_ITER: u32 = 4096;
#[derive(Debug, Clone)]
pub struct ScramVerifier {
pub salt: Vec<u8>,
pub iter: u32,
pub stored_key: [u8; 32],
pub server_key: [u8; 32],
}
impl ScramVerifier {
pub fn from_password(password: &str, salt: Vec<u8>, iter: u32) -> Self {
let salted = salted_password(password.as_bytes(), &salt, iter);
let client_key = hmac_sha256(&salted, b"Client Key");
let stored_key: [u8; 32] = sha256(&client_key);
let server_key = hmac_sha256(&salted, b"Server Key");
Self {
salt,
iter,
stored_key,
server_key,
}
}
}
pub fn salted_password(password: &[u8], salt: &[u8], iter: u32) -> [u8; 32] {
let params = Pbkdf2Params {
iterations: iter,
..Pbkdf2Params::default()
};
let v = pbkdf2_sha256(password, salt, ¶ms);
let mut out = [0u8; 32];
out.copy_from_slice(&v[..32]);
out
}
pub fn hmac_sha256(key: &[u8], data: &[u8]) -> [u8; 32] {
crate::crypto::hmac_sha256(key, data)
}
pub fn sha256(data: &[u8]) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(data);
let mut out = [0u8; 32];
out.copy_from_slice(&hasher.finalize());
out
}
pub fn xor(a: &[u8], b: &[u8]) -> Vec<u8> {
a.iter().zip(b.iter()).map(|(x, y)| x ^ y).collect()
}
pub fn auth_message(
client_first_bare: &str,
server_first: &str,
client_final_no_proof: &str,
) -> Vec<u8> {
let mut out = Vec::with_capacity(
client_first_bare.len() + 1 + server_first.len() + 1 + client_final_no_proof.len(),
);
out.extend_from_slice(client_first_bare.as_bytes());
out.push(b',');
out.extend_from_slice(server_first.as_bytes());
out.push(b',');
out.extend_from_slice(client_final_no_proof.as_bytes());
out
}
pub fn client_proof(stored_key: &[u8], auth_message: &[u8], client_key: &[u8]) -> Vec<u8> {
let signature = hmac_sha256(stored_key, auth_message);
xor(client_key, &signature)
}
pub fn verify_client_proof(
verifier: &ScramVerifier,
auth_message: &[u8],
presented_proof: &[u8],
) -> bool {
if presented_proof.len() != 32 {
return false;
}
let signature = hmac_sha256(&verifier.stored_key, auth_message);
let client_key = xor(presented_proof, &signature);
let derived_stored: [u8; 32] = sha256(&client_key);
crate::crypto::constant_time_eq(&derived_stored, &verifier.stored_key)
}
pub fn server_signature(server_key: &[u8], auth_message: &[u8]) -> [u8; 32] {
hmac_sha256(server_key, auth_message)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn verifier_is_deterministic() {
let salt = b"reddb-test-salt".to_vec();
let v1 = ScramVerifier::from_password("hunter2", salt.clone(), 4096);
let v2 = ScramVerifier::from_password("hunter2", salt, 4096);
assert_eq!(v1.stored_key, v2.stored_key);
assert_eq!(v1.server_key, v2.server_key);
}
#[test]
fn full_round_trip() {
let salt = b"reddb-rt-salt".to_vec();
let iter = 4096;
let verifier = ScramVerifier::from_password("correct horse", salt.clone(), iter);
let client_first_bare = "n=alice,r=cnonce";
let server_first = "r=cnonce+snonce,s=cmVkZGItcnQtc2FsdA==,i=4096";
let client_final_no_proof = "c=biws,r=cnonce+snonce";
let am = auth_message(client_first_bare, server_first, client_final_no_proof);
let salted = salted_password(b"correct horse", &salt, iter);
let client_key = hmac_sha256(&salted, b"Client Key");
let proof = client_proof(&verifier.stored_key, &am, &client_key);
assert!(verify_client_proof(&verifier, &am, &proof));
let salted_bad = salted_password(b"wrong password", &salt, iter);
let client_key_bad = hmac_sha256(&salted_bad, b"Client Key");
let proof_bad = client_proof(&verifier.stored_key, &am, &client_key_bad);
assert!(!verify_client_proof(&verifier, &am, &proof_bad));
}
#[test]
fn server_signature_round_trip() {
let v = ScramVerifier::from_password("p", b"s".to_vec(), 4096);
let am = b"some auth message".to_vec();
let sig = server_signature(&v.server_key, &am);
let again = server_signature(&v.server_key, &am);
assert_eq!(sig, again);
let other = server_signature(&v.server_key, b"different");
assert_ne!(sig, other);
}
#[test]
fn xor_basic() {
assert_eq!(
xor(&[0xff, 0x00, 0xaa], &[0x0f, 0xff, 0x55]),
vec![0xf0, 0xff, 0xff]
);
}
}