use crate::sm4::cipher::{BLOCK_SIZE, KEY_SIZE, Sm4Cipher};
use alloc::vec::Vec;
use subtle::{ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater};
pub struct Sm4CbcEncryptor {
cipher: Sm4Cipher,
prev: [u8; BLOCK_SIZE],
buffer: [u8; BLOCK_SIZE],
buffer_len: usize,
output: Vec<u8>,
}
impl Sm4CbcEncryptor {
#[must_use]
pub fn new(key: &[u8; KEY_SIZE], iv: &[u8; BLOCK_SIZE]) -> Self {
Self {
cipher: Sm4Cipher::new(key),
prev: *iv,
buffer: [0u8; BLOCK_SIZE],
buffer_len: 0,
output: Vec::new(),
}
}
pub fn update(&mut self, mut data: &[u8]) {
if self.buffer_len > 0 {
let need = BLOCK_SIZE - self.buffer_len;
let take = need.min(data.len());
self.buffer[self.buffer_len..self.buffer_len + take].copy_from_slice(&data[..take]);
self.buffer_len += take;
data = &data[take..];
if self.buffer_len == BLOCK_SIZE {
let block = self.buffer;
self.encrypt_one(&block);
self.buffer_len = 0;
}
}
while data.len() >= BLOCK_SIZE {
let mut block = [0u8; BLOCK_SIZE];
block.copy_from_slice(&data[..BLOCK_SIZE]);
self.encrypt_one(&block);
data = &data[BLOCK_SIZE..];
}
if !data.is_empty() {
self.buffer[..data.len()].copy_from_slice(data);
self.buffer_len = data.len();
}
}
#[doc(hidden)]
pub fn take_output(&mut self) -> Vec<u8> {
core::mem::take(&mut self.output)
}
#[must_use]
pub fn finalize(mut self) -> Vec<u8> {
#[allow(clippy::cast_possible_truncation)]
let pad_len = (BLOCK_SIZE - self.buffer_len) as u8;
for i in self.buffer_len..BLOCK_SIZE {
self.buffer[i] = pad_len;
}
let block = self.buffer;
self.encrypt_one(&block);
self.output
}
fn encrypt_one(&mut self, plaintext_block: &[u8; BLOCK_SIZE]) {
let mut block = *plaintext_block;
for (b, p) in block.iter_mut().zip(self.prev.iter()) {
*b ^= *p;
}
self.cipher.encrypt_block(&mut block);
self.prev = block;
self.output.extend_from_slice(&block);
}
}
pub struct Sm4CbcDecryptor {
cipher: Sm4Cipher,
prev: [u8; BLOCK_SIZE],
buffer: [u8; BLOCK_SIZE],
buffer_len: usize,
output: Vec<u8>,
held_back: Option<[u8; BLOCK_SIZE]>,
}
impl Sm4CbcDecryptor {
#[must_use]
pub fn new(key: &[u8; KEY_SIZE], iv: &[u8; BLOCK_SIZE]) -> Self {
Self {
cipher: Sm4Cipher::new(key),
prev: *iv,
buffer: [0u8; BLOCK_SIZE],
buffer_len: 0,
output: Vec::new(),
held_back: None,
}
}
pub fn update(&mut self, mut data: &[u8]) {
if self.buffer_len > 0 {
let need = BLOCK_SIZE - self.buffer_len;
let take = need.min(data.len());
self.buffer[self.buffer_len..self.buffer_len + take].copy_from_slice(&data[..take]);
self.buffer_len += take;
data = &data[take..];
if self.buffer_len == BLOCK_SIZE {
let block = self.buffer;
self.decrypt_one(&block);
self.buffer_len = 0;
}
}
#[cfg(feature = "sm4-bitsliced-simd")]
{
use super::cipher::SIMD_BATCH;
while data.len() >= SIMD_BATCH * BLOCK_SIZE {
let mut batch = [[0u8; BLOCK_SIZE]; SIMD_BATCH];
for i in 0..SIMD_BATCH {
batch[i].copy_from_slice(&data[i * BLOCK_SIZE..(i + 1) * BLOCK_SIZE]);
}
self.decrypt_batch(&batch);
data = &data[SIMD_BATCH * BLOCK_SIZE..];
}
}
while data.len() >= BLOCK_SIZE {
let mut block = [0u8; BLOCK_SIZE];
block.copy_from_slice(&data[..BLOCK_SIZE]);
self.decrypt_one(&block);
data = &data[BLOCK_SIZE..];
}
if !data.is_empty() {
self.buffer[..data.len()].copy_from_slice(data);
self.buffer_len = data.len();
}
}
#[doc(hidden)]
pub fn take_output(&mut self) -> Vec<u8> {
core::mem::take(&mut self.output)
}
#[must_use]
pub fn finalize(mut self) -> Option<Vec<u8>> {
if self.buffer_len != 0 {
return None;
}
let last = self.held_back?;
let stripped = strip_pkcs7_block(&last)?;
self.output.extend_from_slice(&last[..stripped]);
Some(self.output)
}
fn decrypt_one(&mut self, ciphertext_block: &[u8; BLOCK_SIZE]) {
let mut block = *ciphertext_block;
let saved = block;
self.cipher.decrypt_block(&mut block);
for (b, p) in block.iter_mut().zip(self.prev.iter()) {
*b ^= *p;
}
self.prev = saved;
if let Some(prev_held) = self.held_back.take() {
self.output.extend_from_slice(&prev_held);
}
self.held_back = Some(block);
}
#[cfg(feature = "sm4-bitsliced-simd")]
fn decrypt_batch(&mut self, ct_blocks: &[[u8; BLOCK_SIZE]; super::cipher::SIMD_BATCH]) {
use super::cipher::SIMD_BATCH;
let saved = *ct_blocks;
let mut pt_blocks = saved;
self.cipher.decrypt_blocks_simd(&mut pt_blocks);
for (b, p) in pt_blocks[0].iter_mut().zip(self.prev.iter()) {
*b ^= *p;
}
for i in 1..SIMD_BATCH {
let chain = saved[i - 1];
for (b, p) in pt_blocks[i].iter_mut().zip(chain.iter()) {
*b ^= *p;
}
}
if let Some(prev_held) = self.held_back.take() {
self.output.extend_from_slice(&prev_held);
}
for pt in pt_blocks.iter().take(SIMD_BATCH - 1) {
self.output.extend_from_slice(pt);
}
self.held_back = Some(pt_blocks[SIMD_BATCH - 1]);
self.prev = saved[SIMD_BATCH - 1];
}
}
fn strip_pkcs7_block(block: &[u8; BLOCK_SIZE]) -> Option<usize> {
let last = block[BLOCK_SIZE - 1];
let pad_nonzero = !last.ct_eq(&0u8);
#[allow(clippy::cast_possible_truncation)]
let pad_le_block = !last.ct_gt(&(BLOCK_SIZE as u8));
let pad_in_range = pad_nonzero & pad_le_block;
let mut acc: u8 = 0;
for (i, byte) in block.iter().enumerate() {
#[allow(clippy::cast_possible_truncation)]
let pos_from_end = (BLOCK_SIZE - i) as u8;
let in_padding = !pos_from_end.ct_gt(&last);
let diff = *byte ^ last;
let masked = u8::conditional_select(&0u8, &diff, in_padding);
acc |= masked;
}
let acc_zero = acc.ct_eq(&0u8);
let valid = pad_in_range & acc_zero;
if bool::from(valid) {
Some(BLOCK_SIZE - last as usize)
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sm4::mode_cbc;
#[test]
fn encrypt_single_chunk_matches_v02() {
let key = [0x42u8; KEY_SIZE];
let iv = [0x33u8; BLOCK_SIZE];
let plaintext = b"streaming round trip";
let mut enc = Sm4CbcEncryptor::new(&key, &iv);
enc.update(plaintext);
let stream_ct = enc.finalize();
let oneshot_ct = mode_cbc::encrypt(&key, &iv, plaintext);
assert_eq!(stream_ct, oneshot_ct);
}
#[test]
fn encrypt_chunked_matches_v02() {
let key = [0x42u8; KEY_SIZE];
let iv = [0x33u8; BLOCK_SIZE];
let pt: Vec<u8> = (0..100u8).collect();
for chunk_size in [1usize, 7, 16, 17, 31, 32, 100] {
let mut enc = Sm4CbcEncryptor::new(&key, &iv);
for chunk in pt.chunks(chunk_size) {
enc.update(chunk);
}
let stream_ct = enc.finalize();
let oneshot_ct = mode_cbc::encrypt(&key, &iv, &pt);
assert_eq!(stream_ct, oneshot_ct, "chunk_size={chunk_size}");
}
}
#[test]
fn streaming_round_trip() {
let key = [0x42u8; KEY_SIZE];
let iv = [0x33u8; BLOCK_SIZE];
for len in [0usize, 1, 15, 16, 17, 31, 32, 33, 100, 256] {
#[allow(clippy::cast_possible_truncation)]
let pt: Vec<u8> = (0..len).map(|i| (i as u8).wrapping_mul(13)).collect();
let mut enc = Sm4CbcEncryptor::new(&key, &iv);
enc.update(&pt);
let ct = enc.finalize();
for chunk_size in [1usize, 7, 16, 17, 31, 32, ct.len().max(1)] {
let mut dec = Sm4CbcDecryptor::new(&key, &iv);
for chunk in ct.chunks(chunk_size) {
dec.update(chunk);
}
let recovered = dec.finalize().expect("decrypt");
assert_eq!(recovered, pt, "len={len} chunk_size={chunk_size}");
}
}
}
#[test]
fn decrypt_rejects_truncated() {
let key = [0x42u8; KEY_SIZE];
let iv = [0x33u8; BLOCK_SIZE];
let mut dec = Sm4CbcDecryptor::new(&key, &iv);
dec.update(&[0xAB; 31]); assert!(dec.finalize().is_none());
}
#[test]
fn decrypt_rejects_empty() {
let key = [0x42u8; KEY_SIZE];
let iv = [0x33u8; BLOCK_SIZE];
let dec = Sm4CbcDecryptor::new(&key, &iv);
assert!(dec.finalize().is_none());
}
#[test]
fn decrypt_rejects_bad_padding() {
let key = [0x42u8; KEY_SIZE];
let iv = [0x33u8; BLOCK_SIZE];
let pt = b"this is a test message that spans multiple blocks";
let mut enc = Sm4CbcEncryptor::new(&key, &iv);
enc.update(pt);
let mut ct = enc.finalize();
let last = ct.len() - 1;
ct[last] ^= 0x01;
let mut dec = Sm4CbcDecryptor::new(&key, &iv);
dec.update(&ct);
assert!(dec.finalize().is_none());
}
#[cfg(feature = "sm4-bitsliced-simd")]
#[test]
fn cbc_decrypt_simd_batch_boundary_sweep() {
use super::super::cipher::SIMD_BATCH;
let key = [0x42u8; KEY_SIZE];
let iv = [0x33u8; BLOCK_SIZE];
let block_counts: [usize; 8] = [
0,
1,
SIMD_BATCH.saturating_sub(1),
SIMD_BATCH,
SIMD_BATCH + 1,
(2 * SIMD_BATCH).saturating_sub(1),
2 * SIMD_BATCH,
2 * SIMD_BATCH + 1,
];
for &n_blocks in &block_counts {
let pt: Vec<u8> = (0..(n_blocks * BLOCK_SIZE))
.map(|i| u8::try_from(i & 0xFF).unwrap_or(0))
.collect();
let canonical = mode_cbc::encrypt(&key, &iv, &pt);
let mut dec = Sm4CbcDecryptor::new(&key, &iv);
dec.update(&canonical);
let recovered = dec.finalize().expect("decrypt");
assert_eq!(
recovered, pt,
"boundary sweep: n_blocks={n_blocks} (SIMD_BATCH={SIMD_BATCH})",
);
}
}
#[cfg(feature = "sm4-bitsliced-simd")]
#[test]
fn cbc_decrypt_simd_chunked_update_sweep() {
use super::super::cipher::SIMD_BATCH;
let key = [0x42u8; KEY_SIZE];
let iv = [0x33u8; BLOCK_SIZE];
let total_blocks = 3 * SIMD_BATCH + 1;
let pt: Vec<u8> = (0..(total_blocks * BLOCK_SIZE - 5))
.map(|i| u8::try_from((i * 17) & 0xFF).unwrap_or(0))
.collect();
let ct = mode_cbc::encrypt(&key, &iv, &pt);
let batch_bytes = SIMD_BATCH * BLOCK_SIZE;
let chunk_sizes = [
1,
7,
BLOCK_SIZE,
BLOCK_SIZE + 1,
batch_bytes - 1,
batch_bytes,
batch_bytes + 1,
2 * batch_bytes,
ct.len().max(1),
];
for &chunk_size in &chunk_sizes {
let mut dec = Sm4CbcDecryptor::new(&key, &iv);
for chunk in ct.chunks(chunk_size) {
dec.update(chunk);
}
let recovered = dec.finalize().expect("decrypt");
assert_eq!(
recovered, pt,
"chunked update: chunk_size={chunk_size} (batch_bytes={batch_bytes})",
);
}
}
#[cfg(feature = "sm4-bitsliced-simd")]
#[test]
fn cbc_decrypt_simd_take_output_preserves_held_back() {
use super::super::cipher::SIMD_BATCH;
let key = [0x42u8; KEY_SIZE];
let iv = [0x33u8; BLOCK_SIZE];
let total_blocks = 2 * SIMD_BATCH + 1;
let pt: Vec<u8> = (0..(total_blocks * BLOCK_SIZE))
.map(|i| u8::try_from((i ^ 0xA5) & 0xFF).unwrap_or(0))
.collect();
let ct = mode_cbc::encrypt(&key, &iv, &pt);
let mut dec = Sm4CbcDecryptor::new(&key, &iv);
let split = (SIMD_BATCH + 1) * BLOCK_SIZE;
dec.update(&ct[..split]);
let first_chunk_pt = dec.take_output();
dec.update(&ct[split..]);
let rest = dec.finalize().expect("decrypt");
let mut combined = first_chunk_pt;
combined.extend_from_slice(&rest);
assert_eq!(combined, pt);
}
#[test]
fn streaming_decrypt_matches_v02_oneshot() {
let key = [0x42u8; KEY_SIZE];
let iv = [0x33u8; BLOCK_SIZE];
let pt = b"test message for cross-validation";
let canonical = mode_cbc::encrypt(&key, &iv, pt);
let mut dec = Sm4CbcDecryptor::new(&key, &iv);
dec.update(&canonical);
let stream_pt = dec.finalize().expect("streaming decrypt");
assert_eq!(stream_pt, pt);
let mut enc = Sm4CbcEncryptor::new(&key, &iv);
enc.update(pt);
let blob = enc.finalize();
let recovered = mode_cbc::decrypt(&key, &iv, &blob).expect("oneshot decrypt");
assert_eq!(recovered, pt);
}
}