use hkdf::Hkdf;
use pbkdf2::pbkdf2_hmac;
use sha2::Sha256;
pub fn derive_passcode_verifier(
passcode: u32,
salt: &[u8],
iterations: u32,
) -> Result<(Vec<u8>, Vec<u8>), &'static str> {
if iterations == 0 {
return Err("PBKDF2 iterations must be > 0");
}
let passcode_bytes = passcode.to_le_bytes();
let mut dk = vec![0u8; 80];
pbkdf2_hmac::<Sha256>(&passcode_bytes, salt, iterations, &mut dk);
let w0s = dk[..40].to_vec();
let w1s = dk[40..].to_vec();
Ok((w0s, w1s))
}
pub fn hkdf_expand_label(ikm: &[u8], salt: &[u8], label: &str, length: usize) -> Vec<u8> {
let salt_opt = if salt.is_empty() { None } else { Some(salt) };
let hk = Hkdf::<Sha256>::new(salt_opt, ikm);
let mut out = vec![0u8; length];
hk.expand(label.as_bytes(), &mut out)
.expect("HKDF expand: output length must be <= 255 * hash_len");
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn derive_passcode_verifier_length() {
let salt = b"SPAKE2P+";
let (w0s, w1s) = derive_passcode_verifier(20202021, salt, 1000).unwrap();
assert_eq!(w0s.len(), 40, "w0s must be 40 bytes");
assert_eq!(w1s.len(), 40, "w1s must be 40 bytes");
assert_ne!(w0s, w1s, "w0s and w1s should differ");
}
#[test]
fn derive_passcode_verifier_deterministic() {
let salt = b"test-salt-16byte";
let (a0, a1) = derive_passcode_verifier(20202021, salt, 1000).unwrap();
let (b0, b1) = derive_passcode_verifier(20202021, salt, 1000).unwrap();
assert_eq!(a0, b0);
assert_eq!(a1, b1);
}
#[test]
fn derive_passcode_verifier_different_passcode() {
let salt = b"test-salt-16byte";
let (a0, _) = derive_passcode_verifier(20202021, salt, 1000).unwrap();
let (b0, _) = derive_passcode_verifier(11111111, salt, 1000).unwrap();
assert_ne!(a0, b0, "different passcodes must produce different w0s");
}
#[test]
fn derive_passcode_verifier_zero_iterations_err() {
assert!(derive_passcode_verifier(20202021, b"salt", 0).is_err());
}
#[test]
fn hkdf_expand_label_length() {
let ikm = [0u8; 32];
let out = hkdf_expand_label(&ikm, b"salt", "SessionKeys", 48);
assert_eq!(out.len(), 48);
}
#[test]
fn hkdf_expand_label_deterministic() {
let ikm = [1u8; 32];
let a = hkdf_expand_label(&ikm, b"", "TestLabel", 32);
let b = hkdf_expand_label(&ikm, b"", "TestLabel", 32);
assert_eq!(a, b);
}
#[test]
fn hkdf_expand_label_different_labels() {
let ikm = [2u8; 32];
let a = hkdf_expand_label(&ikm, b"s", "LabelA", 32);
let b = hkdf_expand_label(&ikm, b"s", "LabelB", 32);
assert_ne!(a, b);
}
#[test]
fn hkdf_rfc5869_vector() {
let ikm = vec![0x0bu8; 22];
let salt: Vec<u8> = (0x00u8..=0x0cu8).collect();
let info_str = "SessionKeys";
let hk = hkdf::Hkdf::<sha2::Sha256>::new(Some(salt.as_slice()), &ikm);
let mut expected = vec![0u8; 32];
hk.expand(info_str.as_bytes(), &mut expected).unwrap();
let got = hkdf_expand_label(&ikm, &salt, info_str, 32);
assert_eq!(got, expected);
}
}