1use crate::sm2::error::{Sm2Error, Sm2Result};
2use crate::sm2::key::Sm2PublicKey;
3use crate::sm2::p256_ecc::P256C_PARAMS;
4use crate::sm3;
5use crate::sm3::sm3_hash;
6use byteorder::{BigEndian, WriteBytesExt};
7use num_bigint::BigUint;
8use num_traits::{One, Zero};
9use rand::RngCore;
10
11pub(crate) const DEFAULT_ID: &'static str = "1234567812345678";
12
13#[inline]
14pub fn random_uint() -> BigUint {
15 let n = &P256C_PARAMS.n;
16 let mut rng = rand::thread_rng();
17 let mut buf: [u8; 32] = [0; 32];
18 let mut ret;
19 loop {
20 rng.fill_bytes(&mut buf[..]);
21 ret = BigUint::from_bytes_be(&buf[..]);
22 if ret < n - BigUint::one() && ret != BigUint::zero() {
23 break;
24 }
25 }
26 ret
27}
28
29pub fn compute_za(id: &str, pk: &Sm2PublicKey) -> Sm2Result<[u8; 32]> {
31 if !pk.is_valid() {
32 return Err(Sm2Error::InvalidPublic);
33 }
34 let mut prepend: Vec<u8> = Vec::new();
35 if id.len() * 8 > 65535 {
36 return Err(Sm2Error::IdTooLong);
37 }
38 prepend
39 .write_u16::<BigEndian>((id.len() * 8) as u16)
40 .unwrap();
41 for c in id.bytes() {
42 prepend.push(c);
43 }
44
45 prepend.extend_from_slice(&P256C_PARAMS.a.to_bytes_be());
46 prepend.extend_from_slice(&P256C_PARAMS.b.to_bytes_be());
47 prepend.extend_from_slice(&P256C_PARAMS.g_point.x.to_bytes_be());
48 prepend.extend_from_slice(&P256C_PARAMS.g_point.y.to_bytes_be());
49
50 let pk_affine = pk.value().to_affine();
51 prepend.extend_from_slice(&pk_affine.x.to_bytes_be());
52 prepend.extend_from_slice(&pk_affine.y.to_bytes_be());
53
54 Ok(sm3_hash(&prepend))
55}
56
57#[inline]
58pub fn kdf(z: &[u8], klen: usize) -> Vec<u8> {
59 let mut ct = 0x00000001u32;
60 let bound = ((klen as f64) / 32.0).ceil() as u32;
61 let mut h_a = Vec::new();
62 for _i in 1..bound {
63 let mut prepend = Vec::new();
64 prepend.extend_from_slice(z);
65 prepend.extend_from_slice(&ct.to_be_bytes());
66
67 let h_a_i = sm3_hash(&prepend[..]);
68 h_a.extend_from_slice(&h_a_i);
69 ct += 1;
70 }
71
72 let mut prepend = Vec::new();
73 prepend.extend_from_slice(z);
74 prepend.extend_from_slice(&ct.to_be_bytes());
75
76 let last = sm3::sm3_hash(&prepend[..]);
77 if klen % 32 == 0 {
78 h_a.extend_from_slice(&last);
79 } else {
80 h_a.extend_from_slice(&last[0..(klen % 32)]);
81 }
82 h_a
83}
84
85#[inline(always)]
86pub const fn adc_u64(a: u64, b: u64, carry: u64) -> (u64, u64) {
87 let ret = (a as u128) + (b as u128) + (carry as u128);
88 ((ret & 0xffff_ffff_ffff_ffff) as u64, (ret >> 64) as u64)
89}
90
91#[inline(always)]
92pub const fn add_u32(a: u32, b: u32, carry: u32) -> (u32, bool) {
93 let (m, c1) = a.overflowing_add(b);
94 let (r, c2) = m.overflowing_add(carry as u32);
95 (r & 0xffff_ffff, c1 || c2)
96}
97
98#[inline(always)]
99pub const fn fe32_to_fe64(fe32: &[u32; 8]) -> [u64; 4] {
100 [
101 (fe32[0] as u64) | ((fe32[1] as u64) << 32),
102 (fe32[2] as u64) | ((fe32[3] as u64) << 32),
103 (fe32[4] as u64) | ((fe32[5] as u64) << 32),
104 (fe32[6] as u64) | ((fe32[7] as u64) << 32),
105 ]
106}
107
108#[inline(always)]
109pub const fn fe64_to_fe32(fe64: &[u64; 4]) -> [u32; 8] {
110 let (w0, w1, w2, w3) = (fe64[0], fe64[1], fe64[2], fe64[3]);
111 [
112 (w0 & 0xFFFFFFFF) as u32,
113 (w0 >> 32) as u32,
114 (w1 & 0xFFFFFFFF) as u32,
115 (w1 >> 32) as u32,
116 (w2 & 0xFFFFFFFF) as u32,
117 (w2 >> 32) as u32,
118 (w3 & 0xFFFFFFFF) as u32,
119 (w3 >> 32) as u32,
120 ]
121}
122
123#[inline(always)]
124pub const fn add_raw(a: &[u32; 8], b: &[u32; 8]) -> ([u32; 8], bool) {
125 let mut sum = [0; 8];
126 let mut carry = false;
127 let mut i = 7;
128 loop {
129 let (t_sum, c) = add_u32(a[i], b[i], carry as u32);
130 sum[i] = t_sum;
131 carry = c;
132 if i == 0 {
133 break;
134 }
135 i -= 1;
136 }
137 (sum, carry)
138}
139
140#[inline(always)]
141pub const fn sub_raw(a: &[u32; 8], b: &[u32; 8]) -> ([u32; 8], bool) {
142 let mut r = [0; 8];
143 let mut borrow = false;
144 let mut j = 0;
145 while j < 8 {
146 let i = 7 - j;
147 let (diff, bor) = sub_u32(a[i], b[i], borrow);
148 r[i] = diff;
149 borrow = bor;
150 j += 1;
151 }
152 (r, borrow)
153}
154
155#[inline(always)]
156pub const fn sub_u32(a: u32, b: u32, borrow: bool) -> (u32, bool) {
157 let (a, b1) = a.overflowing_sub(borrow as u32);
158 let (res, b2) = a.overflowing_sub(b);
159 (res, b1 || b2)
160}
161
162#[inline(always)]
163pub const fn mul_u32(a: u32, b: u32) -> (u64, u64) {
164 let uv = (a as u64) * (b as u64);
165 let u = uv >> 32;
166 let v = uv & 0xffff_ffff;
167 (u, v)
168}
169
170#[inline(always)]
171pub const fn mul_raw(a: &[u32; 8], b: &[u32; 8]) -> [u32; 16] {
172 let mut local: u64 = 0;
173 let mut carry: u64 = 0;
174 let mut ret: [u32; 16] = [0; 16];
175 let mut ret_idx = 0;
176 while ret_idx < 15 {
177 let index = 15 - ret_idx;
178 let mut a_idx = 0;
179 while a_idx < 8 {
180 if a_idx > ret_idx {
181 break;
182 }
183 let b_idx = ret_idx - a_idx;
184 if b_idx < 8 {
185 let (hi, lo) = mul_u32(a[7 - a_idx], b[7 - b_idx]);
186 local += lo;
187 carry += hi;
188 }
189 a_idx += 1;
190 }
191 carry += local >> 32;
192 local &= 0xffff_ffff;
193 ret[index] = local as u32;
194 local = carry;
195 carry = 0;
196 ret_idx += 1;
197 }
198 ret[0] = local as u32;
199 ret
200}
201