use crate::hmac::hmac_sm3;
use crate::sm3::DIGEST_SIZE;
use alloc::vec::Vec;
use zeroize::Zeroize;
#[must_use]
pub fn pbkdf2_hmac_sm3(
password: &[u8],
salt: &[u8],
iterations: u32,
output: &mut [u8],
) -> Option<()> {
if iterations == 0 || output.is_empty() {
return None;
}
let max_dklen: u64 = (DIGEST_SIZE as u64) * u64::from(u32::MAX);
if output.len() as u64 > max_dklen {
return None;
}
let hlen = DIGEST_SIZE;
#[allow(clippy::cast_possible_truncation)]
let l = output.len().div_ceil(hlen) as u32;
let mut salt_with_counter: Vec<u8> = Vec::with_capacity(salt.len() + 4);
salt_with_counter.extend_from_slice(salt);
salt_with_counter.extend_from_slice(&[0u8; 4]);
let counter_offset = salt.len();
let mut t = [0u8; DIGEST_SIZE]; let mut u = [0u8; DIGEST_SIZE];
for block_index in 1..=l {
salt_with_counter[counter_offset..counter_offset + 4]
.copy_from_slice(&block_index.to_be_bytes());
u = hmac_sm3(password, &salt_with_counter);
t.copy_from_slice(&u);
for _ in 1..iterations {
u = hmac_sm3(password, &u);
for k in 0..hlen {
t[k] ^= u[k];
}
}
let block_start = (block_index as usize - 1) * hlen;
let block_end = (block_start + hlen).min(output.len());
output[block_start..block_end].copy_from_slice(&t[..block_end - block_start]);
}
salt_with_counter.zeroize();
t.zeroize();
u.zeroize();
Some(())
}
#[cfg(test)]
mod tests {
use super::*;
fn to_hex(bytes: &[u8]) -> alloc::string::String {
use alloc::string::String;
use core::fmt::Write;
let mut s = String::with_capacity(bytes.len() * 2);
for b in bytes {
let _ = write!(s, "{b:02x}");
}
s
}
#[test]
fn gmssl_iter10000_out32() {
let mut dk = [0u8; 32];
pbkdf2_hmac_sm3(b"password", b"salt", 10_000, &mut dk).expect("derive");
assert_eq!(
to_hex(&dk),
"738c8c432372d98a73350bc252209e4cf2acdde7cc816730b9812bdfd55c1265"
);
}
#[test]
fn gmssl_iter10000_out20() {
let mut dk = [0u8; 20];
pbkdf2_hmac_sm3(b"password", b"salt", 10_000, &mut dk).expect("derive");
assert_eq!(to_hex(&dk), "738c8c432372d98a73350bc252209e4cf2acdde7");
}
#[test]
fn gmssl_iter10000_out40() {
let mut dk = [0u8; 40];
pbkdf2_hmac_sm3(b"password", b"salt", 10_000, &mut dk).expect("derive");
assert_eq!(
to_hex(&dk),
"738c8c432372d98a73350bc252209e4cf2acdde7cc816730b9812bdfd55c126522b2c8a59d829331"
);
}
#[test]
fn gmssl_iter10000_out64() {
let mut dk = [0u8; 64];
pbkdf2_hmac_sm3(b"password", b"salt", 10_000, &mut dk).expect("derive");
assert_eq!(
to_hex(&dk),
"738c8c432372d98a73350bc252209e4cf2acdde7cc816730b9812bdfd55c126522b2c8a59d8293314c29c1d7be95ca4a2b757103fba96c502b4adb39449b4807"
);
}
#[test]
fn gmssl_iter100000_out32() {
let mut dk = [0u8; 32];
pbkdf2_hmac_sm3(b"password", b"salt", 100_000, &mut dk).expect("derive");
assert_eq!(
to_hex(&dk),
"9b27884dd1aef333a412d92d9fba434dc2394091335a1d0bd172942377bbcec2"
);
}
#[test]
fn rejects_zero_iterations() {
let mut dk = [0u8; 32];
assert_eq!(pbkdf2_hmac_sm3(b"password", b"salt", 0, &mut dk), None);
}
#[test]
fn rejects_empty_output() {
let mut dk: [u8; 0] = [];
assert_eq!(pbkdf2_hmac_sm3(b"password", b"salt", 1, &mut dk), None);
}
#[test]
fn different_passwords_different_keys() {
let mut dk_a = [0u8; 32];
let mut dk_b = [0u8; 32];
pbkdf2_hmac_sm3(b"password-a", b"salt", 1000, &mut dk_a).expect("derive");
pbkdf2_hmac_sm3(b"password-b", b"salt", 1000, &mut dk_b).expect("derive");
assert_ne!(dk_a, dk_b);
}
#[test]
fn different_salts_different_keys() {
let mut dk_a = [0u8; 32];
let mut dk_b = [0u8; 32];
pbkdf2_hmac_sm3(b"password", b"salt-a", 1000, &mut dk_a).expect("derive");
pbkdf2_hmac_sm3(b"password", b"salt-b", 1000, &mut dk_b).expect("derive");
assert_ne!(dk_a, dk_b);
}
#[test]
fn shorter_output_is_prefix_of_longer() {
let mut dk_long = [0u8; 64];
let mut dk_short = [0u8; 20];
pbkdf2_hmac_sm3(b"password", b"salt", 1000, &mut dk_long).expect("derive");
pbkdf2_hmac_sm3(b"password", b"salt", 1000, &mut dk_short).expect("derive");
assert_eq!(&dk_long[..20], &dk_short[..]);
}
}