use crate::asn1::ciphertext::decode;
use crate::sm2::curve::Fp;
use crate::sm2::encrypt::{kdf, point_on_curve, projective_from_affine};
use crate::sm2::private_key::Sm2PrivateKey;
use crate::sm2::scalar_mul::mul_var;
use crate::sm3::Sm3;
use alloc::vec::Vec;
use subtle::{Choice, ConstantTimeEq};
use zeroize::Zeroize;
pub fn decrypt(private: &Sm2PrivateKey, ciphertext_der: &[u8]) -> Result<Vec<u8>, crate::Error> {
let parsed = decode(ciphertext_der).ok_or(crate::Error::Failed)?;
let x1 = Fp::new(&parsed.x);
let y1 = Fp::new(&parsed.y);
if !point_on_curve(&x1, &y1) {
return Err(crate::Error::Failed);
}
let c1 = projective_from_affine(x1, y1);
if bool::from(c1.is_identity()) {
return Err(crate::Error::Failed);
}
let kp = mul_var(private.scalar(), &c1);
let (x2, y2) = kp.to_affine().ok_or(crate::Error::Failed)?;
let mut z = [0u8; 64];
z[..32].copy_from_slice(&x2.retrieve().to_be_bytes());
z[32..].copy_from_slice(&y2.retrieve().to_be_bytes());
let mut t = alloc::vec![0u8; parsed.ciphertext.len()];
kdf(&z, &mut t);
let nonempty: Choice = u8::from(!parsed.ciphertext.is_empty()).into();
let kdf_zero = nonempty & ct_all_zero(&t);
for (i, byte) in parsed.ciphertext.iter().enumerate() {
t[i] ^= byte;
}
let mut plaintext = t;
let mut h = Sm3::new();
h.update(&z[..32]);
h.update(&plaintext);
h.update(&z[32..]);
let u = h.finalize();
let mac_ok = u.ct_eq(&parsed.hash);
let valid = mac_ok & !kdf_zero;
z.zeroize();
if !bool::from(valid) {
plaintext.zeroize();
return Err(crate::Error::Failed);
}
Ok(plaintext)
}
fn ct_all_zero(buf: &[u8]) -> Choice {
let mut acc: u8 = 0;
for b in buf {
acc |= b;
}
acc.ct_eq(&0u8)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::asn1::ciphertext::{Sm2Ciphertext, encode};
use crate::sm2::encrypt::encrypt;
use crate::sm2::private_key::Sm2PrivateKey;
use crate::sm2::public_key::Sm2PublicKey;
use crypto_bigint::U256;
use getrandom::SysRng;
use rand_core::UnwrapErr;
#[test]
fn round_trip_random_nonce() {
let d =
U256::from_be_hex("1649AB77A00637BD5E2EFE283FBF353534AA7F7CB89463F208DDBC2920BB0DA0");
let key = Sm2PrivateKey::from_scalar_inner(d).expect("valid d");
let pk = Sm2PublicKey::from_point(key.public_key());
let plaintext = b"encryption standard";
let mut rng = UnwrapErr(SysRng);
let der = encrypt(&pk, plaintext, &mut rng).expect("encrypt");
let recovered = decrypt(&key, &der).expect("decrypt");
assert_eq!(recovered.as_slice(), plaintext);
}
#[test]
fn round_trip_boundary_lengths() {
let d =
U256::from_be_hex("1649AB77A00637BD5E2EFE283FBF353534AA7F7CB89463F208DDBC2920BB0DA0");
let key = Sm2PrivateKey::from_scalar_inner(d).expect("valid d");
let pk = Sm2PublicKey::from_point(key.public_key());
let mut rng = UnwrapErr(SysRng);
for len in [0usize, 1, 31, 32, 33, 64, 65, 128] {
let plaintext: Vec<u8> = (0..len)
.map(|i| {
#[allow(clippy::cast_possible_truncation)]
{
(i as u8).wrapping_mul(7)
}
})
.collect();
let der = encrypt(&pk, &plaintext, &mut rng).expect("encrypt");
let recovered = decrypt(&key, &der).expect("decrypt");
assert_eq!(recovered, plaintext, "round-trip mismatch at len={len}");
}
}
#[test]
fn rejects_malformed_der() {
let d =
U256::from_be_hex("1649AB77A00637BD5E2EFE283FBF353534AA7F7CB89463F208DDBC2920BB0DA0");
let key = Sm2PrivateKey::from_scalar_inner(d).expect("valid d");
assert_eq!(decrypt(&key, &[]), Err(crate::Error::Failed));
assert_eq!(decrypt(&key, b"not DER"), Err(crate::Error::Failed));
assert_eq!(
decrypt(&key, &[0x30, 0x05, 0xff, 0xff, 0xff]),
Err(crate::Error::Failed)
);
}
#[test]
fn rejects_off_curve_c1() {
let d =
U256::from_be_hex("1649AB77A00637BD5E2EFE283FBF353534AA7F7CB89463F208DDBC2920BB0DA0");
let key = Sm2PrivateKey::from_scalar_inner(d).expect("valid d");
let off_curve = Sm2Ciphertext {
x: U256::from_u64(1),
y: U256::from_u64(1), hash: [0u8; 32],
ciphertext: alloc::vec![0u8; 16],
};
let der = encode(&off_curve);
assert_eq!(decrypt(&key, &der), Err(crate::Error::Failed));
}
#[test]
fn rejects_mac_mismatch() {
let d =
U256::from_be_hex("1649AB77A00637BD5E2EFE283FBF353534AA7F7CB89463F208DDBC2920BB0DA0");
let key = Sm2PrivateKey::from_scalar_inner(d).expect("valid d");
let pk = Sm2PublicKey::from_point(key.public_key());
let mut rng = UnwrapErr(SysRng);
let der = encrypt(&pk, b"encryption standard", &mut rng).expect("encrypt");
let mut parsed = decode(&der).expect("decode our own DER");
parsed.hash[0] ^= 0x01;
let tampered = encode(&parsed);
assert_eq!(decrypt(&key, &tampered), Err(crate::Error::Failed));
}
#[test]
fn rejects_wrong_private_key() {
let d_a =
U256::from_be_hex("1649AB77A00637BD5E2EFE283FBF353534AA7F7CB89463F208DDBC2920BB0DA0");
let d_b =
U256::from_be_hex("3945208F7B2144B13F36E38AC6D39F95889393692860B51A42FB81EF4DF7C5B8");
let key_a = Sm2PrivateKey::from_scalar_inner(d_a).expect("valid d_a");
let key_b = Sm2PrivateKey::from_scalar_inner(d_b).expect("valid d_b");
let pk_a = Sm2PublicKey::from_point(key_a.public_key());
let mut rng = UnwrapErr(SysRng);
let der = encrypt(&pk_a, b"top secret", &mut rng).expect("encrypt to A");
assert_eq!(decrypt(&key_b, &der), Err(crate::Error::Failed));
}
#[test]
fn rejects_tampered_c2() {
let d =
U256::from_be_hex("1649AB77A00637BD5E2EFE283FBF353534AA7F7CB89463F208DDBC2920BB0DA0");
let key = Sm2PrivateKey::from_scalar_inner(d).expect("valid d");
let pk = Sm2PublicKey::from_point(key.public_key());
let mut rng = UnwrapErr(SysRng);
let der = encrypt(&pk, b"some plaintext data", &mut rng).expect("encrypt");
let mut parsed = decode(&der).expect("decode our own DER");
parsed.ciphertext[0] ^= 0xff;
let tampered = encode(&parsed);
assert_eq!(decrypt(&key, &tampered), Err(crate::Error::Failed));
}
#[test]
fn rejects_forged_short_ciphertext() {
let d =
U256::from_be_hex("1649AB77A00637BD5E2EFE283FBF353534AA7F7CB89463F208DDBC2920BB0DA0");
let key = Sm2PrivateKey::from_scalar_inner(d).expect("valid d");
let pk = Sm2PublicKey::from_point(key.public_key());
let mut rng = UnwrapErr(SysRng);
for round in 0..32u8 {
let plaintext = [round];
let der = encrypt(&pk, &plaintext, &mut rng).expect("encrypt 1-byte");
let mut parsed = decode(&der).expect("decode our own DER");
parsed.hash[0] ^= 0x01;
let tampered = encode(&parsed);
assert_eq!(
decrypt(&key, &tampered),
Err(crate::Error::Failed),
"forged 1-byte ciphertext on round {round} must fail"
);
}
}
#[test]
fn round_trip_empty_plaintext() {
let d =
U256::from_be_hex("1649AB77A00637BD5E2EFE283FBF353534AA7F7CB89463F208DDBC2920BB0DA0");
let key = Sm2PrivateKey::from_scalar_inner(d).expect("valid d");
let pk = Sm2PublicKey::from_point(key.public_key());
let mut rng = UnwrapErr(SysRng);
let der = encrypt(&pk, b"", &mut rng).expect("encrypt empty");
let recovered = decrypt(&key, &der).expect("decrypt empty");
assert!(recovered.is_empty(), "empty plaintext round-trip");
}
}