use crate::asn1::ciphertext::{Sm2Ciphertext, encode};
use crate::sm2::curve::{Fn, Fp, b};
use crate::sm2::point::ProjectivePoint;
use crate::sm2::public_key::Sm2PublicKey;
use crate::sm2::scalar_mul::{mul_g, mul_var};
use crate::sm2::sign::sample_nonzero_scalar;
use crate::sm3::{DIGEST_SIZE, Sm3};
use alloc::vec::Vec;
use crypto_bigint::U256;
use rand_core::{CryptoRng, Rng};
use subtle::ConstantTimeEq;
use zeroize::Zeroize;
const ENCRYPT_RETRY_BUDGET: usize = 64;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EncryptError {
Failed,
}
pub fn encrypt<R: CryptoRng + Rng>(
public: &Sm2PublicKey,
plaintext: &[u8],
rng: &mut R,
) -> Result<Vec<u8>, EncryptError> {
if bool::from(public.point().is_identity()) {
return Err(EncryptError::Failed);
}
for _ in 0..ENCRYPT_RETRY_BUDGET {
let k = sample_nonzero_scalar(rng);
if let Some(ct) = try_encrypt_once(public, plaintext, &k) {
return Ok(encode(&ct));
}
}
Err(EncryptError::Failed)
}
fn try_encrypt_once(public: &Sm2PublicKey, plaintext: &[u8], k: &Fn) -> Option<Sm2Ciphertext> {
let c1 = mul_g(k);
let (x1, y1) = c1.to_affine()?;
let kp = mul_var(k, &public.point());
let (x2, y2) = kp.to_affine()?;
let mut z = [0u8; 64];
z[..32].copy_from_slice(&x2.retrieve().to_be_bytes());
z[32..].copy_from_slice(&y2.retrieve().to_be_bytes());
let mut t = alloc::vec![0u8; plaintext.len()];
kdf(&z, &mut t);
if !plaintext.is_empty() && all_zero_ct(&t) {
z.zeroize();
t.zeroize();
return None;
}
for (i, byte) in plaintext.iter().enumerate() {
t[i] ^= byte;
}
let c2 = t;
let mut h = Sm3::new();
h.update(&z[..32]);
h.update(plaintext);
h.update(&z[32..]);
let c3 = h.finalize();
z.zeroize();
Some(Sm2Ciphertext {
x: x1.retrieve(),
y: y1.retrieve(),
hash: c3,
ciphertext: c2,
})
}
pub(super) fn kdf(z: &[u8], output: &mut [u8]) {
let mut counter: u32 = 1;
let mut written = 0;
while written < output.len() {
let mut h = Sm3::new();
h.update(z);
h.update(&counter.to_be_bytes());
let digest = h.finalize();
let block_remaining = output.len() - written;
let copy_len = block_remaining.min(DIGEST_SIZE);
output[written..written + copy_len].copy_from_slice(&digest[..copy_len]);
written += copy_len;
counter += 1;
}
}
fn all_zero_ct(buf: &[u8]) -> bool {
let mut acc: u8 = 0;
for b in buf {
acc |= b;
}
bool::from(acc.ct_eq(&0u8))
}
pub(crate) fn point_on_curve(x: &Fp, y: &Fp) -> bool {
let three = Fp::new(&U256::from_u64(3));
let lhs = *y * *y;
let rhs = (*x) * (*x) * (*x) - three * (*x) + b();
bool::from(lhs.retrieve().ct_eq(&rhs.retrieve()))
}
pub(crate) const fn projective_from_affine(x: Fp, y: Fp) -> ProjectivePoint {
ProjectivePoint {
x,
y,
z: Fp::new(&U256::ONE),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sm2::private_key::Sm2PrivateKey;
use core::convert::Infallible;
use rand_core::{TryCryptoRng, TryRng};
struct FixedScalarRng {
bytes: [u8; 32],
}
impl FixedScalarRng {
const fn new(bytes: [u8; 32]) -> Self {
Self { bytes }
}
}
impl TryRng for FixedScalarRng {
type Error = Infallible;
fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
Ok(0)
}
fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
Ok(0)
}
fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Self::Error> {
assert_eq!(dst.len(), 32);
dst.copy_from_slice(&self.bytes);
Ok(())
}
}
impl TryCryptoRng for FixedScalarRng {}
fn synthetic_z() -> [u8; 64] {
let mut z = [0u8; 64];
for (i, b) in z.iter_mut().enumerate() {
#[allow(clippy::cast_possible_truncation)]
{
*b = (i as u8).wrapping_mul(7);
}
}
z
}
#[test]
fn kdf_single_block_matches_manual_sm3() {
let z = synthetic_z();
let mut out = [0u8; 32];
kdf(&z, &mut out);
let mut h = Sm3::new();
h.update(&z);
h.update(&1u32.to_be_bytes());
let expected = h.finalize();
assert_eq!(out, expected);
}
#[test]
fn kdf_two_block_matches_manual_sm3() {
let z = synthetic_z();
let mut out = [0u8; 40];
kdf(&z, &mut out);
let mut h1 = Sm3::new();
h1.update(&z);
h1.update(&1u32.to_be_bytes());
let block1 = h1.finalize();
let mut h2 = Sm3::new();
h2.update(&z);
h2.update(&2u32.to_be_bytes());
let block2 = h2.finalize();
assert_eq!(&out[..32], &block1);
assert_eq!(&out[32..40], &block2[..8]);
}
#[test]
fn kdf_empty_output_is_noop() {
let z = b"whatever";
let mut out: [u8; 0] = [];
kdf(z, &mut out);
}
#[test]
fn point_on_curve_accepts_generator() {
let g = ProjectivePoint::generator();
let (gx, gy) = g.to_affine().expect("G is finite");
assert!(point_on_curve(&gx, &gy));
}
#[test]
fn point_on_curve_rejects_off_curve() {
let x = Fp::new(&U256::ONE);
let y = Fp::new(&U256::ONE);
assert!(!point_on_curve(&x, &y));
}
#[test]
fn encrypt_rejects_identity_pubkey() {
let pk = Sm2PublicKey::from_point(ProjectivePoint::identity());
let mut rng = rand_core::UnwrapErr(getrandom::SysRng);
assert_eq!(
encrypt(&pk, b"any plaintext", &mut rng),
Err(EncryptError::Failed)
);
}
#[test]
fn encrypt_with_fixed_k_is_deterministic() {
let d =
U256::from_be_hex("1649AB77A00637BD5E2EFE283FBF353534AA7F7CB89463F208DDBC2920BB0DA0");
let key = Sm2PrivateKey::new(d).expect("valid d");
let pk = Sm2PublicKey::from_point(key.public_key());
let k_bytes =
U256::from_be_hex("4C62EEFD6ECFC2B95B92FD6C3D9575148AFA17425546D49018E5388D49DD7B4F")
.to_be_bytes();
let mut bytes = [0u8; 32];
bytes.copy_from_slice(&k_bytes);
let mut rng_a = rand_core::UnwrapErr(FixedScalarRng::new(bytes));
let mut rng_b = rand_core::UnwrapErr(FixedScalarRng::new(bytes));
let der_a = encrypt(&pk, b"encryption standard", &mut rng_a).expect("encrypt a");
let der_b = encrypt(&pk, b"encryption standard", &mut rng_b).expect("encrypt b");
assert_eq!(der_a, der_b, "fixed-k encrypt must be deterministic");
}
}