use crate::asn1::ciphertext::Sm2Ciphertext;
use crate::sm2::curve::Fp;
use crate::sm2::encrypt::point_on_curve;
use alloc::vec::Vec;
use crypto_bigint::U256;
use subtle::ConstantTimeLess;
pub const C1_LEN: usize = 65;
pub const C3_LEN: usize = 32;
const SEC1_UNCOMPRESSED: u8 = 0x04;
#[must_use]
pub fn encode_c1c3c2(ct: &Sm2Ciphertext) -> Vec<u8> {
let mut out = Vec::with_capacity(C1_LEN + C3_LEN + ct.ciphertext.len());
out.push(SEC1_UNCOMPRESSED);
out.extend_from_slice(&ct.x.to_be_bytes());
out.extend_from_slice(&ct.y.to_be_bytes());
out.extend_from_slice(&ct.hash);
out.extend_from_slice(&ct.ciphertext);
out
}
#[must_use]
pub fn decode_c1c3c2(input: &[u8]) -> Option<Sm2Ciphertext> {
let (x, y, c3, c2) = split_c1_c3_c2(input)?;
Some(Sm2Ciphertext {
x,
y,
hash: c3,
ciphertext: c2.to_vec(),
})
}
#[must_use]
pub fn decode_c1c2c3_legacy(input: &[u8]) -> Option<Sm2Ciphertext> {
if input.len() < C1_LEN + C3_LEN {
return None;
}
if input[0] != SEC1_UNCOMPRESSED {
return None;
}
let x = read_field_element(&input[1..33])?;
let y = read_field_element(&input[33..65])?;
let x_fp = Fp::new(&x);
let y_fp = Fp::new(&y);
if !point_on_curve(&x_fp, &y_fp) {
return None;
}
let c2_len = input.len() - C1_LEN - C3_LEN;
let c2 = &input[C1_LEN..C1_LEN + c2_len];
let mut c3 = [0u8; C3_LEN];
c3.copy_from_slice(&input[C1_LEN + c2_len..]);
Some(Sm2Ciphertext {
x,
y,
hash: c3,
ciphertext: c2.to_vec(),
})
}
fn split_c1_c3_c2(input: &[u8]) -> Option<(U256, U256, [u8; C3_LEN], &[u8])> {
if input.len() < C1_LEN + C3_LEN {
return None;
}
if input[0] != SEC1_UNCOMPRESSED {
return None;
}
let x = read_field_element(&input[1..33])?;
let y = read_field_element(&input[33..65])?;
let x_fp = Fp::new(&x);
let y_fp = Fp::new(&y);
if !point_on_curve(&x_fp, &y_fp) {
return None;
}
let mut c3 = [0u8; C3_LEN];
c3.copy_from_slice(&input[C1_LEN..C1_LEN + C3_LEN]);
let c2 = &input[C1_LEN + C3_LEN..];
Some((x, y, c3, c2))
}
fn read_field_element(bytes: &[u8]) -> Option<U256> {
if bytes.len() != 32 {
return None;
}
let v = U256::from_be_slice(bytes);
let p = *Fp::MODULUS.as_ref();
if !bool::from(v.ct_lt(&p)) {
return None;
}
Some(v)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sm2::ProjectivePoint;
use crate::sm2::{Sm2PrivateKey, Sm2PublicKey, decrypt, encrypt};
use crypto_bigint::U256;
use rand_core::UnwrapErr;
fn sample_ct(c2: &[u8]) -> Sm2Ciphertext {
let g = ProjectivePoint::generator();
let (x, y) = g.to_affine().expect("G finite");
Sm2Ciphertext {
x: x.retrieve(),
y: y.retrieve(),
hash: [0xA5; C3_LEN],
ciphertext: c2.to_vec(),
}
}
#[test]
fn modern_round_trip_boundary_lengths() {
for len in [0usize, 1, 16, 32, 100, 1024] {
#[allow(clippy::cast_possible_truncation)]
let c2: Vec<u8> = (0..len).map(|i| (i as u8).wrapping_mul(7)).collect();
let ct = sample_ct(&c2);
let bytes = encode_c1c3c2(&ct);
assert_eq!(bytes.len(), C1_LEN + C3_LEN + len);
assert_eq!(bytes[0], 0x04);
let recovered = decode_c1c3c2(&bytes).expect("decode");
assert_eq!(recovered.x, ct.x);
assert_eq!(recovered.y, ct.y);
assert_eq!(recovered.hash, ct.hash);
assert_eq!(recovered.ciphertext, ct.ciphertext);
}
}
#[test]
fn decode_rejects_too_short() {
assert!(decode_c1c3c2(&[]).is_none());
assert!(decode_c1c3c2(&[0x04; 32]).is_none());
assert!(decode_c1c3c2(&[0x04; 65]).is_none()); assert!(decode_c1c3c2(&[0x04; 96]).is_none()); }
#[test]
fn decode_rejects_wrong_tag() {
let mut bytes = encode_c1c3c2(&sample_ct(b"hi"));
bytes[0] = 0x02;
assert!(decode_c1c3c2(&bytes).is_none());
bytes[0] = 0x03;
assert!(decode_c1c3c2(&bytes).is_none());
bytes[0] = 0x00;
assert!(decode_c1c3c2(&bytes).is_none());
}
#[test]
fn decode_rejects_off_curve() {
let mut bytes = encode_c1c3c2(&sample_ct(b"hi"));
bytes[5] ^= 0x01;
assert!(decode_c1c3c2(&bytes).is_none());
}
#[test]
fn decode_rejects_x_at_p() {
let mut bytes = encode_c1c3c2(&sample_ct(b"hi"));
let p = *Fp::MODULUS.as_ref();
bytes[1..33].copy_from_slice(&p.to_be_bytes());
assert!(decode_c1c3c2(&bytes).is_none());
}
#[test]
fn modern_empty_c2() {
let ct = sample_ct(&[]);
let bytes = encode_c1c3c2(&ct);
assert_eq!(bytes.len(), C1_LEN + C3_LEN);
let recovered = decode_c1c3c2(&bytes).expect("decode empty");
assert_eq!(recovered.ciphertext.len(), 0);
}
#[test]
fn legacy_decode_swaps_c2_c3_position() {
let ct = sample_ct(b"legacy-format-test");
let modern = encode_c1c3c2(&ct);
let mut legacy = Vec::with_capacity(modern.len());
legacy.extend_from_slice(&modern[..C1_LEN]);
legacy.extend_from_slice(&modern[C1_LEN + C3_LEN..]); legacy.extend_from_slice(&modern[C1_LEN..C1_LEN + C3_LEN]); let recovered = decode_c1c2c3_legacy(&legacy).expect("legacy decode");
assert_eq!(recovered.x, ct.x);
assert_eq!(recovered.y, ct.y);
assert_eq!(recovered.hash, ct.hash);
assert_eq!(recovered.ciphertext, ct.ciphertext);
}
#[test]
fn legacy_decode_rejects_off_curve() {
let ct = sample_ct(b"x");
let modern = encode_c1c3c2(&ct);
let mut legacy = Vec::with_capacity(modern.len());
legacy.extend_from_slice(&modern[..C1_LEN]);
legacy.extend_from_slice(&modern[C1_LEN + C3_LEN..]);
legacy.extend_from_slice(&modern[C1_LEN..C1_LEN + C3_LEN]);
legacy[5] ^= 0x01;
assert!(decode_c1c2c3_legacy(&legacy).is_none());
}
#[test]
fn legacy_decode_rejects_too_short() {
assert!(decode_c1c2c3_legacy(&[0x04; 64]).is_none());
assert!(decode_c1c2c3_legacy(&[]).is_none());
}
#[test]
fn modern_raw_round_trips_via_full_decrypt() {
let d =
U256::from_be_hex("3945208F7B2144B13F36E38AC6D39F95889393692860B51A42FB81EF4DF7C5B8");
let key = Sm2PrivateKey::from_scalar_inner(d).expect("valid d");
let pk = Sm2PublicKey::from_point(key.public_key());
let mut rng = UnwrapErr(getrandom::SysRng);
let plaintext = b"raw-ciphertext modern roundtrip";
let der = encrypt(&pk, plaintext, &mut rng).expect("encrypt");
let ct = crate::asn1::ciphertext::decode(&der).expect("DER decode");
let raw = encode_c1c3c2(&ct);
let ct2 = decode_c1c3c2(&raw).expect("raw decode");
let der2 = crate::asn1::ciphertext::encode(&ct2);
let recovered = decrypt(&key, &der2).expect("decrypt round-trip");
assert_eq!(recovered, plaintext);
}
}