#![allow(clippy::cast_possible_truncation)]
use alloc::vec;
use alloc::vec::Vec;
use subtle::ConstantTimeEq;
use zeroize::Zeroize;
use super::cipher::{BLOCK_SIZE, KEY_SIZE, Sm4Cipher};
const VALID_TAG_LENS: [usize; 7] = [4, 6, 8, 10, 12, 14, 16];
const MIN_NONCE_LEN: usize = 7;
const MAX_NONCE_LEN: usize = 13;
#[must_use]
pub fn encrypt(
key: &[u8; KEY_SIZE],
nonce: &[u8],
aad: &[u8],
plaintext: &[u8],
tag_len: usize,
) -> Option<Vec<u8>> {
validate_params(nonce, plaintext.len(), aad.len(), tag_len)?;
let cipher = Sm4Cipher::new(key);
let q = 15 - nonce.len();
let b0 = build_b0(nonce, plaintext.len(), !aad.is_empty(), tag_len, q);
let mut auth = Vec::with_capacity(BLOCK_SIZE);
auth.extend_from_slice(&b0);
if !aad.is_empty() {
format_aad_into(&mut auth, aad);
}
let plaintext_offset = auth.len();
auth.extend_from_slice(plaintext);
while auth.len() % BLOCK_SIZE != 0 {
auth.push(0);
}
let _ = plaintext_offset;
let t = cbc_mac(&cipher, &auth);
let mut a0 = [0u8; BLOCK_SIZE];
a0[0] = (q - 1) as u8; a0[1..=nonce.len()].copy_from_slice(nonce);
let mut s0 = a0;
cipher.encrypt_block(&mut s0);
let mut tag = [0u8; 16];
for i in 0..tag_len {
tag[i] = s0[i] ^ t[i];
}
let mut ct = vec![0u8; plaintext.len()];
ccm_ctr_xor(&cipher, &a0, plaintext, &mut ct);
let mut output = Vec::with_capacity(ct.len() + tag_len);
output.extend_from_slice(&ct);
output.extend_from_slice(&tag[..tag_len]);
Some(output)
}
#[must_use]
pub fn decrypt(
key: &[u8; KEY_SIZE],
nonce: &[u8],
aad: &[u8],
ciphertext_with_tag: &[u8],
tag_len: usize,
) -> Option<Vec<u8>> {
if !VALID_TAG_LENS.contains(&tag_len) {
return None;
}
if ciphertext_with_tag.len() < tag_len {
return None;
}
let split = ciphertext_with_tag.len() - tag_len;
let ct = &ciphertext_with_tag[..split];
let wire_tag = &ciphertext_with_tag[split..];
validate_params(nonce, ct.len(), aad.len(), tag_len)?;
let cipher = Sm4Cipher::new(key);
let q = 15 - nonce.len();
let mut a0 = [0u8; BLOCK_SIZE];
a0[0] = (q - 1) as u8;
a0[1..=nonce.len()].copy_from_slice(nonce);
let mut tentative_pt = vec![0u8; ct.len()];
ccm_ctr_xor(&cipher, &a0, ct, &mut tentative_pt);
let b0 = build_b0(nonce, tentative_pt.len(), !aad.is_empty(), tag_len, q);
let mut auth = Vec::with_capacity(BLOCK_SIZE);
auth.extend_from_slice(&b0);
if !aad.is_empty() {
format_aad_into(&mut auth, aad);
}
auth.extend_from_slice(&tentative_pt);
while auth.len() % BLOCK_SIZE != 0 {
auth.push(0);
}
let t = cbc_mac(&cipher, &auth);
let mut s0 = a0;
cipher.encrypt_block(&mut s0);
let mut expected_tag = [0u8; 16];
for i in 0..tag_len {
expected_tag[i] = s0[i] ^ t[i];
}
if expected_tag[..tag_len].ct_eq(wire_tag).unwrap_u8() != 1 {
tentative_pt.zeroize();
return None;
}
Some(tentative_pt)
}
fn validate_params(nonce: &[u8], pt_len: usize, aad_len: usize, tag_len: usize) -> Option<()> {
if !VALID_TAG_LENS.contains(&tag_len) {
return None;
}
if nonce.len() < MIN_NONCE_LEN || nonce.len() > MAX_NONCE_LEN {
return None;
}
let q = 15 - nonce.len();
if q < 8 {
let max_pt: u64 = (1u64 << (8 * q)) - 1;
if (pt_len as u64) > max_pt {
return None;
}
}
let _ = aad_len;
Some(())
}
fn build_b0(
nonce: &[u8],
pt_len: usize,
has_aad: bool,
tag_len: usize,
q: usize,
) -> [u8; BLOCK_SIZE] {
let mut b0 = [0u8; BLOCK_SIZE];
let adata_bit: u8 = if has_aad { 0x40 } else { 0 };
let t_field: u8 = (((tag_len - 2) / 2) as u8) << 3;
let q_field: u8 = (q - 1) as u8;
b0[0] = adata_bit | t_field | q_field;
b0[1..=nonce.len()].copy_from_slice(nonce);
let pt_len_bytes = (pt_len as u64).to_be_bytes(); let start = 16 - q;
if q <= 8 {
b0[start..16].copy_from_slice(&pt_len_bytes[8 - q..8]);
} else {
debug_assert!(q <= 8);
}
b0
}
fn format_aad_into(out: &mut Vec<u8>, aad: &[u8]) {
let alen = aad.len();
if alen < 0xFF00 {
let l = alen as u16;
out.extend_from_slice(&l.to_be_bytes());
} else if alen <= 0xFFFF_FFFF {
out.push(0xFF);
out.push(0xFE);
out.extend_from_slice(&(alen as u32).to_be_bytes());
} else {
out.push(0xFF);
out.push(0xFF);
out.extend_from_slice(&(alen as u64).to_be_bytes());
}
out.extend_from_slice(aad);
while out.len() % BLOCK_SIZE != 0 {
out.push(0);
}
}
fn cbc_mac(cipher: &Sm4Cipher, data: &[u8]) -> [u8; BLOCK_SIZE] {
debug_assert_eq!(data.len() % BLOCK_SIZE, 0);
let mut t = [0u8; BLOCK_SIZE];
let mut i = 0;
while i < data.len() {
for k in 0..BLOCK_SIZE {
t[k] ^= data[i + k];
}
cipher.encrypt_block(&mut t);
i += BLOCK_SIZE;
}
t
}
fn ccm_ctr_xor(cipher: &Sm4Cipher, a0: &[u8; BLOCK_SIZE], input: &[u8], output: &mut [u8]) {
debug_assert_eq!(input.len(), output.len());
if input.is_empty() {
return;
}
let block_count = input.len().div_ceil(BLOCK_SIZE);
let nonce_part_end = a0[0] as usize; let q = (nonce_part_end + 1) as u32; let counter_start_idx = 16 - q as usize;
let mut keystream: Vec<[u8; BLOCK_SIZE]> = Vec::with_capacity(block_count);
for i in 1..=block_count {
let mut a_i = *a0;
let i_bytes = (i as u64).to_be_bytes();
a_i[counter_start_idx..16].copy_from_slice(&i_bytes[8 - q as usize..8]);
keystream.push(a_i);
}
cipher.encrypt_blocks(&mut keystream);
for j in 0..input.len() {
let block_idx = j / BLOCK_SIZE;
let lane = j % BLOCK_SIZE;
output[j] = input[j] ^ keystream[block_idx][lane];
}
}
#[cfg(test)]
mod tests {
use super::*;
const KEY: [u8; 16] = [
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32,
0x10,
];
#[test]
fn round_trip_canonical() {
let nonce = [0x00u8; 12];
let aad = b"associated data";
let pt = b"v0.8 W3 SM4-CCM smoke";
let ct = encrypt(&KEY, &nonce, aad, pt, 16).expect("valid params");
let recovered = decrypt(&KEY, &nonce, aad, &ct, 16).expect("tag verifies");
assert_eq!(recovered, pt);
}
#[test]
fn round_trip_short_tag() {
let nonce = [0x00u8; 12];
for &tag_len in &[4, 6, 8, 10, 12, 14, 16] {
let pt = b"varying tag length";
let ct = encrypt(&KEY, &nonce, b"aad", pt, tag_len).expect("valid params");
assert_eq!(ct.len(), pt.len() + tag_len);
let recovered = decrypt(&KEY, &nonce, b"aad", &ct, tag_len).expect("tag verifies");
assert_eq!(recovered, pt);
}
}
#[test]
fn round_trip_nonce_length_sweep() {
for nonce_len in 7..=13 {
let nonce = vec![0x42u8; nonce_len];
let ct = encrypt(&KEY, &nonce, b"x", b"hi", 16).expect("valid params");
let recovered = decrypt(&KEY, &nonce, b"x", &ct, 16).expect("tag verifies");
assert_eq!(recovered, b"hi");
}
}
#[test]
fn round_trip_empty_pt() {
let nonce = [0x42u8; 12];
let ct = encrypt(&KEY, &nonce, b"aad", &[], 16).expect("valid params");
assert_eq!(ct.len(), 16);
let recovered = decrypt(&KEY, &nonce, b"aad", &ct, 16).expect("tag verifies");
assert!(recovered.is_empty());
}
#[test]
fn round_trip_empty_aad() {
let nonce = [0x42u8; 12];
let pt = b"hello";
let ct = encrypt(&KEY, &nonce, &[], pt, 16).expect("valid params");
let recovered = decrypt(&KEY, &nonce, &[], &ct, 16).expect("tag verifies");
assert_eq!(recovered, pt);
}
#[test]
fn tampered_tag_fails() {
let nonce = [0x42u8; 12];
let mut ct = encrypt(&KEY, &nonce, b"aad", b"hello", 16).expect("valid params");
let len = ct.len();
ct[len - 1] ^= 0x01;
assert!(decrypt(&KEY, &nonce, b"aad", &ct, 16).is_none());
}
#[test]
fn tampered_ciphertext_fails() {
let nonce = [0x42u8; 12];
let mut ct = encrypt(&KEY, &nonce, b"aad", b"hello", 16).expect("valid params");
ct[0] ^= 0x01;
assert!(decrypt(&KEY, &nonce, b"aad", &ct, 16).is_none());
}
#[test]
fn tampered_aad_fails() {
let nonce = [0x42u8; 12];
let ct = encrypt(&KEY, &nonce, b"correct-aad", b"hello", 16).expect("valid params");
assert!(decrypt(&KEY, &nonce, b"wrong-aad", &ct, 16).is_none());
}
#[test]
fn invalid_nonce_length_rejected() {
assert!(encrypt(&KEY, &[0u8; 6], &[], &[], 16).is_none());
assert!(encrypt(&KEY, &[0u8; 14], &[], &[], 16).is_none());
assert!(decrypt(&KEY, &[0u8; 6], &[], &[0u8; 16], 16).is_none());
}
#[test]
fn invalid_tag_length_rejected() {
let nonce = [0x42u8; 12];
for tag_len in [0usize, 3, 5, 7, 9, 11, 13, 15, 17, 32] {
assert!(
encrypt(&KEY, &nonce, &[], &[], tag_len).is_none(),
"encrypt accepted invalid tag_len={tag_len}",
);
}
}
#[test]
fn ct_with_tag_shorter_than_tag_len_rejected() {
let nonce = [0x42u8; 12];
assert!(decrypt(&KEY, &nonce, b"aad", &[0u8; 8], 16).is_none());
}
}