use alloc::vec::Vec;
use crypto_bigint::U256;
use subtle::ConstantTimeLess;
use crate::sm2::curve::Fp;
use super::{reader, writer};
const HASH_LEN: usize = 32;
#[derive(Clone, Debug)]
pub struct Sm2Ciphertext {
pub x: U256,
pub y: U256,
pub hash: [u8; HASH_LEN],
pub ciphertext: Vec<u8>,
}
#[must_use]
pub fn encode(ct: &Sm2Ciphertext) -> Vec<u8> {
let x_be = ct.x.to_be_bytes();
let y_be = ct.y.to_be_bytes();
let mut body = Vec::with_capacity(ct.ciphertext.len() + 80);
writer::write_integer(&mut body, &x_be);
writer::write_integer(&mut body, &y_be);
writer::write_octet_string(&mut body, &ct.hash);
writer::write_octet_string(&mut body, &ct.ciphertext);
let mut out = Vec::with_capacity(body.len() + 4);
writer::write_sequence(&mut out, &body);
out
}
#[must_use]
pub fn decode(input: &[u8]) -> Option<Sm2Ciphertext> {
let (body, rest) = reader::read_sequence(input)?;
if !rest.is_empty() {
return None;
}
let (x, body) = read_field_element(body)?;
let (y, body) = read_field_element(body)?;
let (hash_bytes, body) = reader::read_octet_string(body)?;
let (ciphertext, body) = reader::read_octet_string(body)?;
if !body.is_empty() {
return None;
}
if hash_bytes.len() != HASH_LEN {
return None;
}
let mut hash = [0u8; HASH_LEN];
hash.copy_from_slice(hash_bytes);
Some(Sm2Ciphertext {
x,
y,
hash,
ciphertext: ciphertext.to_vec(),
})
}
fn read_field_element(input: &[u8]) -> Option<(U256, &[u8])> {
let (bytes, rest) = reader::read_integer(input)?;
if bytes.len() > 32 {
return None;
}
let mut padded = [0u8; 32];
padded[32 - bytes.len()..].copy_from_slice(bytes);
let value = U256::from_be_slice(&padded);
let in_field: bool = value.ct_lt(Fp::MODULUS.as_ref()).into();
if !in_field {
return None;
}
Some((value, rest))
}
#[cfg(test)]
mod tests {
use super::*;
fn make_ct(ciphertext: Vec<u8>) -> Sm2Ciphertext {
Sm2Ciphertext {
x: U256::from_be_hex(
"1234567890ABCDEF1234567890ABCDEF1234567890ABCDEF1234567890ABCDEF",
),
y: U256::from_be_hex(
"FEDCBA0987654321FEDCBA0987654321FEDCBA0987654321FEDCBA0987654321",
),
hash: [0xa5u8; 32],
ciphertext,
}
}
fn wrap_sequence(body: &[u8]) -> Vec<u8> {
let mut out = Vec::new();
writer::write_sequence(&mut out, body);
out
}
#[test]
fn round_trip_short() {
let ct = make_ct(b"hello world".to_vec());
let der = encode(&ct);
let decoded = decode(&der).expect("decode round-trip");
assert_eq!(decoded.x, ct.x);
assert_eq!(decoded.y, ct.y);
assert_eq!(decoded.hash, ct.hash);
assert_eq!(decoded.ciphertext, ct.ciphertext);
}
#[test]
fn round_trip_x_high_bit_set() {
let mut ct = make_ct(b"x".to_vec());
ct.x =
U256::from_be_hex("FFEDCBA9876543210FEDCBA9876543210FEDCBA9876543210FEDCBA987654321");
let der = encode(&ct);
let decoded = decode(&der).expect("decode high-bit round-trip");
assert_eq!(decoded.x, ct.x);
}
#[test]
fn round_trip_medium_ciphertext_300_bytes() {
let mut payload = alloc::vec![0u8; 300];
for (i, b) in payload.iter_mut().enumerate() {
#[allow(clippy::cast_possible_truncation)]
{
*b = (i as u8).wrapping_mul(13);
}
}
let ct = make_ct(payload.clone());
let der = encode(&ct);
let decoded = decode(&der).expect("decode 300-byte round-trip");
assert_eq!(decoded.ciphertext, payload);
}
#[test]
fn round_trip_empty_ciphertext() {
let ct = make_ct(Vec::new());
let der = encode(&ct);
let decoded = decode(&der).expect("decode empty-ciphertext round-trip");
assert!(decoded.ciphertext.is_empty());
}
#[test]
fn rejects_malformed() {
assert!(decode(&[]).is_none(), "empty input");
assert!(decode(&[0x30]).is_none(), "truncated SEQUENCE header");
assert!(decode(&[0x31, 0x00]).is_none(), "wrong outer tag");
assert!(decode(&[0x30, 0x05, 0x02, 0x01, 0x01]).is_none());
}
#[test]
fn rejects_wrong_hash_length() {
let bad_hash = [0x55u8; 31];
let ciphertext = b"x";
let mut body = Vec::new();
writer::write_integer(&mut body, &[0x01]);
writer::write_integer(&mut body, &[0x02]);
writer::write_octet_string(&mut body, &bad_hash);
writer::write_octet_string(&mut body, ciphertext);
let der = wrap_sequence(&body);
assert!(
decode(&der).is_none(),
"31-byte HASH must be rejected; SM3 always produces 32 bytes"
);
}
#[test]
fn rejects_non_canonical_x_leading_zero() {
let mut body = Vec::new();
body.extend_from_slice(&[0x02, 0x02, 0x00, 0x01]); writer::write_integer(&mut body, &[0x02]); writer::write_octet_string(&mut body, &[0u8; 32]);
writer::write_octet_string(&mut body, b"");
let der = wrap_sequence(&body);
assert!(
decode(&der).is_none(),
"non-canonical 00-pad on x must be rejected"
);
}
#[test]
fn rejects_negative_y_encoding() {
let mut body = Vec::new();
writer::write_integer(&mut body, &[0x01]);
body.extend_from_slice(&[0x02, 0x01, 0x80]); writer::write_octet_string(&mut body, &[0u8; 32]);
writer::write_octet_string(&mut body, b"");
let der = wrap_sequence(&body);
assert!(decode(&der).is_none());
}
#[test]
fn rejects_trailing_bytes() {
let ct = make_ct(b"hi".to_vec());
let mut der = encode(&ct);
der.push(0xff); assert!(decode(&der).is_none());
}
#[test]
fn round_trip_x_zero() {
let mut ct = make_ct(b"z".to_vec());
ct.x = U256::ZERO;
let der = encode(&ct);
let decoded = decode(&der).expect("decode round-trip with x = 0");
assert_eq!(decoded.x, U256::ZERO);
assert_eq!(decoded.y, ct.y);
}
#[test]
fn rejects_x_at_or_above_p() {
let p = *Fp::MODULUS.as_ref();
let p_bytes = p.to_be_bytes();
let mut body = Vec::new();
writer::write_integer(&mut body, &p_bytes);
writer::write_integer(&mut body, &[0x01]);
writer::write_octet_string(&mut body, &[0u8; 32]);
writer::write_octet_string(&mut body, b"");
let der = wrap_sequence(&body);
assert!(
decode(&der).is_none(),
"x = p is not a field element and must be rejected"
);
let max_bytes = [0xffu8; 32];
let mut body = Vec::new();
writer::write_integer(&mut body, &max_bytes);
writer::write_integer(&mut body, &[0x01]);
writer::write_octet_string(&mut body, &[0u8; 32]);
writer::write_octet_string(&mut body, b"");
let der = wrap_sequence(&body);
assert!(decode(&der).is_none(), "x = 2^256 - 1 must be rejected");
}
#[test]
fn round_trip_x_p_minus_one() {
let p_minus_one = Fp::MODULUS.as_ref().wrapping_sub(&U256::ONE);
let mut ct = make_ct(b"q".to_vec());
ct.x = p_minus_one;
let der = encode(&ct);
let decoded = decode(&der).expect("decode round-trip with x = p - 1");
assert_eq!(decoded.x, p_minus_one);
}
#[test]
fn round_trip_65536_byte_ciphertext_uses_3byte_length() {
let payload = alloc::vec![0xa5u8; 65_536];
let ct = make_ct(payload.clone());
let der = encode(&ct);
let decoded = decode(&der).expect("decode 65,536-byte round-trip");
assert_eq!(decoded.ciphertext, payload);
}
}