1use crate::sm2::error::{Sm2Error, Sm2Result};
2use crate::sm2::key::{gen_keypair, CompressModle, Sm2PrivateKey, Sm2PublicKey};
3use crate::sm2::p256_ecc::{g_mul, scalar_mul, Point, P256C_PARAMS};
4use crate::sm2::util::{compute_za, kdf, random_uint, DEFAULT_ID};
5use crate::sm3::sm3_hash;
6use byteorder::{BigEndian, WriteBytesExt};
7use num_bigint::BigUint;
8use num_traits::{FromPrimitive, One, Pow};
9
10#[derive(Debug)]
11pub struct Exchange {
12 klen: usize,
13 za: [u8; 32],
14 sk: Sm2PrivateKey,
15 v: Option<Point>,
16 r: Option<BigUint>,
17 r_point: Option<Point>,
18 pub(crate) k: Option<Vec<u8>>,
19
20 rhs_za: [u8; 32],
21 rhs_pk: Sm2PublicKey,
22}
23
24pub fn build_ex_pair(
27 klen: usize,
28 first_id: &str,
29 other_id: &str,
30) -> Sm2Result<(Exchange, Exchange)> {
31 let (pk_a, sk_a) = gen_keypair(CompressModle::Compressed).unwrap();
32 let (pk_b, sk_b) = gen_keypair(CompressModle::Compressed).unwrap();
33 let user_a = Exchange::new(klen, Some(first_id), &pk_a, &sk_a, Some(other_id), &pk_b).unwrap();
34 let user_b = Exchange::new(klen, Some(other_id), &pk_b, &sk_b, Some(first_id), &pk_a).unwrap();
35 Ok((user_a, user_b))
36}
37
38impl Exchange {
39 pub fn new(
40 klen: usize,
41 id: Option<&str>,
42 pk: &Sm2PublicKey,
43 sk: &Sm2PrivateKey,
44 rhs_id: Option<&str>,
45 rhs_pk: &Sm2PublicKey,
46 ) -> Sm2Result<Exchange> {
47 let id = match id {
48 None => DEFAULT_ID,
49 Some(s) => s,
50 };
51 let rhs_id = match rhs_id {
52 None => DEFAULT_ID,
53 Some(s) => s,
54 };
55 Ok(Exchange {
56 klen,
57 za: compute_za(id, &pk)?,
58 sk: sk.clone(),
59 v: None,
60 r: None,
61 r_point: None,
62 k: None,
63 rhs_za: compute_za(rhs_id, &rhs_pk)?,
64 rhs_pk: rhs_pk.clone(),
65 })
66 }
67
68 pub fn exchange_1(&mut self) -> Sm2Result<Point> {
73 let r = random_uint();
74 let r_point = g_mul(&r);
75 self.r = Some(r);
76 self.r_point = Some(r_point);
77 Ok(r_point)
78 }
79
80 pub fn exchange_2(&mut self, ra_point: &Point) -> Sm2Result<(Point, [u8; 32])> {
83 if !ra_point.is_valid() {
84 return Err(Sm2Error::CheckPointErr);
85 }
86 let n = &P256C_PARAMS.n;
87 let w = ((n.bits() as f64) / 2.0).ceil() - 1.0;
88 let pow_w = BigUint::from_u32(2).unwrap().pow(w as u32);
89
90 let r2 = random_uint();
91 let r2_point = g_mul(&r2);
92 self.r = Some(r2);
93 self.r_point = Some(r2_point);
94 let r2_point_affine = r2_point.to_affine();
95 let x2 = r2_point_affine.x;
96 let y2 = r2_point_affine.y;
97 let x2_b = &pow_w + (x2.to_biguint() & (&pow_w - BigUint::one()));
98 let t2 = (&self.sk.d + self.r.as_ref().unwrap() * &x2_b) % n;
99
100 let ra_point_affine = ra_point.to_affine();
101 let x1 = ra_point_affine.x;
102 let y1 = ra_point_affine.y;
103 let x1_a = &pow_w + (x1.to_biguint() & (&pow_w - BigUint::one()));
104
105 let p = self.rhs_pk.value().add(&scalar_mul(&x1_a, ra_point));
106 let v_point = scalar_mul(&(BigUint::one() * t2), &p);
107 if v_point.is_zero() {
108 return Err(Sm2Error::ZeroPoint);
109 }
110 self.v = Some(v_point);
111
112 let v_affine_p = v_point.to_affine();
113 let xv_bytes = v_affine_p.x.to_bytes_be();
114 let yv_bytes = v_affine_p.y.to_bytes_be();
115
116 let mut prepend = Vec::new();
117 prepend.extend_from_slice(&xv_bytes);
118 prepend.extend_from_slice(&yv_bytes);
119 prepend.extend_from_slice(&self.rhs_za); prepend.extend_from_slice(&self.za); let k_b = kdf(&prepend, self.klen);
123 self.k = Some(k_b);
124
125 let mut temp: Vec<u8> = Vec::new();
126 temp.extend_from_slice(&xv_bytes);
127 temp.extend_from_slice(&self.rhs_za);
128 temp.extend_from_slice(&self.za);
129 temp.extend_from_slice(&x1.to_bytes_be());
130 temp.extend_from_slice(&y1.to_bytes_be());
131 temp.extend_from_slice(&x2.to_bytes_be());
132 temp.extend_from_slice(&y2.to_bytes_be());
133
134 let mut prepend: Vec<u8> = Vec::new();
135 prepend.write_u16::<BigEndian>(0x02_u16).unwrap();
136 prepend.extend_from_slice(&yv_bytes);
137 prepend.extend_from_slice(&sm3_hash(&temp));
138 Ok((r2_point, sm3_hash(&prepend)))
139 }
140
141 pub fn exchange_3(&mut self, rb_point: &Point, sb: [u8; 32]) -> Sm2Result<[u8; 32]> {
144 if !rb_point.is_valid() {
145 return Err(Sm2Error::CheckPointErr);
146 }
147 let n = &P256C_PARAMS.n;
148 let w = ((n.bits() as f64) / 2.0).ceil() - 1.0;
149 let pow_w = BigUint::from_u32(2).unwrap().pow(w as u32);
150
151 let ra_point_affine = self.r_point.unwrap().to_affine();
152 let x1 = ra_point_affine.x;
153 let y1 = ra_point_affine.y;
154 let x1_a = &pow_w + (x1.to_biguint() & (&pow_w - BigUint::one()));
155 let t_a = (&self.sk.d + x1_a * self.r.as_ref().unwrap()) % n;
156
157 let rb_point_affine = rb_point.to_affine();
158 let x2 = rb_point_affine.x;
159 let y2 = rb_point_affine.y;
160 let x2_b = &pow_w + (x2.to_biguint() & (&pow_w - BigUint::one()));
161
162 let p = self.rhs_pk.value().add(&scalar_mul(&x2_b, rb_point));
163 let u_point = scalar_mul(&(BigUint::one() * t_a), &p);
164 if u_point.is_zero() {
165 return Err(Sm2Error::ZeroPoint);
166 }
167
168 let u_affine_p = u_point.to_affine();
169 let xu_bytes = u_affine_p.x.to_bytes_be();
170 let yu_bytes = u_affine_p.y.to_bytes_be();
171
172 let mut prepend = Vec::new();
173 prepend.extend_from_slice(&xu_bytes);
174 prepend.extend_from_slice(&yu_bytes);
175 prepend.extend_from_slice(&self.za);
176 prepend.extend_from_slice(&self.rhs_za);
177
178 let k_a = kdf(&prepend, self.klen);
179 self.k = Some(k_a);
180
181 let mut temp: Vec<u8> = Vec::new();
182 temp.extend_from_slice(&xu_bytes);
183 temp.extend_from_slice(&self.za);
184 temp.extend_from_slice(&self.rhs_za);
185 temp.extend_from_slice(&x1.to_bytes_be());
186 temp.extend_from_slice(&y1.to_bytes_be());
187 temp.extend_from_slice(&x2.to_bytes_be());
188 temp.extend_from_slice(&y2.to_bytes_be());
189 let temp_hash = sm3_hash(&temp);
190
191 let mut prepend: Vec<u8> = Vec::new();
192 prepend.write_u16::<BigEndian>(0x02_u16).unwrap();
193 prepend.extend_from_slice(&yu_bytes);
194 prepend.extend_from_slice(&temp_hash);
195
196 let s1 = sm3_hash(&prepend);
197 if s1 != sb {
198 return Err(Sm2Error::HashNotEqual);
199 }
200
201 let mut prepend: Vec<u8> = Vec::new();
202 prepend.write_u16::<BigEndian>(0x03_u16).unwrap();
203 prepend.extend_from_slice(&yu_bytes);
204 prepend.extend_from_slice(&temp_hash);
205 Ok(sm3_hash(&prepend))
206 }
207
208 pub fn exchange_4(&self, sa: [u8; 32], ra_point: &Point) -> Sm2Result<bool> {
210 let ra_point_affine = ra_point.to_affine();
211 let x1 = ra_point_affine.x;
212 let y1 = ra_point_affine.y;
213
214 let r2_point_affine = self.r_point.unwrap().to_affine();
215 let x2 = r2_point_affine.x;
216 let y2 = r2_point_affine.y;
217
218 let v_point_affine = self.v.unwrap().to_affine();
219 let xv = v_point_affine.x;
220 let yv = v_point_affine.y;
221
222 let mut temp: Vec<u8> = Vec::new();
223 temp.extend_from_slice(&xv.to_bytes_be());
224 temp.extend_from_slice(&self.rhs_za);
225 temp.extend_from_slice(&self.za);
226 temp.extend_from_slice(&x1.to_bytes_be());
227 temp.extend_from_slice(&y1.to_bytes_be());
228 temp.extend_from_slice(&x2.to_bytes_be());
229 temp.extend_from_slice(&y2.to_bytes_be());
230
231 let mut prepend: Vec<u8> = Vec::new();
232 prepend.write_u16::<BigEndian>(0x03_u16).unwrap();
233 prepend.extend_from_slice(&yv.to_bytes_be());
234 prepend.extend_from_slice(&sm3_hash(&temp));
235 let s_2 = sm3_hash(&prepend);
236 Ok(s_2 == sa)
237 }
238}