use crypto_bigint::U256;
use crate::sm3::Sm3Hasher;
use crate::sm9::fields::fp::{
fp_add, fp_inv, fp_is_square, fp_mul, fp_neg, fp_sqrt, fp_square, fp_sub, Fp,
};
use crate::sm9::groups::g1::{G1Affine, G1Jacobian};
const Z: Fp = Fp::new(&U256::from_be_hex(
"B640000002A3A6F1D603AB4FF58EC74521F2934B1A7AEEDBE56F9B27E351457C",
));
const C1: Fp = Fp::new(&U256::from_be_hex(
"0000000000000000000000000000000000000000000000000000000000000004",
));
pub fn expand_message_xmd(msg: &[u8], dst: &[u8], len_in_bytes: usize) -> alloc::vec::Vec<u8> {
const B_IN_BYTES: usize = 32;
const R_IN_BYTES: usize = 64;
let ell = len_in_bytes.div_ceil(B_IN_BYTES);
let mut dst_prime = alloc::vec::Vec::with_capacity(dst.len() + 1);
dst_prime.extend_from_slice(dst);
dst_prime.push(dst.len() as u8);
let z_pad = [0u8; R_IN_BYTES];
let l_i_b_str = [(len_in_bytes >> 8) as u8, len_in_bytes as u8];
let mut h = Sm3Hasher::new();
h.update(&z_pad);
h.update(msg);
h.update(&l_i_b_str);
h.update(&[0u8]);
h.update(&dst_prime);
let b_0 = h.finalize();
let mut h = Sm3Hasher::new();
h.update(&b_0);
h.update(&[1u8]);
h.update(&dst_prime);
let b_1 = h.finalize();
let mut uniform_bytes = alloc::vec![0u8; ell * B_IN_BYTES];
uniform_bytes[..B_IN_BYTES].copy_from_slice(&b_1);
let mut b_prev = b_1;
for i in 2..=ell {
let mut xored = [0u8; B_IN_BYTES];
for (j, (&x, &y)) in b_0.iter().zip(b_prev.iter()).enumerate() {
xored[j] = x ^ y;
}
let mut h = Sm3Hasher::new();
h.update(&xored);
h.update(&[i as u8]);
h.update(&dst_prime);
let b_i = h.finalize();
let start = (i - 1) * B_IN_BYTES;
uniform_bytes[start..start + B_IN_BYTES].copy_from_slice(&b_i);
b_prev = b_i;
}
uniform_bytes[..len_in_bytes].to_vec()
}
fn hash_to_field(bytes48: &[u8; 48]) -> Fp {
let high_16: [u8; 16] = bytes48[..16].try_into().unwrap();
let low_32: [u8; 32] = bytes48[16..].try_into().unwrap();
let low_fp = Fp::new(&U256::from_be_slice(&low_32));
const TWO_256_MOD_P: U256 =
U256::from_be_hex("49BFFFFFFFD5C590E9FC54B00A7138BAE0D6CB4E4E858125179110D21CAEBA83");
let mut high_bytes = [0u8; 32];
high_bytes[16..].copy_from_slice(&high_16);
let high_u256 = U256::from_be_slice(&high_bytes);
let high_fp = Fp::new(&high_u256);
let two256_fp = Fp::new(&TWO_256_MOD_P);
fp_add(&fp_mul(&high_fp, &two256_fp), &low_fp)
}
fn sgn0(a: &Fp) -> u8 {
a.retrieve().to_be_bytes()[31] & 1
}
pub fn map_to_curve_svdw(u: &Fp) -> G1Affine {
let two = Fp::new(&U256::from_be_hex(
"0000000000000000000000000000000000000000000000000000000000000002",
));
let c2 = fp_inv(&two).unwrap();
let twelve = Fp::new(&U256::from_be_hex(
"000000000000000000000000000000000000000000000000000000000000000C",
));
let neg12 = fp_neg(&twelve);
let c3 = fp_sqrt(&neg12).expect("SvdW: -12 在 BN256 Fp 上应有平方根");
let sixteen = Fp::new(&U256::from_be_hex(
"0000000000000000000000000000000000000000000000000000000000000010",
));
let three = Fp::new(&U256::from_be_hex(
"0000000000000000000000000000000000000000000000000000000000000003",
));
let c4 = fp_mul(&fp_neg(&sixteen), &fp_inv(&three).unwrap());
let b = Fp::new(&U256::from_be_hex(
"0000000000000000000000000000000000000000000000000000000000000005",
));
let tv1 = fp_mul(&fp_square(u), &C1);
let tv2 = fp_add(&Fp::ONE, &tv1);
let tv1 = fp_sub(&Fp::ONE, &tv1);
let tv3 = fp_mul(&tv1, &tv2);
let tv3 = fp_inv(&tv3).unwrap_or(Fp::ZERO);
let tv4 = fp_mul(&fp_mul(&fp_mul(u, &tv1), &tv3), &c3);
let x1 = fp_sub(&c2, &tv4);
let x2 = fp_add(&c2, &tv4);
let tv2_sq = fp_square(&tv2);
let inner = fp_mul(&tv2_sq, &tv3);
let x3 = fp_add(&Z, &fp_mul(&c4, &fp_square(&inner)));
let g = |x: &Fp| -> Fp {
let x3 = fp_mul(&fp_square(x), x);
fp_add(&x3, &b)
};
let g1 = g(&x1);
let g2 = g(&x2);
let g3 = g(&x3);
let (x, gx) = if fp_is_square(&g1) {
(x1, g1)
} else if fp_is_square(&g2) {
(x2, g2)
} else {
(x3, g3)
};
let mut y = fp_sqrt(&gx).expect("SvdW: g(x) 应为二次剩余");
if sgn0(&y) != sgn0(u) {
y = fp_neg(&y);
}
G1Affine { x, y }
}
pub fn hash_to_g1(msg: &[u8], dst: &[u8]) -> G1Jacobian {
const L: usize = 48;
let uniform_bytes = expand_message_xmd(msg, dst, 2 * L);
let u0_bytes: &[u8; 48] = uniform_bytes[..48].try_into().unwrap();
let u1_bytes: &[u8; 48] = uniform_bytes[48..].try_into().unwrap();
let u0 = hash_to_field(u0_bytes);
let u1 = hash_to_field(u1_bytes);
let q0 = map_to_curve_svdw(&u0);
let q1 = map_to_curve_svdw(&u1);
G1Jacobian::add(&G1Jacobian::from_affine(&q0), &G1Jacobian::from_affine(&q1))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sm9::fields::fp::fp_to_bytes;
#[test]
fn test_expand_message_xmd_length() {
let dst = b"BLS_SIG_SM9G1_XMD:SM3_SVDW_RO_NUL_";
let bytes = expand_message_xmd(b"hello", dst, 96);
assert_eq!(bytes.len(), 96);
}
#[test]
fn test_expand_message_xmd_deterministic() {
let dst = b"BLS_SIG_SM9G1_XMD:SM3_SVDW_RO_NUL_";
let a = expand_message_xmd(b"test", dst, 96);
let b = expand_message_xmd(b"test", dst, 96);
assert_eq!(a, b, "相同输入应产生相同输出");
}
#[test]
fn test_expand_message_xmd_different_msgs() {
let dst = b"BLS_SIG_SM9G1_XMD:SM3_SVDW_RO_NUL_";
let a = expand_message_xmd(b"msg1", dst, 96);
let b = expand_message_xmd(b"msg2", dst, 96);
assert_ne!(a, b, "不同消息应产生不同输出");
}
#[test]
fn test_map_to_curve_output_on_curve() {
let u = Fp::new(&U256::from_be_hex(
"0000000000000000000000000000000000000000000000000000000000000007",
));
let p = map_to_curve_svdw(&u);
let lhs = fp_square(&p.y);
let rhs = fp_add(
&fp_mul(&fp_square(&p.x), &p.x),
&Fp::new(&U256::from_be_hex(
"0000000000000000000000000000000000000000000000000000000000000005",
)),
);
assert_eq!(lhs, rhs, "映射的点应在曲线上");
}
#[test]
fn test_hash_to_g1_deterministic() {
let dst = b"BLS_SIG_SM9G1_XMD:SM3_SVDW_RO_NUL_";
let p1 = hash_to_g1(b"hello", dst);
let p2 = hash_to_g1(b"hello", dst);
let a1 = p1.to_affine().unwrap();
let a2 = p2.to_affine().unwrap();
assert_eq!(fp_to_bytes(&a1.x), fp_to_bytes(&a2.x));
}
#[test]
fn test_hash_to_g1_different_msgs() {
let dst = b"BLS_SIG_SM9G1_XMD:SM3_SVDW_RO_NUL_";
let p1 = hash_to_g1(b"msg1", dst).to_affine().unwrap();
let p2 = hash_to_g1(b"msg2", dst).to_affine().unwrap();
assert_ne!(
fp_to_bytes(&p1.x),
fp_to_bytes(&p2.x),
"不同消息应映射到不同点"
);
}
#[test]
fn test_hash_to_g1_output_on_curve() {
let dst = b"BLS_SIG_SM9G1_XMD:SM3_SVDW_RO_NUL_";
let p = hash_to_g1(b"test message", dst);
let a = p.to_affine().unwrap();
assert!(a.is_on_curve(), "hash_to_g1 的输出应在 G1 曲线上");
}
}