use gmcrypto_core::sm4::{BLOCK_SIZE, KEY_SIZE, Sm4Cipher};
const KEY: [u8; KEY_SIZE] = [
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10,
];
#[allow(clippy::cast_possible_truncation)] fn make_blocks(n: usize) -> Vec<[u8; BLOCK_SIZE]> {
let mut out = Vec::with_capacity(n);
for block_index in 0..n {
let mut block = [0u8; BLOCK_SIZE];
for (lane_index, byte) in block.iter_mut().enumerate() {
let seed = (block_index as u32)
.wrapping_mul(0x9E37_79B9)
.wrapping_add(lane_index as u32);
*byte = (seed ^ (seed >> 17) ^ (seed >> 9)) as u8;
}
out.push(block);
}
out
}
#[test]
fn encrypt_blocks_matches_per_block_at_every_length() {
let cipher = Sm4Cipher::new(&KEY);
for n in 0..=33 {
let blocks = make_blocks(n);
let mut batched = blocks.clone();
let mut sequential = blocks.clone();
cipher.encrypt_blocks(&mut batched);
for block in &mut sequential {
cipher.encrypt_block(block);
}
assert_eq!(
batched, sequential,
"encrypt_blocks divergence at length {n}",
);
}
}
#[test]
fn decrypt_blocks_matches_per_block_at_every_length() {
let cipher = Sm4Cipher::new(&KEY);
for n in 0..=33 {
let blocks = make_blocks(n);
let mut batched = blocks.clone();
let mut sequential = blocks.clone();
cipher.decrypt_blocks(&mut batched);
for block in &mut sequential {
cipher.decrypt_block(block);
}
assert_eq!(
batched, sequential,
"decrypt_blocks divergence at length {n}",
);
}
}
#[test]
fn round_trip_encrypt_then_decrypt_is_identity() {
let cipher = Sm4Cipher::new(&KEY);
for n in 0..=33 {
let original = make_blocks(n);
let mut buf = original.clone();
cipher.encrypt_blocks(&mut buf);
cipher.decrypt_blocks(&mut buf);
assert_eq!(buf, original, "round-trip divergence at length {n}");
}
}
#[test]
fn batch_boundary_named_lengths_round_trip() {
let cipher = Sm4Cipher::new(&KEY);
for &n in &[3usize, 4, 5, 7, 8, 9, 15, 16, 17, 32, 33] {
let original = make_blocks(n);
let mut buf = original.clone();
cipher.encrypt_blocks(&mut buf);
cipher.decrypt_blocks(&mut buf);
assert_eq!(
buf, original,
"boundary round-trip divergence at length {n}"
);
}
}