1use std::io::Cursor;
2
3use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
4use num_bigint::{BigUint, ModInverse};
5use num_integer::Integer;
6
7use crate::sm2::p256_field::{Conversion, Fe, FieldElement};
8use crate::sm2::util::{add_raw, mul_raw, sub_raw};
9use crate::sm2::FeOperation;
10
11impl Conversion for Fe {
12 fn fe_to_bigunit(&self) -> BigUint {
13 let mut ret: Vec<u8> = Vec::new();
14 for i in 0..8 {
15 ret.write_u32::<BigEndian>(self[i]).unwrap();
16 }
17 BigUint::from_bytes_be(&ret[..])
18 }
19
20 fn bigunit_fe(&self) -> Fe {
21 unimplemented!()
22 }
23}
24
25impl Conversion for BigUint {
26 fn fe_to_bigunit(&self) -> BigUint {
27 unimplemented!()
28 }
29
30 fn bigunit_fe(&self) -> Fe {
31 let v = self.to_bytes_be();
32 let mut num_v = [0u8; 32];
33 num_v[32 - v.len()..32].copy_from_slice(&v[..]);
34 let mut elem = [0u32; 8];
35 let mut c = Cursor::new(num_v);
36 for i in 0..8 {
37 let x = c.read_u32::<BigEndian>().unwrap();
38 elem[i] = x;
39 }
40 elem
41 }
42}
43
44impl FeOperation for Fe {
45 #[inline]
46 fn mod_add(&self, other: &Self, modulus: &Self) -> Self {
47 let (raw_sum, carry) = add_raw(self, other);
48 if carry || raw_sum >= *modulus {
49 let (sum, _borrow) = sub_raw(&raw_sum, &modulus);
50 sum
51 } else {
52 raw_sum
53 }
54 }
55
56 #[inline]
57 fn mod_sub(&self, other: &Self, modulus: &Self) -> Self {
58 let (raw_diff, borrow) = sub_raw(&self, other);
59 if borrow {
60 let (modulus_complete, _) = sub_raw(&[0; 8], &modulus);
61 let (diff, _borrow) = sub_raw(&raw_diff, &modulus_complete);
62 diff
63 } else {
64 raw_diff
65 }
66 }
67
68 #[inline]
69 fn mod_mul(&self, other: &Self, modulus: &Self) -> Self {
70 let raw_prod = mul_raw(self, other);
71 fast_reduction(&raw_prod, &modulus)
72 }
73
74 #[inline]
75 fn inv(&self, modulus: &Self) -> Self {
76 let mut ru = *self;
77 let mut rv = *modulus;
78 let mut ra = FieldElement::from_number(1).inner;
79 let mut rc = [0; 8];
80 while ru != [0; 8] {
81 if ru[7] & 0x01 == 0 {
82 ru = ru.right_shift(0);
83 if ra[7] & 0x01 == 0 {
84 ra = ra.right_shift(0);
85 } else {
86 let (sum, car) = add_raw(&ra, &modulus);
87 ra = sum.right_shift(car as u32);
88 }
89 }
90
91 if rv[7] & 0x01 == 0 {
92 rv = rv.right_shift(0);
93 if rc[7] & 0x01 == 0 {
94 rc = rc.right_shift(0);
95 } else {
96 let (sum, car) = add_raw(&rc, &modulus);
97 rc = sum.right_shift(car as u32);
98 }
99 }
100
101 if ru >= rv {
102 ru = ru.mod_sub(&rv, &modulus);
103 ra = ra.mod_sub(&rc, &modulus);
104 } else {
105 rv = rv.mod_sub(&ru, &modulus);
106 rc = rc.mod_sub(&ra, &modulus);
107 }
108 }
109 rc
110 }
111
112 #[inline]
113 fn right_shift(&self, carry: u32) -> Self {
114 let mut ret = [0; 8];
115 let mut carry = carry;
116 let mut i = 0;
117 while i < 8 {
118 ret[i] = (carry << 31) + (self[i] >> 1);
119 carry = self[i] & 0x01;
120 i += 1;
121 }
122 ret
123 }
124}
125
126#[inline(always)]
142pub fn fast_reduction(a: &[u32; 16], modulus: &[u32; 8]) -> [u32; 8] {
143 let mut s: [[u32; 8]; 10] = [[0; 8]; 10];
144 let mut m: [u32; 16] = [0; 16];
145
146 let mut i = 0;
147 while i < 16 {
148 m[i] = a[15 - i];
149 i += 1;
150 }
151
152 s[0] = [m[7], m[6], m[5], m[4], m[3], m[2], m[1], m[0]];
153
154 s[1] = [m[15], 0, 0, 0, 0, 0, m[15], m[14]];
155 s[2] = [m[14], 0, 0, 0, 0, 0, m[14], m[13]];
156 s[3] = [m[13], 0, 0, 0, 0, 0, 0, 0];
157 s[4] = [m[12], 0, m[15], m[14], m[13], 0, 0, m[15]];
158
159 s[5] = [m[15], m[15], m[14], m[13], m[12], 0, m[11], m[10]];
160 s[6] = [m[11], m[14], m[13], m[12], m[11], 0, m[10], m[9]];
161 s[7] = [m[10], m[11], m[10], m[9], m[8], 0, m[13], m[12]];
162 s[8] = [m[9], 0, 0, m[15], m[14], 0, m[9], m[8]];
163 s[9] = [m[8], 0, 0, 0, m[15], 0, m[12], m[11]];
164
165 let mut carry: i32 = 0;
166 let mut ret = [0; 8];
167
168 let (rt, rc) = add_raw(&ret, &s[1]);
170 ret = rt;
171 carry += rc as i32;
172 let (rt, rc) = add_raw(&ret, &s[2]);
173 ret = rt;
174 carry += rc as i32;
175 let (rt, rc) = add_raw(&ret, &s[3]);
176 ret = rt;
177 carry += rc as i32;
178 let (rt, rc) = add_raw(&ret, &s[4]);
179 ret = rt;
180 carry += rc as i32;
181 let (rt, rc) = add_raw(&ret, &ret);
182 ret = rt;
183 carry = carry * 2 + rc as i32;
184
185 let (rt, rc) = add_raw(&ret, &s[5]);
187 ret = rt;
188 carry += rc as i32;
189 let (rt, rc) = add_raw(&ret, &s[6]);
190 ret = rt;
191 carry += rc as i32;
192 let (rt, rc) = add_raw(&ret, &s[7]);
193 ret = rt;
194 carry += rc as i32;
195 let (rt, rc) = add_raw(&ret, &s[8]);
196 ret = rt;
197 carry += rc as i32;
198 let (rt, rc) = add_raw(&ret, &s[9]);
199 ret = rt;
200 carry += rc as i32;
201 let (rt, rc) = add_raw(&ret, &s[0]);
202 ret = rt;
203 carry += rc as i32;
204
205 let mut part3 = [0; 8];
207 let subtra: u64 = u64::from(m[8]) + u64::from(m[9]) + u64::from(m[13]) + u64::from(m[14]);
208 part3[5] = (subtra & 0xffff_ffff) as u32;
209 part3[4] = (subtra >> 32) as u32;
210
211 let (rt, rc) = sub_raw(&ret, &part3);
213 ret = rt;
214 carry -= rc as i32;
215
216 while carry > 0 || ret >= *modulus {
217 let (rs, rb) = sub_raw(&ret, modulus);
218 ret = rs;
219 carry -= rb as i32;
220 }
221 ret
222}
223
224impl FeOperation for BigUint {
225 fn mod_add(&self, other: &Self, modulus: &Self) -> BigUint {
226 (self + other) % modulus
227 }
228
229 fn mod_sub(&self, other: &Self, modulus: &Self) -> BigUint {
230 if self >= other {
231 (self - other) % modulus
232 } else {
233 let d = other - self;
235 let e = d.div_ceil(modulus);
236 e * modulus - d
237 }
238 }
239
240 fn mod_mul(&self, other: &Self, modulus: &Self) -> BigUint {
241 (self * other) % modulus
242 }
243
244 fn inv(&self, modulus: &Self) -> BigUint {
245 self.mod_inverse(modulus).unwrap().to_biguint().unwrap()
246 }
247
248 fn right_shift(&self, carry: u32) -> BigUint {
249 let mut ret = self.clone();
250 ret = ret >> (carry as i32) as usize;
251 ret
252 }
253}
254
255#[cfg(test)]
256mod test_op {
257 use num_bigint::ModInverse;
258 use rand::{thread_rng, Rng};
259
260 use crate::sm2::p256_ecc::P256C_PARAMS;
261 use crate::sm2::p256_pre_table::PRE_TABLE_1;
262 use crate::sm2::FeOperation;
263
264 #[test]
265 fn test_mod_add() {
266 let mut rng = thread_rng();
267 let n: u32 = rng.gen_range(10..256);
268
269 let modulus = &P256C_PARAMS.p;
270
271 let p = &PRE_TABLE_1[n as usize];
272 let x = p.x.to_biguint();
273 let y = p.y.to_biguint();
274
275 let ret1 = x.mod_add(&y, &modulus.to_biguint());
276 let ret2 = (p.x + p.y).to_biguint();
277
278 assert_eq!(ret2, ret1)
279 }
280
281 #[test]
282 fn test_mod_sub() {
283 let mut rng = thread_rng();
284 let n: u32 = rng.gen_range(10..256);
285
286 let modulus = &P256C_PARAMS.p;
287
288 let p = &PRE_TABLE_1[n as usize];
289 let x = p.x.to_biguint();
290 let y = p.y.to_biguint();
291
292 let ret1 = x.mod_sub(&y, &modulus.to_biguint());
293 let ret2 = (p.x - p.y).to_biguint();
294
295 assert_eq!(ret2, ret1)
296 }
297
298 #[test]
299 fn test_mod_mul() {
300 let mut rng = thread_rng();
301 let n: u32 = rng.gen_range(10..256);
302
303 let modulus = &P256C_PARAMS.p;
304
305 let p = &PRE_TABLE_1[n as usize];
306 let x = p.x.to_biguint();
307 let y = p.y.to_biguint();
308
309 let ret1 = x.mod_mul(&y, &modulus.to_biguint());
310 let ret2 = (p.x * p.y).to_biguint();
311
312 assert_eq!(ret2, ret1)
313 }
314
315 #[test]
316 fn test_mod_inv() {
317 let mut rng = thread_rng();
318 let n: u32 = rng.gen_range(10..256);
319
320 let modulus = &P256C_PARAMS.p;
321
322 let p = &PRE_TABLE_1[n as usize];
323 let x = p.x.to_biguint();
324
325 let ret1 = x.inv(&modulus.to_biguint());
326 let ret2 = x
327 .mod_inverse(&modulus.to_biguint())
328 .unwrap()
329 .to_biguint()
330 .unwrap();
331
332 assert_eq!(ret2, ret1)
333 }
334}