gm_rs/sm2/
operation.rs

1use std::io::Cursor;
2
3use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
4use num_bigint::{BigUint, ModInverse};
5use num_integer::Integer;
6
7use crate::sm2::p256_field::{Conversion, Fe, FieldElement};
8use crate::sm2::util::{add_raw, mul_raw, sub_raw};
9use crate::sm2::FeOperation;
10
11impl Conversion for Fe {
12    fn fe_to_bigunit(&self) -> BigUint {
13        let mut ret: Vec<u8> = Vec::new();
14        for i in 0..8 {
15            ret.write_u32::<BigEndian>(self[i]).unwrap();
16        }
17        BigUint::from_bytes_be(&ret[..])
18    }
19
20    fn bigunit_fe(&self) -> Fe {
21        unimplemented!()
22    }
23}
24
25impl Conversion for BigUint {
26    fn fe_to_bigunit(&self) -> BigUint {
27        unimplemented!()
28    }
29
30    fn bigunit_fe(&self) -> Fe {
31        let v = self.to_bytes_be();
32        let mut num_v = [0u8; 32];
33        num_v[32 - v.len()..32].copy_from_slice(&v[..]);
34        let mut elem = [0u32; 8];
35        let mut c = Cursor::new(num_v);
36        for i in 0..8 {
37            let x = c.read_u32::<BigEndian>().unwrap();
38            elem[i] = x;
39        }
40        elem
41    }
42}
43
44impl FeOperation for Fe {
45    #[inline]
46    fn mod_add(&self, other: &Self, modulus: &Self) -> Self {
47        let (raw_sum, carry) = add_raw(self, other);
48        if carry || raw_sum >= *modulus {
49            let (sum, _borrow) = sub_raw(&raw_sum, &modulus);
50            sum
51        } else {
52            raw_sum
53        }
54    }
55
56    #[inline]
57    fn mod_sub(&self, other: &Self, modulus: &Self) -> Self {
58        let (raw_diff, borrow) = sub_raw(&self, other);
59        if borrow {
60            let (modulus_complete, _) = sub_raw(&[0; 8], &modulus);
61            let (diff, _borrow) = sub_raw(&raw_diff, &modulus_complete);
62            diff
63        } else {
64            raw_diff
65        }
66    }
67
68    #[inline]
69    fn mod_mul(&self, other: &Self, modulus: &Self) -> Self {
70        let raw_prod = mul_raw(self, other);
71        fast_reduction(&raw_prod, &modulus)
72    }
73
74    #[inline]
75    fn inv(&self, modulus: &Self) -> Self {
76        let mut ru = *self;
77        let mut rv = *modulus;
78        let mut ra = FieldElement::from_number(1).inner;
79        let mut rc = [0; 8];
80        while ru != [0; 8] {
81            if ru[7] & 0x01 == 0 {
82                ru = ru.right_shift(0);
83                if ra[7] & 0x01 == 0 {
84                    ra = ra.right_shift(0);
85                } else {
86                    let (sum, car) = add_raw(&ra, &modulus);
87                    ra = sum.right_shift(car as u32);
88                }
89            }
90
91            if rv[7] & 0x01 == 0 {
92                rv = rv.right_shift(0);
93                if rc[7] & 0x01 == 0 {
94                    rc = rc.right_shift(0);
95                } else {
96                    let (sum, car) = add_raw(&rc, &modulus);
97                    rc = sum.right_shift(car as u32);
98                }
99            }
100
101            if ru >= rv {
102                ru = ru.mod_sub(&rv, &modulus);
103                ra = ra.mod_sub(&rc, &modulus);
104            } else {
105                rv = rv.mod_sub(&ru, &modulus);
106                rc = rc.mod_sub(&ra, &modulus);
107            }
108        }
109        rc
110    }
111
112    #[inline]
113    fn right_shift(&self, carry: u32) -> Self {
114        let mut ret = [0; 8];
115        let mut carry = carry;
116        let mut i = 0;
117        while i < 8 {
118            ret[i] = (carry << 31) + (self[i] >> 1);
119            carry = self[i] & 0x01;
120            i += 1;
121        }
122        ret
123    }
124}
125
126// a quick algorithm to reduce elements on SCA-256 field
127// Reference:
128// http://ieeexplore.ieee.org/document/7285166/ for details
129// 国密SM2的快速约减算法详细描述
130// S0 = (m7, m6, m5, m4, m3, m2, m1, m0)
131// S1 = (m15, 0, 0, 0, 0, 0, m15, m14)
132// S2 = (m14, 0, 0, 0, 0, 0, m14, m13)
133// S3 = (m13, 0, 0, 0, 0, 0, 0, m15)
134// S4 = (m12, 0, m15, m14, m13, 0, 0, m15)
135// S5 = (m15, m15, m14, m13, m12, 0, m11, m10)
136// S6 = (m11, m14, m13, m12, m11, 0, m10, m9)
137// S7 = (m10, m11, m10, m9, m8, 0, m13, m12)
138// S8 = (m9, 0, 0, m15, m14, 0, m9, m8)
139// S9 = (m8, 0, 0, 0, m15, 0, m12, m11)
140//
141#[inline(always)]
142pub fn fast_reduction(a: &[u32; 16], modulus: &[u32; 8]) -> [u32; 8] {
143    let mut s: [[u32; 8]; 10] = [[0; 8]; 10];
144    let mut m: [u32; 16] = [0; 16];
145
146    let mut i = 0;
147    while i < 16 {
148        m[i] = a[15 - i];
149        i += 1;
150    }
151
152    s[0] = [m[7], m[6], m[5], m[4], m[3], m[2], m[1], m[0]];
153
154    s[1] = [m[15], 0, 0, 0, 0, 0, m[15], m[14]];
155    s[2] = [m[14], 0, 0, 0, 0, 0, m[14], m[13]];
156    s[3] = [m[13], 0, 0, 0, 0, 0, 0, 0];
157    s[4] = [m[12], 0, m[15], m[14], m[13], 0, 0, m[15]];
158
159    s[5] = [m[15], m[15], m[14], m[13], m[12], 0, m[11], m[10]];
160    s[6] = [m[11], m[14], m[13], m[12], m[11], 0, m[10], m[9]];
161    s[7] = [m[10], m[11], m[10], m[9], m[8], 0, m[13], m[12]];
162    s[8] = [m[9], 0, 0, m[15], m[14], 0, m[9], m[8]];
163    s[9] = [m[8], 0, 0, 0, m[15], 0, m[12], m[11]];
164
165    let mut carry: i32 = 0;
166    let mut ret = [0; 8];
167
168    // part1: 2 * (s1+s2+s3+s4)
169    let (rt, rc) = add_raw(&ret, &s[1]);
170    ret = rt;
171    carry += rc as i32;
172    let (rt, rc) = add_raw(&ret, &s[2]);
173    ret = rt;
174    carry += rc as i32;
175    let (rt, rc) = add_raw(&ret, &s[3]);
176    ret = rt;
177    carry += rc as i32;
178    let (rt, rc) = add_raw(&ret, &s[4]);
179    ret = rt;
180    carry += rc as i32;
181    let (rt, rc) = add_raw(&ret, &ret);
182    ret = rt;
183    carry = carry * 2 + rc as i32;
184
185    // part2: s0+s5+s6+s7+s8+s9
186    let (rt, rc) = add_raw(&ret, &s[5]);
187    ret = rt;
188    carry += rc as i32;
189    let (rt, rc) = add_raw(&ret, &s[6]);
190    ret = rt;
191    carry += rc as i32;
192    let (rt, rc) = add_raw(&ret, &s[7]);
193    ret = rt;
194    carry += rc as i32;
195    let (rt, rc) = add_raw(&ret, &s[8]);
196    ret = rt;
197    carry += rc as i32;
198    let (rt, rc) = add_raw(&ret, &s[9]);
199    ret = rt;
200    carry += rc as i32;
201    let (rt, rc) = add_raw(&ret, &s[0]);
202    ret = rt;
203    carry += rc as i32;
204
205    // part3:  m8+m9+m13+m14
206    let mut part3 = [0; 8];
207    let subtra: u64 = u64::from(m[8]) + u64::from(m[9]) + u64::from(m[13]) + u64::from(m[14]);
208    part3[5] = (subtra & 0xffff_ffff) as u32;
209    part3[4] = (subtra >> 32) as u32;
210
211    // part1 + part2 - part3
212    let (rt, rc) = sub_raw(&ret, &part3);
213    ret = rt;
214    carry -= rc as i32;
215
216    while carry > 0 || ret >= *modulus {
217        let (rs, rb) = sub_raw(&ret, modulus);
218        ret = rs;
219        carry -= rb as i32;
220    }
221    ret
222}
223
224impl FeOperation for BigUint {
225    fn mod_add(&self, other: &Self, modulus: &Self) -> BigUint {
226        (self + other) % modulus
227    }
228
229    fn mod_sub(&self, other: &Self, modulus: &Self) -> BigUint {
230        if self >= other {
231            (self - other) % modulus
232        } else {
233            // 负数取模
234            let d = other - self;
235            let e = d.div_ceil(modulus);
236            e * modulus - d
237        }
238    }
239
240    fn mod_mul(&self, other: &Self, modulus: &Self) -> BigUint {
241        (self * other) % modulus
242    }
243
244    fn inv(&self, modulus: &Self) -> BigUint {
245        self.mod_inverse(modulus).unwrap().to_biguint().unwrap()
246    }
247
248    fn right_shift(&self, carry: u32) -> BigUint {
249        let mut ret = self.clone();
250        ret = ret >> (carry as i32) as usize;
251        ret
252    }
253}
254
255#[cfg(test)]
256mod test_op {
257    use num_bigint::ModInverse;
258    use rand::{thread_rng, Rng};
259
260    use crate::sm2::p256_ecc::P256C_PARAMS;
261    use crate::sm2::p256_pre_table::PRE_TABLE_1;
262    use crate::sm2::FeOperation;
263
264    #[test]
265    fn test_mod_add() {
266        let mut rng = thread_rng();
267        let n: u32 = rng.gen_range(10..256);
268
269        let modulus = &P256C_PARAMS.p;
270
271        let p = &PRE_TABLE_1[n as usize];
272        let x = p.x.to_biguint();
273        let y = p.y.to_biguint();
274
275        let ret1 = x.mod_add(&y, &modulus.to_biguint());
276        let ret2 = (p.x + p.y).to_biguint();
277
278        assert_eq!(ret2, ret1)
279    }
280
281    #[test]
282    fn test_mod_sub() {
283        let mut rng = thread_rng();
284        let n: u32 = rng.gen_range(10..256);
285
286        let modulus = &P256C_PARAMS.p;
287
288        let p = &PRE_TABLE_1[n as usize];
289        let x = p.x.to_biguint();
290        let y = p.y.to_biguint();
291
292        let ret1 = x.mod_sub(&y, &modulus.to_biguint());
293        let ret2 = (p.x - p.y).to_biguint();
294
295        assert_eq!(ret2, ret1)
296    }
297
298    #[test]
299    fn test_mod_mul() {
300        let mut rng = thread_rng();
301        let n: u32 = rng.gen_range(10..256);
302
303        let modulus = &P256C_PARAMS.p;
304
305        let p = &PRE_TABLE_1[n as usize];
306        let x = p.x.to_biguint();
307        let y = p.y.to_biguint();
308
309        let ret1 = x.mod_mul(&y, &modulus.to_biguint());
310        let ret2 = (p.x * p.y).to_biguint();
311
312        assert_eq!(ret2, ret1)
313    }
314
315    #[test]
316    fn test_mod_inv() {
317        let mut rng = thread_rng();
318        let n: u32 = rng.gen_range(10..256);
319
320        let modulus = &P256C_PARAMS.p;
321
322        let p = &PRE_TABLE_1[n as usize];
323        let x = p.x.to_biguint();
324
325        let ret1 = x.inv(&modulus.to_biguint());
326        let ret2 = x
327            .mod_inverse(&modulus.to_biguint())
328            .unwrap()
329            .to_biguint()
330            .unwrap();
331
332        assert_eq!(ret2, ret1)
333    }
334}