gm_rs/sm2/
exchange.rs

1use crate::sm2::error::{Sm2Error, Sm2Result};
2use crate::sm2::key::{gen_keypair, CompressModle, Sm2PrivateKey, Sm2PublicKey};
3use crate::sm2::p256_ecc::{g_mul, scalar_mul, Point, P256C_PARAMS};
4use crate::sm2::util::{compute_za, kdf, random_uint, DEFAULT_ID};
5use crate::sm3::sm3_hash;
6use byteorder::{BigEndian, WriteBytesExt};
7use num_bigint::BigUint;
8use num_traits::{FromPrimitive, One, Pow};
9
10#[derive(Debug)]
11pub struct Exchange {
12    klen: usize,
13    za: [u8; 32],
14    sk: Sm2PrivateKey,
15    v: Option<Point>,
16    r: Option<BigUint>,
17    r_point: Option<Point>,
18    pub(crate) k: Option<Vec<u8>>,
19
20    rhs_za: [u8; 32],
21    rhs_pk: Sm2PublicKey,
22}
23
24/// Build the exchange Pair
25///
26pub fn build_ex_pair(
27    klen: usize,
28    first_id: &str,
29    other_id: &str,
30) -> Sm2Result<(Exchange, Exchange)> {
31    let (pk_a, sk_a) = gen_keypair(CompressModle::Compressed).unwrap();
32    let (pk_b, sk_b) = gen_keypair(CompressModle::Compressed).unwrap();
33    let user_a = Exchange::new(klen, Some(first_id), &pk_a, &sk_a, Some(other_id), &pk_b).unwrap();
34    let user_b = Exchange::new(klen, Some(other_id), &pk_b, &sk_b, Some(first_id), &pk_a).unwrap();
35    Ok((user_a, user_b))
36}
37
38impl Exchange {
39    pub fn new(
40        klen: usize,
41        id: Option<&str>,
42        pk: &Sm2PublicKey,
43        sk: &Sm2PrivateKey,
44        rhs_id: Option<&str>,
45        rhs_pk: &Sm2PublicKey,
46    ) -> Sm2Result<Exchange> {
47        let id = match id {
48            None => DEFAULT_ID,
49            Some(s) => s,
50        };
51        let rhs_id = match rhs_id {
52            None => DEFAULT_ID,
53            Some(s) => s,
54        };
55        Ok(Exchange {
56            klen,
57            za: compute_za(id, &pk)?,
58            sk: sk.clone(),
59            v: None,
60            r: None,
61            r_point: None,
62            k: None,
63            rhs_za: compute_za(rhs_id, &rhs_pk)?,
64            rhs_pk: rhs_pk.clone(),
65        })
66    }
67
68    // Step1: UserA Call
69    // A1:用随机数发生器产生随机数rA ∈ [1, n-1];
70    // A2:计算椭圆曲线点RA = [rA]G=(x1,y1);
71    // A3:将RA发送给用户B;
72    pub fn exchange_1(&mut self) -> Sm2Result<Point> {
73        let r = random_uint();
74        let r_point = g_mul(&r);
75        self.r = Some(r);
76        self.r_point = Some(r_point);
77        Ok(r_point)
78    }
79
80    // Step2: UserB Call
81    //
82    pub fn exchange_2(&mut self, ra_point: &Point) -> Sm2Result<(Point, [u8; 32])> {
83        if !ra_point.is_valid() {
84            return Err(Sm2Error::CheckPointErr);
85        }
86        let n = &P256C_PARAMS.n;
87        let w = ((n.bits() as f64) / 2.0).ceil() - 1.0;
88        let pow_w = BigUint::from_u32(2).unwrap().pow(w as u32);
89
90        let r2 = random_uint();
91        let r2_point = g_mul(&r2);
92        self.r = Some(r2);
93        self.r_point = Some(r2_point);
94        let r2_point_affine = r2_point.to_affine();
95        let x2 = r2_point_affine.x;
96        let y2 = r2_point_affine.y;
97        let x2_b = &pow_w + (x2.to_biguint() & (&pow_w - BigUint::one()));
98        let t2 = (&self.sk.d + self.r.as_ref().unwrap() * &x2_b) % n;
99
100        let ra_point_affine = ra_point.to_affine();
101        let x1 = ra_point_affine.x;
102        let y1 = ra_point_affine.y;
103        let x1_a = &pow_w + (x1.to_biguint() & (&pow_w - BigUint::one()));
104
105        let p = self.rhs_pk.value().add(&scalar_mul(&x1_a, ra_point));
106        let v_point = scalar_mul(&(BigUint::one() * t2), &p);
107        if v_point.is_zero() {
108            return Err(Sm2Error::ZeroPoint);
109        }
110        self.v = Some(v_point);
111
112        let v_affine_p = v_point.to_affine();
113        let xv_bytes = v_affine_p.x.to_bytes_be();
114        let yv_bytes = v_affine_p.y.to_bytes_be();
115
116        let mut prepend = Vec::new();
117        prepend.extend_from_slice(&xv_bytes);
118        prepend.extend_from_slice(&yv_bytes);
119        prepend.extend_from_slice(&self.rhs_za); // User A
120        prepend.extend_from_slice(&self.za); // User B
121
122        let k_b = kdf(&prepend, self.klen);
123        self.k = Some(k_b);
124
125        let mut temp: Vec<u8> = Vec::new();
126        temp.extend_from_slice(&xv_bytes);
127        temp.extend_from_slice(&self.rhs_za);
128        temp.extend_from_slice(&self.za);
129        temp.extend_from_slice(&x1.to_bytes_be());
130        temp.extend_from_slice(&y1.to_bytes_be());
131        temp.extend_from_slice(&x2.to_bytes_be());
132        temp.extend_from_slice(&y2.to_bytes_be());
133
134        let mut prepend: Vec<u8> = Vec::new();
135        prepend.write_u16::<BigEndian>(0x02_u16).unwrap();
136        prepend.extend_from_slice(&yv_bytes);
137        prepend.extend_from_slice(&sm3_hash(&temp));
138        Ok((r2_point, sm3_hash(&prepend)))
139    }
140
141    // Step4: UserA Call
142    //
143    pub fn exchange_3(&mut self, rb_point: &Point, sb: [u8; 32]) -> Sm2Result<[u8; 32]> {
144        if !rb_point.is_valid() {
145            return Err(Sm2Error::CheckPointErr);
146        }
147        let n = &P256C_PARAMS.n;
148        let w = ((n.bits() as f64) / 2.0).ceil() - 1.0;
149        let pow_w = BigUint::from_u32(2).unwrap().pow(w as u32);
150
151        let ra_point_affine = self.r_point.unwrap().to_affine();
152        let x1 = ra_point_affine.x;
153        let y1 = ra_point_affine.y;
154        let x1_a = &pow_w + (x1.to_biguint() & (&pow_w - BigUint::one()));
155        let t_a = (&self.sk.d + x1_a * self.r.as_ref().unwrap()) % n;
156
157        let rb_point_affine = rb_point.to_affine();
158        let x2 = rb_point_affine.x;
159        let y2 = rb_point_affine.y;
160        let x2_b = &pow_w + (x2.to_biguint() & (&pow_w - BigUint::one()));
161
162        let p = self.rhs_pk.value().add(&scalar_mul(&x2_b, rb_point));
163        let u_point = scalar_mul(&(BigUint::one() * t_a), &p);
164        if u_point.is_zero() {
165            return Err(Sm2Error::ZeroPoint);
166        }
167
168        let u_affine_p = u_point.to_affine();
169        let xu_bytes = u_affine_p.x.to_bytes_be();
170        let yu_bytes = u_affine_p.y.to_bytes_be();
171
172        let mut prepend = Vec::new();
173        prepend.extend_from_slice(&xu_bytes);
174        prepend.extend_from_slice(&yu_bytes);
175        prepend.extend_from_slice(&self.za);
176        prepend.extend_from_slice(&self.rhs_za);
177
178        let k_a = kdf(&prepend, self.klen);
179        self.k = Some(k_a);
180
181        let mut temp: Vec<u8> = Vec::new();
182        temp.extend_from_slice(&xu_bytes);
183        temp.extend_from_slice(&self.za);
184        temp.extend_from_slice(&self.rhs_za);
185        temp.extend_from_slice(&x1.to_bytes_be());
186        temp.extend_from_slice(&y1.to_bytes_be());
187        temp.extend_from_slice(&x2.to_bytes_be());
188        temp.extend_from_slice(&y2.to_bytes_be());
189        let temp_hash = sm3_hash(&temp);
190
191        let mut prepend: Vec<u8> = Vec::new();
192        prepend.write_u16::<BigEndian>(0x02_u16).unwrap();
193        prepend.extend_from_slice(&yu_bytes);
194        prepend.extend_from_slice(&temp_hash);
195
196        let s1 = sm3_hash(&prepend);
197        if s1 != sb {
198            return Err(Sm2Error::HashNotEqual);
199        }
200
201        let mut prepend: Vec<u8> = Vec::new();
202        prepend.write_u16::<BigEndian>(0x03_u16).unwrap();
203        prepend.extend_from_slice(&yu_bytes);
204        prepend.extend_from_slice(&temp_hash);
205        Ok(sm3_hash(&prepend))
206    }
207
208    // Step4: UserA Call
209    pub fn exchange_4(&self, sa: [u8; 32], ra_point: &Point) -> Sm2Result<bool> {
210        let ra_point_affine = ra_point.to_affine();
211        let x1 = ra_point_affine.x;
212        let y1 = ra_point_affine.y;
213
214        let r2_point_affine = self.r_point.unwrap().to_affine();
215        let x2 = r2_point_affine.x;
216        let y2 = r2_point_affine.y;
217
218        let v_point_affine = self.v.unwrap().to_affine();
219        let xv = v_point_affine.x;
220        let yv = v_point_affine.y;
221
222        let mut temp: Vec<u8> = Vec::new();
223        temp.extend_from_slice(&xv.to_bytes_be());
224        temp.extend_from_slice(&self.rhs_za);
225        temp.extend_from_slice(&self.za);
226        temp.extend_from_slice(&x1.to_bytes_be());
227        temp.extend_from_slice(&y1.to_bytes_be());
228        temp.extend_from_slice(&x2.to_bytes_be());
229        temp.extend_from_slice(&y2.to_bytes_be());
230
231        let mut prepend: Vec<u8> = Vec::new();
232        prepend.write_u16::<BigEndian>(0x03_u16).unwrap();
233        prepend.extend_from_slice(&yv.to_bytes_be());
234        prepend.extend_from_slice(&sm3_hash(&temp));
235        let s_2 = sm3_hash(&prepend);
236        Ok(s_2 == sa)
237    }
238}