1use lazy_static::lazy_static;
2use num_bigint::{BigInt, BigUint};
3use num_traits::{Num, One};
4
5use crate::sm2::error::{Sm2Error, Sm2Result};
6use crate::sm2::formulas::*;
7use crate::sm2::key::CompressModle;
8use crate::sm2::p256_field::{FieldElement, ECC_P};
9use crate::sm2::p256_pre_table::{PRE_TABLE_1, PRE_TABLE_2};
10
11lazy_static! {
12 pub static ref P256C_PARAMS: CurveParameters = CurveParameters::new_default();
13}
14
15#[derive(Debug, Clone)]
17pub struct CurveParameters {
18 pub p: FieldElement,
20
21 pub n: BigUint,
23
24 pub a: FieldElement,
26
27 pub b: FieldElement,
29
30 pub h: BigUint,
33
34 pub g_point: Point,
36
37 pub rr: BigInt,
38 pub rr_pp: BigInt,
39 pub r: BigInt,
40 pub q: BigInt,
41}
42
43impl Default for CurveParameters {
44 fn default() -> Self {
45 CurveParameters::new_default()
46 }
47}
48
49impl CurveParameters {
50 pub fn generate() -> Self {
53 unimplemented!()
54 }
55
56 pub fn verify(&self) -> bool {
59 unimplemented!()
60 }
61
62 pub fn new_default() -> CurveParameters {
63 let p = FieldElement::new(ECC_P);
64 let n = BigUint::from_str_radix(
65 "FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123",
66 16,
67 )
68 .unwrap();
69 let a = FieldElement::new([
71 0xffff_fffe,
72 0xffff_ffff,
73 0xffff_ffff,
74 0xffff_ffff,
75 0xffff_ffff,
76 0x0000_0000,
77 0xffff_ffff,
78 0xffff_fffc,
79 ]);
80 let b = FieldElement::new([
81 0x28e9_fa9e,
82 0x9d9f_5e34,
83 0x4d5a_9e4b,
84 0xcf65_09a7,
85 0xf397_89f5,
86 0x15ab_8f92,
87 0xddbc_bd41,
88 0x4d94_0e93,
89 ]);
90
91 let g_x = FieldElement::new([
92 0x32c4_ae2c,
93 0x1f19_8119,
94 0x5f99_0446,
95 0x6a39_c994,
96 0x8fe3_0bbf,
97 0xf266_0be1,
98 0x715a_4589,
99 0x334c_74c7,
100 ]);
101 let g_y = FieldElement::new([
102 0xbc37_36a2,
103 0xf4f6_779c,
104 0x59bd_cee3,
105 0x6b69_2153,
106 0xd0a9_877c,
107 0xc62a_4740,
108 0x02df_32e5,
109 0x2139_f0a0,
110 ]);
111
112 let r = BigInt::from_str_radix(
113 "010000000000000000000000000000000000000000000000000000000000000000",
114 16,
115 )
116 .unwrap();
117 let ctx = CurveParameters {
118 p,
119 n,
120 a,
121 b,
122 h: BigUint::one(), g_point: Point {
124 x: g_x,
125 y: g_y,
126 z: FieldElement::one(),
127 },
128 rr: &r * &r,
129 rr_pp: BigInt::from_str_radix(
130 "400000002000000010000000100000002ffffffff0000000200000003",
131 16,
132 )
133 .unwrap(),
134 r,
135 q: BigInt::from_str_radix(
136 "-3fffffffe00000001ffffffff00000000fffffffeffffffffffffffff",
137 16,
138 )
139 .unwrap(),
140 };
141 ctx
142 }
143}
144
145#[derive(Debug, Clone, Eq, PartialEq, Copy)]
146pub struct Point {
147 pub x: FieldElement,
148 pub y: FieldElement,
149 pub z: FieldElement,
150}
151
152impl Point {
153 pub fn to_affine(&self) -> Point {
154 let z_inv = &self.z.modinv();
155 let z_inv2 = z_inv * z_inv;
156 let z_inv3 = z_inv2 * z_inv;
157 let x = &self.x * z_inv2;
158 let y = &self.y * z_inv3;
159 Point {
160 x,
161 y,
162 z: FieldElement::one(),
163 }
164 }
165
166 pub fn to_byte(&self, compress_modle: CompressModle) -> Vec<u8> {
167 let p_affine = self.to_affine();
168 let mut x_vec = p_affine.x.to_bytes_be();
169 let mut y_vec = p_affine.y.to_bytes_be();
170 let mut ret: Vec<u8> = Vec::new();
171 match compress_modle {
172 CompressModle::Compressed => {
173 if y_vec[y_vec.len() - 1] & 0x01 == 0 {
174 ret.push(0x02);
175 } else {
176 ret.push(0x03);
177 }
178 ret.append(&mut x_vec);
179 }
180 CompressModle::Uncompressed => {
181 ret.push(0x04);
182 ret.append(&mut x_vec);
183 ret.append(&mut y_vec);
184 }
185 CompressModle::Mixed => {
186 if y_vec[y_vec.len() - 1] & 0x01 == 0 {
187 ret.push(0x06);
188 } else {
189 ret.push(0x07);
190 }
191 ret.append(&mut x_vec);
192 ret.append(&mut y_vec);
193 }
194 }
195 ret
196 }
197
198 pub(crate) fn from_byte(b: &[u8], compress_modle: CompressModle) -> Sm2Result<Point> {
199 return match compress_modle {
200 CompressModle::Compressed => {
201 if b.len() != 33 {
202 return Err(Sm2Error::InvalidPublic);
203 }
204 let y_q;
205 if b[0] == 0x02 {
206 y_q = 0;
207 } else if b[0] == 0x03 {
208 y_q = 1
209 } else {
210 return Err(Sm2Error::InvalidPublic);
211 }
212 let x = FieldElement::from_bytes_be(&b[1..])?;
213 let xxx = &x * &x * &x;
214 let ax = &P256C_PARAMS.a * &x;
215 let yy = &xxx + &ax + &P256C_PARAMS.b;
216
217 let mut y = yy.sqrt()?;
218 let y_vec = y.to_bytes_be();
219 if y_vec[y_vec.len() - 1] & 0x01 != y_q {
220 y = &P256C_PARAMS.p - y;
221 }
222 Ok(Point {
223 x,
224 y,
225 z: FieldElement::one(),
226 })
227 }
228 CompressModle::Uncompressed | CompressModle::Mixed => {
229 if b.len() != 65 {
230 return Err(Sm2Error::InvalidPublic);
231 }
232 let x = FieldElement::from_bytes_be(&b[1..33])?;
233 let y = FieldElement::from_bytes_be(&b[33..65])?;
234 Ok(Point {
235 x,
236 y,
237 z: FieldElement::one(),
238 })
239 }
240 };
241 }
242}
243
244impl Point {
245 pub fn is_zero(&self) -> bool {
246 self.z.is_zero()
247 }
248
249 pub fn is_valid(&self) -> bool {
250 if self.is_zero() {
251 true
252 } else {
253 let yy = &self.y * &self.y;
255 let xx = &self.x * &self.x;
256 let z2 = &self.z * &self.z;
257 let z4 = &z2 * &z2;
258 let z6 = &z4 * &z2;
259 let z6_b = &P256C_PARAMS.b * &z6;
260 let a_z4 = &P256C_PARAMS.a * &z4;
261 let xx_a_4z = &xx + &a_z4;
262 let xxx_a_4z = &xx_a_4z * &self.x;
263 let exp = &xxx_a_4z + &z6_b;
264 yy.eq(&exp)
265 }
266 }
267
268 pub fn is_valid_affine(&self) -> bool {
269 let yy = &self.y * &self.y;
271 let xx = &self.x * &self.x;
272 let xx_a = &P256C_PARAMS.a + &xx;
273 let xxx_a = &self.x * &xx_a;
274 let b = &P256C_PARAMS.b;
275 let exp = &xxx_a + b;
276 yy.eq(&exp)
277 }
278
279 pub fn zero() -> Point {
280 Point {
281 x: FieldElement::one(),
282 y: FieldElement::one(),
283 z: FieldElement::zero(),
284 }
285 }
286
287 pub fn neg(&self) -> Point {
288 Point {
289 x: self.x.clone(),
290 y: &P256C_PARAMS.p - &self.y,
291 z: self.z.clone(),
292 }
293 }
294
295 pub fn double(&self) -> Point {
296 double_1998_cmo(self)
297 }
298
299 pub fn add(&self, p2: &Point) -> Point {
300 if self.is_zero() {
302 return p2.clone();
303 }
304 if p2.is_zero() {
306 return self.clone();
307 }
308 let x1 = &self.x;
309 let y1 = &self.y;
310 let z1 = &self.z;
311
312 let x2 = &p2.x;
313 let y2 = &p2.y;
314 let z2 = &p2.z;
315 if x1 == x2 && y1 == y2 && z1 == z2 {
317 return self.double();
318 } else {
319 add_1998_cmo(self, &p2)
320 }
321 }
322}
323
324#[inline(always)]
325const fn ith_bit(n: u32, i: i32) -> u32 {
326 (n >> i) & 0x01
327}
328
329#[inline(always)]
330const fn compose_index(v: &[u32], i: i32) -> u32 {
331 ith_bit(v[7], i)
332 + (ith_bit(v[6], i) << 1)
333 + (ith_bit(v[5], i) << 2)
334 + (ith_bit(v[4], i) << 3)
335 + (ith_bit(v[3], i) << 4)
336 + (ith_bit(v[2], i) << 5)
337 + (ith_bit(v[1], i) << 6)
338 + (ith_bit(v[0], i) << 7)
339}
340
341pub fn g_mul(m: &BigUint) -> Point {
342 let k = FieldElement::from_biguint(&m).unwrap();
343 let mut q = Point::zero();
344 let mut i = 15;
345 while i >= 0 {
346 q = q.double();
347 let low_index = compose_index(&k.inner, i);
348 let high_index = compose_index(&k.inner, i + 16);
349 let p1 = &PRE_TABLE_1[low_index as usize];
350 let p2 = &PRE_TABLE_2[high_index as usize];
351 q = q.add(p1).add(p2);
352 i -= 1;
353 }
354 q
355}
356
357pub fn scalar_mul(m: &BigUint, p: &Point) -> Point {
360 mul_naf(m, p)
361}
362
363pub fn mlsm_mul(k: &BigUint, p: &Point) -> Point {
373 let bi = k.to_bytes_be();
374 let mut q0 = Point::zero();
375 let mut qt = p.clone();
376 let mut i = 0;
377 while i < bi.len() {
378 let q1 = q0.add(&qt);
379 let q2 = qt.double();
380 if bi[i] & 0x1 == 1 {
381 q0 = q1;
382 }
383 qt = q2;
384 i += 1;
385 }
386 q0
387}
388
389fn mul_naf(m: &BigUint, p: &Point) -> Point {
391 let p1 = p.clone();
393 let p2 = p.double();
394 let mut pre_table = vec![];
395 for _ in 0..32 {
396 pre_table.push(Point::zero());
397 }
398 let offset = 16;
399 pre_table[1 + offset] = p1;
400 pre_table[offset - 1] = pre_table[1 + offset].neg();
401 for i in 1..8 {
402 pre_table[2 * i + offset + 1] = p2.add(&pre_table[2 * i + offset - 1]);
403 pre_table[offset - 2 * i - 1] = pre_table[2 * i + offset + 1].neg();
404 }
405
406 let k = FieldElement::from_biguint(m).unwrap();
407 let mut l = 256;
408 let naf = w_naf(&k.inner, 5, &mut l);
409 let mut q = Point::zero();
410 loop {
411 q = q.double();
412 if naf[l] != 0 {
413 let index = (naf[l] + 16) as usize;
414 q = q.add(&pre_table[index]);
415 }
416 if l == 0 {
417 break;
418 }
419 l -= 1;
420 }
421 q
422}
423
424#[inline(always)]
426fn w_naf(k: &[u32], w: usize, lst: &mut usize) -> [i8; 257] {
427 let mut carry = 0;
428 let mut bit = 0;
429 let mut ret: [i8; 257] = [0; 257];
430 let mut n: [u32; 9] = [0; 9];
431
432 n[1..9].clone_from_slice(&k[..8]);
433
434 let window: u32 = (1 << w) - 1;
435
436 while bit < 256 {
437 let u32_idx = 8 - bit as usize / 32;
438 let bit_idx = 31 - bit as usize % 32;
439
440 if ((n[u32_idx] >> (31 - bit_idx)) & 1) == carry {
441 bit += 1;
442 continue;
443 }
444
445 let mut word: u32 = if bit_idx >= w - 1 {
446 (n[u32_idx] >> (31 - bit_idx)) & window
447 } else {
448 ((n[u32_idx] >> (31 - bit_idx)) | (n[u32_idx - 1] << (bit_idx + 1))) & window
449 };
450
451 word += carry;
452
453 carry = (word >> (w - 1)) & 1;
454 ret[bit] = word as i8 - (carry << w) as i8;
455
456 *lst = bit;
457 bit += w;
458 }
459
460 if carry == 1 {
461 ret[256] = 1;
462 *lst = 256;
463 }
464 ret
465}
466
467fn pre_vec_gen(n: u32) -> [u32; 8] {
468 let mut pre_vec: [u32; 8] = [0; 8];
469 let mut i = 0;
470 while i < 8 {
471 pre_vec[7 - i] = (n >> i) & 0x01;
472 i += 1;
473 }
474 pre_vec
475}
476
477fn pre_vec_gen2(n: u32) -> [u32; 8] {
478 let mut pre_vec: [u32; 8] = [0; 8];
479 let mut i = 0;
480 while i < 8 {
481 pre_vec[7 - i] = ((n >> i) & 0x01) << 16;
482 i += 1;
483 }
484 pre_vec
485}
486
487#[cfg(test)]
488mod test {
489 use crate::sm2::p256_ecc::{pre_vec_gen, pre_vec_gen2, scalar_mul, Point, P256C_PARAMS};
490 use crate::sm2::p256_field::FieldElement;
491 use num_bigint::BigUint;
492 use num_traits::{FromPrimitive, Num, Pow};
493
494 #[test]
495 fn test_g_table() {
496 let mut table_1: Vec<Point> = Vec::new();
497 for i in 0..256 {
498 let k = FieldElement::from_slice(&pre_vec_gen(i as u32));
499 let p1 = scalar_mul(&k.to_biguint(), &P256C_PARAMS.g_point);
500 table_1.push(p1);
501 }
502
503 let mut table_2: Vec<Point> = Vec::new();
504 for i in 0..256 {
505 let k = FieldElement::from_slice(&pre_vec_gen2(i as u32));
506 let p1 = scalar_mul(&k.to_biguint(), &P256C_PARAMS.g_point);
507 table_2.push(p1);
508 }
509
510 println!("table_1 = {:?}", table_1);
511 println!("table_2 = {:?}", table_2);
512 }
513
514 #[test]
515 fn test_r() {
516 let r_1 = BigUint::from_str_radix(
517 "010000000000000000000000000000000000000000000000000000000000000000",
518 16,
519 )
520 .unwrap();
521
522 let p = BigUint::from_str_radix(
523 "FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF",
524 16,
525 )
526 .unwrap();
527
528 let n = BigUint::from_str_radix(
529 "FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123",
530 16,
531 )
532 .unwrap();
533
534 let r = BigUint::from_u32(2).unwrap().pow(256 as u32);
535 let rr = r.pow(2u32);
536 println!("r = {:?}", r.to_str_radix(16));
537 println!("r1= {:?}", r_1.to_str_radix(16));
538 println!("r_p = {:?}", (&r % &p).to_str_radix(16));
539 println!("r_n = {:?}", (&r % &n).to_str_radix(16));
540
541 println!("rr_p = {:?}", (&rr % &p).to_str_radix(16));
542 println!("rr_n = {:?}", (&rr % &n).to_str_radix(16));
543 }
544}