#[cfg(feature = "alloc")]
use alloc::vec::Vec;
use crypto_bigint::U256;
use crate::sm3::Sm3Hasher;
use crate::sm9::fields::fp::GROUP_ORDER;
pub fn sm9_h1(z: &[u8], hid: u8) -> U256 {
hash_to_range(z, hid, &GROUP_ORDER)
}
pub fn sm9_h2(m: &[u8], w: &[u8]) -> U256 {
let mut combined = [0u8; 512 + 384]; let m_len = m.len().min(512);
let w_len = w.len().min(384);
combined[..m_len].copy_from_slice(&m[..m_len]);
combined[m_len..m_len + w_len].copy_from_slice(&w[..w_len]);
hash_to_range(&combined[..m_len + w_len], 0, &GROUP_ORDER)
}
fn hash_to_range(z: &[u8], hid: u8, n: &U256) -> U256 {
let n_minus_1 = n.wrapping_sub(&U256::ONE);
let mut ha = [0u8; 160]; let mut prefix = [0u8; 1];
prefix[0] = hid;
for ct in 0u32..5 {
let ct_bytes = ct.to_be_bytes();
let mut h = Sm3Hasher::new();
if hid != 0 {
h.update(&prefix);
}
h.update(z);
h.update(&ct_bytes);
let digest = h.finalize();
ha[ct as usize * 32..(ct as usize + 1) * 32].copy_from_slice(&digest);
}
let h_raw = U256::from_be_slice(&ha[..32]);
use subtle::{ConditionallySelectable, ConstantTimeLess};
let need_reduce = !h_raw.ct_lt(&n_minus_1); let reduced = h_raw.wrapping_sub(&n_minus_1);
let h = U256::conditional_select(&h_raw, &reduced, need_reduce);
h.wrapping_add(&U256::ONE)
}
#[cfg(feature = "alloc")]
pub fn sm9_kdf(z: &[u8], klen: usize) -> Vec<u8> {
let mut out = Vec::with_capacity(klen);
let mut ct = 1u32;
while out.len() < klen {
let mut h = Sm3Hasher::new();
h.update(z);
h.update(&ct.to_be_bytes());
let digest = h.finalize();
let remaining = klen - out.len();
out.extend_from_slice(&digest[..digest.len().min(remaining)]);
ct += 1;
}
out
}
#[cfg(feature = "alloc")]
pub fn sm9_enc_kdf(w_bytes: &[u8; 384], c1_bytes: &[u8; 128], id: &[u8], klen: usize) -> Vec<u8> {
let z_len = 128 + 384 + id.len();
let mut z = Vec::with_capacity(z_len);
z.extend_from_slice(c1_bytes);
z.extend_from_slice(w_bytes);
z.extend_from_slice(id);
sm9_kdf(&z, klen)
}
pub fn fp12_to_bytes_for_kdf(w: &crate::sm9::fields::fp12::Fp12) -> [u8; 384] {
crate::sm9::fields::fp12::fp12_to_bytes(w)
}
#[cfg(test)]
mod tests {
use super::*;
use crypto_bigint::Zero;
#[test]
fn test_sm9_h1_nonzero() {
let id = b"Alice";
let h = sm9_h1(id, 0x01);
assert!(!bool::from(h.is_zero()), "H1 结果不应为零");
assert!(h < GROUP_ORDER, "H1 结果应在 [1, n-1]");
}
#[test]
fn test_sm9_h2_nonzero() {
let m = b"message";
let w = [0x42u8; 32];
let h = sm9_h2(m, &w);
assert!(!bool::from(h.is_zero()), "H2 结果不应为零");
}
#[cfg(feature = "alloc")]
#[test]
fn test_sm9_kdf_length() {
let z = b"test input";
let klen = 64;
let k = sm9_kdf(z, klen);
assert_eq!(k.len(), klen);
}
#[cfg(feature = "alloc")]
#[test]
fn test_sm9_kdf_deterministic() {
let z = b"test";
let k1 = sm9_kdf(z, 32);
let k2 = sm9_kdf(z, 32);
assert_eq!(k1, k2);
}
}