use crate::sm4::cipher::{BLOCK_SIZE, KEY_SIZE, Sm4Cipher};
use alloc::vec::Vec;
use subtle::{ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater};
pub struct Sm4CbcEncryptor {
cipher: Sm4Cipher,
prev: [u8; BLOCK_SIZE],
buffer: [u8; BLOCK_SIZE],
buffer_len: usize,
output: Vec<u8>,
}
impl Sm4CbcEncryptor {
#[must_use]
pub fn new(key: &[u8; KEY_SIZE], iv: &[u8; BLOCK_SIZE]) -> Self {
Self {
cipher: Sm4Cipher::new(key),
prev: *iv,
buffer: [0u8; BLOCK_SIZE],
buffer_len: 0,
output: Vec::new(),
}
}
pub fn update(&mut self, mut data: &[u8]) {
if self.buffer_len > 0 {
let need = BLOCK_SIZE - self.buffer_len;
let take = need.min(data.len());
self.buffer[self.buffer_len..self.buffer_len + take].copy_from_slice(&data[..take]);
self.buffer_len += take;
data = &data[take..];
if self.buffer_len == BLOCK_SIZE {
let block = self.buffer;
self.encrypt_one(&block);
self.buffer_len = 0;
}
}
while data.len() >= BLOCK_SIZE {
let mut block = [0u8; BLOCK_SIZE];
block.copy_from_slice(&data[..BLOCK_SIZE]);
self.encrypt_one(&block);
data = &data[BLOCK_SIZE..];
}
if !data.is_empty() {
self.buffer[..data.len()].copy_from_slice(data);
self.buffer_len = data.len();
}
}
#[must_use]
pub fn finalize(mut self) -> Vec<u8> {
#[allow(clippy::cast_possible_truncation)]
let pad_len = (BLOCK_SIZE - self.buffer_len) as u8;
for i in self.buffer_len..BLOCK_SIZE {
self.buffer[i] = pad_len;
}
let block = self.buffer;
self.encrypt_one(&block);
self.output
}
fn encrypt_one(&mut self, plaintext_block: &[u8; BLOCK_SIZE]) {
let mut block = *plaintext_block;
for (b, p) in block.iter_mut().zip(self.prev.iter()) {
*b ^= *p;
}
self.cipher.encrypt_block(&mut block);
self.prev = block;
self.output.extend_from_slice(&block);
}
}
pub struct Sm4CbcDecryptor {
cipher: Sm4Cipher,
prev: [u8; BLOCK_SIZE],
buffer: [u8; BLOCK_SIZE],
buffer_len: usize,
output: Vec<u8>,
held_back: Option<[u8; BLOCK_SIZE]>,
}
impl Sm4CbcDecryptor {
#[must_use]
pub fn new(key: &[u8; KEY_SIZE], iv: &[u8; BLOCK_SIZE]) -> Self {
Self {
cipher: Sm4Cipher::new(key),
prev: *iv,
buffer: [0u8; BLOCK_SIZE],
buffer_len: 0,
output: Vec::new(),
held_back: None,
}
}
pub fn update(&mut self, mut data: &[u8]) {
if self.buffer_len > 0 {
let need = BLOCK_SIZE - self.buffer_len;
let take = need.min(data.len());
self.buffer[self.buffer_len..self.buffer_len + take].copy_from_slice(&data[..take]);
self.buffer_len += take;
data = &data[take..];
if self.buffer_len == BLOCK_SIZE {
let block = self.buffer;
self.decrypt_one(&block);
self.buffer_len = 0;
}
}
while data.len() >= BLOCK_SIZE {
let mut block = [0u8; BLOCK_SIZE];
block.copy_from_slice(&data[..BLOCK_SIZE]);
self.decrypt_one(&block);
data = &data[BLOCK_SIZE..];
}
if !data.is_empty() {
self.buffer[..data.len()].copy_from_slice(data);
self.buffer_len = data.len();
}
}
#[must_use]
pub fn finalize(mut self) -> Option<Vec<u8>> {
if self.buffer_len != 0 {
return None;
}
let last = self.held_back?;
let stripped = strip_pkcs7_block(&last)?;
self.output.extend_from_slice(&last[..stripped]);
Some(self.output)
}
fn decrypt_one(&mut self, ciphertext_block: &[u8; BLOCK_SIZE]) {
let mut block = *ciphertext_block;
let saved = block;
self.cipher.decrypt_block(&mut block);
for (b, p) in block.iter_mut().zip(self.prev.iter()) {
*b ^= *p;
}
self.prev = saved;
if let Some(prev_held) = self.held_back.take() {
self.output.extend_from_slice(&prev_held);
}
self.held_back = Some(block);
}
}
fn strip_pkcs7_block(block: &[u8; BLOCK_SIZE]) -> Option<usize> {
let last = block[BLOCK_SIZE - 1];
let pad_nonzero = !last.ct_eq(&0u8);
#[allow(clippy::cast_possible_truncation)]
let pad_le_block = !last.ct_gt(&(BLOCK_SIZE as u8));
let pad_in_range = pad_nonzero & pad_le_block;
let mut acc: u8 = 0;
for (i, byte) in block.iter().enumerate() {
#[allow(clippy::cast_possible_truncation)]
let pos_from_end = (BLOCK_SIZE - i) as u8;
let in_padding = !pos_from_end.ct_gt(&last);
let diff = *byte ^ last;
let masked = u8::conditional_select(&0u8, &diff, in_padding);
acc |= masked;
}
let acc_zero = acc.ct_eq(&0u8);
let valid = pad_in_range & acc_zero;
if bool::from(valid) {
Some(BLOCK_SIZE - last as usize)
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sm4::mode_cbc;
#[test]
fn encrypt_single_chunk_matches_v02() {
let key = [0x42u8; KEY_SIZE];
let iv = [0x33u8; BLOCK_SIZE];
let plaintext = b"streaming round trip";
let mut enc = Sm4CbcEncryptor::new(&key, &iv);
enc.update(plaintext);
let stream_ct = enc.finalize();
let oneshot_ct = mode_cbc::encrypt(&key, &iv, plaintext);
assert_eq!(stream_ct, oneshot_ct);
}
#[test]
fn encrypt_chunked_matches_v02() {
let key = [0x42u8; KEY_SIZE];
let iv = [0x33u8; BLOCK_SIZE];
let pt: Vec<u8> = (0..100u8).collect();
for chunk_size in [1usize, 7, 16, 17, 31, 32, 100] {
let mut enc = Sm4CbcEncryptor::new(&key, &iv);
for chunk in pt.chunks(chunk_size) {
enc.update(chunk);
}
let stream_ct = enc.finalize();
let oneshot_ct = mode_cbc::encrypt(&key, &iv, &pt);
assert_eq!(stream_ct, oneshot_ct, "chunk_size={chunk_size}");
}
}
#[test]
fn streaming_round_trip() {
let key = [0x42u8; KEY_SIZE];
let iv = [0x33u8; BLOCK_SIZE];
for len in [0usize, 1, 15, 16, 17, 31, 32, 33, 100, 256] {
#[allow(clippy::cast_possible_truncation)]
let pt: Vec<u8> = (0..len).map(|i| (i as u8).wrapping_mul(13)).collect();
let mut enc = Sm4CbcEncryptor::new(&key, &iv);
enc.update(&pt);
let ct = enc.finalize();
for chunk_size in [1usize, 7, 16, 17, 31, 32, ct.len().max(1)] {
let mut dec = Sm4CbcDecryptor::new(&key, &iv);
for chunk in ct.chunks(chunk_size) {
dec.update(chunk);
}
let recovered = dec.finalize().expect("decrypt");
assert_eq!(recovered, pt, "len={len} chunk_size={chunk_size}");
}
}
}
#[test]
fn decrypt_rejects_truncated() {
let key = [0x42u8; KEY_SIZE];
let iv = [0x33u8; BLOCK_SIZE];
let mut dec = Sm4CbcDecryptor::new(&key, &iv);
dec.update(&[0xAB; 31]); assert!(dec.finalize().is_none());
}
#[test]
fn decrypt_rejects_empty() {
let key = [0x42u8; KEY_SIZE];
let iv = [0x33u8; BLOCK_SIZE];
let dec = Sm4CbcDecryptor::new(&key, &iv);
assert!(dec.finalize().is_none());
}
#[test]
fn decrypt_rejects_bad_padding() {
let key = [0x42u8; KEY_SIZE];
let iv = [0x33u8; BLOCK_SIZE];
let pt = b"this is a test message that spans multiple blocks";
let mut enc = Sm4CbcEncryptor::new(&key, &iv);
enc.update(pt);
let mut ct = enc.finalize();
let last = ct.len() - 1;
ct[last] ^= 0x01;
let mut dec = Sm4CbcDecryptor::new(&key, &iv);
dec.update(&ct);
assert!(dec.finalize().is_none());
}
#[test]
fn streaming_decrypt_matches_v02_oneshot() {
let key = [0x42u8; KEY_SIZE];
let iv = [0x33u8; BLOCK_SIZE];
let pt = b"test message for cross-validation";
let canonical = mode_cbc::encrypt(&key, &iv, pt);
let mut dec = Sm4CbcDecryptor::new(&key, &iv);
dec.update(&canonical);
let stream_pt = dec.finalize().expect("streaming decrypt");
assert_eq!(stream_pt, pt);
let mut enc = Sm4CbcEncryptor::new(&key, &iv);
enc.update(pt);
let blob = enc.finalize();
let recovered = mode_cbc::decrypt(&key, &iv, &blob).expect("oneshot decrypt");
assert_eq!(recovered, pt);
}
}