use crate::asn1::sig::encode_sig;
use crate::sm2::curve::{b, Fn, NMod, PMod, GX_HEX, GY_HEX};
use crate::sm2::private_key::Sm2PrivateKey;
use crate::sm2::public_key::Sm2PublicKey;
use crate::sm2::scalar_mul::mul_g;
use crate::sm3::{Sm3, DIGEST_SIZE};
use alloc::vec::Vec;
use crypto_bigint::modular::ConstMontyParams;
use crypto_bigint::{Invert, U256};
use rand_core::{CryptoRng, RngCore};
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeLess, CtOption};
pub const DEFAULT_SIGNER_ID: &[u8; 16] = b"1234567812345678";
pub const MAX_ID_LEN: usize = (u16::MAX as usize) / 8;
#[must_use]
pub fn compute_z(public: &Sm2PublicKey, id: &[u8]) -> [u8; DIGEST_SIZE] {
assert!(
id.len() <= MAX_ID_LEN,
"id.len() exceeds MAX_ID_LEN — ENTL_A would silently wrap"
);
let mut h = Sm3::new();
#[allow(clippy::cast_possible_truncation)]
let entl: u16 = (id.len() as u16) * 8;
h.update(&entl.to_be_bytes());
h.update(id);
let three = U256::from_u64(3);
let p_minus_three = PMod::MODULUS.get().wrapping_sub(&three);
h.update(&p_minus_three.to_be_bytes());
h.update(&b().retrieve().to_be_bytes());
h.update(&U256::from_be_hex(GX_HEX).to_be_bytes());
h.update(&U256::from_be_hex(GY_HEX).to_be_bytes());
let (px, py) = public.point().to_affine().expect("public key is finite");
h.update(&px.retrieve().to_be_bytes());
h.update(&py.retrieve().to_be_bytes());
h.finalize()
}
pub(crate) const SIGN_RETRY_BUDGET: usize = 2;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SignError {
Failed,
}
pub fn sign_with_id<R: CryptoRng + RngCore>(
key: &Sm2PrivateKey,
id: &[u8],
message: &[u8],
rng: &mut R,
) -> Result<Vec<u8>, SignError> {
let (r, s) = sign_raw_with_id(key, id, message, rng)?;
Ok(encode_sig(&r, &s))
}
#[doc(hidden)]
pub fn sign_raw_with_id<R: CryptoRng + RngCore>(
key: &Sm2PrivateKey,
id: &[u8],
message: &[u8],
rng: &mut R,
) -> Result<(U256, U256), SignError> {
if id.len() > MAX_ID_LEN {
return Err(SignError::Failed);
}
let public = Sm2PublicKey::from_point(key.public_key());
let z = compute_z(&public, id);
let e_bytes = {
let mut h = Sm3::new();
h.update(&z);
h.update(message);
h.finalize()
};
let e_scalar = Fn::new(&U256::from_be_slice(&e_bytes));
let mut chosen: CtOption<RsPair> = CtOption::new(RsPair::default(), Choice::from(0));
for _ in 0..SIGN_RETRY_BUDGET {
let candidate = try_sign_once(key, &e_scalar, rng);
chosen = ct_or_else(chosen, candidate);
}
let pair: Option<RsPair> = chosen.into();
let pair = pair.ok_or(SignError::Failed)?;
Ok((pair.r, pair.s))
}
#[derive(Clone, Copy, Debug, Default)]
struct RsPair {
r: U256,
s: U256,
}
impl ConditionallySelectable for RsPair {
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
Self {
r: U256::conditional_select(&a.r, &b.r, choice),
s: U256::conditional_select(&a.s, &b.s, choice),
}
}
}
#[allow(clippy::similar_names, clippy::many_single_char_names)]
fn try_sign_once<R: CryptoRng + RngCore>(
key: &Sm2PrivateKey,
e: &Fn,
rng: &mut R,
) -> CtOption<RsPair> {
let k = sample_nonzero_scalar(rng);
let kg = mul_g(&k);
let (x1, _y1) = kg.to_affine().expect("k·G is finite for k != 0");
let x1_in_n = Fn::new(&x1.retrieve());
let r = *e + x1_in_n;
let r_u = r.retrieve();
let r_plus_k = (r + k).retrieve();
let r_zero: Choice = r_u.ct_eq(&U256::ZERO);
let rk_zero: Choice = r_plus_k.ct_eq(&U256::ZERO);
let bad_r = r_zero | rk_zero;
let d = key.scalar();
let one = Fn::new(&U256::ONE);
let one_plus_d = one + *d;
let one_plus_d_inv = one_plus_d.invert();
let rd = r * *d;
let k_minus_rd = k - rd;
let inv_unwrapped: Fn = one_plus_d_inv.unwrap_or(Fn::new(&U256::ONE));
let inv_ok: Choice = one_plus_d_inv.is_some();
let s = inv_unwrapped * k_minus_rd;
let s_u = s.retrieve();
let s_zero: Choice = s_u.ct_eq(&U256::ZERO);
let valid = !bad_r & !s_zero & inv_ok;
CtOption::new(RsPair { r: r_u, s: s_u }, valid)
}
fn sample_nonzero_scalar<R: CryptoRng + RngCore>(rng: &mut R) -> Fn {
let n = NMod::MODULUS.get();
loop {
let mut buf = [0u8; 32];
rng.fill_bytes(&mut buf);
let candidate = U256::from_be_slice(&buf);
let valid = !candidate.ct_eq(&U256::ZERO) & candidate.ct_lt(&n);
if bool::from(valid) {
return Fn::new(&candidate);
}
}
}
fn ct_or_else<T: ConditionallySelectable + Default>(a: CtOption<T>, b: CtOption<T>) -> CtOption<T> {
let a_some = a.is_some();
let b_some = b.is_some();
let a_val = a.unwrap_or_else(T::default);
let b_val = b.unwrap_or_else(T::default);
let chosen = T::conditional_select(&b_val, &a_val, a_some);
let some = a_some | b_some;
CtOption::new(chosen, some)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sm2::private_key::Sm2PrivateKey;
use crypto_bigint::modular::ConstMontyParams;
use rand_core::Error;
struct SequenceRng {
values: [U256; 2],
index: usize,
}
impl RngCore for SequenceRng {
fn next_u32(&mut self) -> u32 {
0
}
fn next_u64(&mut self) -> u64 {
0
}
fn fill_bytes(&mut self, dst: &mut [u8]) {
assert_eq!(dst.len(), 32);
let value = self.values[self.index];
self.index += 1;
dst.copy_from_slice(&value.to_be_bytes());
}
fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Error> {
self.fill_bytes(dst);
Ok(())
}
}
impl CryptoRng for SequenceRng {}
#[test]
fn z_appendix_a2() {
let d =
U256::from_be_hex("3945208F7B2144B13F36E38AC6D39F95889393692860B51A42FB81EF4DF7C5B8");
let key = Sm2PrivateKey::new(d).expect("valid scalar");
let public = Sm2PublicKey::from_point(key.public_key());
let z = compute_z(&public, b"ALICE123@YAHOO.COM");
#[allow(clippy::format_collect)]
let z_hex: alloc::string::String =
z.iter().map(|byte| alloc::format!("{byte:02x}")).collect();
assert_eq!(
z_hex,
"26db4bc1839bd22e97e1dab667ec5e0a730d5e16521398b4435c576a93afd7ed"
);
}
#[test]
fn sign_over_long_id_rejected() {
use rand_core::OsRng;
let d =
U256::from_be_hex("3945208F7B2144B13F36E38AC6D39F95889393692860B51A42FB81EF4DF7C5B8");
let key = Sm2PrivateKey::new(d).expect("valid scalar");
let too_long = alloc::vec![0u8; MAX_ID_LEN + 1];
let result = sign_with_id(&key, &too_long, b"msg", &mut OsRng);
assert_eq!(result, Err(SignError::Failed));
}
#[test]
fn sample_nonzero_scalar_rejects_candidates_above_order() {
let n_plus_one = NMod::MODULUS.get().wrapping_add(&U256::ONE);
let mut rng = SequenceRng {
values: [n_plus_one, U256::from_u64(2)],
index: 0,
};
let sampled = sample_nonzero_scalar(&mut rng).retrieve();
assert_eq!(sampled, U256::from_u64(2));
assert_eq!(rng.index, 2);
}
}
#[cfg(test)]
mod sign_tests {
use super::*;
use rand_core::Error;
struct FixedScalarRng {
k_bytes: [u8; 32],
}
impl FixedScalarRng {
fn new(k_hex: &str) -> Self {
let k = U256::from_be_hex(k_hex);
Self {
k_bytes: k.to_be_bytes(),
}
}
}
impl RngCore for FixedScalarRng {
fn next_u32(&mut self) -> u32 {
0
}
fn next_u64(&mut self) -> u64 {
0
}
fn fill_bytes(&mut self, dst: &mut [u8]) {
assert_eq!(dst.len(), 32);
dst.copy_from_slice(&self.k_bytes);
}
fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Error> {
self.fill_bytes(dst);
Ok(())
}
}
impl CryptoRng for FixedScalarRng {}
#[test]
fn gbt32918_appendix_a2_fixed_k() {
let d =
U256::from_be_hex("3945208F7B2144B13F36E38AC6D39F95889393692860B51A42FB81EF4DF7C5B8");
let key = Sm2PrivateKey::new(d).expect("valid scalar");
let id = b"ALICE123@YAHOO.COM";
let message = b"message digest";
let mut rng =
FixedScalarRng::new("59276E27D506861A16680F3AD9C02DCFBFBF904F533DA0AC2EE1C9A45B58FF85");
let der = sign_with_id(&key, id, message, &mut rng).expect("sign succeeds");
let (r, s) = crate::asn1::sig::decode_sig(&der).expect("our own DER decodes");
assert_eq!(
r,
U256::from_be_hex("88348A09A3E324C4FE946843123E40C175468F3E36481885844A144D2167EA4C"),
"r mismatch"
);
assert_eq!(
s,
U256::from_be_hex("0AD2CE552FD33EAB792E5A2805E0504D014C96135F8E03891087132ABB24D48D"),
"s mismatch"
);
}
}