Skip to main content

csv_rs/crypto/
sm.rs

1// Copyright (C) Hygon Info Technologies Ltd.
2//
3// SPDX-License-Identifier: Apache-2.0
4
5//! Interfaces for GuoMi that is not supported on rust-openssl.
6
7use crate::crypto::key::{ecc, group};
8use libc::*;
9use openssl::nid;
10use openssl_sys::*;
11use std::{
12    io::{Error, ErrorKind, Result},
13    ptr,
14};
15
16#[cfg(ossl111)]
17pub const EVP_PKEY_CTRL_SET1_ID: c_int = EVP_PKEY_ALG_CTRL + 11;
18
19const ECDH_KDF_MAX: size_t = 1 << 30;
20
21extern "C" {
22    #[cfg(ossl111)]
23    pub fn EVP_PKEY_set_alias_type(pkey: *mut EVP_PKEY, ttype: c_int) -> c_int;
24
25    pub fn EVP_MD_CTX_set_pkey_ctx(ctx: *mut EVP_MD_CTX, sctx: *mut EVP_PKEY_CTX) -> c_int;
26
27    #[cfg(ossl300)]
28    pub fn EVP_PKEY_CTX_set1_id(ctx: *mut EVP_PKEY_CTX, id: *const c_void, len: c_int) -> c_int;
29
30    /// openssl libc functions
31    pub fn EC_GROUP_get0_order(group: *const EC_GROUP) -> *const BIGNUM;
32    pub fn BN_CTX_start(ctx: *mut BN_CTX) -> c_void;
33    pub fn BN_CTX_get(ctx: *mut BN_CTX) -> *mut BIGNUM;
34    pub fn BN_priv_rand_range(r: *mut BIGNUM, range: *const BIGNUM) -> c_int;
35    pub fn BN_bn2binpad(a: *const BIGNUM, to: *mut c_uchar, tolen: c_int) -> c_int;
36}
37
38#[cfg(ossl111)]
39#[allow(non_snake_case)]
40pub unsafe fn EVP_PKEY_CTX_set1_id(
41    ctx: *mut EVP_PKEY_CTX,
42    id: *const c_void,
43    id_len: c_int,
44) -> c_int {
45    EVP_PKEY_CTX_ctrl(
46        ctx,
47        -1,
48        -1,
49        EVP_PKEY_CTRL_SET1_ID,
50        id_len,
51        id as *mut c_void,
52    )
53}
54
55#[derive(Debug, Copy, Clone, PartialEq, Eq)]
56pub struct SM2 {}
57
58impl SM2 {
59    /// use SM2 algorithm to verify a msg's signature
60    pub fn verify(ecc_pubkey: ecc::PubKey, sig: &[u8], id: &[u8], msg: &[u8]) -> Result<bool> {
61        let mut verify_result = false;
62        let pubkey_size = ecc_pubkey.g.size()?;
63
64        unsafe {
65            let eckey = EC_KEY_new_by_curve_name(NID_sm2);
66            let pub_x = &ecc_pubkey.x[..pubkey_size]
67                .iter()
68                .rev()
69                .cloned()
70                .collect::<Vec<_>>();
71            let pub_y = &ecc_pubkey.y[..pubkey_size]
72                .iter()
73                .rev()
74                .cloned()
75                .collect::<Vec<_>>();
76            let bn_x = BN_bin2bn(
77                pub_x.as_ptr() as *const c_uchar,
78                pubkey_size as c_int,
79                ptr::null_mut(),
80            );
81            let bn_y = BN_bin2bn(
82                pub_y.as_ptr() as *const c_uchar,
83                pubkey_size as c_int,
84                ptr::null_mut(),
85            );
86            EC_KEY_set_public_key_affine_coordinates(eckey, bn_x, bn_y);
87            let pkey = EVP_PKEY_new();
88            if EVP_PKEY_assign(pkey, EVP_PKEY_SM2, eckey as *mut c_void) <= 0 {
89                EVP_PKEY_free(pkey);
90                return Err(Error::new(
91                    ErrorKind::InvalidData,
92                    "EVP_KEY_set1_EC_KEY failed",
93                ));
94            }
95
96            #[cfg(ossl111)]
97            EVP_PKEY_set_alias_type(pkey, EVP_PKEY_SM2);
98
99            let mctx: *mut EVP_MD_CTX = EVP_MD_CTX_new();
100            let pctx = EVP_PKEY_CTX_new(pkey, ptr::null_mut());
101            EVP_PKEY_CTX_set1_id(pctx, id.as_ptr() as *mut c_void, id.len() as c_int);
102            EVP_MD_CTX_set_pkey_ctx(mctx, pctx);
103            EVP_DigestVerifyInit(mctx, ptr::null_mut(), EVP_sm3(), ptr::null_mut(), pkey);
104            EVP_DigestVerifyUpdate(mctx, msg.as_ptr() as *mut c_void, msg.len());
105            if EVP_DigestVerifyFinal(mctx, sig.as_ptr(), sig.len()) == 1 {
106                verify_result = true;
107            }
108            EVP_PKEY_CTX_free(pctx);
109            EVP_PKEY_free(pkey);
110            EVP_MD_CTX_free(mctx);
111        }
112        Ok(verify_result)
113    }
114
115    pub fn generate(group: group::Group) -> Result<(ecc::PubKey, *mut EC_KEY)> {
116        let value: nid::Nid = group.try_into()?;
117        let mut qx: Vec<u8> = vec![0; 32];
118        let mut qy: Vec<u8> = vec![0; 32];
119        let eckey: *mut EC_KEY = unsafe { EC_KEY_new() };
120        unsafe {
121            let c = BN_CTX_new();
122            if eckey.is_null() || c.is_null() {
123                return Err(ErrorKind::InvalidData.into());
124            }
125            let x = BN_new();
126            let y = BN_new();
127            if x.is_null() || y.is_null() {
128                return Err(ErrorKind::InvalidData.into());
129            }
130            let g: *mut EC_GROUP = EC_GROUP_new_by_curve_name(value.as_raw());
131            if EC_KEY_set_group(eckey, g) == 0 {
132                EC_KEY_free(eckey);
133                return Err(ErrorKind::InvalidData.into());
134            }
135
136            if 0 == EC_KEY_generate_key(eckey) {
137                EC_KEY_free(eckey);
138                return Err(ErrorKind::InvalidData.into());
139            }
140
141            if 0 == EC_POINT_get_affine_coordinates(g, EC_KEY_get0_public_key(eckey), x, y, c)
142                || BN_bn2binpad(x, qx.as_mut_ptr() as *mut c_uchar, 32) < 0
143                || BN_bn2binpad(y, qy.as_mut_ptr() as *mut c_uchar, 32) < 0
144            {
145                return Err(ErrorKind::InvalidData.into());
146            }
147        }
148        qx.reverse();
149        qy.reverse();
150        qx.resize(72, 0);
151        qy.resize(72, 0);
152        let pubkey = ecc::PubKey {
153            g: group,
154            x: qx.try_into().unwrap(),
155            y: qy.try_into().unwrap(),
156        };
157
158        Ok((pubkey, eckey))
159    }
160
161    pub fn sign(pri_key: *mut EC_KEY, id: &[u8], data: &[u8]) -> Result<Vec<u8>> {
162        let r = unsafe {
163            let pkey = EVP_PKEY_new();
164            if pkey.is_null() {
165                return Err(ErrorKind::InvalidData.into());
166            }
167            if EVP_PKEY_assign(pkey, EVP_PKEY_SM2, pri_key as *mut c_void) <= 0 {
168                EVP_PKEY_free(pkey);
169                return Err(Error::new(ErrorKind::InvalidData, "EVP_PKEY_assign failed"));
170            }
171
172            #[cfg(ossl111)]
173            EVP_PKEY_set_alias_type(pkey, EVP_PKEY_SM2);
174
175            let mctx: *mut EVP_MD_CTX = EVP_MD_CTX_new();
176            let pctx = EVP_PKEY_CTX_new(pkey, ptr::null_mut());
177            EVP_PKEY_CTX_set1_id(pctx, id.as_ptr() as *mut c_void, id.len() as c_int);
178            EVP_MD_CTX_set_pkey_ctx(mctx, pctx);
179
180            let sig_len: *mut size_t = Box::into_raw(Box::new(0));
181            EVP_DigestSignInit(mctx, ptr::null_mut(), EVP_sm3(), ptr::null_mut(), pkey);
182            EVP_DigestSign(mctx, ptr::null_mut(), sig_len, data.as_ptr(), data.len());
183            let mut sig = vec![0; *sig_len].into_boxed_slice();
184            EVP_DigestSign(mctx, sig.as_mut_ptr(), sig_len, data.as_ptr(), data.len());
185            EVP_MD_CTX_free(mctx);
186            EVP_PKEY_CTX_free(pctx);
187            EVP_PKEY_free(pkey);
188
189            let mut result_vec = sig.to_vec();
190            result_vec.truncate(*sig_len);
191            result_vec
192        };
193        Ok(r)
194    }
195
196    /// KDF function of ecdh
197    pub fn ecdh_kdf_x9_63(out: &mut [u8], input: &[u8]) -> Result<()> {
198        let mut outlen = out.len();
199        let mut buf = &mut out[..];
200        let inlen = input.len();
201
202        if outlen > ECDH_KDF_MAX || inlen > ECDH_KDF_MAX {
203            return Err(ErrorKind::InvalidData.into());
204        }
205        unsafe {
206            let md = EVP_sm3();
207            let mctx = EVP_MD_CTX_new();
208            if mctx.is_null() {
209                return Err(ErrorKind::InvalidData.into());
210            }
211            let mdlen = EVP_MD_size(md);
212            let mdlen = mdlen as usize;
213            let mut counter: u32 = 1;
214            while outlen > 0 {
215                let counter_be = counter.to_be_bytes();
216                if 0 == EVP_DigestInit_ex(mctx, md, ptr::null_mut()) {
217                    return Err(ErrorKind::InvalidData.into());
218                }
219                if 0 == EVP_DigestUpdate(mctx, input.as_ptr() as *const c_void, inlen) {
220                    return Err(ErrorKind::InvalidData.into());
221                }
222                if 0 == EVP_DigestUpdate(
223                    mctx,
224                    counter_be.as_ptr() as *const c_void,
225                    counter_be.len(),
226                ) {
227                    return Err(ErrorKind::InvalidData.into());
228                }
229                if 0 == EVP_DigestUpdate(mctx, ptr::null_mut(), 0) {
230                    return Err(ErrorKind::InvalidData.into());
231                }
232                if outlen >= mdlen {
233                    if 0 == EVP_DigestFinal(mctx, buf.as_ptr() as *mut c_uchar, ptr::null_mut()) {
234                        return Err(ErrorKind::InvalidData.into());
235                    }
236                    outlen -= mdlen;
237                    if outlen == 0 {
238                        break;
239                    }
240                    buf = &mut buf[mdlen..];
241                } else {
242                    let mtmp: Vec<u8> = vec![0; 64];
243                    if 0 == EVP_DigestFinal(mctx, mtmp.as_ptr() as *mut c_uchar, ptr::null_mut()) {
244                        return Err(ErrorKind::InvalidData.into());
245                    }
246                    buf.copy_from_slice(&mtmp[..outlen]);
247                    break;
248                }
249                counter += 1;
250            }
251
252            EVP_MD_CTX_free(mctx);
253        }
254        Ok(())
255    }
256
257    /// use SM2 algorithm to encrypt data with pubKey
258    pub fn encrypt(data: &[u8], pub_key: ecc::PubKey) -> Result<Vec<u8>> {
259        let pubkey_size = pub_key.g.size()?;
260        let mut ciphertext_buf: Vec<u8> = Vec::new();
261        unsafe {
262            let eckey = EC_KEY_new_by_curve_name(NID_sm2);
263            let pub_x = &pub_key.x[..pubkey_size]
264                .iter()
265                .rev()
266                .cloned()
267                .collect::<Vec<_>>();
268            let pub_y = &pub_key.y[..pubkey_size]
269                .iter()
270                .rev()
271                .cloned()
272                .collect::<Vec<_>>();
273            let bn_x = BN_bin2bn(
274                pub_x.as_ptr() as *const c_uchar,
275                pubkey_size as c_int,
276                ptr::null_mut(),
277            );
278            let bn_y = BN_bin2bn(
279                pub_y.as_ptr() as *const c_uchar,
280                pubkey_size as c_int,
281                ptr::null_mut(),
282            );
283            EC_KEY_set_public_key_affine_coordinates(eckey, bn_x, bn_y);
284
285            let c3_size: size_t = 32;
286            let group = EC_KEY_get0_group(eckey);
287            let order = EC_GROUP_get0_order(group);
288            let pub_key = EC_KEY_get0_public_key(eckey);
289            let field_size: size_t = 32;
290
291            let k_g = EC_POINT_new(group);
292            let k_p = EC_POINT_new(group);
293            let ctx = BN_CTX_new();
294            if k_g.is_null() || k_p.is_null() || ctx.is_null() {
295                return Err(ErrorKind::InvalidData.into());
296            }
297
298            BN_CTX_start(ctx);
299            let k = BN_CTX_get(ctx);
300            let x1 = BN_CTX_get(ctx);
301            let x2 = BN_CTX_get(ctx);
302            let y1 = BN_CTX_get(ctx);
303            let y2 = BN_CTX_get(ctx);
304
305            if y2.is_null() {
306                return Err(ErrorKind::InvalidData.into());
307            }
308
309            let mut x2_u8: Vec<u8> = vec![0; 32];
310            let mut y2_u8: Vec<u8> = vec![0; 32];
311            let mut c3: Vec<u8> = vec![0; c3_size];
312
313            if BN_priv_rand_range(k, order) == 0 {
314                return Err(ErrorKind::InvalidData.into());
315            }
316
317            if EC_POINT_mul(group, k_g, k, ptr::null_mut(), ptr::null_mut(), ctx) == 0
318                || 0 == EC_POINT_get_affine_coordinates(group, k_g, x1, y1, ctx)
319                || 0 == EC_POINT_mul(group, k_p, ptr::null_mut(), pub_key, k, ctx)
320                || 0 == EC_POINT_get_affine_coordinates(group, k_p, x2, y2, ctx)
321            {
322                return Err(ErrorKind::InvalidData.into());
323            }
324
325            if BN_bn2binpad(x2, x2_u8.as_mut_ptr() as *mut c_uchar, field_size as i32) < 0
326                || BN_bn2binpad(y2, y2_u8.as_mut_ptr() as *mut c_uchar, field_size as i32) < 0
327            {
328                return Err(ErrorKind::InvalidData.into());
329            }
330
331            let mut msg_mask: Vec<u8> = vec![0; data.len()];
332            let mut x2y2: Vec<u8> = Vec::new();
333            x2y2.extend_from_slice(&x2_u8[..]);
334            x2y2.extend_from_slice(&y2_u8[..]);
335            /* X9.63 with no salt happens to match the KDF used in SM2 */
336            SM2::ecdh_kdf_x9_63(&mut msg_mask[..], &x2y2[..])?;
337            for i in 0..data.len() {
338                msg_mask[i] ^= data[i];
339            }
340
341            let md_ctx: *mut EVP_MD_CTX = EVP_MD_CTX_new();
342            if md_ctx.is_null() {
343                return Err(ErrorKind::InvalidData.into());
344            }
345
346            if EVP_DigestInit(md_ctx, EVP_sm3()) == 0
347                || EVP_DigestUpdate(md_ctx, x2_u8.as_ptr() as *const c_void, field_size) == 0
348                || EVP_DigestUpdate(md_ctx, data.as_ptr() as *const c_void, data.len()) == 0
349                || EVP_DigestUpdate(md_ctx, y2_u8.as_ptr() as *const c_void, field_size) == 0
350                || EVP_DigestFinal(md_ctx, c3.as_mut_ptr() as *mut c_uchar, ptr::null_mut()) == 0
351            {
352                return Err(ErrorKind::InvalidData.into());
353            }
354
355            let mut x1_u8: Vec<u8> = vec![0; 32];
356            let mut y1_u8: Vec<u8> = vec![0; 32];
357
358            BN_bn2bin(x1, x1_u8.as_mut_ptr());
359            BN_bn2bin(y1, y1_u8.as_mut_ptr());
360
361            // ciphertext_len = 1 + 32 + 32 + msg.len() + c3_size;
362            ciphertext_buf.push(4);
363            ciphertext_buf.extend_from_slice(&x1_u8[..]);
364            ciphertext_buf.extend_from_slice(&y1_u8[..]);
365            ciphertext_buf.extend_from_slice(&c3[..]);
366            ciphertext_buf.extend_from_slice(&msg_mask[..]); //C2
367        };
368        Ok(ciphertext_buf)
369    }
370}