use crate::sm4::cipher::{BLOCK_SIZE, KEY_SIZE, Sm4Cipher};
use alloc::vec::Vec;
use subtle::{ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater};
#[must_use]
#[allow(clippy::missing_panics_doc)]
pub fn encrypt(key: &[u8; KEY_SIZE], iv: &[u8; BLOCK_SIZE], plaintext: &[u8]) -> Vec<u8> {
let pad_len = BLOCK_SIZE - (plaintext.len() % BLOCK_SIZE);
let mut buf = Vec::with_capacity(plaintext.len() + pad_len);
buf.extend_from_slice(plaintext);
#[allow(clippy::cast_possible_truncation)]
buf.resize(buf.len() + pad_len, pad_len as u8);
let cipher = Sm4Cipher::new(key);
let mut prev: [u8; BLOCK_SIZE] = *iv;
for chunk in buf.chunks_exact_mut(BLOCK_SIZE) {
let block: &mut [u8; BLOCK_SIZE] = chunk.try_into().expect("chunk is 16 bytes");
for i in 0..BLOCK_SIZE {
block[i] ^= prev[i];
}
cipher.encrypt_block(block);
prev = *block;
}
buf
}
#[must_use]
#[allow(clippy::missing_panics_doc)]
pub fn decrypt(key: &[u8; KEY_SIZE], iv: &[u8; BLOCK_SIZE], ciphertext: &[u8]) -> Option<Vec<u8>> {
if ciphertext.is_empty() || ciphertext.len() % BLOCK_SIZE != 0 {
return None;
}
let mut buf = ciphertext.to_vec();
let cipher = Sm4Cipher::new(key);
let mut prev: [u8; BLOCK_SIZE] = *iv;
for chunk in buf.chunks_exact_mut(BLOCK_SIZE) {
let block: &mut [u8; BLOCK_SIZE] = chunk.try_into().expect("chunk is 16 bytes");
let saved = *block;
cipher.decrypt_block(block);
for i in 0..BLOCK_SIZE {
block[i] ^= prev[i];
}
prev = saved;
}
strip_pkcs7_ct(&mut buf).map(|()| buf)
}
fn strip_pkcs7_ct(buf: &mut Vec<u8>) -> Option<()> {
let n = buf.len();
if n == 0 || n % BLOCK_SIZE != 0 {
return None;
}
let last = buf[n - 1];
let pad_nonzero = !last.ct_eq(&0u8);
#[allow(clippy::cast_possible_truncation)]
let pad_le_block = !last.ct_gt(&(BLOCK_SIZE as u8));
let pad_in_range = pad_nonzero & pad_le_block;
let mut acc: u8 = 0;
for i in 0..BLOCK_SIZE {
#[allow(clippy::cast_possible_truncation)]
let pos_from_end = (BLOCK_SIZE - i) as u8; let byte = buf[n - BLOCK_SIZE + i];
let in_padding = !pos_from_end.ct_gt(&last);
let diff = byte ^ last;
let masked = u8::conditional_select(&0u8, &diff, in_padding);
acc |= masked;
}
let acc_zero = acc.ct_eq(&0u8);
let valid = pad_in_range & acc_zero;
if bool::from(valid) {
buf.truncate(n - last as usize);
Some(())
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trip_boundary_lengths() {
let key: [u8; 16] = [0x42; 16];
let iv: [u8; 16] = [0x33; 16];
for len in [0usize, 1, 15, 16, 17, 31, 32, 100] {
#[allow(clippy::cast_possible_truncation)]
let plaintext: Vec<u8> = (0..len).map(|i| (i as u8).wrapping_mul(7)).collect();
let ciphertext = encrypt(&key, &iv, &plaintext);
assert!(
ciphertext.len() > plaintext.len(),
"ciphertext must be longer than plaintext for len={len}"
);
assert_eq!(
ciphertext.len() % BLOCK_SIZE,
0,
"ciphertext must be block-aligned"
);
let recovered = decrypt(&key, &iv, &ciphertext).expect("decrypt must succeed");
assert_eq!(recovered, plaintext, "round-trip mismatch at len={len}");
}
}
#[test]
fn empty_plaintext_yields_one_block() {
let key: [u8; 16] = [0x42; 16];
let iv: [u8; 16] = [0x33; 16];
let ciphertext = encrypt(&key, &iv, b"");
assert_eq!(ciphertext.len(), BLOCK_SIZE, "empty PT → exactly one block");
let recovered = decrypt(&key, &iv, &ciphertext).expect("decrypt empty");
assert_eq!(recovered, b"");
}
#[test]
fn decrypt_rejects_misaligned_length() {
let key: [u8; 16] = [0x42; 16];
let iv: [u8; 16] = [0x33; 16];
assert!(decrypt(&key, &iv, &[0u8; 15]).is_none());
assert!(decrypt(&key, &iv, &[0u8; 17]).is_none());
}
#[test]
fn decrypt_rejects_empty() {
let key: [u8; 16] = [0x42; 16];
let iv: [u8; 16] = [0x33; 16];
assert!(decrypt(&key, &iv, &[]).is_none());
}
#[test]
fn decrypt_rejects_tampered_final_block() {
let key: [u8; 16] = [0x42; 16];
let iv: [u8; 16] = [0x33; 16];
let plaintext = b"this is a test message that spans multiple blocks";
let mut ciphertext = encrypt(&key, &iv, plaintext);
let last = ciphertext.len() - 1;
ciphertext[last] ^= 0x01;
assert!(
decrypt(&key, &iv, &ciphertext).is_none(),
"tampered last byte must break PKCS#7"
);
}
#[test]
fn strip_pkcs7_known_good() {
let mut buf = alloc::vec![0x10u8; 16];
assert!(strip_pkcs7_ct(&mut buf).is_some());
assert_eq!(buf.len(), 0);
let mut buf = alloc::vec![0u8; 16];
buf[15] = 0x01;
assert!(strip_pkcs7_ct(&mut buf).is_some());
assert_eq!(buf.len(), 15);
let mut buf = alloc::vec![0u8; 16];
buf[12] = 0x04;
buf[13] = 0x04;
buf[14] = 0x04;
buf[15] = 0x04;
assert!(strip_pkcs7_ct(&mut buf).is_some());
assert_eq!(buf.len(), 12);
}
#[test]
fn strip_pkcs7_known_bad() {
let mut buf = alloc::vec![0u8; 16];
assert!(strip_pkcs7_ct(&mut buf).is_none());
let mut buf = alloc::vec![0u8; 16];
buf[15] = 17;
assert!(strip_pkcs7_ct(&mut buf).is_none());
let mut buf = alloc::vec![0u8; 16];
buf[12] = 0x04;
buf[13] = 0xff; buf[14] = 0x04;
buf[15] = 0x04;
assert!(strip_pkcs7_ct(&mut buf).is_none());
}
}