1use crate::asn1::sig::encode_sig;
4use crate::sm2::curve::{b, Fn, NMod, PMod, GX_HEX, GY_HEX};
5use crate::sm2::private_key::Sm2PrivateKey;
6use crate::sm2::public_key::Sm2PublicKey;
7use crate::sm2::scalar_mul::mul_g;
8use crate::sm3::{Sm3, DIGEST_SIZE};
9use alloc::vec::Vec;
10use crypto_bigint::modular::ConstMontyParams;
11use crypto_bigint::{Invert, U256};
12use rand_core::{CryptoRng, RngCore};
13use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeLess, CtOption};
14
15pub const DEFAULT_SIGNER_ID: &[u8; 16] = b"1234567812345678";
17
18pub const MAX_ID_LEN: usize = (u16::MAX as usize) / 8;
27
28#[must_use]
40pub fn compute_z(public: &Sm2PublicKey, id: &[u8]) -> [u8; DIGEST_SIZE] {
41 assert!(
42 id.len() <= MAX_ID_LEN,
43 "id.len() exceeds MAX_ID_LEN — ENTL_A would silently wrap"
44 );
45 let mut h = Sm3::new();
46
47 #[allow(clippy::cast_possible_truncation)]
50 let entl: u16 = (id.len() as u16) * 8;
51 h.update(&entl.to_be_bytes());
52 h.update(id);
53
54 let three = U256::from_u64(3);
56 let p_minus_three = PMod::MODULUS.get().wrapping_sub(&three);
57 h.update(&p_minus_three.to_be_bytes());
58
59 h.update(&b().retrieve().to_be_bytes());
61
62 h.update(&U256::from_be_hex(GX_HEX).to_be_bytes());
64 h.update(&U256::from_be_hex(GY_HEX).to_be_bytes());
65
66 let (px, py) = public.point().to_affine().expect("public key is finite");
68 h.update(&px.retrieve().to_be_bytes());
69 h.update(&py.retrieve().to_be_bytes());
70
71 h.finalize()
72}
73
74pub(crate) const SIGN_RETRY_BUDGET: usize = 2;
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
79pub enum SignError {
80 Failed,
82}
83
84pub fn sign_with_id<R: CryptoRng + RngCore>(
103 key: &Sm2PrivateKey,
104 id: &[u8],
105 message: &[u8],
106 rng: &mut R,
107) -> Result<Vec<u8>, SignError> {
108 let (r, s) = sign_raw_with_id(key, id, message, rng)?;
109 Ok(encode_sig(&r, &s))
110}
111
112#[doc(hidden)]
131pub fn sign_raw_with_id<R: CryptoRng + RngCore>(
132 key: &Sm2PrivateKey,
133 id: &[u8],
134 message: &[u8],
135 rng: &mut R,
136) -> Result<(U256, U256), SignError> {
137 if id.len() > MAX_ID_LEN {
138 return Err(SignError::Failed);
139 }
140 let public = Sm2PublicKey::from_point(key.public_key());
141 let z = compute_z(&public, id);
142
143 let e_bytes = {
144 let mut h = Sm3::new();
145 h.update(&z);
146 h.update(message);
147 h.finalize()
148 };
149 let e_scalar = Fn::new(&U256::from_be_slice(&e_bytes));
150
151 let mut chosen: CtOption<RsPair> = CtOption::new(RsPair::default(), Choice::from(0));
152 for _ in 0..SIGN_RETRY_BUDGET {
153 let candidate = try_sign_once(key, &e_scalar, rng);
154 chosen = ct_or_else(chosen, candidate);
155 }
156
157 let pair: Option<RsPair> = chosen.into();
158 let pair = pair.ok_or(SignError::Failed)?;
159 Ok((pair.r, pair.s))
160}
161
162#[derive(Clone, Copy, Debug, Default)]
163struct RsPair {
164 r: U256,
165 s: U256,
166}
167
168impl ConditionallySelectable for RsPair {
169 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
170 Self {
171 r: U256::conditional_select(&a.r, &b.r, choice),
172 s: U256::conditional_select(&a.s, &b.s, choice),
173 }
174 }
175}
176
177#[allow(clippy::similar_names, clippy::many_single_char_names)]
178fn try_sign_once<R: CryptoRng + RngCore>(
179 key: &Sm2PrivateKey,
180 e: &Fn,
181 rng: &mut R,
182) -> CtOption<RsPair> {
183 let k = sample_nonzero_scalar(rng);
184 let kg = mul_g(&k);
185 let (x1, _y1) = kg.to_affine().expect("k·G is finite for k != 0");
186
187 let x1_in_n = Fn::new(&x1.retrieve());
188 let r = *e + x1_in_n;
189
190 let r_u = r.retrieve();
191 let r_plus_k = (r + k).retrieve();
192 let r_zero: Choice = r_u.ct_eq(&U256::ZERO);
193 let rk_zero: Choice = r_plus_k.ct_eq(&U256::ZERO);
194 let bad_r = r_zero | rk_zero;
195
196 let d = key.scalar();
197 let one = Fn::new(&U256::ONE);
198 let one_plus_d = one + *d;
199 let one_plus_d_inv = one_plus_d.invert();
200 let rd = r * *d;
201 let k_minus_rd = k - rd;
202
203 let inv_unwrapped: Fn = one_plus_d_inv.unwrap_or(Fn::new(&U256::ONE));
204 let inv_ok: Choice = one_plus_d_inv.is_some();
205
206 let s = inv_unwrapped * k_minus_rd;
207 let s_u = s.retrieve();
208 let s_zero: Choice = s_u.ct_eq(&U256::ZERO);
209
210 let valid = !bad_r & !s_zero & inv_ok;
211 CtOption::new(RsPair { r: r_u, s: s_u }, valid)
212}
213
214fn sample_nonzero_scalar<R: CryptoRng + RngCore>(rng: &mut R) -> Fn {
215 let n = NMod::MODULUS.get();
216 loop {
217 let mut buf = [0u8; 32];
218 rng.fill_bytes(&mut buf);
219 let candidate = U256::from_be_slice(&buf);
220 let valid = !candidate.ct_eq(&U256::ZERO) & candidate.ct_lt(&n);
221 if bool::from(valid) {
222 return Fn::new(&candidate);
223 }
224 }
225}
226
227fn ct_or_else<T: ConditionallySelectable + Default>(a: CtOption<T>, b: CtOption<T>) -> CtOption<T> {
228 let a_some = a.is_some();
229 let b_some = b.is_some();
230 let a_val = a.unwrap_or_else(T::default);
231 let b_val = b.unwrap_or_else(T::default);
232 let chosen = T::conditional_select(&b_val, &a_val, a_some);
233 let some = a_some | b_some;
234 CtOption::new(chosen, some)
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use crate::sm2::private_key::Sm2PrivateKey;
241 use crypto_bigint::modular::ConstMontyParams;
242 use rand_core::Error;
243
244 struct SequenceRng {
245 values: [U256; 2],
246 index: usize,
247 }
248
249 impl RngCore for SequenceRng {
250 fn next_u32(&mut self) -> u32 {
251 0
252 }
253
254 fn next_u64(&mut self) -> u64 {
255 0
256 }
257
258 fn fill_bytes(&mut self, dst: &mut [u8]) {
259 assert_eq!(dst.len(), 32);
260 let value = self.values[self.index];
261 self.index += 1;
262 dst.copy_from_slice(&value.to_be_bytes());
263 }
264
265 fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Error> {
266 self.fill_bytes(dst);
267 Ok(())
268 }
269 }
270
271 impl CryptoRng for SequenceRng {}
272
273 #[test]
277 fn z_appendix_a2() {
278 let d =
279 U256::from_be_hex("3945208F7B2144B13F36E38AC6D39F95889393692860B51A42FB81EF4DF7C5B8");
280 let key = Sm2PrivateKey::new(d).expect("valid scalar");
281 let public = Sm2PublicKey::from_point(key.public_key());
282 let z = compute_z(&public, b"ALICE123@YAHOO.COM");
283
284 #[allow(clippy::format_collect)]
285 let z_hex: alloc::string::String =
286 z.iter().map(|byte| alloc::format!("{byte:02x}")).collect();
287 assert_eq!(
288 z_hex,
289 "26db4bc1839bd22e97e1dab667ec5e0a730d5e16521398b4435c576a93afd7ed"
290 );
291 }
292
293 #[test]
297 fn sign_over_long_id_rejected() {
298 use rand_core::OsRng;
299 let d =
300 U256::from_be_hex("3945208F7B2144B13F36E38AC6D39F95889393692860B51A42FB81EF4DF7C5B8");
301 let key = Sm2PrivateKey::new(d).expect("valid scalar");
302 let too_long = alloc::vec![0u8; MAX_ID_LEN + 1];
303 let result = sign_with_id(&key, &too_long, b"msg", &mut OsRng);
304 assert_eq!(result, Err(SignError::Failed));
305 }
306
307 #[test]
308 fn sample_nonzero_scalar_rejects_candidates_above_order() {
309 let n_plus_one = NMod::MODULUS.get().wrapping_add(&U256::ONE);
310 let mut rng = SequenceRng {
311 values: [n_plus_one, U256::from_u64(2)],
312 index: 0,
313 };
314
315 let sampled = sample_nonzero_scalar(&mut rng).retrieve();
316
317 assert_eq!(sampled, U256::from_u64(2));
318 assert_eq!(rng.index, 2);
319 }
320}
321
322#[cfg(test)]
323mod sign_tests {
324 use super::*;
325 use rand_core::Error;
326
327 struct FixedScalarRng {
330 k_bytes: [u8; 32],
331 }
332 impl FixedScalarRng {
333 fn new(k_hex: &str) -> Self {
334 let k = U256::from_be_hex(k_hex);
335 Self {
336 k_bytes: k.to_be_bytes(),
337 }
338 }
339 }
340 impl RngCore for FixedScalarRng {
341 fn next_u32(&mut self) -> u32 {
342 0
343 }
344 fn next_u64(&mut self) -> u64 {
345 0
346 }
347 fn fill_bytes(&mut self, dst: &mut [u8]) {
348 assert_eq!(dst.len(), 32);
349 dst.copy_from_slice(&self.k_bytes);
350 }
351 fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Error> {
352 self.fill_bytes(dst);
353 Ok(())
354 }
355 }
356 impl CryptoRng for FixedScalarRng {}
357
358 #[test]
365 fn gbt32918_appendix_a2_fixed_k() {
366 let d =
367 U256::from_be_hex("3945208F7B2144B13F36E38AC6D39F95889393692860B51A42FB81EF4DF7C5B8");
368 let key = Sm2PrivateKey::new(d).expect("valid scalar");
369 let id = b"ALICE123@YAHOO.COM";
370 let message = b"message digest";
371 let mut rng =
372 FixedScalarRng::new("59276E27D506861A16680F3AD9C02DCFBFBF904F533DA0AC2EE1C9A45B58FF85");
373 let der = sign_with_id(&key, id, message, &mut rng).expect("sign succeeds");
374 let (r, s) = crate::asn1::sig::decode_sig(&der).expect("our own DER decodes");
375 assert_eq!(
376 r,
377 U256::from_be_hex("88348A09A3E324C4FE946843123E40C175468F3E36481885844A144D2167EA4C"),
378 "r mismatch"
379 );
380 assert_eq!(
381 s,
382 U256::from_be_hex("0AD2CE552FD33EAB792E5A2805E0504D014C96135F8E03891087132ABB24D48D"),
383 "s mismatch"
384 );
385 }
386}