gm_rs/sm2/
p256_ecc.rs

1use lazy_static::lazy_static;
2use num_bigint::{BigInt, BigUint};
3use num_traits::{Num, One};
4
5use crate::sm2::error::{Sm2Error, Sm2Result};
6use crate::sm2::formulas::*;
7use crate::sm2::key::CompressModle;
8use crate::sm2::p256_field::{FieldElement, ECC_P};
9use crate::sm2::p256_pre_table::{PRE_TABLE_1, PRE_TABLE_2};
10
11lazy_static! {
12    pub static ref P256C_PARAMS: CurveParameters = CurveParameters::new_default();
13}
14
15/// ecc equation: y^2 == x^3 +ax + b (mod p)
16#[derive(Debug, Clone)]
17pub struct CurveParameters {
18    /// p:大于3的素数
19    pub p: FieldElement,
20
21    /// n:基点G的阶(n是#E(Fq)的素因子)
22    pub n: BigUint,
23
24    /// a:Fq中的元素,它们定义Fq上的一条椭圆曲线E
25    pub a: FieldElement,
26
27    /// b:Fq中的元素,它们定义Fq上的一条椭圆曲线E
28    pub b: FieldElement,
29
30    /// The Cofactor, the recommended value is 1
31    /// 余因子,h = #E(Fq)/n,其中n是基点G的阶
32    pub h: BigUint,
33
34    /// G:椭圆曲线的一个基点,其阶为素数
35    pub g_point: Point,
36
37    pub rr: BigInt,
38    pub rr_pp: BigInt,
39    pub r: BigInt,
40    pub q: BigInt,
41}
42
43impl Default for CurveParameters {
44    fn default() -> Self {
45        CurveParameters::new_default()
46    }
47}
48
49impl CurveParameters {
50    /// 生成椭圆曲线参数
51    ///
52    pub fn generate() -> Self {
53        unimplemented!()
54    }
55
56    /// 验证椭圆曲线参数
57    ///
58    pub fn verify(&self) -> bool {
59        unimplemented!()
60    }
61
62    pub fn new_default() -> CurveParameters {
63        let p = FieldElement::new(ECC_P);
64        let n = BigUint::from_str_radix(
65            "FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123",
66            16,
67        )
68        .unwrap();
69        // a = p - 3
70        let a = FieldElement::new([
71            0xffff_fffe,
72            0xffff_ffff,
73            0xffff_ffff,
74            0xffff_ffff,
75            0xffff_ffff,
76            0x0000_0000,
77            0xffff_ffff,
78            0xffff_fffc,
79        ]);
80        let b = FieldElement::new([
81            0x28e9_fa9e,
82            0x9d9f_5e34,
83            0x4d5a_9e4b,
84            0xcf65_09a7,
85            0xf397_89f5,
86            0x15ab_8f92,
87            0xddbc_bd41,
88            0x4d94_0e93,
89        ]);
90
91        let g_x = FieldElement::new([
92            0x32c4_ae2c,
93            0x1f19_8119,
94            0x5f99_0446,
95            0x6a39_c994,
96            0x8fe3_0bbf,
97            0xf266_0be1,
98            0x715a_4589,
99            0x334c_74c7,
100        ]);
101        let g_y = FieldElement::new([
102            0xbc37_36a2,
103            0xf4f6_779c,
104            0x59bd_cee3,
105            0x6b69_2153,
106            0xd0a9_877c,
107            0xc62a_4740,
108            0x02df_32e5,
109            0x2139_f0a0,
110        ]);
111
112        let r = BigInt::from_str_radix(
113            "010000000000000000000000000000000000000000000000000000000000000000",
114            16,
115        )
116        .unwrap();
117        let ctx = CurveParameters {
118            p,
119            n,
120            a,
121            b,
122            h: BigUint::one(), // The Cofactor, the recommended value is 1
123            g_point: Point {
124                x: g_x,
125                y: g_y,
126                z: FieldElement::one(),
127            },
128            rr: &r * &r,
129            rr_pp: BigInt::from_str_radix(
130                "400000002000000010000000100000002ffffffff0000000200000003",
131                16,
132            )
133            .unwrap(),
134            r,
135            q: BigInt::from_str_radix(
136                "-3fffffffe00000001ffffffff00000000fffffffeffffffffffffffff",
137                16,
138            )
139            .unwrap(),
140        };
141        ctx
142    }
143}
144
145#[derive(Debug, Clone, Eq, PartialEq, Copy)]
146pub struct Point {
147    pub x: FieldElement,
148    pub y: FieldElement,
149    pub z: FieldElement,
150}
151
152impl Point {
153    pub fn to_affine(&self) -> Point {
154        let z_inv = &self.z.modinv();
155        let z_inv2 = z_inv * z_inv;
156        let z_inv3 = z_inv2 * z_inv;
157        let x = &self.x * z_inv2;
158        let y = &self.y * z_inv3;
159        Point {
160            x,
161            y,
162            z: FieldElement::one(),
163        }
164    }
165
166    pub fn to_byte(&self, compress_modle: CompressModle) -> Vec<u8> {
167        let p_affine = self.to_affine();
168        let mut x_vec = p_affine.x.to_bytes_be();
169        let mut y_vec = p_affine.y.to_bytes_be();
170        let mut ret: Vec<u8> = Vec::new();
171        match compress_modle {
172            CompressModle::Compressed => {
173                if y_vec[y_vec.len() - 1] & 0x01 == 0 {
174                    ret.push(0x02);
175                } else {
176                    ret.push(0x03);
177                }
178                ret.append(&mut x_vec);
179            }
180            CompressModle::Uncompressed => {
181                ret.push(0x04);
182                ret.append(&mut x_vec);
183                ret.append(&mut y_vec);
184            }
185            CompressModle::Mixed => {
186                if y_vec[y_vec.len() - 1] & 0x01 == 0 {
187                    ret.push(0x06);
188                } else {
189                    ret.push(0x07);
190                }
191                ret.append(&mut x_vec);
192                ret.append(&mut y_vec);
193            }
194        }
195        ret
196    }
197
198    pub(crate) fn from_byte(b: &[u8], compress_modle: CompressModle) -> Sm2Result<Point> {
199        return match compress_modle {
200            CompressModle::Compressed => {
201                if b.len() != 33 {
202                    return Err(Sm2Error::InvalidPublic);
203                }
204                let y_q;
205                if b[0] == 0x02 {
206                    y_q = 0;
207                } else if b[0] == 0x03 {
208                    y_q = 1
209                } else {
210                    return Err(Sm2Error::InvalidPublic);
211                }
212                let x = FieldElement::from_bytes_be(&b[1..])?;
213                let xxx = &x * &x * &x;
214                let ax = &P256C_PARAMS.a * &x;
215                let yy = &xxx + &ax + &P256C_PARAMS.b;
216
217                let mut y = yy.sqrt()?;
218                let y_vec = y.to_bytes_be();
219                if y_vec[y_vec.len() - 1] & 0x01 != y_q {
220                    y = &P256C_PARAMS.p - y;
221                }
222                Ok(Point {
223                    x,
224                    y,
225                    z: FieldElement::one(),
226                })
227            }
228            CompressModle::Uncompressed | CompressModle::Mixed => {
229                if b.len() != 65 {
230                    return Err(Sm2Error::InvalidPublic);
231                }
232                let x = FieldElement::from_bytes_be(&b[1..33])?;
233                let y = FieldElement::from_bytes_be(&b[33..65])?;
234                Ok(Point {
235                    x,
236                    y,
237                    z: FieldElement::one(),
238                })
239            }
240        };
241    }
242}
243
244impl Point {
245    pub fn is_zero(&self) -> bool {
246        self.z.is_zero()
247    }
248
249    pub fn is_valid(&self) -> bool {
250        if self.is_zero() {
251            true
252        } else {
253            // y^2 = x * (x^2 + a * z^4) + b * z^6
254            let yy = &self.y * &self.y;
255            let xx = &self.x * &self.x;
256            let z2 = &self.z * &self.z;
257            let z4 = &z2 * &z2;
258            let z6 = &z4 * &z2;
259            let z6_b = &P256C_PARAMS.b * &z6;
260            let a_z4 = &P256C_PARAMS.a * &z4;
261            let xx_a_4z = &xx + &a_z4;
262            let xxx_a_4z = &xx_a_4z * &self.x;
263            let exp = &xxx_a_4z + &z6_b;
264            yy.eq(&exp)
265        }
266    }
267
268    pub fn is_valid_affine(&self) -> bool {
269        // y^2 = x * (x^2 + a) + b
270        let yy = &self.y * &self.y;
271        let xx = &self.x * &self.x;
272        let xx_a = &P256C_PARAMS.a + &xx;
273        let xxx_a = &self.x * &xx_a;
274        let b = &P256C_PARAMS.b;
275        let exp = &xxx_a + b;
276        yy.eq(&exp)
277    }
278
279    pub fn zero() -> Point {
280        Point {
281            x: FieldElement::one(),
282            y: FieldElement::one(),
283            z: FieldElement::zero(),
284        }
285    }
286
287    pub fn neg(&self) -> Point {
288        Point {
289            x: self.x.clone(),
290            y: &P256C_PARAMS.p - &self.y,
291            z: self.z.clone(),
292        }
293    }
294
295    pub fn double(&self) -> Point {
296        double_1998_cmo(self)
297    }
298
299    pub fn add(&self, p2: &Point) -> Point {
300        // 0 + p2 = p2
301        if self.is_zero() {
302            return p2.clone();
303        }
304        // p1 + 0 = p1
305        if p2.is_zero() {
306            return self.clone();
307        }
308        let x1 = &self.x;
309        let y1 = &self.y;
310        let z1 = &self.z;
311
312        let x2 = &p2.x;
313        let y2 = &p2.y;
314        let z2 = &p2.z;
315        // p1 = p2
316        if x1 == x2 && y1 == y2 && z1 == z2 {
317            return self.double();
318        } else {
319            add_1998_cmo(self, &p2)
320        }
321    }
322}
323
324#[inline(always)]
325const fn ith_bit(n: u32, i: i32) -> u32 {
326    (n >> i) & 0x01
327}
328
329#[inline(always)]
330const fn compose_index(v: &[u32], i: i32) -> u32 {
331    ith_bit(v[7], i)
332        + (ith_bit(v[6], i) << 1)
333        + (ith_bit(v[5], i) << 2)
334        + (ith_bit(v[4], i) << 3)
335        + (ith_bit(v[3], i) << 4)
336        + (ith_bit(v[2], i) << 5)
337        + (ith_bit(v[1], i) << 6)
338        + (ith_bit(v[0], i) << 7)
339}
340
341pub fn g_mul(m: &BigUint) -> Point {
342    let k = FieldElement::from_biguint(&m).unwrap();
343    let mut q = Point::zero();
344    let mut i = 15;
345    while i >= 0 {
346        q = q.double();
347        let low_index = compose_index(&k.inner, i);
348        let high_index = compose_index(&k.inner, i + 16);
349        let p1 = &PRE_TABLE_1[low_index as usize];
350        let p2 = &PRE_TABLE_2[high_index as usize];
351        q = q.add(p1).add(p2);
352        i -= 1;
353    }
354    q
355}
356
357//
358// P = [k]G
359pub fn scalar_mul(m: &BigUint, p: &Point) -> Point {
360    mul_naf(m, p)
361}
362
363// Montgomery ladder based scalar multiplication (MLSM)
364// Input: integer k and point P, m = bit length of k
365// 1: Initial: Q1 = Q0 = 0, QT = P, i = 0
366// 2: While i < m, do:
367// 3: Q1 = Q0 + QT , Q2 = 2QT
368// 4: If(ki = 1) Switch(Q0, Q1)
369// 5: QT = Q2, i = i + 1
370// 6: end While
371// TODO fixme: The mlsm_mul cause signature verify failed
372pub fn mlsm_mul(k: &BigUint, p: &Point) -> Point {
373    let bi = k.to_bytes_be();
374    let mut q0 = Point::zero();
375    let mut qt = p.clone();
376    let mut i = 0;
377    while i < bi.len() {
378        let q1 = q0.add(&qt);
379        let q2 = qt.double();
380        if bi[i] & 0x1 == 1 {
381            q0 = q1;
382        }
383        qt = q2;
384        i += 1;
385    }
386    q0
387}
388
389// 滑动窗法
390fn mul_naf(m: &BigUint, p: &Point) -> Point {
391    // 预处理计算
392    let p1 = p.clone();
393    let p2 = p.double();
394    let mut pre_table = vec![];
395    for _ in 0..32 {
396        pre_table.push(Point::zero());
397    }
398    let offset = 16;
399    pre_table[1 + offset] = p1;
400    pre_table[offset - 1] = pre_table[1 + offset].neg();
401    for i in 1..8 {
402        pre_table[2 * i + offset + 1] = p2.add(&pre_table[2 * i + offset - 1]);
403        pre_table[offset - 2 * i - 1] = pre_table[2 * i + offset + 1].neg();
404    }
405
406    let k = FieldElement::from_biguint(m).unwrap();
407    let mut l = 256;
408    let naf = w_naf(&k.inner, 5, &mut l);
409    let mut q = Point::zero();
410    loop {
411        q = q.double();
412        if naf[l] != 0 {
413            let index = (naf[l] + 16) as usize;
414            q = q.add(&pre_table[index]);
415        }
416        if l == 0 {
417            break;
418        }
419        l -= 1;
420    }
421    q
422}
423
424//w-naf algorithm
425#[inline(always)]
426fn w_naf(k: &[u32], w: usize, lst: &mut usize) -> [i8; 257] {
427    let mut carry = 0;
428    let mut bit = 0;
429    let mut ret: [i8; 257] = [0; 257];
430    let mut n: [u32; 9] = [0; 9];
431
432    n[1..9].clone_from_slice(&k[..8]);
433
434    let window: u32 = (1 << w) - 1;
435
436    while bit < 256 {
437        let u32_idx = 8 - bit as usize / 32;
438        let bit_idx = 31 - bit as usize % 32;
439
440        if ((n[u32_idx] >> (31 - bit_idx)) & 1) == carry {
441            bit += 1;
442            continue;
443        }
444
445        let mut word: u32 = if bit_idx >= w - 1 {
446            (n[u32_idx] >> (31 - bit_idx)) & window
447        } else {
448            ((n[u32_idx] >> (31 - bit_idx)) | (n[u32_idx - 1] << (bit_idx + 1))) & window
449        };
450
451        word += carry;
452
453        carry = (word >> (w - 1)) & 1;
454        ret[bit] = word as i8 - (carry << w) as i8;
455
456        *lst = bit;
457        bit += w;
458    }
459
460    if carry == 1 {
461        ret[256] = 1;
462        *lst = 256;
463    }
464    ret
465}
466
467fn pre_vec_gen(n: u32) -> [u32; 8] {
468    let mut pre_vec: [u32; 8] = [0; 8];
469    let mut i = 0;
470    while i < 8 {
471        pre_vec[7 - i] = (n >> i) & 0x01;
472        i += 1;
473    }
474    pre_vec
475}
476
477fn pre_vec_gen2(n: u32) -> [u32; 8] {
478    let mut pre_vec: [u32; 8] = [0; 8];
479    let mut i = 0;
480    while i < 8 {
481        pre_vec[7 - i] = ((n >> i) & 0x01) << 16;
482        i += 1;
483    }
484    pre_vec
485}
486
487#[cfg(test)]
488mod test {
489    use crate::sm2::p256_ecc::{pre_vec_gen, pre_vec_gen2, scalar_mul, Point, P256C_PARAMS};
490    use crate::sm2::p256_field::FieldElement;
491    use num_bigint::BigUint;
492    use num_traits::{FromPrimitive, Num, Pow};
493
494    #[test]
495    fn test_g_table() {
496        let mut table_1: Vec<Point> = Vec::new();
497        for i in 0..256 {
498            let k = FieldElement::from_slice(&pre_vec_gen(i as u32));
499            let p1 = scalar_mul(&k.to_biguint(), &P256C_PARAMS.g_point);
500            table_1.push(p1);
501        }
502
503        let mut table_2: Vec<Point> = Vec::new();
504        for i in 0..256 {
505            let k = FieldElement::from_slice(&pre_vec_gen2(i as u32));
506            let p1 = scalar_mul(&k.to_biguint(), &P256C_PARAMS.g_point);
507            table_2.push(p1);
508        }
509
510        println!("table_1 = {:?}", table_1);
511        println!("table_2 = {:?}", table_2);
512    }
513
514    #[test]
515    fn test_r() {
516        let r_1 = BigUint::from_str_radix(
517            "010000000000000000000000000000000000000000000000000000000000000000",
518            16,
519        )
520        .unwrap();
521
522        let p = BigUint::from_str_radix(
523            "FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF",
524            16,
525        )
526        .unwrap();
527
528        let n = BigUint::from_str_radix(
529            "FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123",
530            16,
531        )
532        .unwrap();
533
534        let r = BigUint::from_u32(2).unwrap().pow(256 as u32);
535        let rr = r.pow(2u32);
536        println!("r = {:?}", r.to_str_radix(16));
537        println!("r1= {:?}", r_1.to_str_radix(16));
538        println!("r_p = {:?}", (&r % &p).to_str_radix(16));
539        println!("r_n = {:?}", (&r % &n).to_str_radix(16));
540
541        println!("rr_p = {:?}", (&rr % &p).to_str_radix(16));
542        println!("rr_n = {:?}", (&rr % &n).to_str_radix(16));
543    }
544}