1use byteorder::{BigEndian, WriteBytesExt};
2use gm_sm3::sm3_hash;
3
4use crate::error::{Sm2Error, Sm2Result};
5use crate::fields::FieldModOperation;
6use crate::fields::fp64::{from_mont, SM2_G_X, SM2_G_Y, SM2_MODP_MONT_A, SM2_MODP_MONT_B};
7use crate::p256_ecc::Point;
8
9pub(crate) const DEFAULT_ID: &'static str = "1234567812345678";
10
11
12pub fn compute_za(id: &str, pk: &Point) -> Sm2Result<[u8; 32]> {
13 if !pk.is_valid() {
14 return Err(Sm2Error::InvalidPublic);
15 }
16 let mut prepend: Vec<u8> = Vec::new();
17 if id.len() * 8 > 65535 {
18 return Err(Sm2Error::IdTooLong);
19 }
20 prepend
21 .write_u16::<BigEndian>((id.len() * 8) as u16)
22 .unwrap();
23 for c in id.bytes() {
24 prepend.push(c);
25 }
26
27 prepend.extend_from_slice(&from_mont(&SM2_MODP_MONT_A).to_byte_be());
28 prepend.extend_from_slice(&from_mont(&SM2_MODP_MONT_B).to_byte_be());
29 prepend.extend_from_slice(&SM2_G_X.to_byte_be());
30 prepend.extend_from_slice(&SM2_G_Y.to_byte_be());
31
32 let pk_affine = pk.to_affine_point();
33 prepend.extend_from_slice(&from_mont(&pk_affine.x).to_byte_be());
34 prepend.extend_from_slice(&from_mont(&pk_affine.y).to_byte_be());
35
36 Ok(sm3_hash(&prepend))
37}
38
39pub fn xor_bytes(a: &[u8], b: &[u8]) -> Vec<u8> {
40 assert_eq!(a.len(), b.len());
42 let mut result = Vec::with_capacity(a.len());
43 for i in 0..a.len() {
44 result.push(a[i] ^ b[i]);
45 }
46 result
47}
48
49#[inline]
50pub fn kdf(z: &[u8], klen: usize) -> Vec<u8> {
51 let mut ct = 0x00000001u32;
52 let bound = ((klen as f64) / 32.0).ceil() as u32;
53 let mut h_a = Vec::new();
54 for _i in 1..bound {
55 let mut prepend = Vec::new();
56 prepend.extend_from_slice(z);
57 prepend.extend_from_slice(&ct.to_be_bytes());
58
59 let h_a_i = sm3_hash(&prepend[..]);
60 h_a.extend_from_slice(&h_a_i);
61 ct += 1;
62 }
63
64 let mut prepend = Vec::new();
65 prepend.extend_from_slice(z);
66 prepend.extend_from_slice(&ct.to_be_bytes());
67
68 let last = sm3_hash(&prepend[..]);
69 if klen % 32 == 0 {
70 h_a.extend_from_slice(&last);
71 } else {
72 h_a.extend_from_slice(&last[0..(klen % 32)]);
73 }
74 h_a
75}
76
77#[inline(always)]
78pub const fn add_raw(a: &[u32; 8], b: &[u32; 8]) -> ([u32; 8], bool) {
79 let mut sum = [0; 8];
80 let mut carry = false;
81 let mut i = 7;
82 loop {
83 let (t_sum, c) = {
84 let (m, c1) = a[i].overflowing_add(b[i]);
85 let (r, c2) = m.overflowing_add(carry as u32);
86 (r & 0xffff_ffff, c1 || c2)
87 };
88 sum[i] = t_sum;
89 carry = c;
90 if i == 0 {
91 break;
92 }
93 i -= 1;
94 }
95 (sum, carry)
96}
97
98#[inline(always)]
99pub const fn add_raw_u64(a: &[u64; 4], b: &[u64; 4]) -> ([u64; 4], bool) {
100 let mut sum = [0; 4];
101 let mut carry = false;
102 let mut i = 3;
103 loop {
104 let (t_sum, c) = {
105 let (m, c1) = a[i].overflowing_add(b[i]);
106 let (r, c2) = m.overflowing_add(carry as u64);
107 (r & 0xffff_ffff_ffff_ffff, c1 || c2)
108 };
109 sum[i] = t_sum;
110 carry = c;
111 if i == 0 {
112 break;
113 }
114 i -= 1;
115 }
116 (sum, carry)
117}
118
119#[inline(always)]
120pub const fn sub_raw_u64(a: &[u64; 4], b: &[u64; 4]) -> ([u64; 4], bool) {
121 let mut r = [0; 4];
122 let mut borrow = false;
123 let mut j = 0;
124 loop {
125 let i = 3 - j;
126 let (diff, bor) = {
127 let (a, b1) = a[i].overflowing_sub(borrow as u64);
128 let (res, b2) = a.overflowing_sub(b[i]);
129 (res, b1 || b2)
130 };
131 r[i] = diff;
132 borrow = bor;
133 if j == 3 {
134 break;
135 }
136 j += 1;
137 }
138 (r, borrow)
139}
140
141#[inline(always)]
142pub const fn sub_raw(a: &[u32; 8], b: &[u32; 8]) -> ([u32; 8], bool) {
143 let mut r = [0; 8];
144 let mut borrow = false;
145 let mut j = 0;
146 loop {
147 let i = 7 - j;
148 let (diff, bor) = {
149 let (a, b1) = a[i].overflowing_sub(borrow as u32);
150 let (res, b2) = a.overflowing_sub(b[i]);
151 (res, b1 || b2)
152 };
153 r[i] = diff;
154 borrow = bor;
155 if j == 7 {
156 break;
157 }
158 j += 1;
159 }
160 (r, borrow)
161}
162
163#[inline(always)]
164pub const fn mul_raw(a: &[u32; 8], b: &[u32; 8]) -> [u32; 16] {
165 let mut local: u64 = 0;
166 let mut carry: u64 = 0;
167 let mut ret: [u32; 16] = [0; 16];
168 let mut ret_idx = 0;
169 while ret_idx < 15 {
170 let index = 15 - ret_idx;
171 let mut a_idx = 0;
172 while a_idx < 8 {
173 if a_idx > ret_idx {
174 break;
175 }
176 let b_idx = ret_idx - a_idx;
177 if b_idx < 8 {
178 let (hi, lo) = {
179 let uv = (a[7 - a_idx] as u64) * (b[7 - b_idx] as u64);
180 let u = uv >> 32;
181 let v = uv & 0xffff_ffff;
182 (u, v)
183 };
184 local += lo;
185 carry += hi;
186 }
187 a_idx += 1;
188 }
189 carry += local >> 32;
190 local &= 0xffff_ffff;
191 ret[index] = local as u32;
192 local = carry;
193 carry = 0;
194 ret_idx += 1;
195 }
196 ret[0] = local as u32;
197 ret
198}
199
200#[inline(always)]
201pub const fn mul_raw_u64(a: &[u64; 4], b: &[u64; 4]) -> [u64; 8] {
202 let mut local: u128 = 0;
203 let mut carry: u128 = 0;
204 let mut ret: [u64; 8] = [0; 8];
205 let mut ret_idx = 0;
206 while ret_idx < 7 {
207 let index = 7 - ret_idx;
208 let mut a_idx = 0;
209 while a_idx < 4 {
210 if a_idx > ret_idx {
211 break;
212 }
213 let b_idx = ret_idx - a_idx;
214 if b_idx < 4 {
215 let (hi, lo) = {
216 let uv = (a[3 - a_idx] as u128) * (b[3 - b_idx] as u128);
217 let u = uv >> 64;
218 let v = uv & 0xffff_ffff_ffff_ffff;
219 (u, v)
220 };
221 local += lo;
222 carry += hi;
223 }
224 a_idx += 1;
225 }
226 carry += local >> 64;
227 local &= 0xffff_ffff_ffff_ffff;
228 ret[index] = local as u64;
229 local = carry;
230 carry = 0;
231 ret_idx += 1;
232 }
233 ret[0] = local as u64;
234 ret
235}
236
237#[cfg(test)]
238mod test_operation {
239 use num_bigint::BigUint;
240 use num_traits::Num;
241
242 use crate::util::{add_raw_u64, mul_raw_u64, sub_raw_u64};
243
244 #[test]
245 fn test_raw_add_u64() {
246 let a: [u64; 4] = [
247 0xF9B7213BAF82D65B,
248 0xEE265948D19C17AB,
249 0xD2AAB97FD34EC120,
250 0x3722755292130B08,
251 ];
252
253 let b: [u64; 4] = [
254 0x54806C11D8806141,
255 0xF1DD2C190F5E93C4,
256 0x597B6027B441A01F,
257 0x85AEF3D078640C98,
258 ];
259
260 let a1 = BigUint::from_str_radix(
261 "F9B7213BAF82D65BEE265948D19C17ABD2AAB97FD34EC1203722755292130B08",
262 16,
263 )
264 .unwrap();
265 let b1 = BigUint::from_str_radix(
266 "54806C11D8806141F1DD2C190F5E93C4597B6027B441A01F85AEF3D078640C98",
267 16,
268 )
269 .unwrap();
270
271 let (r, _c) = add_raw_u64(&a, &b);
272 println!("sum r={:?}", r);
273
274 let mut sum = (&a1 + &b1).to_u64_digits();
275 sum.reverse();
276 println!("sum r={:?}", &sum[1..]);
277
278 let (r, _c) = sub_raw_u64(&a, &b);
279 println!("sub r={:?}", r);
280
281 let mut sub = (&a1 - &b1).to_u64_digits();
282 sub.reverse();
283 println!("sub r={:?}", sub.as_slice());
284
285 let r = mul_raw_u64(&a, &b);
286 println!("mul r={:?}", r);
287
288 let mut mul = (&a1 * &b1).to_u64_digits();
289 mul.reverse();
290 println!("mul r={:?}", mul.as_slice());
291 }
292}