#[cfg(feature = "alloc")]
use alloc::vec::Vec;
use crypto_bigint::{Zero, U256};
use rand_core::RngCore;
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::error::Error;
use crate::sm2::ec::{AffinePoint, JacobianPoint};
use crate::sm2::field::{fn_add, fn_mul, fp_to_bytes, Fn, GROUP_ORDER_MINUS_1};
use crate::sm2::get_z;
use crate::sm3::Sm3Hasher;
fn x_bar(x_bytes: &[u8; 32]) -> U256 {
let mut buf = [0u8; 32];
buf[16..32].copy_from_slice(&x_bytes[16..32]);
buf[16] |= 0x80;
U256::from_be_slice(&buf)
}
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct EphemeralKey {
r_bytes: [u8; 32],
#[zeroize(skip)]
r_point: [u8; 65],
}
impl EphemeralKey {
pub fn generate<R: RngCore>(rng: &mut R) -> Self {
loop {
let mut r_bytes = [0u8; 32];
rng.fill_bytes(&mut r_bytes);
let r = U256::from_be_slice(&r_bytes);
if bool::from(r.is_zero()) || r >= GROUP_ORDER_MINUS_1 {
r_bytes.zeroize();
continue;
}
let r_jac = JacobianPoint::scalar_mul_g(&r);
let r_aff = r_jac.to_affine().expect("valid r produces valid point");
return EphemeralKey {
r_bytes,
r_point: r_aff.to_bytes(),
};
}
}
pub fn from_scalar(r: &U256) -> Result<Self, Error> {
if bool::from(r.is_zero()) || *r >= GROUP_ORDER_MINUS_1 {
return Err(Error::InvalidPrivateKey);
}
let r_jac = JacobianPoint::scalar_mul_g(r);
let r_aff = r_jac.to_affine().map_err(|_| Error::InvalidPrivateKey)?;
Ok(EphemeralKey {
r_bytes: r.to_be_bytes(),
r_point: r_aff.to_bytes(),
})
}
pub fn public_key(&self) -> &[u8; 65] {
&self.r_point
}
}
pub fn ecdh(my_priv: &crate::sm2::PrivateKey, peer_pub: &[u8; 65]) -> Result<[u8; 32], Error> {
let peer = AffinePoint::from_bytes(peer_pub)?;
let d = U256::from_be_slice(my_priv.as_bytes());
let peer_jac = JacobianPoint::from_affine(&peer);
let shared = JacobianPoint::scalar_mul(&d, &peer_jac);
let shared_aff = shared.to_affine()?;
Ok(fp_to_bytes(&shared_aff.x))
}
pub fn ecdh_from_slice(
my_priv: &crate::sm2::PrivateKey,
peer_pub: &[u8],
) -> Result<[u8; 32], Error> {
let pub_fixed: &[u8; 65] = peer_pub.try_into().map_err(|_| Error::InvalidInputLength)?;
ecdh(my_priv, pub_fixed)
}
#[cfg(feature = "alloc")]
pub struct ExchangeResult {
pub key: Vec<u8>,
pub s_self: [u8; 32],
pub s_peer: [u8; 32],
}
#[cfg(feature = "alloc")]
#[allow(clippy::too_many_arguments)]
pub fn exchange_a(
klen: usize,
id_a: &[u8],
id_b: &[u8],
pri_key_a: &crate::sm2::PrivateKey,
pub_key_a: &[u8; 65],
pub_key_b: &[u8; 65],
eph_key_a: &EphemeralKey,
r_b: &[u8; 65],
) -> Result<ExchangeResult, Error> {
compute_shared(
true, klen, id_a, id_b, pri_key_a, pub_key_a, pub_key_b, eph_key_a, r_b,
)
}
#[cfg(feature = "alloc")]
#[allow(clippy::too_many_arguments)]
pub fn exchange_b(
klen: usize,
id_a: &[u8],
id_b: &[u8],
pri_key_b: &crate::sm2::PrivateKey,
pub_key_a: &[u8; 65],
pub_key_b: &[u8; 65],
eph_key_b: &EphemeralKey,
r_a: &[u8; 65],
) -> Result<ExchangeResult, Error> {
compute_shared(
false, klen, id_a, id_b, pri_key_b, pub_key_a, pub_key_b, eph_key_b, r_a,
)
}
#[cfg(feature = "alloc")]
#[allow(clippy::too_many_arguments)]
fn compute_shared(
is_initiator: bool,
klen: usize,
id_a: &[u8],
id_b: &[u8],
pri_key_self: &crate::sm2::PrivateKey,
pub_key_a: &[u8; 65],
pub_key_b: &[u8; 65],
eph_key_self: &EphemeralKey,
r_peer: &[u8; 65],
) -> Result<ExchangeResult, Error> {
let z_a = get_z(id_a, pub_key_a);
let z_b = get_z(id_b, pub_key_b);
let r_self_aff = AffinePoint::from_bytes(eph_key_self.public_key())?;
let r_peer_aff = AffinePoint::from_bytes(r_peer)?;
let x_self_bytes = fp_to_bytes(&r_self_aff.x);
let x_peer_bytes = fp_to_bytes(&r_peer_aff.x);
let x_bar_self = x_bar(&x_self_bytes);
let x_bar_peer = x_bar(&x_peer_bytes);
let d_self = U256::from_be_slice(pri_key_self.as_bytes());
let r_self = U256::from_be_slice(&eph_key_self.r_bytes);
let t_fn = fn_add(
&Fn::new(&d_self),
&fn_mul(&Fn::new(&x_bar_self), &Fn::new(&r_self)),
);
let peer_pub_bytes = if is_initiator { pub_key_b } else { pub_key_a };
let peer_pub_aff = AffinePoint::from_bytes(peer_pub_bytes)?;
let peer_pub_jac = JacobianPoint::from_affine(&peer_pub_aff);
let r_peer_jac = JacobianPoint::from_affine(&r_peer_aff);
let x_bar_peer_r = JacobianPoint::scalar_mul(&x_bar_peer, &r_peer_jac);
let combined = JacobianPoint::add(&peer_pub_jac, &x_bar_peer_r);
let t = t_fn.retrieve();
let v_point = JacobianPoint::scalar_mul(&t, &combined);
let v_aff = v_point.to_affine().map_err(|_| Error::KeyExchangeFailed)?;
let xv = fp_to_bytes(&v_aff.x);
let yv = fp_to_bytes(&v_aff.y);
let mut kdf_input = Vec::with_capacity(32 + 32 + 32 + 32);
kdf_input.extend_from_slice(&xv);
kdf_input.extend_from_slice(&yv);
kdf_input.extend_from_slice(&z_a);
kdf_input.extend_from_slice(&z_b);
let key = crate::sm2::kdf::kdf(&kdf_input, klen);
if key.iter().all(|&b| b == 0) {
return Err(Error::KeyExchangeFailed);
}
let (x1, y1, x2, y2) = if is_initiator {
(
fp_to_bytes(&r_self_aff.x),
fp_to_bytes(&r_self_aff.y),
fp_to_bytes(&r_peer_aff.x),
fp_to_bytes(&r_peer_aff.y),
)
} else {
(
fp_to_bytes(&r_peer_aff.x),
fp_to_bytes(&r_peer_aff.y),
fp_to_bytes(&r_self_aff.x),
fp_to_bytes(&r_self_aff.y),
)
};
let mut h = Sm3Hasher::new();
h.update(&xv);
h.update(&z_a);
h.update(&z_b);
h.update(&x1);
h.update(&y1);
h.update(&x2);
h.update(&y2);
let hash_v = h.finalize();
let s1 = {
let mut h = Sm3Hasher::new();
h.update(&[0x02]);
h.update(&yv);
h.update(&hash_v);
h.finalize()
};
let sa = {
let mut h = Sm3Hasher::new();
h.update(&[0x03]);
h.update(&yv);
h.update(&hash_v);
h.finalize()
};
let (s_self, s_peer) = if is_initiator {
(sa, s1) } else {
(s1, sa) };
Ok(ExchangeResult {
key,
s_self,
s_peer,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sm2::PrivateKey;
#[allow(dead_code)]
struct FakeRng(#[allow(dead_code)] [u8; 32]);
impl RngCore for FakeRng {
fn next_u32(&mut self) -> u32 {
0
}
fn next_u64(&mut self) -> u64 {
0
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
for (i, b) in dest.iter_mut().enumerate() {
*b = self.0[i % 32];
}
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> {
self.fill_bytes(dest);
Ok(())
}
}
#[test]
fn test_x_bar() {
let mut x_bytes = [0u8; 32];
x_bytes[31] = 0x01;
let result = x_bar(&x_bytes);
let mut expected = [0u8; 32];
expected[16] = 0x80;
expected[31] = 0x01;
assert_eq!(result, U256::from_be_slice(&expected));
}
#[test]
fn test_x_bar_high_bits_cleared() {
let x_bytes = [0xFFu8; 32];
let result = x_bar(&x_bytes);
let mut expected = [0u8; 32];
expected[16..32].copy_from_slice(&[0xFF; 16]);
expected[16] |= 0x80; assert_eq!(result, U256::from_be_slice(&expected));
}
#[test]
fn test_ecdh_roundtrip() {
let d_a: [u8; 32] = [
0x39, 0x45, 0x20, 0x8f, 0x7b, 0x21, 0x44, 0xb1, 0x3f, 0x36, 0xe3, 0x8a, 0xc6, 0xd3,
0x9f, 0x95, 0x88, 0x93, 0x93, 0x69, 0x28, 0x60, 0xb5, 0x1a, 0x42, 0xfb, 0x81, 0xef,
0x4d, 0xf7, 0xc5, 0xb8,
];
let d_b: [u8; 32] = [
0x59, 0x27, 0x6e, 0x27, 0xd5, 0x06, 0x86, 0x1a, 0x16, 0x68, 0x0f, 0x3a, 0xd9, 0xc0,
0x2d, 0xcc, 0xef, 0x3c, 0xc1, 0xfa, 0x3c, 0xdb, 0xe4, 0xce, 0x6d, 0x54, 0xb8, 0x0d,
0xea, 0xc1, 0xbc, 0x21,
];
let pri_a = PrivateKey::from_bytes(&d_a).unwrap();
let pri_b = PrivateKey::from_bytes(&d_b).unwrap();
let pub_a = pri_a.public_key();
let pub_b = pri_b.public_key();
let shared_a = ecdh(&pri_a, &pub_b).unwrap();
let shared_b = ecdh(&pri_b, &pub_a).unwrap();
assert_eq!(shared_a, shared_b);
}
#[test]
fn test_ecdh_invalid_pubkey() {
let d_a: [u8; 32] = [
0x39, 0x45, 0x20, 0x8f, 0x7b, 0x21, 0x44, 0xb1, 0x3f, 0x36, 0xe3, 0x8a, 0xc6, 0xd3,
0x9f, 0x95, 0x88, 0x93, 0x93, 0x69, 0x28, 0x60, 0xb5, 0x1a, 0x42, 0xfb, 0x81, 0xef,
0x4d, 0xf7, 0xc5, 0xb8,
];
let pri_a = PrivateKey::from_bytes(&d_a).unwrap();
let mut bad_pub = [0x04u8; 65];
bad_pub[1] = 0x01;
assert!(ecdh(&pri_a, &bad_pub).is_err());
}
#[test]
fn test_ecdh_from_slice_length_check() {
let d_a: [u8; 32] = [
0x39, 0x45, 0x20, 0x8f, 0x7b, 0x21, 0x44, 0xb1, 0x3f, 0x36, 0xe3, 0x8a, 0xc6, 0xd3,
0x9f, 0x95, 0x88, 0x93, 0x93, 0x69, 0x28, 0x60, 0xb5, 0x1a, 0x42, 0xfb, 0x81, 0xef,
0x4d, 0xf7, 0xc5, 0xb8,
];
let pri_a = PrivateKey::from_bytes(&d_a).unwrap();
assert!(ecdh_from_slice(&pri_a, &[0x04u8; 64]).is_err());
assert!(ecdh_from_slice(&pri_a, &[0x04u8; 66]).is_err());
}
#[test]
fn test_ecdh_from_slice_equals_ecdh() {
let d_a: [u8; 32] = [
0x39, 0x45, 0x20, 0x8f, 0x7b, 0x21, 0x44, 0xb1, 0x3f, 0x36, 0xe3, 0x8a, 0xc6, 0xd3,
0x9f, 0x95, 0x88, 0x93, 0x93, 0x69, 0x28, 0x60, 0xb5, 0x1a, 0x42, 0xfb, 0x81, 0xef,
0x4d, 0xf7, 0xc5, 0xb8,
];
let d_b: [u8; 32] = [
0x59, 0x27, 0x6e, 0x27, 0xd5, 0x06, 0x86, 0x1a, 0x16, 0x68, 0x0f, 0x3a, 0xd9, 0xc0,
0x2d, 0xcc, 0xef, 0x3c, 0xc1, 0xfa, 0x3c, 0xdb, 0xe4, 0xce, 0x6d, 0x54, 0xb8, 0x0d,
0xea, 0xc1, 0xbc, 0x21,
];
let pri_a = PrivateKey::from_bytes(&d_a).unwrap();
let pri_b = PrivateKey::from_bytes(&d_b).unwrap();
let pub_b = pri_b.public_key();
let r1 = ecdh(&pri_a, &pub_b).unwrap();
let r2 = ecdh_from_slice(&pri_a, &pub_b).unwrap();
assert_eq!(r1, r2);
}
#[cfg(feature = "alloc")]
#[test]
fn test_exchange_roundtrip() {
let d_a: [u8; 32] = [
0x39, 0x45, 0x20, 0x8f, 0x7b, 0x21, 0x44, 0xb1, 0x3f, 0x36, 0xe3, 0x8a, 0xc6, 0xd3,
0x9f, 0x95, 0x88, 0x93, 0x93, 0x69, 0x28, 0x60, 0xb5, 0x1a, 0x42, 0xfb, 0x81, 0xef,
0x4d, 0xf7, 0xc5, 0xb8,
];
let d_b: [u8; 32] = [
0x59, 0x27, 0x6e, 0x27, 0xd5, 0x06, 0x86, 0x1a, 0x16, 0x68, 0x0f, 0x3a, 0xd9, 0xc0,
0x2d, 0xcc, 0xef, 0x3c, 0xc1, 0xfa, 0x3c, 0xdb, 0xe4, 0xce, 0x6d, 0x54, 0xb8, 0x0d,
0xea, 0xc1, 0xbc, 0x21,
];
let pri_a = PrivateKey::from_bytes(&d_a).unwrap();
let pri_b = PrivateKey::from_bytes(&d_b).unwrap();
let pub_a = pri_a.public_key();
let pub_b = pri_b.public_key();
let id_a = b"Alice@test.com";
let id_b = b"Bob@test.com";
let ra_scalar =
U256::from_be_hex("83A2C9C8B96E5AF70BD480B472409A9A327257F1EBB73F5B073354B248668563");
let rb_scalar =
U256::from_be_hex("33FE21940342161C55619C4A0C060293D543C80AF19748CE176D83477DE71C80");
let eph_a = EphemeralKey::from_scalar(&ra_scalar).unwrap();
let eph_b = EphemeralKey::from_scalar(&rb_scalar).unwrap();
let result_a = exchange_a(
16,
id_a,
id_b,
&pri_a,
&pub_a,
&pub_b,
&eph_a,
eph_b.public_key(),
)
.unwrap();
let result_b = exchange_b(
16,
id_a,
id_b,
&pri_b,
&pub_a,
&pub_b,
&eph_b,
eph_a.public_key(),
)
.unwrap();
assert_eq!(result_a.key, result_b.key);
assert!(!result_a.key.is_empty());
}
#[cfg(feature = "alloc")]
#[test]
fn test_exchange_confirmation() {
let d_a: [u8; 32] = [
0x39, 0x45, 0x20, 0x8f, 0x7b, 0x21, 0x44, 0xb1, 0x3f, 0x36, 0xe3, 0x8a, 0xc6, 0xd3,
0x9f, 0x95, 0x88, 0x93, 0x93, 0x69, 0x28, 0x60, 0xb5, 0x1a, 0x42, 0xfb, 0x81, 0xef,
0x4d, 0xf7, 0xc5, 0xb8,
];
let d_b: [u8; 32] = [
0x59, 0x27, 0x6e, 0x27, 0xd5, 0x06, 0x86, 0x1a, 0x16, 0x68, 0x0f, 0x3a, 0xd9, 0xc0,
0x2d, 0xcc, 0xef, 0x3c, 0xc1, 0xfa, 0x3c, 0xdb, 0xe4, 0xce, 0x6d, 0x54, 0xb8, 0x0d,
0xea, 0xc1, 0xbc, 0x21,
];
let pri_a = PrivateKey::from_bytes(&d_a).unwrap();
let pri_b = PrivateKey::from_bytes(&d_b).unwrap();
let pub_a = pri_a.public_key();
let pub_b = pri_b.public_key();
let id_a = b"1234567812345678";
let id_b = b"1234567812345678";
let ra_scalar =
U256::from_be_hex("83A2C9C8B96E5AF70BD480B472409A9A327257F1EBB73F5B073354B248668563");
let rb_scalar =
U256::from_be_hex("33FE21940342161C55619C4A0C060293D543C80AF19748CE176D83477DE71C80");
let eph_a = EphemeralKey::from_scalar(&ra_scalar).unwrap();
let eph_b = EphemeralKey::from_scalar(&rb_scalar).unwrap();
let result_a = exchange_a(
16,
id_a,
id_b,
&pri_a,
&pub_a,
&pub_b,
&eph_a,
eph_b.public_key(),
)
.unwrap();
let result_b = exchange_b(
16,
id_a,
id_b,
&pri_b,
&pub_a,
&pub_b,
&eph_b,
eph_a.public_key(),
)
.unwrap();
assert_eq!(result_a.s_peer, result_b.s_self);
assert_eq!(result_b.s_peer, result_a.s_self);
}
#[cfg(feature = "alloc")]
#[test]
fn test_exchange_different_ids() {
let d_a: [u8; 32] = [
0x39, 0x45, 0x20, 0x8f, 0x7b, 0x21, 0x44, 0xb1, 0x3f, 0x36, 0xe3, 0x8a, 0xc6, 0xd3,
0x9f, 0x95, 0x88, 0x93, 0x93, 0x69, 0x28, 0x60, 0xb5, 0x1a, 0x42, 0xfb, 0x81, 0xef,
0x4d, 0xf7, 0xc5, 0xb8,
];
let d_b: [u8; 32] = [
0x59, 0x27, 0x6e, 0x27, 0xd5, 0x06, 0x86, 0x1a, 0x16, 0x68, 0x0f, 0x3a, 0xd9, 0xc0,
0x2d, 0xcc, 0xef, 0x3c, 0xc1, 0xfa, 0x3c, 0xdb, 0xe4, 0xce, 0x6d, 0x54, 0xb8, 0x0d,
0xea, 0xc1, 0xbc, 0x21,
];
let pri_a = PrivateKey::from_bytes(&d_a).unwrap();
let pri_b = PrivateKey::from_bytes(&d_b).unwrap();
let pub_a = pri_a.public_key();
let pub_b = pri_b.public_key();
let ra_scalar =
U256::from_be_hex("83A2C9C8B96E5AF70BD480B472409A9A327257F1EBB73F5B073354B248668563");
let rb_scalar =
U256::from_be_hex("33FE21940342161C55619C4A0C060293D543C80AF19748CE176D83477DE71C80");
let eph_a1 = EphemeralKey::from_scalar(&ra_scalar).unwrap();
let eph_b1 = EphemeralKey::from_scalar(&rb_scalar).unwrap();
let result_1 = exchange_a(
16,
b"ID_A_1",
b"ID_B_1",
&pri_a,
&pub_a,
&pub_b,
&eph_a1,
eph_b1.public_key(),
)
.unwrap();
let eph_a2 = EphemeralKey::from_scalar(&ra_scalar).unwrap();
let eph_b2 = EphemeralKey::from_scalar(&rb_scalar).unwrap();
let result_2 = exchange_a(
16,
b"ID_A_2",
b"ID_B_2",
&pri_a,
&pub_a,
&pub_b,
&eph_a2,
eph_b2.public_key(),
)
.unwrap();
assert_ne!(result_1.key, result_2.key);
}
}