1use crate::error::{Sm9Error, Sm9Result};
2use crate::fields::{mod_n_add, mod_n_from_hash, mod_n_inv, mod_n_mul, mod_n_sub, FieldElement};
3use crate::points::{sm9_u256_pairing, twist_point_add_full, Point, TwistPoint};
4use crate::u256::{sm9_random_u256, u256_cmp, xor, U256};
5use crate::{
6 SM9_HASH1_PREFIX, SM9_HASH2_PREFIX, SM9_HID_ENC, SM9_HID_EXCH, SM9_HID_SIGN, SM9_N_MINUS_ONE,
7 SM9_POINT_MONT_P1, SM9_TWIST_POINT_MONT_P2,
8};
9use gm_sm3::sm3_hash;
10
11#[derive(Copy, Debug, Clone)]
12pub struct Sm9EncKey {
13 pub ppube: Point,
14 pub de: TwistPoint,
15}
16
17#[derive(Copy, Debug, Clone)]
18pub struct Sm9EncMasterKey {
19 pub ke: U256,
20 pub ppube: Point,
21}
22
23pub fn generate_sign_master_key() -> Sm9SignMasterKey {
24 let ks = sm9_random_u256(&SM9_N_MINUS_ONE);
25 Sm9SignMasterKey {
26 ks,
27 ppubs: TwistPoint::g_mul(&ks),
28 }
29}
30
31pub fn generate_enc_master_key() -> Sm9EncMasterKey {
32 let ke = sm9_random_u256(&SM9_N_MINUS_ONE);
33 Sm9EncMasterKey {
34 ke,
35 ppube: Point::g_mul(&ke),
36 }
37}
38
39impl Sm9EncKey {
40 pub fn decrypt(&self, idb: &[u8], data: &[u8]) -> Sm9Result<Vec<u8>> {
41 let c1_bytes = &data[0..65];
42 let c2 = &data[(65 + 32)..];
43 let c3 = &data[65..(65 + 32)];
44 let c1 = Point::from_bytes(c1_bytes);
45 let w = sm9_u256_pairing(&self.de, &c1);
46 let w_bytes = w.to_bytes_be();
47 let mut k_append: Vec<u8> = vec![];
48 k_append.extend_from_slice(&c1_bytes[1..65]);
49 k_append.extend_from_slice(&w_bytes);
50 k_append.extend_from_slice(idb);
51 let k = kdf(&k_append, (255 + 32) as usize);
52 fn is_zero(x: &Vec<u8>) -> bool {
53 x.iter().all(|&byte| byte == 0)
54 }
55
56 if !is_zero(&k) {
57 let k = k.as_slice();
58 let mlen = data.len() - (65 + 32);
59 let k1 = &k[0..mlen];
60 let k2 = &k[mlen..];
61 let u = sm3_hmac(k2, c2, 32);
62 if !u.as_slice().eq(c3) {
63 return Err(Sm9Error::InvalidDigest);
64 }
65 let m = xor(c2, &k1, k1.len());
66 Ok(m)
67 } else {
68 Err(Sm9Error::KdfHashError)
69 }
70 }
71}
72
73impl Sm9EncMasterKey {
74 pub fn master_key_generate() -> Sm9EncMasterKey {
75 let ke = sm9_random_u256(&SM9_N_MINUS_ONE);
77 Self {
78 ke,
79 ppube: Point::g_mul(&ke), }
81 }
82
83 pub fn encrypt(&self, idb: &[u8], data: &[u8]) -> Vec<u8> {
84 let t = sm9_u256_hash1(idb, SM9_HID_ENC);
86 let mut c1 = SM9_POINT_MONT_P1.point_mul(&t);
87 c1 = c1.point_add(&self.ppube);
88
89 let mut k = vec![];
90 loop {
91 let r = sm9_random_u256(&SM9_N_MINUS_ONE);
93
94 c1 = c1.point_mul(&r);
96 let cbuf = c1.to_bytes_be();
97 let cbuf = cbuf.as_slice();
98
99 let mut g = sm9_u256_pairing(&SM9_TWIST_POINT_MONT_P2, &self.ppube);
101
102 g = g.pow(&r);
104 let gbuf = g.to_bytes_be();
105 let gbuf = gbuf.as_slice();
106
107 let mut k_append: Vec<u8> = vec![];
109 k_append.extend_from_slice(&cbuf[1..cbuf.len()]);
111 k_append.extend_from_slice(gbuf);
112 k_append.extend_from_slice(idb);
113 k = kdf(&k_append, (255 + 32) as usize);
114 fn is_zero(x: &Vec<u8>) -> bool {
115 x.iter().all(|&byte| byte == 0)
116 }
117
118 if !is_zero(&k) {
119 break;
120 }
121 }
122
123 let k1 = &k[0..data.len()];
124 let k2 = &k[data.len()..];
125 let c2 = xor(k1, &data, data.len());
126 let c3 = sm3_hmac(k2, &c2, 32usize);
127 let mut c: Vec<u8> = vec![];
128 c.extend_from_slice(&c1.to_bytes_be());
129 c.extend_from_slice(&c3);
130 c.extend_from_slice(&c2);
131 c
132 }
133
134 pub fn extract_key(&self, id: &[u8]) -> Option<Sm9EncKey> {
135 let mut t = sm9_u256_hash1(id, SM9_HID_ENC);
137 t = mod_n_add(&t, &self.ke);
138 if t.is_zero() {
139 return None;
140 }
141 t = mod_n_inv(&t);
143
144 t = mod_n_mul(&t, &self.ke);
146 Some(Sm9EncKey {
147 ppube: self.ppube,
148 de: TwistPoint::g_mul(&t),
149 })
150 }
151
152 pub fn extract_exch_key(&self, id: &[u8]) -> Option<Sm9EncKey> {
153 let mut t = sm9_u256_hash1(id, SM9_HID_EXCH);
155 t = mod_n_add(&t, &self.ke);
156 if t.is_zero() {
157 return None;
158 }
159 t = mod_n_inv(&t);
161
162 t = mod_n_mul(&t, &self.ke);
164 Some(Sm9EncKey {
165 ppube: self.ppube,
166 de: TwistPoint::g_mul(&t),
167 })
168 }
169}
170
171const BLOCK_SIZE: usize = 64;
172
173fn sm3_hmac(key: &[u8], message: &[u8], klen: usize) -> Vec<u8> {
174 let mut ipad = [0x36u8; BLOCK_SIZE];
175 let mut opad = [0x5cu8; BLOCK_SIZE];
176
177 let mut key_block = [0u8; 64];
178
179 if klen > BLOCK_SIZE {
181 key_block[..32].copy_from_slice(&sm3_hash(key));
182 } else {
183 key_block[..klen].copy_from_slice(&key[0..klen]);
184 };
185
186 for i in 0..BLOCK_SIZE {
188 ipad[i] ^= key_block[i];
189 opad[i] ^= key_block[i];
190 }
191
192 let mut ipad_append = vec![];
194 ipad_append.extend_from_slice(&ipad);
195 ipad_append.extend_from_slice(message);
196 let inner_result = sm3_hash(&ipad_append);
197
198 let mut opad_append = vec![];
200 opad_append.extend_from_slice(&opad);
201 opad_append.extend_from_slice(&inner_result);
202 sm3_hash(&opad_append).to_vec()
203}
204
205fn sm9_u256_hash1(id: &[u8], hid: u8) -> U256 {
206 let ct1: [u8; 4] = [0x00, 0x00, 0x00, 0x01];
207 let ct2: [u8; 4] = [0x00, 0x00, 0x00, 0x02];
208 let mut c3_append: Vec<u8> = vec![];
209 c3_append.extend_from_slice(&vec![SM9_HASH1_PREFIX]);
210 c3_append.extend_from_slice(id);
211 c3_append.extend_from_slice(&vec![hid]);
212 c3_append.extend_from_slice(&ct1);
213 let ha1 = sm3_hash(&c3_append);
214
215 let mut c3_append2: Vec<u8> = vec![];
216 c3_append2.extend_from_slice(&vec![SM9_HASH1_PREFIX]);
217 c3_append2.extend_from_slice(id);
218 c3_append2.extend_from_slice(&vec![hid]);
219 c3_append2.extend_from_slice(&ct2);
220 let ha2 = sm3_hash(&c3_append2);
221
222 let mut ha = vec![];
223 ha.extend_from_slice(&ha1);
224 ha.extend_from_slice(&ha2);
225 let r = mod_n_from_hash(&ha);
226 r
227}
228
229fn sm9_u256_hash2(data: &[u8], wbuf: &[u8]) -> U256 {
230 let ct1: [u8; 4] = [0x00, 0x00, 0x00, 0x01];
231 let ct2: [u8; 4] = [0x00, 0x00, 0x00, 0x02];
232 let mut c3_append: Vec<u8> = vec![];
233 c3_append.extend_from_slice(&vec![SM9_HASH2_PREFIX]);
234 c3_append.extend_from_slice(data);
235 c3_append.extend_from_slice(wbuf);
236 c3_append.extend_from_slice(&ct1);
237 let ha1 = sm3_hash(&c3_append);
238
239 let mut c3_append2: Vec<u8> = vec![];
240 c3_append2.extend_from_slice(&vec![SM9_HASH2_PREFIX]);
241 c3_append2.extend_from_slice(data);
242 c3_append2.extend_from_slice(wbuf);
243 c3_append2.extend_from_slice(&ct2);
244 let ha2 = sm3_hash(&c3_append2);
245
246 let mut ha = vec![];
247 ha.extend_from_slice(&ha1);
248 ha.extend_from_slice(&ha2);
249 let r = mod_n_from_hash(&ha);
250 r
251}
252
253fn kdf(z: &[u8], klen: usize) -> Vec<u8> {
254 let mut ct = 0x00000001u32;
255 let bound = ((klen as f64) / 32.0).ceil() as u32;
256 let mut h_a = Vec::new();
257 for _i in 1..bound {
258 let mut prepend = Vec::new();
259 prepend.extend_from_slice(z);
260 prepend.extend_from_slice(&ct.to_be_bytes());
261
262 let h_a_i = sm3_hash(&prepend[..]);
263 h_a.extend_from_slice(&h_a_i);
264 ct += 1;
265 }
266
267 let mut prepend = Vec::new();
268 prepend.extend_from_slice(z);
269 prepend.extend_from_slice(&ct.to_be_bytes());
270
271 let last = sm3_hash(&prepend[..]);
272 if klen % 32 == 0 {
273 h_a.extend_from_slice(&last);
274 } else {
275 h_a.extend_from_slice(&last[0..(klen % 32)]);
276 }
277 h_a
278}
279
280#[derive(Copy, Debug, Clone)]
281pub struct Sm9SignKey {
282 pub ppubs: TwistPoint,
283 pub ds: Point,
284}
285
286impl Sm9SignKey {
287 pub fn sign(&self, data: &[u8]) -> Sm9Result<(U256, Point)> {
289 let g = sm9_u256_pairing(&self.ppubs, &SM9_POINT_MONT_P1);
291 let mut h: U256 = [0, 0, 0, 0];
292 let mut r: U256 = [0, 0, 0, 0];
293 loop {
294 r = sm9_random_u256(&SM9_N_MINUS_ONE);
296
297 let w = g.pow(&r);
299 let wbuf = w.to_bytes_be();
300 let wbuf = wbuf.as_slice();
301
302 h = sm9_u256_hash2(data, wbuf);
304
305 r = mod_n_sub(&r, &h);
307
308 if !r.is_zero() {
309 break;
310 }
311 }
312
313 let s = self.ds.point_mul(&r);
315
316 Ok((h, s))
317 }
318}
319
320#[derive(Copy, Debug, Clone)]
321pub struct Sm9SignMasterKey {
322 pub ks: U256,
323 pub ppubs: TwistPoint,
324}
325
326impl Sm9SignMasterKey {
327 pub fn master_key_generate() -> Self {
328 let ks = sm9_random_u256(&SM9_N_MINUS_ONE);
330 Self {
331 ks,
332 ppubs: TwistPoint::g_mul(&ks), }
334 }
335
336 pub fn extract_key(&self, idb: &[u8]) -> Option<Sm9SignKey> {
337 let mut t = sm9_u256_hash1(idb, SM9_HID_SIGN);
339 t = mod_n_add(&t, &self.ks);
340 if t.is_zero() {
341 return None;
342 }
343 t = mod_n_inv(&t);
345
346 t = mod_n_mul(&t, &self.ks);
348 Some(Sm9SignKey {
349 ppubs: self.ppubs,
350 ds: Point::g_mul(&t),
351 })
352 }
353
354 pub fn verify_sign(&self, id: &[u8], data: &[u8], h: &U256, s: &Point) -> Sm9Result<()> {
355 let g = sm9_u256_pairing(&self.ppubs, &SM9_POINT_MONT_P1);
356 let t = g.pow(h);
357 let h1 = sm9_u256_hash1(id, SM9_HID_SIGN);
359 let mut p = TwistPoint::g_mul(&h1);
360 p = twist_point_add_full(&self.ppubs, &p);
361
362 let u = sm9_u256_pairing(&p, s);
363 let w = u.fp_mul(&t);
364 let wbuf = w.to_bytes_be();
365 let wbuf = wbuf.as_slice();
366 let h2 = sm9_u256_hash2(data, wbuf);
367 if u256_cmp(&h2, h) != 0 {
368 Err(Sm9Error::InvalidDigest)
369 } else {
370 Ok(())
371 }
372 }
373}
374
375pub fn exch_step_1a(msk: &Sm9EncMasterKey, idb: &[u8]) -> (Point, U256) {
376 let mut ra = sm9_u256_hash1(idb, SM9_HID_EXCH);
378 let mut r = SM9_POINT_MONT_P1.point_mul(&ra);
379 r = r.point_add(&msk.ppube);
380
381 ra = sm9_random_u256(&SM9_N_MINUS_ONE);
383 r = r.point_mul(&ra);
387
388 (r, ra)
389}
390
391pub fn exch_step_1b(
392 msk: &Sm9EncMasterKey,
393 ida: &[u8],
394 idb: &[u8],
395 key: &Sm9EncKey,
396 ra: &Point,
397 klen: usize,
398) -> Sm9Result<(Point, Vec<u8>)> {
399 let mut rb = sm9_u256_hash1(ida, SM9_HID_EXCH);
401 let mut r = SM9_POINT_MONT_P1.point_mul(&rb);
402 r = r.point_add(&msk.ppube);
403 let mut sk = vec![];
404 loop {
405 rb = sm9_random_u256(&SM9_N_MINUS_ONE);
407
408 r = r.point_mul(&rb);
412
413 if !ra.is_on_curve() {
415 return Err(Sm9Error::InvalidPoint);
416 }
417
418 let g1 = sm9_u256_pairing(&key.de, &ra);
419 let mut g2 = sm9_u256_pairing(&SM9_TWIST_POINT_MONT_P2, &msk.ppube);
420 g2 = g2.pow(&rb);
421 let g3 = g1.pow(&rb);
422 let ta = ra.to_bytes_be();
423 let tb = r.to_bytes_be();
424
425 let g1 = g1.to_bytes_be();
426 let g2 = g2.to_bytes_be();
427 let g3 = g3.to_bytes_be();
428
429 let mut pre_append = vec![];
430 pre_append.extend_from_slice(ida);
431 pre_append.extend_from_slice(idb);
432 pre_append.extend_from_slice(&ta[1..]);
433 pre_append.extend_from_slice(&tb[1..]);
434 pre_append.extend_from_slice(&g1);
435 pre_append.extend_from_slice(&g2);
436 pre_append.extend_from_slice(&g3);
437
438 sk = kdf(&pre_append, klen);
439
440 fn is_zero(x: &Vec<u8>, klen: usize) -> bool {
441 let mut ret = true;
442 for i in 0..klen {
443 if x[i] != 0 {
444 ret = false;
445 }
446 }
447 ret
448 }
449
450 if !is_zero(&sk, klen) {
451 break;
452 }
453 }
454 Ok((r, sk))
455}
456
457pub fn exch_step_2a(
458 msk: &Sm9EncMasterKey,
459 ida: &[u8],
460 idb: &[u8],
461 key: &Sm9EncKey,
462 ra_: U256,
463 ra: &Point,
464 rb: &Point,
465 klen: usize,
466) -> Sm9Result<Vec<u8>> {
467 let mut sk = vec![];
468 loop {
469 if !rb.is_on_curve() {
470 return Err(Sm9Error::InvalidPoint);
471 }
472
473 let mut g1 = sm9_u256_pairing(&SM9_TWIST_POINT_MONT_P2, &msk.ppube);
474 g1 = g1.pow(&ra_);
475
476 let g2 = sm9_u256_pairing(&key.de, &rb);
477 let g3 = g2.pow(&ra_);
478
479 let ta = ra.to_bytes_be();
480 let tb = rb.to_bytes_be();
481
482 let g1 = g1.to_bytes_be();
483 let g2 = g2.to_bytes_be();
484 let g3 = g3.to_bytes_be();
485
486 let mut pre_append = vec![];
487 pre_append.extend_from_slice(ida);
488 pre_append.extend_from_slice(idb);
489 pre_append.extend_from_slice(&ta[1..]);
490 pre_append.extend_from_slice(&tb[1..]);
491 pre_append.extend_from_slice(&g1);
492 pre_append.extend_from_slice(&g2);
493 pre_append.extend_from_slice(&g3);
494
495 sk = kdf(&pre_append, klen);
496 fn is_zero(x: &Vec<u8>, klen: usize) -> bool {
497 let mut ret = true;
498 for i in 0..klen {
499 if x[i] != 0 {
500 ret = false;
501 }
502 }
503 ret
504 }
505
506 if !is_zero(&sk, klen) {
507 break;
508 }
509 }
510 Ok(sk)
511}
512
513#[cfg(test)]
514mod sm9_key_test {
515 use crate::key::{
516 exch_step_1a, exch_step_1b, exch_step_2a, Sm9EncKey, Sm9EncMasterKey, Sm9SignMasterKey,
517 };
518 use crate::points::{Point, TwistPoint};
519 use crate::u256::u256_from_be_bytes;
520
521 #[test]
522 fn test_encrypt() {
523 let data: [u8; 21] = [
524 0x43, 0x68, 0x69, 0x6E, 0x65, 0x73, 0x65, 0x20, 0x49, 0x42, 0x53, 0x20, 0x73, 0x74,
525 0x61, 0x6E, 0x64, 0x61, 0x72, 0x64, 0x64,
526 ];
527
528 let idb = [0x42, 0x6F, 0x62u8];
529
530 let ke = u256_from_be_bytes(
531 &hex::decode("0001EDEE3778F441F8DEA3D9FA0ACC4E07EE36C93F9A08618AF4AD85CEDE1C22")
532 .unwrap(),
533 );
534
535 let msk = Sm9EncMasterKey {
536 ke,
537 ppube: Point::g_mul(&ke),
538 };
539
540 let r = msk.extract_key(&idb);
541 let r_de = TwistPoint::from_hex(
542 [
543 "115BAE85F5D8BC6C3DBD9E5342979ACCCF3C2F4F28420B1CB4F8C0B59A19B158",
544 "94736ACD2C8C8796CC4785E938301A139A059D3537B6414140B2D31EECF41683",
545 ],
546 [
547 "27538A62E7F7BFB51DCE08704796D94C9D56734F119EA44732B50E31CDEB75C1",
548 "7AA5E47570DA7600CD760A0CF7BEAF71C447F3844753FE74FA7BA92CA7D3B55F",
549 ],
550 );
551 assert_eq!(true, r.unwrap().de.point_equals(&r_de));
552
553 let ret = msk.encrypt(&idb, &data);
554 println!("Message = {:?}", &data);
555 println!("Ciphertext = {:?}", ret);
556 let m = r.unwrap().decrypt(&idb, &ret).expect("Decryption failed");
557 println!("Plaintext = {:?}", &m);
558 assert_eq!(true, data == m.as_slice());
559 }
560
561 #[test]
562 fn test_sign_verify() {
563 let data: [u8; 20] = [
564 0x43, 0x68, 0x69, 0x6E, 0x65, 0x73, 0x65, 0x20, 0x49, 0x42, 0x53, 0x20, 0x73, 0x74,
565 0x61, 0x6E, 0x64, 0x61, 0x72, 0x64,
566 ];
567
568 let ida = [0x41, 0x6C, 0x69, 0x63, 0x65u8];
569
570 let ks = u256_from_be_bytes(
571 &hex::decode("000130E78459D78545CB54C587E02CF480CE0B66340F319F348A1D5B1F2DC5F4")
572 .unwrap(),
573 );
574 let msk = Sm9SignMasterKey {
575 ks,
576 ppubs: TwistPoint::g_mul(&ks),
577 };
578
579 let r_ds = Point::from_hex([
580 "A5702F05CF1315305E2D6EB64B0DEB923DB1A0BCF0CAFF90523AC8754AA69820",
581 "78559A844411F9825C109F5EE3F52D720DD01785392A727BB1556952B2B013D3",
582 ]);
583 let r = msk.extract_key(&ida);
584 let ps = r.unwrap();
585 assert_eq!(true, ps.ds.point_equals(&r_ds));
586
587 println!("Message = {:?}", &data);
588 let (h, s) = ps.sign(&data).unwrap();
589 println!("Sign H = {:?}", &h);
590 println!("Sign S = {:?}", &s);
591
592 let r = msk.verify_sign(&ida, &data, &h, &s);
593 println!("VersionSign ={:?}", &r);
594 }
595
596 #[test]
597 fn test_exchange_key() {
598 let msk: Sm9EncMasterKey = Sm9EncMasterKey::master_key_generate();
607 let klen = 20usize;
608 let ida = [0x41, 0x6C, 0x69, 0x63, 0x65u8];
609 let idb = [0x42, 0x6F, 0x62u8];
610 let key_a: Sm9EncKey = msk.extract_exch_key(&ida).unwrap();
611 let key_b: Sm9EncKey = msk.extract_exch_key(&idb).unwrap();
612
613 let (ra, ra_) = exch_step_1a(&msk, &idb);
614 let (rb, skb) = exch_step_1b(&msk, &ida, &idb, &key_b, &ra, klen).unwrap();
615 let ska = exch_step_2a(&msk, &ida, &idb, &key_a, ra_, &ra, &rb, klen).unwrap();
616 println!("SKB = {:?}", &skb);
617 println!("SKA = {:?}", &ska);
618 for i in 0..klen {
619 if ska[i] != skb[i] {
620 println!("Exchange key different at byte index: {}", i)
621 }
622 }
623 }
624}