use alloc::vec;
use alloc::vec::Vec;
use subtle::ConstantTimeEq;
use super::cipher::{BLOCK_SIZE, KEY_SIZE, Sm4Cipher};
pub const TAG_SIZE: usize = 16;
#[must_use]
pub fn encrypt(
key: &[u8; KEY_SIZE],
nonce: &[u8],
aad: &[u8],
plaintext: &[u8],
) -> (Vec<u8>, [u8; TAG_SIZE]) {
let cipher = Sm4Cipher::new(key);
let mut h_block = [0u8; BLOCK_SIZE];
cipher.encrypt_block(&mut h_block);
let j0 = derive_j0(&h_block, nonce);
let mut ciphertext = vec![0u8; plaintext.len()];
gctr(&cipher, &inc32(&j0), plaintext, &mut ciphertext);
let s = ghash_a_c_lens(&h_block, aad, &ciphertext);
let mut tag = [0u8; TAG_SIZE];
gctr(&cipher, &j0, &s, &mut tag);
(ciphertext, tag)
}
#[must_use]
pub fn decrypt(
key: &[u8; KEY_SIZE],
nonce: &[u8],
aad: &[u8],
ciphertext: &[u8],
tag: &[u8; TAG_SIZE],
) -> Option<Vec<u8>> {
let cipher = Sm4Cipher::new(key);
let mut h_block = [0u8; BLOCK_SIZE];
cipher.encrypt_block(&mut h_block);
let j0 = derive_j0(&h_block, nonce);
let s = ghash_a_c_lens(&h_block, aad, ciphertext);
let mut expected_tag = [0u8; TAG_SIZE];
gctr(&cipher, &j0, &s, &mut expected_tag);
if expected_tag.ct_eq(tag).unwrap_u8() != 1 {
return None;
}
let mut plaintext = vec![0u8; ciphertext.len()];
gctr(&cipher, &inc32(&j0), ciphertext, &mut plaintext);
Some(plaintext)
}
const fn inc32(b: &[u8; BLOCK_SIZE]) -> [u8; BLOCK_SIZE] {
let mut out = *b;
let mut counter = u32::from_be_bytes([out[12], out[13], out[14], out[15]]);
counter = counter.wrapping_add(1);
let bytes = counter.to_be_bytes();
out[12] = bytes[0];
out[13] = bytes[1];
out[14] = bytes[2];
out[15] = bytes[3];
out
}
fn gctr(cipher: &Sm4Cipher, icb: &[u8; BLOCK_SIZE], input: &[u8], out: &mut [u8]) {
debug_assert_eq!(out.len(), input.len());
if input.is_empty() {
return;
}
let block_count = input.len().div_ceil(BLOCK_SIZE);
let mut keystream: Vec<[u8; BLOCK_SIZE]> = Vec::with_capacity(block_count);
let mut cb = *icb;
for _ in 0..block_count {
keystream.push(cb);
cb = inc32(&cb);
}
cipher.encrypt_blocks(&mut keystream);
for (i, &b) in input.iter().enumerate() {
let block_idx = i / BLOCK_SIZE;
let lane = i % BLOCK_SIZE;
out[i] = b ^ keystream[block_idx][lane];
}
}
fn derive_j0(h_block: &[u8; BLOCK_SIZE], nonce: &[u8]) -> [u8; BLOCK_SIZE] {
if nonce.len() == 12 {
let mut j0 = [0u8; BLOCK_SIZE];
j0[..12].copy_from_slice(nonce);
j0[15] = 0x01;
return j0;
}
let nonce_bit_len = u64::try_from(nonce.len())
.unwrap_or(u64::MAX)
.saturating_mul(8);
let mut padded = Vec::with_capacity(nonce.len() + BLOCK_SIZE + BLOCK_SIZE);
padded.extend_from_slice(nonce);
while padded.len() % BLOCK_SIZE != 0 {
padded.push(0);
}
padded.extend_from_slice(&[0u8; 8]);
padded.extend_from_slice(&nonce_bit_len.to_be_bytes());
ghash(h_block, &padded)
}
fn ghash_a_c_lens(h_block: &[u8; BLOCK_SIZE], aad: &[u8], ct: &[u8]) -> [u8; BLOCK_SIZE] {
let mut buf = Vec::with_capacity(aad.len() + BLOCK_SIZE + ct.len() + BLOCK_SIZE + BLOCK_SIZE);
buf.extend_from_slice(aad);
while buf.len() % BLOCK_SIZE != 0 {
buf.push(0);
}
let aad_end = buf.len();
buf.extend_from_slice(ct);
while buf.len() % BLOCK_SIZE != 0 {
buf.push(0);
}
debug_assert_eq!((buf.len() - aad_end) % BLOCK_SIZE, 0);
let aad_bits = u64::try_from(aad.len())
.unwrap_or(u64::MAX)
.saturating_mul(8);
let ct_bits = u64::try_from(ct.len())
.unwrap_or(u64::MAX)
.saturating_mul(8);
buf.extend_from_slice(&aad_bits.to_be_bytes());
buf.extend_from_slice(&ct_bits.to_be_bytes());
ghash(h_block, &buf)
}
fn ghash(h_block: &[u8; BLOCK_SIZE], data: &[u8]) -> [u8; BLOCK_SIZE] {
debug_assert_eq!(data.len() % BLOCK_SIZE, 0);
let mut y = [0u8; BLOCK_SIZE];
let mut i = 0;
while i < data.len() {
let mut xored = [0u8; BLOCK_SIZE];
for k in 0..BLOCK_SIZE {
xored[k] = y[k] ^ data[i + k];
}
y = gmcrypto_simd::ghash::ghash_mul(h_block, &xored);
i += BLOCK_SIZE;
}
y
}
#[cfg(test)]
mod tests {
use super::*;
const KEY: [u8; 16] = [
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32,
0x10,
];
const NONCE_12: [u8; 12] = [
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b,
];
#[test]
fn round_trip_canonical_nonce() {
let aad = b"associated data";
let plaintext = b"v0.8 W2 SM4-GCM round-trip smoke test";
let (ct, tag) = encrypt(&KEY, &NONCE_12, aad, plaintext);
let recovered = decrypt(&KEY, &NONCE_12, aad, &ct, &tag).expect("tag verifies");
assert_eq!(recovered, plaintext);
}
#[test]
fn round_trip_empty_plaintext() {
let aad = b"aad-only message";
let (ct, tag) = encrypt(&KEY, &NONCE_12, aad, &[]);
assert!(ct.is_empty());
let recovered = decrypt(&KEY, &NONCE_12, aad, &ct, &tag).expect("tag verifies");
assert_eq!(recovered, &[] as &[u8]);
}
#[test]
fn round_trip_empty_aad() {
let plaintext = b"hello GCM, no AAD";
let (ct, tag) = encrypt(&KEY, &NONCE_12, &[], plaintext);
let recovered = decrypt(&KEY, &NONCE_12, &[], &ct, &tag).expect("tag verifies");
assert_eq!(recovered, plaintext);
}
#[test]
fn round_trip_non_12_byte_nonce() {
let nonce: [u8; 7] = [0x42u8; 7];
let aad = b"aad";
let plaintext = b"short-nonce SM4-GCM";
let (ct, tag) = encrypt(&KEY, &nonce, aad, plaintext);
let recovered = decrypt(&KEY, &nonce, aad, &ct, &tag).expect("tag verifies");
assert_eq!(recovered, plaintext);
}
#[test]
fn tampered_tag_fails() {
let aad = b"x";
let plaintext = b"original";
let (ct, mut tag) = encrypt(&KEY, &NONCE_12, aad, plaintext);
tag[0] ^= 0x01;
assert!(decrypt(&KEY, &NONCE_12, aad, &ct, &tag).is_none());
}
#[test]
fn tampered_ciphertext_fails() {
let aad = b"x";
let plaintext = b"original";
let (mut ct, tag) = encrypt(&KEY, &NONCE_12, aad, plaintext);
if !ct.is_empty() {
ct[0] ^= 0x01;
}
assert!(decrypt(&KEY, &NONCE_12, aad, &ct, &tag).is_none());
}
#[test]
fn tampered_aad_fails() {
let aad = b"correct-aad";
let plaintext = b"original";
let (ct, tag) = encrypt(&KEY, &NONCE_12, aad, plaintext);
assert!(decrypt(&KEY, &NONCE_12, b"wrong-aad", &ct, &tag).is_none());
}
}