use alloc::vec::Vec;
use super::cipher::{BLOCK_SIZE, KEY_SIZE, Sm4Cipher};
#[must_use]
pub fn encrypt(key: &[u8; KEY_SIZE], counter: &[u8; BLOCK_SIZE], plaintext: &[u8]) -> Vec<u8> {
apply_keystream(key, counter, plaintext)
}
#[must_use]
pub fn decrypt(key: &[u8; KEY_SIZE], counter: &[u8; BLOCK_SIZE], ciphertext: &[u8]) -> Vec<u8> {
apply_keystream(key, counter, ciphertext)
}
fn apply_keystream(key: &[u8; KEY_SIZE], counter: &[u8; BLOCK_SIZE], input: &[u8]) -> Vec<u8> {
let cipher = Sm4Cipher::new(key);
let block_count = input.len().div_ceil(BLOCK_SIZE);
let mut keystream: Vec<[u8; BLOCK_SIZE]> = (0..block_count)
.map(|i| counter_add(counter, i as u128))
.collect();
cipher.encrypt_blocks(&mut keystream);
let mut out = Vec::with_capacity(input.len());
for (i, &b) in input.iter().enumerate() {
let block_idx = i / BLOCK_SIZE;
let lane = i % BLOCK_SIZE;
out.push(b ^ keystream[block_idx][lane]);
}
out
}
const fn counter_add(counter: &[u8; BLOCK_SIZE], offset: u128) -> [u8; BLOCK_SIZE] {
let n = u128::from_be_bytes(*counter);
n.wrapping_add(offset).to_be_bytes()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ctr_round_trip_is_identity() {
let key: [u8; 16] = [0x42; 16];
let counter: [u8; 16] = [0x01; 16];
#[allow(clippy::cast_possible_truncation)]
for len in 0..=64usize {
let plaintext: Vec<u8> = (0..len).map(|i| (i ^ 0xAB) as u8).collect();
let ciphertext = encrypt(&key, &counter, &plaintext);
assert_eq!(ciphertext.len(), plaintext.len(), "len mismatch at {len}");
let recovered = decrypt(&key, &counter, &ciphertext);
assert_eq!(recovered, plaintext, "round-trip at length {len}");
}
}
#[test]
fn ctr_keystream_matches_ecb_of_counter_blocks() {
let key: [u8; 16] = [
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54,
0x32, 0x10,
];
let counter: [u8; 16] = [0u8; 16];
let zeros = [0u8; 16];
let ctr_out = encrypt(&key, &counter, &zeros);
let mut ecb_out = counter;
Sm4Cipher::new(&key).encrypt_block(&mut ecb_out);
assert_eq!(&ctr_out[..], &ecb_out[..]);
}
#[test]
fn counter_add_wraps_at_2_to_128() {
let max_counter: [u8; 16] = [0xFF; 16];
let next = counter_add(&max_counter, 1);
assert_eq!(next, [0u8; 16], "counter must wrap at 2^128");
}
#[test]
fn empty_input_returns_empty_output() {
let key = [0u8; 16];
let counter = [0u8; 16];
assert!(encrypt(&key, &counter, &[]).is_empty());
assert!(decrypt(&key, &counter, &[]).is_empty());
}
}