use alloc::vec;
use alloc::vec::Vec;
use subtle::ConstantTimeEq;
use super::cipher::{BLOCK_SIZE, KEY_SIZE, Sm4Cipher};
pub const TAG_SIZE: usize = 16;
pub(crate) const GCM_MAX_PT_BYTES: u64 = (1u64 << 36) - 32;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct GcmTagLen(usize);
impl GcmTagLen {
#[must_use]
pub const fn new(bytes: usize) -> Option<Self> {
match bytes {
4 | 8 | 12 | 13 | 14 | 15 | 16 => Some(Self(bytes)),
_ => None,
}
}
#[must_use]
pub const fn as_usize(self) -> usize {
self.0
}
}
#[must_use]
pub fn encrypt(
key: &[u8; KEY_SIZE],
nonce: &[u8],
aad: &[u8],
plaintext: &[u8],
) -> Option<(Vec<u8>, [u8; TAG_SIZE])> {
if plaintext.len() as u64 > GCM_MAX_PT_BYTES {
return None;
}
let cipher = Sm4Cipher::new(key);
let mut h_block = [0u8; BLOCK_SIZE];
cipher.encrypt_block(&mut h_block);
let j0 = derive_j0(&h_block, nonce);
let mut ciphertext = vec![0u8; plaintext.len()];
gctr(&cipher, &inc32(&j0), plaintext, &mut ciphertext);
let s = ghash_a_c_lens(&h_block, aad, &ciphertext);
let mut tag = [0u8; TAG_SIZE];
gctr(&cipher, &j0, &s, &mut tag);
Some((ciphertext, tag))
}
#[must_use]
pub fn decrypt(
key: &[u8; KEY_SIZE],
nonce: &[u8],
aad: &[u8],
ciphertext: &[u8],
tag: &[u8; TAG_SIZE],
) -> Option<Vec<u8>> {
if ciphertext.len() as u64 > GCM_MAX_PT_BYTES {
return None;
}
let cipher = Sm4Cipher::new(key);
let mut h_block = [0u8; BLOCK_SIZE];
cipher.encrypt_block(&mut h_block);
let j0 = derive_j0(&h_block, nonce);
let s = ghash_a_c_lens(&h_block, aad, ciphertext);
let mut expected_tag = [0u8; TAG_SIZE];
gctr(&cipher, &j0, &s, &mut expected_tag);
if expected_tag.ct_eq(tag).unwrap_u8() != 1 {
return None;
}
let mut plaintext = vec![0u8; ciphertext.len()];
gctr(&cipher, &inc32(&j0), ciphertext, &mut plaintext);
Some(plaintext)
}
#[must_use]
pub fn encrypt_with_tag_len(
key: &[u8; KEY_SIZE],
nonce: &[u8],
aad: &[u8],
plaintext: &[u8],
tag_len: GcmTagLen,
) -> Option<(Vec<u8>, Vec<u8>)> {
if plaintext.len() as u64 > GCM_MAX_PT_BYTES {
return None;
}
let (ciphertext, full_tag) = encrypt(key, nonce, aad, plaintext)?;
let tag = full_tag[..tag_len.as_usize()].to_vec();
Some((ciphertext, tag))
}
#[must_use]
pub fn decrypt_with_tag_len(
key: &[u8; KEY_SIZE],
nonce: &[u8],
aad: &[u8],
ciphertext: &[u8],
tag: &[u8],
) -> Option<Vec<u8>> {
if ciphertext.len() as u64 > GCM_MAX_PT_BYTES {
return None;
}
let tag_len = GcmTagLen::new(tag.len())?;
let t = tag_len.as_usize();
let cipher = Sm4Cipher::new(key);
let mut h_block = [0u8; BLOCK_SIZE];
cipher.encrypt_block(&mut h_block);
let j0 = derive_j0(&h_block, nonce);
let s = ghash_a_c_lens(&h_block, aad, ciphertext);
let mut expected_full = [0u8; TAG_SIZE];
gctr(&cipher, &j0, &s, &mut expected_full);
if expected_full[..t].ct_eq(tag).unwrap_u8() != 1 {
return None;
}
let mut plaintext = vec![0u8; ciphertext.len()];
gctr(&cipher, &inc32(&j0), ciphertext, &mut plaintext);
Some(plaintext)
}
pub(super) const fn inc32(b: &[u8; BLOCK_SIZE]) -> [u8; BLOCK_SIZE] {
let mut out = *b;
let mut counter = u32::from_be_bytes([out[12], out[13], out[14], out[15]]);
counter = counter.wrapping_add(1);
let bytes = counter.to_be_bytes();
out[12] = bytes[0];
out[13] = bytes[1];
out[14] = bytes[2];
out[15] = bytes[3];
out
}
pub(super) fn gctr(cipher: &Sm4Cipher, icb: &[u8; BLOCK_SIZE], input: &[u8], out: &mut [u8]) {
debug_assert_eq!(out.len(), input.len());
if input.is_empty() {
return;
}
let block_count = input.len().div_ceil(BLOCK_SIZE);
let mut keystream: Vec<[u8; BLOCK_SIZE]> = Vec::with_capacity(block_count);
let mut cb = *icb;
for _ in 0..block_count {
keystream.push(cb);
cb = inc32(&cb);
}
cipher.encrypt_blocks(&mut keystream);
for (i, &b) in input.iter().enumerate() {
let block_idx = i / BLOCK_SIZE;
let lane = i % BLOCK_SIZE;
out[i] = b ^ keystream[block_idx][lane];
}
}
pub(super) fn derive_j0(h_block: &[u8; BLOCK_SIZE], nonce: &[u8]) -> [u8; BLOCK_SIZE] {
if nonce.len() == 12 {
let mut j0 = [0u8; BLOCK_SIZE];
j0[..12].copy_from_slice(nonce);
j0[15] = 0x01;
return j0;
}
let nonce_bit_len = u64::try_from(nonce.len())
.unwrap_or(u64::MAX)
.saturating_mul(8);
let mut padded = Vec::with_capacity(nonce.len() + BLOCK_SIZE + BLOCK_SIZE);
padded.extend_from_slice(nonce);
while padded.len() % BLOCK_SIZE != 0 {
padded.push(0);
}
padded.extend_from_slice(&[0u8; 8]);
padded.extend_from_slice(&nonce_bit_len.to_be_bytes());
ghash(h_block, &padded)
}
fn ghash_a_c_lens(h_block: &[u8; BLOCK_SIZE], aad: &[u8], ct: &[u8]) -> [u8; BLOCK_SIZE] {
let mut buf = Vec::with_capacity(aad.len() + BLOCK_SIZE + ct.len() + BLOCK_SIZE + BLOCK_SIZE);
buf.extend_from_slice(aad);
while buf.len() % BLOCK_SIZE != 0 {
buf.push(0);
}
let aad_end = buf.len();
buf.extend_from_slice(ct);
while buf.len() % BLOCK_SIZE != 0 {
buf.push(0);
}
debug_assert_eq!((buf.len() - aad_end) % BLOCK_SIZE, 0);
let aad_bits = u64::try_from(aad.len())
.unwrap_or(u64::MAX)
.saturating_mul(8);
let ct_bits = u64::try_from(ct.len())
.unwrap_or(u64::MAX)
.saturating_mul(8);
buf.extend_from_slice(&aad_bits.to_be_bytes());
buf.extend_from_slice(&ct_bits.to_be_bytes());
ghash(h_block, &buf)
}
fn ghash(h_block: &[u8; BLOCK_SIZE], data: &[u8]) -> [u8; BLOCK_SIZE] {
debug_assert_eq!(data.len() % BLOCK_SIZE, 0);
let mut y = [0u8; BLOCK_SIZE];
let mut i = 0;
while i < data.len() {
let mut xored = [0u8; BLOCK_SIZE];
for k in 0..BLOCK_SIZE {
xored[k] = y[k] ^ data[i + k];
}
y = gmcrypto_simd::ghash::ghash_mul(h_block, &xored);
i += BLOCK_SIZE;
}
y
}
#[cfg(test)]
mod tests {
use super::*;
const KEY: [u8; 16] = [
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32,
0x10,
];
const NONCE_12: [u8; 12] = [
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b,
];
#[test]
fn round_trip_canonical_nonce() {
let aad = b"associated data";
let plaintext = b"v0.8 W2 SM4-GCM round-trip smoke test";
let (ct, tag) = encrypt(&KEY, &NONCE_12, aad, plaintext).expect("under ceiling");
let recovered = decrypt(&KEY, &NONCE_12, aad, &ct, &tag).expect("tag verifies");
assert_eq!(recovered, plaintext);
}
#[test]
fn round_trip_empty_plaintext() {
let aad = b"aad-only message";
let (ct, tag) = encrypt(&KEY, &NONCE_12, aad, &[]).expect("under ceiling");
assert!(ct.is_empty());
let recovered = decrypt(&KEY, &NONCE_12, aad, &ct, &tag).expect("tag verifies");
assert_eq!(recovered, &[] as &[u8]);
}
#[test]
fn round_trip_empty_aad() {
let plaintext = b"hello GCM, no AAD";
let (ct, tag) = encrypt(&KEY, &NONCE_12, &[], plaintext).expect("under ceiling");
let recovered = decrypt(&KEY, &NONCE_12, &[], &ct, &tag).expect("tag verifies");
assert_eq!(recovered, plaintext);
}
#[test]
fn round_trip_non_12_byte_nonce() {
let nonce: [u8; 7] = [0x42u8; 7];
let aad = b"aad";
let plaintext = b"short-nonce SM4-GCM";
let (ct, tag) = encrypt(&KEY, &nonce, aad, plaintext).expect("under ceiling");
let recovered = decrypt(&KEY, &nonce, aad, &ct, &tag).expect("tag verifies");
assert_eq!(recovered, plaintext);
}
#[test]
fn tampered_tag_fails() {
let aad = b"x";
let plaintext = b"original";
let (ct, mut tag) = encrypt(&KEY, &NONCE_12, aad, plaintext).expect("under ceiling");
tag[0] ^= 0x01;
assert!(decrypt(&KEY, &NONCE_12, aad, &ct, &tag).is_none());
}
#[test]
fn tampered_ciphertext_fails() {
let aad = b"x";
let plaintext = b"original";
let (mut ct, tag) = encrypt(&KEY, &NONCE_12, aad, plaintext).expect("under ceiling");
if !ct.is_empty() {
ct[0] ^= 0x01;
}
assert!(decrypt(&KEY, &NONCE_12, aad, &ct, &tag).is_none());
}
#[test]
fn tampered_aad_fails() {
let aad = b"correct-aad";
let plaintext = b"original";
let (ct, tag) = encrypt(&KEY, &NONCE_12, aad, plaintext).expect("under ceiling");
assert!(decrypt(&KEY, &NONCE_12, b"wrong-aad", &ct, &tag).is_none());
}
#[test]
fn gcm_tag_len_accepts_valid_lengths() {
for &n in &[4usize, 8, 12, 13, 14, 15, 16] {
assert_eq!(GcmTagLen::new(n).map(GcmTagLen::as_usize), Some(n));
}
}
#[test]
fn gcm_tag_len_rejects_invalid_lengths() {
for &n in &[0usize, 1, 2, 3, 5, 6, 7, 9, 10, 11, 17, 32] {
assert!(GcmTagLen::new(n).is_none(), "len {n} must be rejected");
}
}
#[test]
fn tag_len_truncation_matches_full_tag_prefix() {
let aad = b"hdr";
let pt = b"truncate me to a short tag";
let (ct_full, tag_full) = encrypt(&KEY, &NONCE_12, aad, pt).expect("under ceiling");
for &n in &[4usize, 8, 12, 13, 14, 15, 16] {
let tl = GcmTagLen::new(n).unwrap();
let (ct_t, tag_t) =
encrypt_with_tag_len(&KEY, &NONCE_12, aad, pt, tl).expect("under ceiling");
assert_eq!(ct_t, ct_full, "ciphertext invariant under tag_len {n}");
assert_eq!(tag_t.as_slice(), &tag_full[..n], "tag = MSB_n(full) at {n}");
}
}
#[test]
fn tag_len_round_trip() {
let aad = b"hdr";
let pt = b"round trip under every tag length";
for &n in &[4usize, 8, 12, 13, 14, 15, 16] {
let tl = GcmTagLen::new(n).unwrap();
let (ct, tag) =
encrypt_with_tag_len(&KEY, &NONCE_12, aad, pt, tl).expect("under ceiling");
let got = decrypt_with_tag_len(&KEY, &NONCE_12, aad, &ct, &tag);
assert_eq!(
got.as_deref(),
Some(pt.as_slice()),
"round trip at tag_len {n}"
);
}
}
#[test]
fn tag_len_decrypt_rejects_bad_tag_and_bad_len() {
let aad = b"hdr";
let pt = b"reject me";
let tl = GcmTagLen::new(12).unwrap();
let (ct, mut tag) =
encrypt_with_tag_len(&KEY, &NONCE_12, aad, pt, tl).expect("under ceiling");
tag[0] ^= 0x01;
assert!(decrypt_with_tag_len(&KEY, &NONCE_12, aad, &ct, &tag).is_none());
assert!(decrypt_with_tag_len(&KEY, &NONCE_12, aad, &ct, &tag[..5]).is_none());
}
#[test]
fn gcm_max_pt_bytes_matches_spec() {
assert_eq!(GCM_MAX_PT_BYTES, (1u64 << 36) - 32);
assert_eq!(GCM_MAX_PT_BYTES, 68_719_476_704);
}
#[test]
fn tag_len_full_16_matches_plain_decrypt() {
let aad = b"hdr";
let pt = b"cross-API consistency";
let tl = GcmTagLen::new(16).unwrap();
let (ct, tag) = encrypt_with_tag_len(&KEY, &NONCE_12, aad, pt, tl).expect("under ceiling");
let tag16: [u8; TAG_SIZE] = tag.as_slice().try_into().unwrap();
assert_eq!(
decrypt(&KEY, &NONCE_12, aad, &ct, &tag16).as_deref(),
Some(pt.as_slice()),
);
}
}