gm_rs/sm2/
key.rs

1use num_bigint::BigUint;
2use num_traits::{FromPrimitive, One, Zero};
3
4use crate::sm2::error::{Sm2Error, Sm2Result};
5use crate::sm2::p256_ecc::{Point, P256C_PARAMS};
6use crate::sm2::util::{compute_za, kdf, random_uint, DEFAULT_ID};
7use crate::sm2::{p256_ecc, FeOperation};
8use crate::sm3::sm3_hash;
9
10#[derive(Debug, Clone, Copy)]
11pub struct Sm2PublicKey {
12    value: Point,
13    compress_modle: CompressModle,
14}
15
16impl Sm2PublicKey {
17    pub fn is_valid(&self) -> bool {
18        self.value.is_valid()
19    }
20
21    pub fn encrypt(&self, msg: &[u8]) -> Sm2Result<Vec<u8>> {
22        loop {
23            let klen = msg.len();
24            let k = random_uint();
25            let c1_p = p256_ecc::g_mul(&k);
26            let c1_p = c1_p.to_affine(); // 根据加密算法,z坐标会被丢弃,为保证解密还原回来的坐标在曲线上,则必须转换坐标系到 affine 坐标系
27
28            let s_p = p256_ecc::scalar_mul(&P256C_PARAMS.h, &self.value);
29            if s_p.is_zero() {
30                return Err(Sm2Error::ZeroPoint);
31            }
32
33            let c2_p = p256_ecc::scalar_mul(&k, &self.value).to_affine();
34            let x2_bytes = c2_p.x.to_bytes_be();
35            let y2_bytes = c2_p.y.to_bytes_be();
36            let mut c2_append = vec![];
37            c2_append.extend_from_slice(&x2_bytes);
38            c2_append.extend_from_slice(&y2_bytes);
39
40            let t = kdf(&c2_append[..], klen);
41            let mut flag = true;
42            for elem in &t {
43                if elem != &0 {
44                    flag = false;
45                    break;
46                }
47            }
48            if !flag {
49                let c2 = BigUint::from_bytes_be(msg) ^ BigUint::from_bytes_be(&t[..]);
50                let mut c3_append: Vec<u8> = vec![];
51                c3_append.extend_from_slice(&x2_bytes);
52                c3_append.extend_from_slice(msg);
53                c3_append.extend_from_slice(&y2_bytes);
54                let c3 = sm3_hash(&c3_append);
55
56                let mut c: Vec<u8> = vec![];
57                c.extend_from_slice(&c1_p.to_byte(self.compress_modle));
58                c.extend_from_slice(&c2.to_bytes_be());
59                c.extend_from_slice(&c3);
60                return Ok(c);
61            }
62        }
63    }
64
65    pub fn verify(&self, id: Option<&'static str>, msg: &[u8], sig: &[u8]) -> Sm2Result<()> {
66        let id = match id {
67            None => DEFAULT_ID,
68            Some(u_id) => u_id,
69        };
70        let mut digest = compute_za(id, self)?;
71        digest = sm3_hash(&[digest.to_vec(), msg.to_vec()].concat());
72        self.verify_raw(&digest[..], self, sig)
73    }
74
75    fn verify_raw(&self, digest: &[u8], pk: &Sm2PublicKey, sig: &[u8]) -> Sm2Result<()> {
76        if digest.len() != 32 {
77            return Err(Sm2Error::InvalidDigestLen);
78        }
79        let n = &P256C_PARAMS.n;
80        let r = &BigUint::from_bytes_be(&sig[..32]);
81        let s = &BigUint::from_bytes_be(&sig[32..]);
82        if r.is_zero() || s.is_zero() {
83            return Err(Sm2Error::ZeroSig);
84        }
85
86        if r >= n || s >= n {
87            return Err(Sm2Error::InvalidDigest);
88        }
89
90        let t = s.mod_add(r, n);
91        if t.is_zero() {
92            return Err(Sm2Error::InvalidDigest);
93        }
94
95        let s_g = p256_ecc::g_mul(&s);
96        let t_p = p256_ecc::scalar_mul(&t, &pk.value());
97
98        let p = s_g.add(&t_p).to_affine();
99        let x1 = BigUint::from_bytes_be(&p.x.to_bytes_be());
100        let e = BigUint::from_bytes_be(digest);
101        let r1 = x1.mod_add(&e, n);
102        return if &r1 == r {
103            Ok(())
104        } else {
105            Err(Sm2Error::InvalidDigest)
106        };
107    }
108
109    pub fn to_str_hex(&self) -> String {
110        format!(
111            "{}{}",
112            self.value.x.to_str_radix(16),
113            self.value.y.to_str_radix(16)
114        )
115    }
116    pub fn value(&self) -> &Point {
117        &self.value
118    }
119}
120
121#[derive(Debug, Clone)]
122pub struct Sm2PrivateKey {
123    pub d: BigUint,
124    pub compress_modle: CompressModle,
125    pub public_key: Sm2PublicKey,
126}
127
128impl Sm2PrivateKey {
129    pub fn sign(&self, id: Option<&'static str>, msg: &[u8]) -> Sm2Result<Vec<u8>> {
130        let id = match id {
131            None => DEFAULT_ID,
132            Some(u_id) => u_id,
133        };
134        let mut digest = compute_za(id, &self.public_key)?;
135        digest = sm3_hash(&[digest.to_vec(), msg.to_vec()].concat());
136        self.sign_raw(&digest[..], &self.d)
137    }
138
139    fn sign_raw(&self, digest: &[u8], sk: &BigUint) -> Sm2Result<Vec<u8>> {
140        if digest.len() != 32 {
141            return Err(Sm2Error::InvalidDigestLen);
142        }
143        let e = BigUint::from_bytes_be(&digest);
144        let n = &P256C_PARAMS.n;
145        loop {
146            let k = random_uint();
147            let p_x = p256_ecc::g_mul(&k).to_affine();
148            let x1 = BigUint::from_bytes_be(&p_x.x.to_bytes_be());
149            let r = e.mod_add(&x1, n);
150            if r.is_zero() || &r + &k == *n {
151                continue;
152            }
153
154            let s1 = &(BigUint::one() + sk).modpow(&(n - BigUint::from_u32(2).unwrap()), n);
155
156            let s2_1 = r.mod_mul(&sk, n);
157            let s2 = k.mod_sub(&s2_1, n);
158
159            let s = s1.mod_mul(&s2, n);
160
161            if s.is_zero() {
162                return Err(Sm2Error::ZeroSig);
163            }
164            let mut sig: Vec<u8> = vec![];
165            sig.extend_from_slice(&r.to_bytes_be());
166            sig.extend_from_slice(&s.to_bytes_be());
167            return Ok(sig);
168        }
169    }
170
171    pub fn decrypt(&self, ciphertext: &[u8]) -> Sm2Result<Vec<u8>> {
172        let c1_end_index = match self.compress_modle {
173            CompressModle::Compressed => 33,
174            CompressModle::Uncompressed | CompressModle::Mixed => 65,
175        };
176
177        let c1_bytes = &ciphertext[0..c1_end_index];
178        let c2_bytes = &ciphertext[c1_end_index..(ciphertext.len() - 32)];
179        let c3_bytes = &ciphertext[(ciphertext.len() - 32)..];
180
181        let kelen = c2_bytes.len();
182        let c1_point = Point::from_byte(c1_bytes, self.compress_modle)?;
183        if !c1_point.to_affine().is_valid_affine() {
184            return Err(Sm2Error::CheckPointErr);
185        }
186
187        let s_point = p256_ecc::scalar_mul(&P256C_PARAMS.h, &c1_point);
188        if s_point.is_zero() {
189            return Err(Sm2Error::ZeroPoint);
190        }
191
192        let c2_point = p256_ecc::scalar_mul(&self.d, &c1_point).to_affine();
193        let x2_bytes = c2_point.x.to_bytes_be();
194        let y2_bytes = c2_point.y.to_bytes_be();
195        let mut prepend: Vec<u8> = vec![];
196        prepend.extend_from_slice(&x2_bytes);
197        prepend.extend_from_slice(&y2_bytes);
198        let t = kdf(&prepend, kelen);
199        let mut flag = true;
200        for elem in &t {
201            if elem != &0 {
202                flag = false;
203                break;
204            }
205        }
206        if flag {
207            return Err(Sm2Error::ZeroData);
208        }
209
210        let m = BigUint::from_bytes_be(c2_bytes) ^ BigUint::from_bytes_be(&t);
211        let mut prepend: Vec<u8> = vec![];
212        prepend.extend_from_slice(&x2_bytes);
213        prepend.extend_from_slice(&m.to_bytes_be());
214        prepend.extend_from_slice(&y2_bytes);
215
216        let u = sm3_hash(&prepend);
217        if u != c3_bytes {
218            return Err(Sm2Error::HashNotEqual);
219        }
220        Ok(m.to_bytes_be())
221    }
222}
223
224#[derive(Debug, Clone, Copy)]
225pub enum CompressModle {
226    Compressed,
227    Uncompressed,
228    Mixed,
229}
230
231/// generate key pair
232pub fn gen_keypair(compress_modle: CompressModle) -> Sm2Result<(Sm2PublicKey, Sm2PrivateKey)> {
233    let d = random_uint();
234    let pk = public_from_private(&d, compress_modle)?;
235    let sk = Sm2PrivateKey {
236        d,
237        compress_modle,
238        public_key: pk,
239    };
240    Ok((pk, sk))
241}
242
243fn public_from_private(sk: &BigUint, compress_modle: CompressModle) -> Sm2Result<Sm2PublicKey> {
244    let p = p256_ecc::g_mul(&sk);
245    if p.is_valid() {
246        Ok(Sm2PublicKey {
247            value: p,
248            compress_modle,
249        })
250    } else {
251        Err(Sm2Error::InvalidPublic)
252    }
253}