1use num_bigint::BigUint;
2use num_traits::{FromPrimitive, One, Zero};
3
4use crate::sm2::error::{Sm2Error, Sm2Result};
5use crate::sm2::p256_ecc::{Point, P256C_PARAMS};
6use crate::sm2::util::{compute_za, kdf, random_uint, DEFAULT_ID};
7use crate::sm2::{p256_ecc, FeOperation};
8use crate::sm3::sm3_hash;
9
10#[derive(Debug, Clone, Copy)]
11pub struct Sm2PublicKey {
12 value: Point,
13 compress_modle: CompressModle,
14}
15
16impl Sm2PublicKey {
17 pub fn is_valid(&self) -> bool {
18 self.value.is_valid()
19 }
20
21 pub fn encrypt(&self, msg: &[u8]) -> Sm2Result<Vec<u8>> {
22 loop {
23 let klen = msg.len();
24 let k = random_uint();
25 let c1_p = p256_ecc::g_mul(&k);
26 let c1_p = c1_p.to_affine(); let s_p = p256_ecc::scalar_mul(&P256C_PARAMS.h, &self.value);
29 if s_p.is_zero() {
30 return Err(Sm2Error::ZeroPoint);
31 }
32
33 let c2_p = p256_ecc::scalar_mul(&k, &self.value).to_affine();
34 let x2_bytes = c2_p.x.to_bytes_be();
35 let y2_bytes = c2_p.y.to_bytes_be();
36 let mut c2_append = vec![];
37 c2_append.extend_from_slice(&x2_bytes);
38 c2_append.extend_from_slice(&y2_bytes);
39
40 let t = kdf(&c2_append[..], klen);
41 let mut flag = true;
42 for elem in &t {
43 if elem != &0 {
44 flag = false;
45 break;
46 }
47 }
48 if !flag {
49 let c2 = BigUint::from_bytes_be(msg) ^ BigUint::from_bytes_be(&t[..]);
50 let mut c3_append: Vec<u8> = vec![];
51 c3_append.extend_from_slice(&x2_bytes);
52 c3_append.extend_from_slice(msg);
53 c3_append.extend_from_slice(&y2_bytes);
54 let c3 = sm3_hash(&c3_append);
55
56 let mut c: Vec<u8> = vec![];
57 c.extend_from_slice(&c1_p.to_byte(self.compress_modle));
58 c.extend_from_slice(&c2.to_bytes_be());
59 c.extend_from_slice(&c3);
60 return Ok(c);
61 }
62 }
63 }
64
65 pub fn verify(&self, id: Option<&'static str>, msg: &[u8], sig: &[u8]) -> Sm2Result<()> {
66 let id = match id {
67 None => DEFAULT_ID,
68 Some(u_id) => u_id,
69 };
70 let mut digest = compute_za(id, self)?;
71 digest = sm3_hash(&[digest.to_vec(), msg.to_vec()].concat());
72 self.verify_raw(&digest[..], self, sig)
73 }
74
75 fn verify_raw(&self, digest: &[u8], pk: &Sm2PublicKey, sig: &[u8]) -> Sm2Result<()> {
76 if digest.len() != 32 {
77 return Err(Sm2Error::InvalidDigestLen);
78 }
79 let n = &P256C_PARAMS.n;
80 let r = &BigUint::from_bytes_be(&sig[..32]);
81 let s = &BigUint::from_bytes_be(&sig[32..]);
82 if r.is_zero() || s.is_zero() {
83 return Err(Sm2Error::ZeroSig);
84 }
85
86 if r >= n || s >= n {
87 return Err(Sm2Error::InvalidDigest);
88 }
89
90 let t = s.mod_add(r, n);
91 if t.is_zero() {
92 return Err(Sm2Error::InvalidDigest);
93 }
94
95 let s_g = p256_ecc::g_mul(&s);
96 let t_p = p256_ecc::scalar_mul(&t, &pk.value());
97
98 let p = s_g.add(&t_p).to_affine();
99 let x1 = BigUint::from_bytes_be(&p.x.to_bytes_be());
100 let e = BigUint::from_bytes_be(digest);
101 let r1 = x1.mod_add(&e, n);
102 return if &r1 == r {
103 Ok(())
104 } else {
105 Err(Sm2Error::InvalidDigest)
106 };
107 }
108
109 pub fn to_str_hex(&self) -> String {
110 format!(
111 "{}{}",
112 self.value.x.to_str_radix(16),
113 self.value.y.to_str_radix(16)
114 )
115 }
116 pub fn value(&self) -> &Point {
117 &self.value
118 }
119}
120
121#[derive(Debug, Clone)]
122pub struct Sm2PrivateKey {
123 pub d: BigUint,
124 pub compress_modle: CompressModle,
125 pub public_key: Sm2PublicKey,
126}
127
128impl Sm2PrivateKey {
129 pub fn sign(&self, id: Option<&'static str>, msg: &[u8]) -> Sm2Result<Vec<u8>> {
130 let id = match id {
131 None => DEFAULT_ID,
132 Some(u_id) => u_id,
133 };
134 let mut digest = compute_za(id, &self.public_key)?;
135 digest = sm3_hash(&[digest.to_vec(), msg.to_vec()].concat());
136 self.sign_raw(&digest[..], &self.d)
137 }
138
139 fn sign_raw(&self, digest: &[u8], sk: &BigUint) -> Sm2Result<Vec<u8>> {
140 if digest.len() != 32 {
141 return Err(Sm2Error::InvalidDigestLen);
142 }
143 let e = BigUint::from_bytes_be(&digest);
144 let n = &P256C_PARAMS.n;
145 loop {
146 let k = random_uint();
147 let p_x = p256_ecc::g_mul(&k).to_affine();
148 let x1 = BigUint::from_bytes_be(&p_x.x.to_bytes_be());
149 let r = e.mod_add(&x1, n);
150 if r.is_zero() || &r + &k == *n {
151 continue;
152 }
153
154 let s1 = &(BigUint::one() + sk).modpow(&(n - BigUint::from_u32(2).unwrap()), n);
155
156 let s2_1 = r.mod_mul(&sk, n);
157 let s2 = k.mod_sub(&s2_1, n);
158
159 let s = s1.mod_mul(&s2, n);
160
161 if s.is_zero() {
162 return Err(Sm2Error::ZeroSig);
163 }
164 let mut sig: Vec<u8> = vec![];
165 sig.extend_from_slice(&r.to_bytes_be());
166 sig.extend_from_slice(&s.to_bytes_be());
167 return Ok(sig);
168 }
169 }
170
171 pub fn decrypt(&self, ciphertext: &[u8]) -> Sm2Result<Vec<u8>> {
172 let c1_end_index = match self.compress_modle {
173 CompressModle::Compressed => 33,
174 CompressModle::Uncompressed | CompressModle::Mixed => 65,
175 };
176
177 let c1_bytes = &ciphertext[0..c1_end_index];
178 let c2_bytes = &ciphertext[c1_end_index..(ciphertext.len() - 32)];
179 let c3_bytes = &ciphertext[(ciphertext.len() - 32)..];
180
181 let kelen = c2_bytes.len();
182 let c1_point = Point::from_byte(c1_bytes, self.compress_modle)?;
183 if !c1_point.to_affine().is_valid_affine() {
184 return Err(Sm2Error::CheckPointErr);
185 }
186
187 let s_point = p256_ecc::scalar_mul(&P256C_PARAMS.h, &c1_point);
188 if s_point.is_zero() {
189 return Err(Sm2Error::ZeroPoint);
190 }
191
192 let c2_point = p256_ecc::scalar_mul(&self.d, &c1_point).to_affine();
193 let x2_bytes = c2_point.x.to_bytes_be();
194 let y2_bytes = c2_point.y.to_bytes_be();
195 let mut prepend: Vec<u8> = vec![];
196 prepend.extend_from_slice(&x2_bytes);
197 prepend.extend_from_slice(&y2_bytes);
198 let t = kdf(&prepend, kelen);
199 let mut flag = true;
200 for elem in &t {
201 if elem != &0 {
202 flag = false;
203 break;
204 }
205 }
206 if flag {
207 return Err(Sm2Error::ZeroData);
208 }
209
210 let m = BigUint::from_bytes_be(c2_bytes) ^ BigUint::from_bytes_be(&t);
211 let mut prepend: Vec<u8> = vec![];
212 prepend.extend_from_slice(&x2_bytes);
213 prepend.extend_from_slice(&m.to_bytes_be());
214 prepend.extend_from_slice(&y2_bytes);
215
216 let u = sm3_hash(&prepend);
217 if u != c3_bytes {
218 return Err(Sm2Error::HashNotEqual);
219 }
220 Ok(m.to_bytes_be())
221 }
222}
223
224#[derive(Debug, Clone, Copy)]
225pub enum CompressModle {
226 Compressed,
227 Uncompressed,
228 Mixed,
229}
230
231pub fn gen_keypair(compress_modle: CompressModle) -> Sm2Result<(Sm2PublicKey, Sm2PrivateKey)> {
233 let d = random_uint();
234 let pk = public_from_private(&d, compress_modle)?;
235 let sk = Sm2PrivateKey {
236 d,
237 compress_modle,
238 public_key: pk,
239 };
240 Ok((pk, sk))
241}
242
243fn public_from_private(sk: &BigUint, compress_modle: CompressModle) -> Sm2Result<Sm2PublicKey> {
244 let p = p256_ecc::g_mul(&sk);
245 if p.is_valid() {
246 Ok(Sm2PublicKey {
247 value: p,
248 compress_modle,
249 })
250 } else {
251 Err(Sm2Error::InvalidPublic)
252 }
253}