use alloc::vec;
use alloc::vec::Vec;
use subtle::ConstantTimeEq;
use super::cipher::{BLOCK_SIZE, KEY_SIZE, Sm4Cipher};
use super::mode_gcm::{GcmTagLen, TAG_SIZE, derive_j0, gctr, inc32};
const GCM_MAX_PT_BYTES: u64 = (1u64 << 36) - 32;
struct GhashAcc {
h: [u8; BLOCK_SIZE],
y: [u8; BLOCK_SIZE],
block: [u8; BLOCK_SIZE],
block_len: usize,
}
impl GhashAcc {
const fn new(h: &[u8; BLOCK_SIZE]) -> Self {
Self {
h: *h,
y: [0u8; BLOCK_SIZE],
block: [0u8; BLOCK_SIZE],
block_len: 0,
}
}
fn fold(&mut self) {
let mut xored = [0u8; BLOCK_SIZE];
for ((x, &yk), &bk) in xored.iter_mut().zip(&self.y).zip(&self.block) {
*x = yk ^ bk;
}
self.y = gmcrypto_simd::ghash::ghash_mul(&self.h, &xored);
self.block = [0u8; BLOCK_SIZE];
self.block_len = 0;
}
fn update(&mut self, data: &[u8]) {
for &b in data {
self.block[self.block_len] = b;
self.block_len += 1;
if self.block_len == BLOCK_SIZE {
self.fold();
}
}
}
fn pad_to_block(&mut self) {
if self.block_len != 0 {
self.fold();
}
}
#[cfg(test)]
fn finish_no_lengths(mut self) -> [u8; BLOCK_SIZE] {
self.pad_to_block();
self.y
}
fn finish_with_lengths(mut self, aad_len: u64, ct_len: u64) -> [u8; BLOCK_SIZE] {
self.pad_to_block();
let mut lb = [0u8; BLOCK_SIZE];
lb[..8].copy_from_slice(&aad_len.saturating_mul(8).to_be_bytes());
lb[8..].copy_from_slice(&ct_len.saturating_mul(8).to_be_bytes());
self.block = lb;
self.block_len = BLOCK_SIZE;
self.fold();
self.y
}
}
fn gctr_block(cipher: &Sm4Cipher, icb: &[u8; BLOCK_SIZE], s: &[u8; BLOCK_SIZE], out: &mut [u8]) {
let mut ks = *icb;
cipher.encrypt_block(&mut ks);
for (i, o) in out.iter_mut().enumerate() {
*o = s[i] ^ ks[i];
}
}
fn init_gcm(
key: &[u8; KEY_SIZE],
nonce: &[u8],
aad: &[u8],
) -> (Sm4Cipher, [u8; BLOCK_SIZE], GhashAcc, u64) {
let cipher = Sm4Cipher::new(key);
let mut h = [0u8; BLOCK_SIZE];
cipher.encrypt_block(&mut h);
let j0 = derive_j0(&h, nonce);
debug_assert!(
u64::try_from(aad.len()).is_ok(),
"AAD length exceeds u64 — unreachable on real hardware",
);
let aad_len = u64::try_from(aad.len()).unwrap_or(u64::MAX);
let mut ghash = GhashAcc::new(&h);
ghash.update(aad);
ghash.pad_to_block();
(cipher, j0, ghash, aad_len)
}
pub struct Sm4GcmEncryptor {
cipher: Sm4Cipher,
j0: [u8; BLOCK_SIZE],
counter: [u8; BLOCK_SIZE],
ks: [u8; BLOCK_SIZE],
ks_pos: usize,
ghash: GhashAcc,
aad_len: u64,
ct_len: u64,
poisoned: bool,
}
impl Sm4GcmEncryptor {
#[must_use]
pub fn new(key: &[u8; KEY_SIZE], nonce: &[u8], aad: &[u8]) -> Self {
let (cipher, j0, ghash, aad_len) = init_gcm(key, nonce, aad);
Self {
cipher,
j0,
counter: inc32(&j0),
ks: [0u8; BLOCK_SIZE],
ks_pos: BLOCK_SIZE, ghash,
aad_len,
ct_len: 0,
poisoned: false,
}
}
#[must_use]
pub fn update(&mut self, chunk: &[u8]) -> Option<Vec<u8>> {
if self.poisoned {
return None;
}
let new_len = self.ct_len.checked_add(u64::try_from(chunk.len()).ok()?)?;
if new_len > GCM_MAX_PT_BYTES {
self.poisoned = true;
return None;
}
let mut out = vec![0u8; chunk.len()];
let mut i = 0;
while i < chunk.len() && self.ks_pos < BLOCK_SIZE {
out[i] = chunk[i] ^ self.ks[self.ks_pos];
self.ks_pos += 1;
i += 1;
}
while chunk.len() - i >= BLOCK_SIZE {
self.ks = self.counter;
self.cipher.encrypt_block(&mut self.ks);
self.counter = inc32(&self.counter);
for lane in 0..BLOCK_SIZE {
out[i + lane] = chunk[i + lane] ^ self.ks[lane];
}
i += BLOCK_SIZE;
}
if i < chunk.len() {
self.ks = self.counter;
self.cipher.encrypt_block(&mut self.ks);
self.counter = inc32(&self.counter);
self.ks_pos = 0;
while i < chunk.len() {
out[i] = chunk[i] ^ self.ks[self.ks_pos];
self.ks_pos += 1;
i += 1;
}
}
self.ghash.update(&out);
self.ct_len = new_len;
Some(out)
}
#[must_use]
pub fn finalize(self) -> [u8; TAG_SIZE] {
let s = self.ghash.finish_with_lengths(self.aad_len, self.ct_len);
let mut tag = [0u8; TAG_SIZE];
gctr_block(&self.cipher, &self.j0, &s, &mut tag);
tag
}
#[must_use]
pub fn finalize_with_tag_len(self, tag_len: GcmTagLen) -> Vec<u8> {
let full = self.finalize();
full[..tag_len.as_usize()].to_vec()
}
}
pub struct Sm4GcmDecryptor {
cipher: Sm4Cipher,
j0: [u8; BLOCK_SIZE],
ghash: GhashAcc,
ct_buf: Vec<u8>,
aad_len: u64,
overflowed: bool,
}
impl Sm4GcmDecryptor {
#[must_use]
pub fn new(key: &[u8; KEY_SIZE], nonce: &[u8], aad: &[u8]) -> Self {
let (cipher, j0, ghash, aad_len) = init_gcm(key, nonce, aad);
Self {
cipher,
j0,
ghash,
ct_buf: Vec::new(),
aad_len,
overflowed: false,
}
}
pub fn update(&mut self, chunk: &[u8]) {
if self.overflowed {
return;
}
let within_ceiling = u64::try_from(self.ct_buf.len())
.ok()
.zip(u64::try_from(chunk.len()).ok())
.and_then(|(cur, c)| cur.checked_add(c))
.is_some_and(|n| n <= GCM_MAX_PT_BYTES);
if !within_ceiling {
self.overflowed = true;
return;
}
self.ghash.update(chunk);
self.ct_buf.extend_from_slice(chunk);
}
#[must_use]
pub fn finalize_verify(self, tag: &[u8]) -> Option<Vec<u8>> {
if self.overflowed {
return None;
}
let _ = GcmTagLen::new(tag.len())?;
let ct_len = u64::try_from(self.ct_buf.len()).ok()?;
let s = self.ghash.finish_with_lengths(self.aad_len, ct_len);
let mut expected_full = [0u8; TAG_SIZE];
gctr_block(&self.cipher, &self.j0, &s, &mut expected_full);
if expected_full[..tag.len()].ct_eq(tag).unwrap_u8() != 1 {
return None;
}
let mut pt = vec![0u8; self.ct_buf.len()];
gctr(&self.cipher, &inc32(&self.j0), &self.ct_buf, &mut pt);
Some(pt)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sm4::mode_gcm;
const KEY: [u8; 16] = [
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32,
0x10,
];
const NONCE_12: [u8; 12] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
#[allow(clippy::cast_possible_truncation)]
fn make_payload(len: usize) -> Vec<u8> {
(0..len as u32).map(|i| (i ^ (i >> 3)) as u8).collect()
}
#[test]
fn ghash_incremental_single_zero_block_is_zero() {
let h = [
0x66u8, 0xe9, 0x4b, 0xd4, 0xef, 0x8a, 0x2c, 0x3b, 0x88, 0x4c, 0xfa, 0x59, 0xca, 0x34,
0x2b, 0x2e,
];
let mut g = GhashAcc::new(&h);
g.update(&[0u8; 16]);
assert_eq!(g.finish_no_lengths(), [0u8; 16]);
}
#[test]
fn encryptor_chunked_matches_single_shot() {
let aad = b"associated header";
let pt = make_payload(200);
let (ref_ct, ref_tag) =
mode_gcm::encrypt(&KEY, &NONCE_12, aad, &pt).expect("under ceiling");
for chunk in [1usize, 7, 15, 16, 17, 31, 32, 33, 100, pt.len().max(1)] {
let mut enc = Sm4GcmEncryptor::new(&KEY, &NONCE_12, aad);
let mut ct = Vec::new();
let mut off = 0;
while off < pt.len() {
let take = chunk.min(pt.len() - off);
ct.extend_from_slice(&enc.update(&pt[off..off + take]).expect("under ceiling"));
off += take;
}
let tag = enc.finalize();
assert_eq!(ct, ref_ct, "ct divergence at chunk {chunk}");
assert_eq!(tag, ref_tag, "tag divergence at chunk {chunk}");
}
}
#[test]
fn encryptor_tag_len_matches_single_shot_truncation() {
let aad = b"h";
let pt = b"tag-len finalize path";
let (_, full) = mode_gcm::encrypt(&KEY, &NONCE_12, aad, pt).expect("under ceiling");
let mut enc = Sm4GcmEncryptor::new(&KEY, &NONCE_12, aad);
let _ = enc.update(pt).unwrap();
let tag = enc.finalize_with_tag_len(GcmTagLen::new(12).unwrap());
assert_eq!(tag.as_slice(), &full[..12]);
}
#[test]
fn encryptor_empty_updates_are_noops() {
let mut enc = Sm4GcmEncryptor::new(&KEY, &NONCE_12, b"a");
assert_eq!(enc.update(&[]).unwrap().len(), 0);
assert_eq!(enc.update(&[]).unwrap().len(), 0);
let _ = enc.finalize();
}
#[test]
fn decryptor_chunked_matches_single_shot() {
let aad = b"associated header";
let pt = make_payload(200);
let (ct, tag) = mode_gcm::encrypt(&KEY, &NONCE_12, aad, &pt).expect("under ceiling");
for chunk in [1usize, 7, 15, 16, 17, 31, 32, 33, 100, ct.len().max(1)] {
let mut dec = Sm4GcmDecryptor::new(&KEY, &NONCE_12, aad);
let mut off = 0;
while off < ct.len() {
let take = chunk.min(ct.len() - off);
dec.update(&ct[off..off + take]);
off += take;
}
let got = dec.finalize_verify(&tag);
assert_eq!(
got.as_deref(),
Some(pt.as_slice()),
"divergence at chunk {chunk}"
);
}
}
#[test]
fn decryptor_rejects_tampered_tag() {
let aad = b"h";
let pt = b"tamper target";
let (ct, mut tag) = mode_gcm::encrypt(&KEY, &NONCE_12, aad, pt).expect("under ceiling");
tag[0] ^= 0x01;
let mut dec = Sm4GcmDecryptor::new(&KEY, &NONCE_12, aad);
dec.update(&ct);
assert!(dec.finalize_verify(&tag).is_none());
}
#[test]
fn decryptor_rejects_invalid_tag_length() {
let aad = b"h";
let pt = b"bad tag length";
let (ct, tag) = mode_gcm::encrypt(&KEY, &NONCE_12, aad, pt).expect("under ceiling");
let mut dec = Sm4GcmDecryptor::new(&KEY, &NONCE_12, aad);
dec.update(&ct);
assert!(dec.finalize_verify(&tag[..5]).is_none());
}
#[test]
fn decryptor_supports_truncated_tag() {
let aad = b"h";
let pt = b"short tag decrypt";
let mut enc = Sm4GcmEncryptor::new(&KEY, &NONCE_12, aad);
let ct = enc.update(pt).unwrap();
let tag12 = enc.finalize_with_tag_len(GcmTagLen::new(12).unwrap());
let mut dec = Sm4GcmDecryptor::new(&KEY, &NONCE_12, aad);
dec.update(&ct);
assert_eq!(dec.finalize_verify(&tag12).as_deref(), Some(pt.as_slice()));
}
#[test]
fn decryptor_empty_then_verify() {
let (ct, tag) = mode_gcm::encrypt(&KEY, &NONCE_12, b"a", &[]).expect("under ceiling");
let mut dec = Sm4GcmDecryptor::new(&KEY, &NONCE_12, b"a");
dec.update(&[]);
dec.update(&ct);
assert_eq!(dec.finalize_verify(&tag).as_deref(), Some(&[][..]));
}
#[test]
fn round_trip_through_streaming_both_directions() {
let aad = b"end to end";
let pt = make_payload(137);
let mut enc = Sm4GcmEncryptor::new(&KEY, &NONCE_12, aad);
let mut ct = Vec::new();
for c in pt.chunks(13) {
ct.extend_from_slice(&enc.update(c).unwrap());
}
let tag = enc.finalize();
let mut dec = Sm4GcmDecryptor::new(&KEY, &NONCE_12, aad);
for c in ct.chunks(11) {
dec.update(c);
}
assert_eq!(dec.finalize_verify(&tag).as_deref(), Some(pt.as_slice()));
}
#[test]
fn streaming_matches_single_shot_with_non_12_byte_nonce() {
let nonce: [u8; 7] = [0x42; 7];
let aad = b"short nonce";
let pt = make_payload(80);
let (ref_ct, ref_tag) = mode_gcm::encrypt(&KEY, &nonce, aad, &pt).expect("under ceiling");
let mut enc = Sm4GcmEncryptor::new(&KEY, &nonce, aad);
let mut ct = Vec::new();
for c in pt.chunks(16) {
ct.extend_from_slice(&enc.update(c).unwrap());
}
assert_eq!(ct, ref_ct);
assert_eq!(enc.finalize(), ref_tag);
}
}