use alloc::vec::Vec;
use super::cipher::{BLOCK_SIZE, KEY_SIZE, Sm4Cipher};
pub struct Sm4CtrCipher {
cipher: Sm4Cipher,
counter: [u8; BLOCK_SIZE],
leftover: [u8; BLOCK_SIZE],
leftover_pos: usize,
}
impl Sm4CtrCipher {
#[must_use]
pub fn new(key: &[u8; KEY_SIZE], counter: &[u8; BLOCK_SIZE]) -> Self {
Self {
cipher: Sm4Cipher::new(key),
counter: *counter,
leftover: [0u8; BLOCK_SIZE],
leftover_pos: BLOCK_SIZE,
}
}
pub fn update(&mut self, input: &[u8], output: &mut [u8]) {
assert!(
output.len() >= input.len(),
"Sm4CtrCipher::update: output buffer too short ({} < {})",
output.len(),
input.len(),
);
let mut i = 0usize;
while i < input.len() && self.leftover_pos < BLOCK_SIZE {
output[i] = input[i] ^ self.leftover[self.leftover_pos];
self.leftover_pos += 1;
i += 1;
}
let remaining = input.len() - i;
let full_blocks = remaining / BLOCK_SIZE;
if full_blocks > 0 {
let mut keystream: Vec<[u8; BLOCK_SIZE]> = (0..full_blocks)
.map(|j| counter_add(&self.counter, j as u128))
.collect();
self.cipher.encrypt_blocks(&mut keystream);
for (b, ks) in keystream.iter().enumerate() {
let off = i + b * BLOCK_SIZE;
for lane in 0..BLOCK_SIZE {
output[off + lane] = input[off + lane] ^ ks[lane];
}
}
self.counter = counter_add(&self.counter, full_blocks as u128);
i += full_blocks * BLOCK_SIZE;
}
if i < input.len() {
self.leftover = self.counter;
self.cipher.encrypt_block(&mut self.leftover);
self.counter = counter_add(&self.counter, 1);
self.leftover_pos = 0;
while i < input.len() {
output[i] = input[i] ^ self.leftover[self.leftover_pos];
self.leftover_pos += 1;
i += 1;
}
}
}
pub fn finalize(self) {
}
}
const fn counter_add(counter: &[u8; BLOCK_SIZE], offset: u128) -> [u8; BLOCK_SIZE] {
let n = u128::from_be_bytes(*counter);
n.wrapping_add(offset).to_be_bytes()
}
#[cfg(test)]
mod tests {
use alloc::vec;
use super::*;
use crate::sm4::mode_ctr;
const KEY: [u8; 16] = [
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32,
0x10,
];
const COUNTER: [u8; 16] = [0x42u8; 16];
#[allow(clippy::cast_possible_truncation)]
fn make_plaintext(len: usize) -> Vec<u8> {
(0..len)
.map(|i| {
let s = (i as u32).wrapping_mul(0x9E37_79B9);
(s ^ (s >> 17)) as u8
})
.collect()
}
#[test]
fn streaming_single_call_matches_single_shot() {
for len in 0..=33 {
let plaintext = make_plaintext(len);
let mut out = vec![0u8; len];
let mut cipher = Sm4CtrCipher::new(&KEY, &COUNTER);
cipher.update(&plaintext, &mut out);
cipher.finalize();
let expected = mode_ctr::encrypt(&KEY, &COUNTER, &plaintext);
assert_eq!(out, expected, "single-call divergence at length {len}");
}
}
#[test]
fn chunked_update_sweep_matches_single_shot() {
let total = 64;
let plaintext = make_plaintext(total);
let reference = mode_ctr::encrypt(&KEY, &COUNTER, &plaintext);
for chunk_size in 1..=17 {
let mut cipher = Sm4CtrCipher::new(&KEY, &COUNTER);
let mut out = vec![0u8; total];
let mut written = 0;
while written < total {
let take = chunk_size.min(total - written);
cipher.update(
&plaintext[written..written + take],
&mut out[written..written + take],
);
written += take;
}
cipher.finalize();
assert_eq!(
out, reference,
"chunked update divergence at chunk_size {chunk_size}",
);
}
}
#[test]
fn streaming_round_trip_is_identity() {
for len in 0..=33 {
let plaintext = make_plaintext(len);
let mut ciphertext = vec![0u8; len];
let mut enc = Sm4CtrCipher::new(&KEY, &COUNTER);
enc.update(&plaintext, &mut ciphertext);
enc.finalize();
let mut recovered = vec![0u8; len];
let mut dec = Sm4CtrCipher::new(&KEY, &COUNTER);
dec.update(&ciphertext, &mut recovered);
dec.finalize();
assert_eq!(recovered, plaintext, "streaming round-trip at length {len}");
}
}
#[test]
fn empty_update() {
let mut cipher = Sm4CtrCipher::new(&KEY, &COUNTER);
let mut out = [];
cipher.update(&[], &mut out);
cipher.finalize();
}
}