use alloc::vec::Vec;
use crypto_bigint::U256;
use subtle::ConstantTimeLess;
use crate::sm2::curve::Fp;
const HASH_LEN: usize = 32;
#[derive(Clone, Debug)]
pub struct Sm2Ciphertext {
pub x: U256,
pub y: U256,
pub hash: [u8; HASH_LEN],
pub ciphertext: Vec<u8>,
}
#[must_use]
pub fn encode(ct: &Sm2Ciphertext) -> Vec<u8> {
let x_der = encode_integer(&ct.x.to_be_bytes());
let y_der = encode_integer(&ct.y.to_be_bytes());
let hash_der = encode_octet_string(&ct.hash);
let ciphertext_der = encode_octet_string(&ct.ciphertext);
let body_len = x_der.len() + y_der.len() + hash_der.len() + ciphertext_der.len();
let mut out = Vec::with_capacity(body_len + 8);
out.push(0x30); push_length(&mut out, body_len);
out.extend_from_slice(&x_der);
out.extend_from_slice(&y_der);
out.extend_from_slice(&hash_der);
out.extend_from_slice(&ciphertext_der);
out
}
#[must_use]
pub fn decode(input: &[u8]) -> Option<Sm2Ciphertext> {
let (tag, rest) = input.split_first()?;
if *tag != 0x30 {
return None;
}
let (body_len, rest) = read_length(rest)?;
if rest.len() != body_len {
return None;
}
let (x, rest) = read_integer(rest)?;
let (y, rest) = read_integer(rest)?;
let (hash_bytes, rest) = read_octet_string(rest)?;
let (ciphertext, rest) = read_octet_string(rest)?;
if !rest.is_empty() {
return None;
}
if hash_bytes.len() != HASH_LEN {
return None;
}
let mut hash = [0u8; HASH_LEN];
hash.copy_from_slice(hash_bytes);
Some(Sm2Ciphertext {
x,
y,
hash,
ciphertext: ciphertext.to_vec(),
})
}
fn encode_integer(value_be: &[u8]) -> Vec<u8> {
let mut start = 0;
while start < value_be.len() - 1 && value_be[start] == 0 {
start += 1;
}
let trimmed = &value_be[start..];
let needs_pad = (trimmed[0] & 0x80) != 0;
let int_len = trimmed.len() + usize::from(needs_pad);
let mut out = Vec::with_capacity(int_len + 4);
out.push(0x02); push_length(&mut out, int_len);
if needs_pad {
out.push(0x00);
}
out.extend_from_slice(trimmed);
out
}
fn read_integer(input: &[u8]) -> Option<(U256, &[u8])> {
let (tag, rest) = input.split_first()?;
if *tag != 0x02 {
return None;
}
let (int_len, rest) = read_length(rest)?;
if rest.len() < int_len {
return None;
}
let (int_bytes, rest_after) = rest.split_at(int_len);
if int_bytes.is_empty() {
return None;
}
if int_bytes[0] & 0x80 != 0 {
return None;
}
let bytes = if int_bytes[0] == 0x00 {
if int_bytes.len() == 1 {
int_bytes
} else if int_bytes[1] & 0x80 == 0 {
return None;
} else {
&int_bytes[1..]
}
} else {
int_bytes
};
if bytes.len() > 32 {
return None;
}
let mut padded = [0u8; 32];
padded[32 - bytes.len()..].copy_from_slice(bytes);
let value = U256::from_be_slice(&padded);
let in_field: bool = value.ct_lt(Fp::MODULUS.as_ref()).into();
if !in_field {
return None;
}
Some((value, rest_after))
}
fn encode_octet_string(value: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(value.len() + 4);
out.push(0x04); push_length(&mut out, value.len());
out.extend_from_slice(value);
out
}
fn read_octet_string(input: &[u8]) -> Option<(&[u8], &[u8])> {
let (tag, rest) = input.split_first()?;
if *tag != 0x04 {
return None;
}
let (len, rest) = read_length(rest)?;
if rest.len() < len {
return None;
}
Some(rest.split_at(len))
}
fn push_length(out: &mut Vec<u8>, len: usize) {
if len < 128 {
#[allow(clippy::cast_possible_truncation)]
out.push(len as u8);
} else if len < 256 {
out.push(0x81);
#[allow(clippy::cast_possible_truncation)]
out.push(len as u8);
} else if len < 65_536 {
#[allow(clippy::cast_possible_truncation)]
{
out.push(0x82);
out.push((len >> 8) as u8);
out.push(len as u8);
}
} else if len < 16_777_216 {
#[allow(clippy::cast_possible_truncation)]
{
out.push(0x83);
out.push((len >> 16) as u8);
out.push((len >> 8) as u8);
out.push(len as u8);
}
} else {
panic!("ciphertext DER length overflow (> 16 MB)");
}
}
fn read_length(input: &[u8]) -> Option<(usize, &[u8])> {
let (first, rest) = input.split_first()?;
if *first < 0x80 {
Some((*first as usize, rest))
} else if *first == 0x81 {
let (b, rest) = rest.split_first()?;
if *b < 0x80 {
return None; }
Some((*b as usize, rest))
} else if *first == 0x82 {
let (hi, rest) = rest.split_first()?;
let (lo, rest) = rest.split_first()?;
let len = ((*hi as usize) << 8) | (*lo as usize);
if len < 256 {
return None; }
Some((len, rest))
} else if *first == 0x83 {
let (b2, rest) = rest.split_first()?;
let (b1, rest) = rest.split_first()?;
let (b0, rest) = rest.split_first()?;
let len = ((*b2 as usize) << 16) | ((*b1 as usize) << 8) | (*b0 as usize);
if len < 65_536 {
return None; }
Some((len, rest))
} else {
None }
}
#[cfg(test)]
mod tests {
use super::*;
fn make_ct(ciphertext: Vec<u8>) -> Sm2Ciphertext {
Sm2Ciphertext {
x: U256::from_be_hex(
"1234567890ABCDEF1234567890ABCDEF1234567890ABCDEF1234567890ABCDEF",
),
y: U256::from_be_hex(
"FEDCBA0987654321FEDCBA0987654321FEDCBA0987654321FEDCBA0987654321",
),
hash: [0xa5u8; 32],
ciphertext,
}
}
#[test]
fn round_trip_short() {
let ct = make_ct(b"hello world".to_vec());
let der = encode(&ct);
let decoded = decode(&der).expect("decode round-trip");
assert_eq!(decoded.x, ct.x);
assert_eq!(decoded.y, ct.y);
assert_eq!(decoded.hash, ct.hash);
assert_eq!(decoded.ciphertext, ct.ciphertext);
}
#[test]
fn round_trip_x_high_bit_set() {
let mut ct = make_ct(b"x".to_vec());
ct.x =
U256::from_be_hex("FFEDCBA9876543210FEDCBA9876543210FEDCBA9876543210FEDCBA987654321");
let der = encode(&ct);
let decoded = decode(&der).expect("decode high-bit round-trip");
assert_eq!(decoded.x, ct.x);
}
#[test]
fn round_trip_medium_ciphertext_300_bytes() {
let mut payload = alloc::vec![0u8; 300];
for (i, b) in payload.iter_mut().enumerate() {
#[allow(clippy::cast_possible_truncation)]
{
*b = (i as u8).wrapping_mul(13);
}
}
let ct = make_ct(payload.clone());
let der = encode(&ct);
let decoded = decode(&der).expect("decode 300-byte round-trip");
assert_eq!(decoded.ciphertext, payload);
}
#[test]
fn round_trip_empty_ciphertext() {
let ct = make_ct(Vec::new());
let der = encode(&ct);
let decoded = decode(&der).expect("decode empty-ciphertext round-trip");
assert!(decoded.ciphertext.is_empty());
}
#[test]
fn rejects_malformed() {
assert!(decode(&[]).is_none(), "empty input");
assert!(decode(&[0x30]).is_none(), "truncated SEQUENCE header");
assert!(decode(&[0x31, 0x00]).is_none(), "wrong outer tag");
assert!(decode(&[0x30, 0x05, 0x02, 0x01, 0x01]).is_none());
}
#[test]
fn rejects_wrong_hash_length() {
let bad_hash = [0x55u8; 31];
let ciphertext = b"x";
let mut body = Vec::new();
body.extend_from_slice(&encode_integer(&[0x01]));
body.extend_from_slice(&encode_integer(&[0x02]));
body.extend_from_slice(&encode_octet_string(&bad_hash));
body.extend_from_slice(&encode_octet_string(ciphertext));
let mut der = Vec::new();
der.push(0x30);
push_length(&mut der, body.len());
der.extend_from_slice(&body);
assert!(
decode(&der).is_none(),
"31-byte HASH must be rejected; SM3 always produces 32 bytes"
);
}
#[test]
fn rejects_non_canonical_x_leading_zero() {
let mut body = Vec::new();
body.extend_from_slice(&[0x02, 0x02, 0x00, 0x01]); body.extend_from_slice(&encode_integer(&[0x02])); body.extend_from_slice(&encode_octet_string(&[0u8; 32]));
body.extend_from_slice(&encode_octet_string(b""));
let mut der = Vec::new();
der.push(0x30);
push_length(&mut der, body.len());
der.extend_from_slice(&body);
assert!(
decode(&der).is_none(),
"non-canonical 00-pad on x must be rejected"
);
}
#[test]
fn rejects_negative_y_encoding() {
let mut body = Vec::new();
body.extend_from_slice(&encode_integer(&[0x01]));
body.extend_from_slice(&[0x02, 0x01, 0x80]); body.extend_from_slice(&encode_octet_string(&[0u8; 32]));
body.extend_from_slice(&encode_octet_string(b""));
let mut der = Vec::new();
der.push(0x30);
push_length(&mut der, body.len());
der.extend_from_slice(&body);
assert!(decode(&der).is_none());
}
#[test]
fn rejects_trailing_bytes() {
let ct = make_ct(b"hi".to_vec());
let mut der = encode(&ct);
der.push(0xff); assert!(decode(&der).is_none());
}
#[test]
fn round_trip_x_zero() {
let mut ct = make_ct(b"z".to_vec());
ct.x = U256::ZERO;
let der = encode(&ct);
let decoded = decode(&der).expect("decode round-trip with x = 0");
assert_eq!(decoded.x, U256::ZERO);
assert_eq!(decoded.y, ct.y);
}
#[test]
fn rejects_x_at_or_above_p() {
let p = *Fp::MODULUS.as_ref();
let p_bytes = p.to_be_bytes();
let mut body = Vec::new();
body.extend_from_slice(&encode_integer(&p_bytes));
body.extend_from_slice(&encode_integer(&[0x01]));
body.extend_from_slice(&encode_octet_string(&[0u8; 32]));
body.extend_from_slice(&encode_octet_string(b""));
let mut der = Vec::new();
der.push(0x30);
push_length(&mut der, body.len());
der.extend_from_slice(&body);
assert!(
decode(&der).is_none(),
"x = p is not a field element and must be rejected"
);
let max_bytes = [0xffu8; 32];
let mut body = Vec::new();
body.extend_from_slice(&encode_integer(&max_bytes));
body.extend_from_slice(&encode_integer(&[0x01]));
body.extend_from_slice(&encode_octet_string(&[0u8; 32]));
body.extend_from_slice(&encode_octet_string(b""));
let mut der = Vec::new();
der.push(0x30);
push_length(&mut der, body.len());
der.extend_from_slice(&body);
assert!(decode(&der).is_none(), "x = 2^256 - 1 must be rejected");
}
#[test]
fn round_trip_x_p_minus_one() {
let p_minus_one = Fp::MODULUS.as_ref().wrapping_sub(&U256::ONE);
let mut ct = make_ct(b"q".to_vec());
ct.x = p_minus_one;
let der = encode(&ct);
let decoded = decode(&der).expect("decode round-trip with x = p - 1");
assert_eq!(decoded.x, p_minus_one);
}
#[test]
fn round_trip_65536_byte_ciphertext_uses_3byte_length() {
let payload = alloc::vec![0xa5u8; 65_536];
let ct = make_ct(payload.clone());
let der = encode(&ct);
let decoded = decode(&der).expect("decode 65,536-byte round-trip");
assert_eq!(decoded.ciphertext, payload);
}
}