use super::{Aes128, Aes256, BlockCipher, TagMismatch};
use crate::ct::ConstantTimeEq;
fn polyval_mul(a: u128, b: u128) -> u128 {
const R: u128 = 0xe1000000000000000000000000000000;
let mut z = 0u128;
let mut v = a;
let mut i = 128;
while i > 0 {
i -= 1;
let lsb = v & 1;
v >>= 1;
v ^= 0u128.wrapping_sub(lsb) & R;
let bit = (b >> i) & 1;
z ^= 0u128.wrapping_sub(bit) & v;
}
z
}
struct Polyval {
h: u128,
acc: u128,
}
impl Polyval {
fn new(h: &[u8; 16]) -> Self {
Polyval {
h: u128::from_le_bytes(*h),
acc: 0,
}
}
fn update_block(&mut self, block: &[u8; 16]) {
self.acc ^= u128::from_le_bytes(*block);
self.acc = polyval_mul(self.acc, self.h);
}
fn finish(self) -> [u8; 16] {
self.acc.to_le_bytes()
}
}
enum Cipher {
Aes128(Aes128),
Aes256(Aes256),
}
impl Cipher {
fn encrypt_block(&self, block: &mut [u8; 16]) {
match self {
Cipher::Aes128(c) => c.encrypt_block(block),
Cipher::Aes256(c) => c.encrypt_block(block),
}
}
fn encrypt_blocks(&self, blocks: &mut [u8]) {
match self {
Cipher::Aes128(c) => c.encrypt_blocks(blocks),
Cipher::Aes256(c) => c.encrypt_blocks(blocks),
}
}
}
pub struct AesGcmSiv {
cipher: Cipher,
key_len: usize,
kgk: [u8; 32],
}
impl AesGcmSiv {
pub fn new(key: &[u8]) -> Self {
let mut kgk = [0u8; 32];
let cipher = match key.len() {
16 => {
kgk[..16].copy_from_slice(key);
Cipher::Aes128(Aes128::new(key.try_into().unwrap()))
}
32 => {
kgk.copy_from_slice(key);
Cipher::Aes256(Aes256::new(key.try_into().unwrap()))
}
_ => panic!("AES-GCM-SIV key must be 16 bytes (AES-128) or 32 bytes (AES-256)"),
};
AesGcmSiv {
cipher,
key_len: key.len(),
kgk,
}
}
fn derive_keys(&self, nonce: &[u8; 12]) -> ([u8; 16], Cipher) {
let mut block = [0u8; 16];
block[4..].copy_from_slice(nonce);
let mut auth_key = [0u8; 16];
let mut enc_key = [0u8; 32];
let enc_blocks = self.key_len / 8;
for counter in 0u32..(2 + enc_blocks as u32) {
block[..4].copy_from_slice(&counter.to_le_bytes());
let mut b = block;
self.cipher.encrypt_block(&mut b);
let half = &b[..8];
let idx = counter as usize;
if idx < 2 {
auth_key[idx * 8..idx * 8 + 8].copy_from_slice(half);
} else {
let j = idx - 2;
enc_key[j * 8..j * 8 + 8].copy_from_slice(half);
}
}
let enc_cipher = match self.key_len {
16 => Cipher::Aes128(Aes128::new(enc_key[..16].try_into().unwrap())),
_ => Cipher::Aes256(Aes256::new(&enc_key)),
};
enc_key = [0u8; 32];
let _ = core::hint::black_box(&enc_key);
(auth_key, enc_cipher)
}
pub const MAX_PLAINTEXT_LEN: u64 = 1u64 << 36;
pub const MAX_AAD_LEN: u64 = 1u64 << 36;
fn validate(aad: &[u8], buffer: &[u8]) {
assert!(
(buffer.len() as u64) <= Self::MAX_PLAINTEXT_LEN,
"AES-GCM-SIV plaintext exceeds 2^36 bytes (RFC 8452 §6)"
);
assert!(
(aad.len() as u64) <= Self::MAX_AAD_LEN,
"AES-GCM-SIV AAD exceeds 2^36 bytes (RFC 8452 §6)"
);
}
fn make_tag(
auth_key: &[u8; 16],
enc_cipher: &Cipher,
nonce: &[u8; 12],
aad: &[u8],
plaintext: &[u8],
) -> [u8; 16] {
let mut pv = Polyval::new(auth_key);
let mut chunks = aad.chunks_exact(16);
for c in chunks.by_ref() {
let mut b = [0u8; 16];
b.copy_from_slice(c);
pv.update_block(&b);
}
let rem = chunks.remainder();
if !rem.is_empty() {
let mut b = [0u8; 16];
b[..rem.len()].copy_from_slice(rem);
pv.update_block(&b);
}
let mut chunks = plaintext.chunks_exact(16);
for c in chunks.by_ref() {
let mut b = [0u8; 16];
b.copy_from_slice(c);
pv.update_block(&b);
}
let rem = chunks.remainder();
if !rem.is_empty() {
let mut b = [0u8; 16];
b[..rem.len()].copy_from_slice(rem);
pv.update_block(&b);
}
let mut len_block = [0u8; 16];
len_block[..8].copy_from_slice(&((aad.len() as u64) * 8).to_le_bytes());
len_block[8..].copy_from_slice(&((plaintext.len() as u64) * 8).to_le_bytes());
pv.update_block(&len_block);
let mut s = pv.finish();
for i in 0..12 {
s[i] ^= nonce[i];
}
s[15] &= 0x7f;
enc_cipher.encrypt_block(&mut s);
s
}
fn ctr(enc_cipher: &Cipher, tag: &[u8; 16], buf: &mut [u8]) {
let mut counter = *tag;
counter[15] |= 0x80;
const W: usize = 64; let mut ks = [0u8; 16 * W];
let mut off = 0;
while off < buf.len() {
let n = (buf.len() - off).min(16 * W);
let blocks = n.div_ceil(16);
for blk in ks[..blocks * 16].chunks_exact_mut(16) {
blk.copy_from_slice(&counter);
let c = u32::from_le_bytes([counter[0], counter[1], counter[2], counter[3]])
.wrapping_add(1);
counter[..4].copy_from_slice(&c.to_le_bytes());
}
enc_cipher.encrypt_blocks(&mut ks[..blocks * 16]);
for (b, k) in buf[off..off + n].iter_mut().zip(ks[..n].iter()) {
*b ^= *k;
}
off += n;
}
}
pub fn encrypt(&self, nonce: &[u8; 12], aad: &[u8], buffer: &mut [u8]) -> [u8; 16] {
Self::validate(aad, buffer);
let (auth_key, enc_cipher) = self.derive_keys(nonce);
let tag = Self::make_tag(&auth_key, &enc_cipher, nonce, aad, buffer);
Self::ctr(&enc_cipher, &tag, buffer);
tag
}
pub fn decrypt(
&self,
nonce: &[u8; 12],
aad: &[u8],
buffer: &mut [u8],
tag: &[u8; 16],
) -> Result<(), TagMismatch> {
Self::validate(aad, buffer);
let (auth_key, enc_cipher) = self.derive_keys(nonce);
Self::ctr(&enc_cipher, tag, buffer);
let expected = Self::make_tag(&auth_key, &enc_cipher, nonce, aad, buffer);
if bool::from(expected.ct_eq(tag)) {
Ok(())
} else {
for b in buffer.iter_mut() {
*b = 0;
}
Err(TagMismatch)
}
}
}
impl Drop for AesGcmSiv {
fn drop(&mut self) {
self.kgk = [0u8; 32];
let _ = core::hint::black_box(&self.kgk);
}
}
pub type Aes128GcmSiv = AesGcmSiv;
pub type Aes256GcmSiv = AesGcmSiv;
#[cfg(test)]
mod tests {
use super::*;
use crate::test_util::{from_hex, from_hex_vec};
#[test]
fn polyval_worked_example() {
let h = from_hex::<16>("25629347589242761d31f826ba4b757b");
let x1 = from_hex::<16>("4f4f95668c83dfb6401762bb2d01a262");
let x2 = from_hex::<16>("d1a24ddd2721d006bbe45f20d3c9f362");
let mut pv = Polyval::new(&h);
pv.update_block(&x1);
pv.update_block(&x2);
assert_eq!(
pv.finish(),
from_hex::<16>("f7a3b47b846119fae5b7866cf5e5b77e")
);
}
#[test]
fn rfc8452_c1_empty() {
let key = from_hex::<16>("01000000000000000000000000000000");
let nonce = from_hex::<12>("030000000000000000000000");
let siv = AesGcmSiv::new(&key);
let mut buf: [u8; 0] = [];
let tag = siv.encrypt(&nonce, &[], &mut buf);
assert_eq!(tag, from_hex::<16>("dc20e2d83f25705bb49e439eca56de25"));
}
#[test]
fn rfc8452_c1_8byte() {
let key = from_hex::<16>("01000000000000000000000000000000");
let nonce = from_hex::<12>("030000000000000000000000");
let siv = AesGcmSiv::new(&key);
let mut buf = from_hex::<8>("0100000000000000");
let tag = siv.encrypt(&nonce, &[], &mut buf);
assert_eq!(buf, from_hex::<8>("b5d839330ac7b786"));
let mut full = buf.to_vec();
full.extend_from_slice(&tag);
assert_eq!(
full,
from_hex_vec("b5d839330ac7b786578782fff6013b815b287c22493a364c")
);
}
#[test]
fn rfc8452_c1_with_aad() {
let key = from_hex::<16>("01000000000000000000000000000000");
let nonce = from_hex::<12>("030000000000000000000000");
let aad = from_hex::<1>("01");
let siv = AesGcmSiv::new(&key);
let mut buf = from_hex::<8>("0200000000000000");
let tag = siv.encrypt(&nonce, &aad, &mut buf);
assert_eq!(buf, from_hex::<8>("1e6daba35669f427"));
assert_eq!(tag, from_hex::<16>("3b0a1a2560969cdf790d99759abd1508"));
}
#[test]
fn rfc8452_c2_aes256_empty() {
let key = from_hex::<32>(
"01000000000000000000000000000000\
00000000000000000000000000000000",
);
let nonce = from_hex::<12>("030000000000000000000000");
let siv = AesGcmSiv::new(&key);
let mut buf: [u8; 0] = [];
let tag = siv.encrypt(&nonce, &[], &mut buf);
assert_eq!(tag, from_hex::<16>("07f5f4169bbf55a8400cd47ea6fd400f"));
}
#[test]
fn rfc8452_c2_aes256_8byte() {
let key = from_hex::<32>(
"01000000000000000000000000000000\
00000000000000000000000000000000",
);
let nonce = from_hex::<12>("030000000000000000000000");
let siv = AesGcmSiv::new(&key);
let mut buf = from_hex::<8>("0100000000000000");
let tag = siv.encrypt(&nonce, &[], &mut buf);
assert_eq!(buf, from_hex::<8>("c2ef328e5c71c83b"));
assert_eq!(tag, from_hex::<16>("843122130f7364b761e0b97427e3df28"));
}
#[test]
fn rfc8452_length_caps() {
assert_eq!(AesGcmSiv::MAX_PLAINTEXT_LEN, 1u64 << 36);
assert_eq!(AesGcmSiv::MAX_AAD_LEN, 1u64 << 36);
AesGcmSiv::validate(&[0u8; 32], &[0u8; 64]);
AesGcmSiv::validate(&[], &[]);
}
#[test]
fn length_cap_comparison_branch() {
let cap = core::hint::black_box(AesGcmSiv::MAX_PLAINTEXT_LEN);
let over = core::hint::black_box(cap + 1);
assert!(
over > AesGcmSiv::MAX_PLAINTEXT_LEN,
"over-cap must exceed cap"
);
assert!(
cap <= AesGcmSiv::MAX_PLAINTEXT_LEN,
"cap itself is accepted"
);
}
#[test]
fn roundtrip_and_reject() {
let key = from_hex::<16>("01000000000000000000000000000000");
let nonce = from_hex::<12>("030000000000000000000000");
let aad = b"some associated data";
let siv = AesGcmSiv::new(&key);
let pt = *b"GCM-SIV nonce-misuse-resistant!!";
let mut buf = pt;
let tag = siv.encrypt(&nonce, aad, &mut buf);
siv.decrypt(&nonce, aad, &mut buf, &tag).unwrap();
assert_eq!(buf, pt);
let mut buf = pt;
let tag = siv.encrypt(&nonce, aad, &mut buf);
let mut bad = tag;
bad[0] ^= 1;
assert!(siv.decrypt(&nonce, aad, &mut buf, &bad).is_err());
assert_eq!(buf, [0u8; 32]);
}
}