gm_sm2/
util.rs

1use byteorder::{BigEndian, WriteBytesExt};
2use gm_sm3::sm3_hash;
3
4use crate::error::{Sm2Error, Sm2Result};
5use crate::fields::FieldModOperation;
6use crate::fields::fp64::{from_mont, SM2_G_X, SM2_G_Y, SM2_MODP_MONT_A, SM2_MODP_MONT_B};
7use crate::p256_ecc::Point;
8
9pub(crate) const DEFAULT_ID: &'static str = "1234567812345678";
10
11
12pub fn compute_za(id: &str, pk: &Point) -> Sm2Result<[u8; 32]> {
13    if !pk.is_valid() {
14        return Err(Sm2Error::InvalidPublic);
15    }
16    let mut prepend: Vec<u8> = Vec::new();
17    if id.len() * 8 > 65535 {
18        return Err(Sm2Error::IdTooLong);
19    }
20    prepend
21        .write_u16::<BigEndian>((id.len() * 8) as u16)
22        .unwrap();
23    for c in id.bytes() {
24        prepend.push(c);
25    }
26
27    prepend.extend_from_slice(&from_mont(&SM2_MODP_MONT_A).to_byte_be());
28    prepend.extend_from_slice(&from_mont(&SM2_MODP_MONT_B).to_byte_be());
29    prepend.extend_from_slice(&SM2_G_X.to_byte_be());
30    prepend.extend_from_slice(&SM2_G_Y.to_byte_be());
31
32    let pk_affine = pk.to_affine_point();
33    prepend.extend_from_slice(&from_mont(&pk_affine.x).to_byte_be());
34    prepend.extend_from_slice(&from_mont(&pk_affine.y).to_byte_be());
35
36    Ok(sm3_hash(&prepend))
37}
38
39pub fn xor_bytes(a: &[u8], b: &[u8]) -> Vec<u8> {
40    // 确保两个向量的长度相同
41    assert_eq!(a.len(), b.len());
42    let mut result = Vec::with_capacity(a.len());
43    for i in 0..a.len() {
44        result.push(a[i] ^ b[i]);
45    }
46    result
47}
48
49#[inline]
50pub fn kdf(z: &[u8], klen: usize) -> Vec<u8> {
51    let mut ct = 0x00000001u32;
52    let bound = ((klen as f64) / 32.0).ceil() as u32;
53    let mut h_a = Vec::new();
54    for _i in 1..bound {
55        let mut prepend = Vec::new();
56        prepend.extend_from_slice(z);
57        prepend.extend_from_slice(&ct.to_be_bytes());
58
59        let h_a_i = sm3_hash(&prepend[..]);
60        h_a.extend_from_slice(&h_a_i);
61        ct += 1;
62    }
63
64    let mut prepend = Vec::new();
65    prepend.extend_from_slice(z);
66    prepend.extend_from_slice(&ct.to_be_bytes());
67
68    let last = sm3_hash(&prepend[..]);
69    if klen % 32 == 0 {
70        h_a.extend_from_slice(&last);
71    } else {
72        h_a.extend_from_slice(&last[0..(klen % 32)]);
73    }
74    h_a
75}
76
77#[inline(always)]
78pub const fn add_raw(a: &[u32; 8], b: &[u32; 8]) -> ([u32; 8], bool) {
79    let mut sum = [0; 8];
80    let mut carry = false;
81    let mut i = 7;
82    loop {
83        let (t_sum, c) = {
84            let (m, c1) = a[i].overflowing_add(b[i]);
85            let (r, c2) = m.overflowing_add(carry as u32);
86            (r & 0xffff_ffff, c1 || c2)
87        };
88        sum[i] = t_sum;
89        carry = c;
90        if i == 0 {
91            break;
92        }
93        i -= 1;
94    }
95    (sum, carry)
96}
97
98#[inline(always)]
99pub const fn add_raw_u64(a: &[u64; 4], b: &[u64; 4]) -> ([u64; 4], bool) {
100    let mut sum = [0; 4];
101    let mut carry = false;
102    let mut i = 3;
103    loop {
104        let (t_sum, c) = {
105            let (m, c1) = a[i].overflowing_add(b[i]);
106            let (r, c2) = m.overflowing_add(carry as u64);
107            (r & 0xffff_ffff_ffff_ffff, c1 || c2)
108        };
109        sum[i] = t_sum;
110        carry = c;
111        if i == 0 {
112            break;
113        }
114        i -= 1;
115    }
116    (sum, carry)
117}
118
119#[inline(always)]
120pub const fn sub_raw_u64(a: &[u64; 4], b: &[u64; 4]) -> ([u64; 4], bool) {
121    let mut r = [0; 4];
122    let mut borrow = false;
123    let mut j = 0;
124    loop {
125        let i = 3 - j;
126        let (diff, bor) = {
127            let (a, b1) = a[i].overflowing_sub(borrow as u64);
128            let (res, b2) = a.overflowing_sub(b[i]);
129            (res, b1 || b2)
130        };
131        r[i] = diff;
132        borrow = bor;
133        if j == 3 {
134            break;
135        }
136        j += 1;
137    }
138    (r, borrow)
139}
140
141#[inline(always)]
142pub const fn sub_raw(a: &[u32; 8], b: &[u32; 8]) -> ([u32; 8], bool) {
143    let mut r = [0; 8];
144    let mut borrow = false;
145    let mut j = 0;
146    loop {
147        let i = 7 - j;
148        let (diff, bor) = {
149            let (a, b1) = a[i].overflowing_sub(borrow as u32);
150            let (res, b2) = a.overflowing_sub(b[i]);
151            (res, b1 || b2)
152        };
153        r[i] = diff;
154        borrow = bor;
155        if j == 7 {
156            break;
157        }
158        j += 1;
159    }
160    (r, borrow)
161}
162
163#[inline(always)]
164pub const fn mul_raw(a: &[u32; 8], b: &[u32; 8]) -> [u32; 16] {
165    let mut local: u64 = 0;
166    let mut carry: u64 = 0;
167    let mut ret: [u32; 16] = [0; 16];
168    let mut ret_idx = 0;
169    while ret_idx < 15 {
170        let index = 15 - ret_idx;
171        let mut a_idx = 0;
172        while a_idx < 8 {
173            if a_idx > ret_idx {
174                break;
175            }
176            let b_idx = ret_idx - a_idx;
177            if b_idx < 8 {
178                let (hi, lo) = {
179                    let uv = (a[7 - a_idx] as u64) * (b[7 - b_idx] as u64);
180                    let u = uv >> 32;
181                    let v = uv & 0xffff_ffff;
182                    (u, v)
183                };
184                local += lo;
185                carry += hi;
186            }
187            a_idx += 1;
188        }
189        carry += local >> 32;
190        local &= 0xffff_ffff;
191        ret[index] = local as u32;
192        local = carry;
193        carry = 0;
194        ret_idx += 1;
195    }
196    ret[0] = local as u32;
197    ret
198}
199
200#[inline(always)]
201pub const fn mul_raw_u64(a: &[u64; 4], b: &[u64; 4]) -> [u64; 8] {
202    let mut local: u128 = 0;
203    let mut carry: u128 = 0;
204    let mut ret: [u64; 8] = [0; 8];
205    let mut ret_idx = 0;
206    while ret_idx < 7 {
207        let index = 7 - ret_idx;
208        let mut a_idx = 0;
209        while a_idx < 4 {
210            if a_idx > ret_idx {
211                break;
212            }
213            let b_idx = ret_idx - a_idx;
214            if b_idx < 4 {
215                let (hi, lo) = {
216                    let uv = (a[3 - a_idx] as u128) * (b[3 - b_idx] as u128);
217                    let u = uv >> 64;
218                    let v = uv & 0xffff_ffff_ffff_ffff;
219                    (u, v)
220                };
221                local += lo;
222                carry += hi;
223            }
224            a_idx += 1;
225        }
226        carry += local >> 64;
227        local &= 0xffff_ffff_ffff_ffff;
228        ret[index] = local as u64;
229        local = carry;
230        carry = 0;
231        ret_idx += 1;
232    }
233    ret[0] = local as u64;
234    ret
235}
236
237#[cfg(test)]
238mod test_operation {
239    use num_bigint::BigUint;
240    use num_traits::Num;
241
242    use crate::util::{add_raw_u64, mul_raw_u64, sub_raw_u64};
243
244    #[test]
245    fn test_raw_add_u64() {
246        let a: [u64; 4] = [
247            0xF9B7213BAF82D65B,
248            0xEE265948D19C17AB,
249            0xD2AAB97FD34EC120,
250            0x3722755292130B08,
251        ];
252
253        let b: [u64; 4] = [
254            0x54806C11D8806141,
255            0xF1DD2C190F5E93C4,
256            0x597B6027B441A01F,
257            0x85AEF3D078640C98,
258        ];
259
260        let a1 = BigUint::from_str_radix(
261            "F9B7213BAF82D65BEE265948D19C17ABD2AAB97FD34EC1203722755292130B08",
262            16,
263        )
264        .unwrap();
265        let b1 = BigUint::from_str_radix(
266            "54806C11D8806141F1DD2C190F5E93C4597B6027B441A01F85AEF3D078640C98",
267            16,
268        )
269        .unwrap();
270
271        let (r, _c) = add_raw_u64(&a, &b);
272        println!("sum r={:?}", r);
273
274        let mut sum = (&a1 + &b1).to_u64_digits();
275        sum.reverse();
276        println!("sum r={:?}", &sum[1..]);
277
278        let (r, _c) = sub_raw_u64(&a, &b);
279        println!("sub r={:?}", r);
280
281        let mut sub = (&a1 - &b1).to_u64_digits();
282        sub.reverse();
283        println!("sub r={:?}", sub.as_slice());
284
285        let r = mul_raw_u64(&a, &b);
286        println!("mul r={:?}", r);
287
288        let mut mul = (&a1 * &b1).to_u64_digits();
289        mul.reverse();
290        println!("mul r={:?}", mul.as_slice());
291    }
292}