1use crate::asn1::sig::encode_sig;
4use crate::sm2::curve::{Fn, Fp, GX_HEX, GY_HEX, b};
5use crate::sm2::private_key::Sm2PrivateKey;
6use crate::sm2::public_key::Sm2PublicKey;
7use crate::sm2::scalar_mul::mul_g;
8use crate::sm3::{DIGEST_SIZE, Sm3};
9use alloc::vec::Vec;
10use crypto_bigint::U256;
11use rand_core::{CryptoRng, Rng};
12use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeLess, CtOption};
13
14pub const DEFAULT_SIGNER_ID: &[u8; 16] = b"1234567812345678";
16
17pub const MAX_ID_LEN: usize = (u16::MAX as usize) / 8;
26
27#[must_use]
39pub fn compute_z(public: &Sm2PublicKey, id: &[u8]) -> [u8; DIGEST_SIZE] {
40 assert!(
41 id.len() <= MAX_ID_LEN,
42 "id.len() exceeds MAX_ID_LEN — ENTL_A would silently wrap"
43 );
44 let mut h = Sm3::new();
45
46 #[allow(clippy::cast_possible_truncation)]
49 let entl: u16 = (id.len() as u16) * 8;
50 h.update(&entl.to_be_bytes());
51 h.update(id);
52
53 let three = U256::from_u64(3);
55 let p_minus_three = Fp::MODULUS.as_ref().wrapping_sub(&three);
56 h.update(&p_minus_three.to_be_bytes());
57
58 h.update(&b().retrieve().to_be_bytes());
60
61 h.update(&U256::from_be_hex(GX_HEX).to_be_bytes());
63 h.update(&U256::from_be_hex(GY_HEX).to_be_bytes());
64
65 let (px, py) = public.point().to_affine().expect("public key is finite");
67 h.update(&px.retrieve().to_be_bytes());
68 h.update(&py.retrieve().to_be_bytes());
69
70 h.finalize()
71}
72
73pub(crate) const SIGN_RETRY_BUDGET: usize = 2;
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub enum SignError {
79 Failed,
81}
82
83pub fn sign_with_id<R: CryptoRng + Rng>(
102 key: &Sm2PrivateKey,
103 id: &[u8],
104 message: &[u8],
105 rng: &mut R,
106) -> Result<Vec<u8>, SignError> {
107 let (r, s) = sign_raw_with_id(key, id, message, rng)?;
108 Ok(encode_sig(&r, &s))
109}
110
111#[doc(hidden)]
130pub fn sign_raw_with_id<R: CryptoRng + Rng>(
131 key: &Sm2PrivateKey,
132 id: &[u8],
133 message: &[u8],
134 rng: &mut R,
135) -> Result<(U256, U256), SignError> {
136 if id.len() > MAX_ID_LEN {
137 return Err(SignError::Failed);
138 }
139 let public = Sm2PublicKey::from_point(key.public_key());
140 let z = compute_z(&public, id);
141
142 let e_bytes = {
143 let mut h = Sm3::new();
144 h.update(&z);
145 h.update(message);
146 h.finalize()
147 };
148 let e_scalar = Fn::new(&U256::from_be_slice(&e_bytes));
149
150 let mut chosen: CtOption<RsPair> = CtOption::new(RsPair::default(), Choice::from(0));
151 for _ in 0..SIGN_RETRY_BUDGET {
152 let candidate = try_sign_once(key, &e_scalar, rng);
153 chosen = ct_or_else(chosen, candidate);
154 }
155
156 let pair: Option<RsPair> = chosen.into();
157 let pair = pair.ok_or(SignError::Failed)?;
158 Ok((pair.r, pair.s))
159}
160
161#[derive(Clone, Copy, Debug, Default)]
162struct RsPair {
163 r: U256,
164 s: U256,
165}
166
167impl ConditionallySelectable for RsPair {
168 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
169 Self {
170 r: U256::conditional_select(&a.r, &b.r, choice),
171 s: U256::conditional_select(&a.s, &b.s, choice),
172 }
173 }
174}
175
176#[allow(clippy::similar_names, clippy::many_single_char_names)]
177fn try_sign_once<R: CryptoRng + Rng>(key: &Sm2PrivateKey, e: &Fn, rng: &mut R) -> CtOption<RsPair> {
178 let k = sample_nonzero_scalar(rng);
179 let kg = mul_g(&k);
180 let (x1, _y1) = kg.to_affine().expect("k·G is finite for k != 0");
181
182 let x1_in_n = Fn::new(&x1.retrieve());
183 let r = *e + x1_in_n;
184
185 let r_u = r.retrieve();
186 let r_plus_k = (r + k).retrieve();
187 let r_zero: Choice = r_u.ct_eq(&U256::ZERO);
188 let rk_zero: Choice = r_plus_k.ct_eq(&U256::ZERO);
189 let bad_r = r_zero | rk_zero;
190
191 let d = key.scalar();
192 let one = Fn::new(&U256::ONE);
193 let one_plus_d = one + *d;
194 let one_plus_d_inv = one_plus_d.invert();
195 let rd = r * *d;
196 let k_minus_rd = k - rd;
197
198 let inv_unwrapped: Fn = one_plus_d_inv.unwrap_or(Fn::new(&U256::ONE));
199 let inv_ok: Choice = one_plus_d_inv.is_some().into();
200
201 let s = inv_unwrapped * k_minus_rd;
202 let s_u = s.retrieve();
203 let s_zero: Choice = s_u.ct_eq(&U256::ZERO);
204
205 let valid = !bad_r & !s_zero & inv_ok;
206 CtOption::new(RsPair { r: r_u, s: s_u }, valid)
207}
208
209pub(crate) fn sample_nonzero_scalar<R: CryptoRng + Rng>(rng: &mut R) -> Fn {
210 let n = *Fn::MODULUS.as_ref();
211 loop {
212 let mut buf = [0u8; 32];
213 rng.fill_bytes(&mut buf);
214 let candidate = U256::from_be_slice(&buf);
215 let valid = !candidate.ct_eq(&U256::ZERO) & candidate.ct_lt(&n);
216 if bool::from(valid) {
217 return Fn::new(&candidate);
218 }
219 }
220}
221
222fn ct_or_else<T: ConditionallySelectable + Default>(a: CtOption<T>, b: CtOption<T>) -> CtOption<T> {
223 let a_some = a.is_some();
224 let b_some = b.is_some();
225 let a_val = a.unwrap_or_else(T::default);
226 let b_val = b.unwrap_or_else(T::default);
227 let chosen = T::conditional_select(&b_val, &a_val, a_some);
228 let some = a_some | b_some;
229 CtOption::new(chosen, some)
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use crate::sm2::private_key::Sm2PrivateKey;
236 use core::convert::Infallible;
237 use getrandom::SysRng;
238 use rand_core::{TryCryptoRng, TryRng, UnwrapErr};
239
240 struct SequenceRng {
241 values: [U256; 2],
242 index: usize,
243 }
244
245 impl TryRng for SequenceRng {
246 type Error = Infallible;
247
248 fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
249 Ok(0)
250 }
251
252 fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
253 Ok(0)
254 }
255
256 fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Self::Error> {
257 assert_eq!(dst.len(), 32);
258 let value = self.values[self.index];
259 self.index += 1;
260 dst.copy_from_slice(&value.to_be_bytes());
261 Ok(())
262 }
263 }
264
265 impl TryCryptoRng for SequenceRng {}
266
267 #[test]
271 fn z_appendix_a2() {
272 let d =
273 U256::from_be_hex("3945208F7B2144B13F36E38AC6D39F95889393692860B51A42FB81EF4DF7C5B8");
274 let key = Sm2PrivateKey::new(d).expect("valid scalar");
275 let public = Sm2PublicKey::from_point(key.public_key());
276 let z = compute_z(&public, b"ALICE123@YAHOO.COM");
277
278 #[allow(clippy::format_collect)]
279 let z_hex: alloc::string::String =
280 z.iter().map(|byte| alloc::format!("{byte:02x}")).collect();
281 assert_eq!(
282 z_hex,
283 "26db4bc1839bd22e97e1dab667ec5e0a730d5e16521398b4435c576a93afd7ed"
284 );
285 }
286
287 #[test]
291 fn sign_over_long_id_rejected() {
292 let d =
293 U256::from_be_hex("3945208F7B2144B13F36E38AC6D39F95889393692860B51A42FB81EF4DF7C5B8");
294 let key = Sm2PrivateKey::new(d).expect("valid scalar");
295 let too_long = alloc::vec![0u8; MAX_ID_LEN + 1];
296 let mut rng = UnwrapErr(SysRng);
297 let result = sign_with_id(&key, &too_long, b"msg", &mut rng);
298 assert_eq!(result, Err(SignError::Failed));
299 }
300
301 #[test]
302 fn sample_nonzero_scalar_rejects_candidates_above_order() {
303 let n_plus_one = Fn::MODULUS.as_ref().wrapping_add(&U256::ONE);
304 let mut rng = SequenceRng {
305 values: [n_plus_one, U256::from_u64(2)],
306 index: 0,
307 };
308
309 let sampled = sample_nonzero_scalar(&mut rng).retrieve();
310
311 assert_eq!(sampled, U256::from_u64(2));
312 assert_eq!(rng.index, 2);
313 }
314}
315
316#[cfg(test)]
317mod sign_tests {
318 use super::*;
319 use core::convert::Infallible;
320 use rand_core::{TryCryptoRng, TryRng};
321
322 struct FixedScalarRng {
325 k_bytes: [u8; 32],
326 }
327 impl FixedScalarRng {
328 fn new(k_hex: &str) -> Self {
329 let k = U256::from_be_hex(k_hex);
330 Self {
331 k_bytes: k.to_be_bytes().into(),
332 }
333 }
334 }
335 impl TryRng for FixedScalarRng {
336 type Error = Infallible;
337
338 fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
339 Ok(0)
340 }
341
342 fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
343 Ok(0)
344 }
345
346 fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Self::Error> {
347 assert_eq!(dst.len(), 32);
348 dst.copy_from_slice(&self.k_bytes);
349 Ok(())
350 }
351 }
352
353 impl TryCryptoRng for FixedScalarRng {}
354
355 #[test]
362 fn gbt32918_appendix_a2_fixed_k() {
363 let d =
364 U256::from_be_hex("3945208F7B2144B13F36E38AC6D39F95889393692860B51A42FB81EF4DF7C5B8");
365 let key = Sm2PrivateKey::new(d).expect("valid scalar");
366 let id = b"ALICE123@YAHOO.COM";
367 let message = b"message digest";
368 let mut rng =
369 FixedScalarRng::new("59276E27D506861A16680F3AD9C02DCFBFBF904F533DA0AC2EE1C9A45B58FF85");
370 let der = sign_with_id(&key, id, message, &mut rng).expect("sign succeeds");
371 let (r, s) = crate::asn1::sig::decode_sig(&der).expect("our own DER decodes");
372 assert_eq!(
373 r,
374 U256::from_be_hex("88348A09A3E324C4FE946843123E40C175468F3E36481885844A144D2167EA4C"),
375 "r mismatch"
376 );
377 assert_eq!(
378 s,
379 U256::from_be_hex("0AD2CE552FD33EAB792E5A2805E0504D014C96135F8E03891087132ABB24D48D"),
380 "s mismatch"
381 );
382 }
383}