use crate::Error;
use crate::sm2::curve::Fn;
use crate::sm2::encrypt::{KDF_MAX_OUTPUT, kdf};
use crate::sm2::point::ProjectivePoint;
use crate::sm2::scalar_mul::{mul_g, mul_var};
use crate::sm2::sign::{MAX_ID_LEN, compute_z, sample_nonzero_scalar};
use crate::sm2::{Sm2PrivateKey, Sm2PublicKey};
use crate::sm3::Sm3;
use alloc::vec::Vec;
use crypto_bigint::U256;
use rand_core::TryCryptoRng;
use subtle::{Choice, ConstantTimeEq};
use zeroize::{Zeroize, ZeroizeOnDrop};
type Result<T> = core::result::Result<T, Error>;
#[derive(Clone)]
pub struct Sm2KxEphemeralPoint([u8; 65]);
impl Sm2KxEphemeralPoint {
#[must_use]
pub const fn to_bytes(&self) -> [u8; 65] {
self.0
}
#[must_use]
pub const fn from_bytes(b: &[u8; 65]) -> Self {
Self(*b)
}
}
#[derive(ZeroizeOnDrop)]
pub struct Sm2SharedKey(Vec<u8>);
impl Sm2SharedKey {
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
&self.0
}
}
#[derive(Clone)]
pub struct Sm2KxConfirm([u8; 32]);
impl Sm2KxConfirm {
#[must_use]
pub const fn to_bytes(&self) -> [u8; 32] {
self.0
}
#[must_use]
pub const fn from_bytes(b: &[u8; 32]) -> Self {
Self(*b)
}
}
fn validate_params(p_peer: &Sm2PublicKey, id_a: &[u8], id_b: &[u8], klen: usize) -> Result<()> {
let klen64 = u64::try_from(klen).map_err(|_| Error::Failed)?;
if klen == 0
|| klen64 > KDF_MAX_OUTPUT
|| id_a.len() > MAX_ID_LEN
|| id_b.len() > MAX_ID_LEN
|| bool::from(p_peer.point().is_identity())
{
return Err(Error::Failed);
}
Ok(())
}
pub struct Sm2KxInitiator {
d: Sm2PrivateKey,
p_peer: ProjectivePoint,
z_a: [u8; 32],
z_b: [u8; 32],
klen: usize,
}
impl Sm2KxInitiator {
pub fn new(
d_a: &Sm2PrivateKey,
p_b: &Sm2PublicKey,
id_a: &[u8],
id_b: &[u8],
klen: usize,
) -> Result<Self> {
validate_params(p_b, id_a, id_b, klen)?;
Ok(Self {
d: d_a.clone(),
p_peer: p_b.point(),
z_a: compute_z(&d_a.public_key(), id_a),
z_b: compute_z(p_b, id_b),
klen,
})
}
}
struct EphScalar(Fn);
impl Drop for EphScalar {
fn drop(&mut self) {
self.0.zeroize();
}
}
fn sample_ephemeral<R: TryCryptoRng>(rng: &mut R) -> Result<(Fn, [u8; 65])> {
let (r, sample_ok) = sample_nonzero_scalar(rng).ok_or(Error::Failed)?;
let r_point = mul_g(&r);
let (x, y) = r_point.to_affine().ok_or(Error::Failed)?;
let mut sec1 = [0u8; 65];
sec1[0] = 0x04;
sec1[1..33].copy_from_slice(&crate::u256_to_be32(&x.retrieve()));
sec1[33..65].copy_from_slice(&crate::u256_to_be32(&y.retrieve()));
if !bool::from(sample_ok) {
return Err(Error::Failed);
}
Ok((r, sec1))
}
pub struct Sm2KxInitiatorWaiting {
inner: Sm2KxInitiator,
r_eph: EphScalar,
r_point_bytes: [u8; 65],
}
impl Sm2KxInitiator {
pub fn produce_ephemeral<R: TryCryptoRng>(
self,
rng: &mut R,
) -> Result<(Sm2KxEphemeralPoint, Sm2KxInitiatorWaiting)> {
let (r, r_bytes) = sample_ephemeral(rng)?;
Ok((
Sm2KxEphemeralPoint(r_bytes),
Sm2KxInitiatorWaiting {
inner: self,
r_eph: EphScalar(r),
r_point_bytes: r_bytes,
},
))
}
}
pub struct Sm2KxResponder {
d: Sm2PrivateKey,
p_peer: ProjectivePoint,
z_a: [u8; 32],
z_b: [u8; 32],
klen: usize,
}
impl Sm2KxResponder {
pub fn new(
d_b: &Sm2PrivateKey,
p_a: &Sm2PublicKey,
id_a: &[u8],
id_b: &[u8],
klen: usize,
) -> Result<Self> {
validate_params(p_a, id_a, id_b, klen)?;
Ok(Self {
d: d_b.clone(),
p_peer: p_a.point(),
z_a: compute_z(p_a, id_a),
z_b: compute_z(&d_b.public_key(), id_b),
klen,
})
}
}
fn avf(x_be: &[u8; 32]) -> Fn {
let mut buf = [0u8; 32];
buf[17..32].copy_from_slice(&x_be[17..32]);
buf[16] = (x_be[16] & 0x7F) | 0x80;
Fn::new(&U256::from_be_slice(&buf))
}
fn split_xy(sec1: &[u8; 65]) -> ([u8; 32], [u8; 32]) {
let mut x = [0u8; 32];
let mut y = [0u8; 32];
x.copy_from_slice(&sec1[1..33]);
y.copy_from_slice(&sec1[33..65]);
(x, y)
}
#[allow(clippy::too_many_arguments)]
fn s_tag(
prefix: u8,
yu: &[u8; 32],
xu: &[u8; 32],
za: &[u8; 32],
zb: &[u8; 32],
x1: &[u8; 32],
y1: &[u8; 32],
x2: &[u8; 32],
y2: &[u8; 32],
) -> [u8; 32] {
let mut hi = Sm3::new();
for part in [xu, za, zb, x1, y1, x2, y2] {
hi.update(part);
}
let inner = hi.finalize();
let mut ho = Sm3::new();
ho.update(&[prefix]);
ho.update(yu);
ho.update(&inner);
ho.finalize()
}
#[allow(clippy::too_many_arguments)]
fn shared_secret(
d: &Fn,
r_eph: &Fn,
r_local_x: &[u8; 32],
peer_r: &[u8; 65],
p_peer: &ProjectivePoint,
z_a: &[u8; 32],
z_b: &[u8; 32],
klen: usize,
) -> Result<(Vec<u8>, [u8; 32], [u8; 32])> {
let peer_pub = Sm2PublicKey::from_sec1_bytes(peer_r).ok_or(Error::Failed)?;
let peer_point = peer_pub.point();
let mut peer_x = [0u8; 32];
peer_x.copy_from_slice(&peer_r[1..33]);
let x_bar_local = avf(r_local_x); let x_bar_peer = avf(&peer_x);
let mut xr = x_bar_local * *r_eph;
let mut t = *d + xr;
let sum = p_peer.add(&mul_var(&x_bar_peer, &peer_point));
let u = mul_var(&t, &sum);
t.zeroize();
xr.zeroize();
let (mut xu, mut yu) = u.to_affine().ok_or(Error::Failed)?; let xu_b = crate::u256_to_be32(&xu.retrieve());
let yu_b = crate::u256_to_be32(&yu.retrieve());
xu.zeroize();
yu.zeroize();
let mut kin = Vec::with_capacity(128);
kin.extend_from_slice(&xu_b);
kin.extend_from_slice(&yu_b);
kin.extend_from_slice(z_a);
kin.extend_from_slice(z_b);
let mut key = alloc::vec![0u8; klen];
kdf(&kin, &mut key);
kin.zeroize();
let mut allzero = Choice::from(1u8);
for b in &key {
allzero &= b.ct_eq(&0u8);
}
if bool::from(allzero) {
return Err(Error::Failed);
}
Ok((key, xu_b, yu_b))
}
impl Sm2KxInitiatorWaiting {
pub fn confirm(
self,
r_b: &Sm2KxEphemeralPoint,
s_b: &Sm2KxConfirm,
) -> Result<(Sm2SharedKey, Sm2KxConfirm)> {
let Self {
inner,
r_eph,
r_point_bytes,
} = self;
let mut local_x = [0u8; 32];
local_x.copy_from_slice(&r_point_bytes[1..33]);
let (key, mut xu_b, mut yu_b) = shared_secret(
inner.d.scalar(),
&r_eph.0,
&local_x,
&r_b.0,
&inner.p_peer,
&inner.z_a,
&inner.z_b,
inner.klen,
)?;
let key = Sm2SharedKey(key);
let (x1, y1) = split_xy(&r_point_bytes);
let (x2, y2) = split_xy(&r_b.0);
let expected_s_b = s_tag(
0x02, &yu_b, &xu_b, &inner.z_a, &inner.z_b, &x1, &y1, &x2, &y2,
);
let ok = expected_s_b[..].ct_eq(&s_b.0[..]);
if !bool::from(ok) {
xu_b.zeroize();
yu_b.zeroize();
return Err(Error::Failed);
}
let s_a = Sm2KxConfirm(s_tag(
0x03, &yu_b, &xu_b, &inner.z_a, &inner.z_b, &x1, &y1, &x2, &y2,
));
xu_b.zeroize();
yu_b.zeroize();
Ok((key, s_a))
}
}
pub struct Sm2KxResponderWaiting {
key: Sm2SharedKey,
expected_s_a: [u8; 32],
}
impl Sm2KxResponder {
pub fn respond<R: TryCryptoRng>(
self,
r_a: &Sm2KxEphemeralPoint,
rng: &mut R,
) -> Result<(Sm2KxEphemeralPoint, Sm2KxConfirm, Sm2KxResponderWaiting)> {
let (r, rb_bytes) = sample_ephemeral(rng)?;
let r_eph = EphScalar(r);
let mut local_x = [0u8; 32];
local_x.copy_from_slice(&rb_bytes[1..33]);
let (key, mut xu_b, mut yu_b) = shared_secret(
self.d.scalar(),
&r_eph.0,
&local_x,
&r_a.0,
&self.p_peer,
&self.z_a,
&self.z_b,
self.klen,
)?;
let (x1, y1) = split_xy(&r_a.0);
let (x2, y2) = split_xy(&rb_bytes);
let s_b = Sm2KxConfirm(s_tag(
0x02, &yu_b, &xu_b, &self.z_a, &self.z_b, &x1, &y1, &x2, &y2,
));
let expected_s_a = s_tag(0x03, &yu_b, &xu_b, &self.z_a, &self.z_b, &x1, &y1, &x2, &y2);
xu_b.zeroize();
yu_b.zeroize();
Ok((
Sm2KxEphemeralPoint(rb_bytes),
s_b,
Sm2KxResponderWaiting {
key: Sm2SharedKey(key),
expected_s_a,
},
))
}
}
impl Sm2KxResponderWaiting {
pub fn finish(self, s_a: &Sm2KxConfirm) -> Result<Sm2SharedKey> {
let ok = self.expected_s_a[..].ct_eq(&s_a.0[..]);
if !bool::from(ok) {
return Err(Error::Failed);
}
Ok(self.key)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crypto_bigint::U256;
pub(super) struct FixedRng(pub [u8; 32]);
impl rand_core::TryRng for FixedRng {
type Error = core::convert::Infallible;
fn try_next_u32(&mut self) -> core::result::Result<u32, Self::Error> {
Ok(0)
}
fn try_next_u64(&mut self) -> core::result::Result<u64, Self::Error> {
Ok(0)
}
fn try_fill_bytes(&mut self, dst: &mut [u8]) -> core::result::Result<(), Self::Error> {
assert_eq!(dst.len(), 32);
dst.copy_from_slice(&self.0);
Ok(())
}
}
impl rand_core::TryCryptoRng for FixedRng {}
#[test]
fn s_tag_prefixes_differ_and_deterministic() {
let z = [1u8; 32];
let xu = [2u8; 32];
let yu = [3u8; 32];
let r = [4u8; 32];
let sb = s_tag(0x02, &yu, &xu, &z, &z, &r, &r, &r, &r);
let sa = s_tag(0x03, &yu, &xu, &z, &z, &r, &r, &r, &r);
assert_ne!(sb, sa, "domain-separation prefix must change the tag");
assert_eq!(
sb,
s_tag(0x02, &yu, &xu, &z, &z, &r, &r, &r, &r),
"deterministic"
);
}
#[test]
fn confirm_rejects_tampered_s_b() {
use crate::sm2::Sm2PrivateKey;
let da = Sm2PrivateKey::from_bytes_be(&[5u8; 32]).unwrap();
let db = Sm2PrivateKey::from_bytes_be(&[6u8; 32]).unwrap();
let (pa, pb) = (da.public_key(), db.public_key());
let init = Sm2KxInitiator::new(&da, &pb, b"a", b"b", 16).unwrap();
let (ra, iw) = init.produce_ephemeral(&mut FixedRng([11u8; 32])).unwrap();
let resp = Sm2KxResponder::new(&db, &pa, b"a", b"b", 16).unwrap();
let (rb, sb, _rw) = resp.respond(&ra, &mut FixedRng([12u8; 32])).unwrap();
let mut bad = sb.to_bytes();
bad[0] ^= 1;
let sb_bad = Sm2KxConfirm::from_bytes(&bad);
assert!(iw.confirm(&rb, &sb_bad).is_err(), "tampered S_B accepted");
}
#[test]
fn finish_rejects_tampered_s_a() {
use crate::sm2::Sm2PrivateKey;
let da = Sm2PrivateKey::from_bytes_be(&[7u8; 32]).unwrap();
let db = Sm2PrivateKey::from_bytes_be(&[8u8; 32]).unwrap();
let (pa, pb) = (da.public_key(), db.public_key());
let init = Sm2KxInitiator::new(&da, &pb, b"a", b"b", 16).unwrap();
let (ra, iw) = init.produce_ephemeral(&mut FixedRng([13u8; 32])).unwrap();
let resp = Sm2KxResponder::new(&db, &pa, b"a", b"b", 16).unwrap();
let (rb, sb, rw) = resp.respond(&ra, &mut FixedRng([14u8; 32])).unwrap();
let (_k_a, sa) = iw.confirm(&rb, &sb).unwrap();
let mut bad = sa.to_bytes();
bad[31] ^= 0x80;
let sa_bad = Sm2KxConfirm::from_bytes(&bad);
assert!(rw.finish(&sa_bad).is_err(), "tampered S_A accepted");
}
#[test]
fn round_trip_shared_key_matches() {
use crate::sm2::Sm2PrivateKey;
let da = Sm2PrivateKey::from_bytes_be(&[3u8; 32]).unwrap();
let db = Sm2PrivateKey::from_bytes_be(&[4u8; 32]).unwrap();
let (pa, pb) = (da.public_key(), db.public_key());
let mut rng_a = FixedRng([9u8; 32]);
let mut rng_b = FixedRng([10u8; 32]);
let init = Sm2KxInitiator::new(&da, &pb, b"a", b"b", 32).unwrap();
let (ra, init_w) = init.produce_ephemeral(&mut rng_a).unwrap();
let resp = Sm2KxResponder::new(&db, &pa, b"a", b"b", 32).unwrap();
let (rb, sb, resp_w) = resp.respond(&ra, &mut rng_b).unwrap();
let (k_a, sa) = init_w.confirm(&rb, &sb).unwrap();
let k_b = resp_w.finish(&sa).unwrap();
assert_eq!(k_a.as_bytes(), k_b.as_bytes());
assert_eq!(k_a.as_bytes().len(), 32);
assert!(k_a.as_bytes().iter().any(|&b| b != 0));
}
#[test]
fn produce_ephemeral_yields_on_curve_point() {
use crate::sm2::Sm2PrivateKey;
let d = Sm2PrivateKey::from_bytes_be(&[2u8; 32]).unwrap();
let p = d.public_key();
let init = Sm2KxInitiator::new(&d, &p, b"a", b"b", 16).unwrap();
let mut rng = FixedRng([7u8; 32]);
let (r_a, _waiting) = init.produce_ephemeral(&mut rng).unwrap();
assert!(Sm2PublicKey::from_sec1_bytes(&r_a.to_bytes()).is_some());
}
#[test]
fn avf_sets_bit_127_and_masks_low_127() {
let x = [0xFFu8; 32];
let got = avf(&x).retrieve();
let expect =
U256::from_be_hex("00000000000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF");
assert_eq!(got, expect);
}
#[test]
fn initiator_new_rejects_overlong_id() {
use crate::sm2::{Sm2PrivateKey, sign::MAX_ID_LEN};
let d = Sm2PrivateKey::from_bytes_be(&[1u8; 32]).unwrap();
let p = d.public_key();
let too_long = alloc::vec![0u8; MAX_ID_LEN + 1];
assert!(Sm2KxInitiator::new(&d, &p, &too_long, b"b", 16).is_err());
assert!(Sm2KxInitiator::new(&d, &p, b"a", &too_long, 16).is_err());
assert!(Sm2KxInitiator::new(&d, &p, b"a", b"b", 16).is_ok());
assert!(Sm2KxResponder::new(&d, &p, &too_long, b"b", 16).is_err());
assert!(Sm2KxResponder::new(&d, &p, b"a", b"b", 16).is_ok());
}
#[test]
fn new_rejects_bad_klen() {
use crate::sm2::Sm2PrivateKey;
let d = Sm2PrivateKey::from_bytes_be(&[1u8; 32]).unwrap();
let p = d.public_key();
assert!(Sm2KxInitiator::new(&d, &p, b"a", b"b", 0).is_err());
assert!(Sm2KxResponder::new(&d, &p, b"a", b"b", 0).is_err());
let over = usize::try_from(32u64 * ((1u64 << 32) - 1) + 1).unwrap();
assert!(Sm2KxInitiator::new(&d, &p, b"a", b"b", over).is_err());
}
#[test]
fn new_rejects_identity_peer_pubkey() {
use crate::sm2::point::ProjectivePoint;
use crate::sm2::{Sm2PrivateKey, Sm2PublicKey};
let d = Sm2PrivateKey::from_bytes_be(&[1u8; 32]).unwrap();
let identity = Sm2PublicKey::from_point(ProjectivePoint::identity());
assert!(Sm2KxInitiator::new(&d, &identity, b"a", b"b", 16).is_err());
assert!(Sm2KxResponder::new(&d, &identity, b"a", b"b", 16).is_err());
}
#[test]
fn avf_zero_input_yields_exactly_bit_127() {
let x = [0u8; 32];
let got = avf(&x).retrieve();
let expect =
U256::from_be_hex("0000000000000000000000000000000080000000000000000000000000000000");
assert_eq!(got, expect);
}
}