use core::slice;
use crate::Error;
use aes::Aes256;
use generic_array::{GenericArray, typenum::U16};
use ghash::{GHash, universal_hash::UniversalHash};
pub use subtle::ConstantTimeEq;
use ctr::cipher::{BlockEncrypt, KeyInit, KeyIvInit, StreamCipher, StreamCipherSeek};
type Aes256Ctr = ctr::Ctr128BE<aes::Aes256>;
pub const BLOCK_SIZE: usize = 128 / 8;
pub const TAG_LENGTH: usize = BLOCK_SIZE;
pub const KEY_SIZE: usize = 32;
pub const NONCE_AES_SIZE: usize = 96 / 8;
pub const KEY_COMMITMENT_SIZE: usize = KEY_SIZE * 2;
const AES256_GCM_MAX_PLAINTEXT_LENGTH: u64 = (1 << 39) - 256;
pub type Nonce = [u8; NONCE_AES_SIZE];
pub type Key = [u8; KEY_SIZE];
pub struct AesGcm256 {
cipher: Aes256Ctr,
ghash: GHash,
associated_data_bits_len: u64,
current_block: Vec<u8>,
bytes_encrypted: u64,
}
pub type Tag = GenericArray<u8, U16>;
impl AesGcm256 {
#[allow(clippy::unnecessary_wraps)]
pub fn new(key: &Key, nonce: &Nonce, associated_data: &[u8]) -> Result<AesGcm256, Error> {
let associated_data_len = u64::try_from(associated_data.len()).unwrap();
assert!(associated_data_len < AES256_GCM_MAX_PLAINTEXT_LENGTH);
let mut counter_block = [0u8; BLOCK_SIZE];
counter_block[..12].copy_from_slice(nonce);
counter_block[15] = 1;
let mut ghash_key = GenericArray::default();
let cipher = Aes256::new(GenericArray::from_slice(key));
cipher.encrypt_block(&mut ghash_key);
let mut ghash = GHash::new(&ghash_key);
ghash.update_padded(associated_data);
let mut cipher = Aes256Ctr::new(key.into(), &counter_block.into());
cipher.seek(BLOCK_SIZE as u64);
Ok(AesGcm256 {
cipher,
ghash,
associated_data_bits_len: associated_data_len.checked_mul(8).unwrap(),
current_block: Vec::with_capacity(BLOCK_SIZE),
bytes_encrypted: 0,
})
}
pub fn encrypt(&mut self, mut buffer: &mut [u8]) {
self.bytes_encrypted = self
.bytes_encrypted
.checked_add(u64::try_from(buffer.len()).unwrap())
.unwrap();
assert!(
self.bytes_encrypted
.saturating_add(self.associated_data_bits_len / 8)
< AES256_GCM_MAX_PLAINTEXT_LENGTH,
"Attempted to encrypt more than what AES-GCM is secure for"
);
if !self.current_block.is_empty() {
if (self.current_block.len().checked_add(buffer.len()).unwrap()) < BLOCK_SIZE {
self.cipher.apply_keystream(buffer);
self.current_block.extend_from_slice(buffer);
return;
}
let (in_block, out_block) =
buffer.split_at_mut(BLOCK_SIZE.checked_sub(self.current_block.len()).unwrap());
self.cipher.apply_keystream(in_block);
self.current_block.extend_from_slice(in_block);
self.ghash
.update(slice::from_ref(self.current_block.as_slice().into()));
self.current_block.clear();
buffer = out_block;
}
let mut chunks = buffer.chunks_exact_mut(BLOCK_SIZE);
for chunk in &mut chunks {
self.cipher.apply_keystream(chunk);
self.ghash
.update(slice::from_ref(GenericArray::from_slice(chunk)));
}
let rem = chunks.into_remainder();
if !rem.is_empty() {
self.cipher.apply_keystream(rem);
self.current_block.extend_from_slice(rem);
}
}
pub fn into_tag(mut self) -> Tag {
self.ghash.update_padded(&self.current_block);
let buffer_bits = self.bytes_encrypted.checked_mul(8).unwrap();
let mut block = GenericArray::default();
block[..8].copy_from_slice(&self.associated_data_bits_len.to_be_bytes());
block[8..].copy_from_slice(&buffer_bits.to_be_bytes());
self.ghash.update(&[block]);
let mut tag = self.ghash.finalize();
self.cipher.seek(0);
self.cipher.apply_keystream(tag.as_mut_slice());
tag
}
pub fn decrypt_unauthenticated(&mut self, buffer: &mut [u8]) {
self.cipher.apply_keystream(buffer);
}
pub fn decrypt(&mut self, buffer: &mut [u8]) -> Tag {
let buffer_len = u64::try_from(buffer.len()).unwrap();
assert!(
buffer_len.saturating_add(self.associated_data_bits_len / 8)
< AES256_GCM_MAX_PLAINTEXT_LENGTH,
"Attempted to decrypt more than what AES-GCM is secure for"
);
let mut chunks = buffer.chunks_exact_mut(BLOCK_SIZE);
for chunk in &mut chunks {
self.ghash
.update(slice::from_ref(GenericArray::from_slice(chunk)));
self.cipher.apply_keystream(chunk);
}
let rem = chunks.into_remainder();
if !rem.is_empty() {
self.ghash.update_padded(rem);
self.cipher.apply_keystream(rem);
}
let buffer_bits = buffer_len.checked_mul(8).unwrap();
let mut block = GenericArray::default();
block[..8].copy_from_slice(&self.associated_data_bits_len.to_be_bytes());
block[8..].copy_from_slice(&buffer_bits.to_be_bytes());
self.ghash.update(&[block]);
let mut tag = self.ghash.clone().finalize();
self.cipher.seek(0);
self.cipher.apply_keystream(tag.as_mut_slice());
tag
}
}
#[cfg(test)]
mod aes_gcm_navs_test_vector;
#[cfg(test)]
mod tests {
use super::*;
use aead::Payload;
use aes_gcm::{Aes256Gcm, aead::Aead};
fn test_against_aesgcm(key: &Key, nonce: &Nonce, associated_data: &[u8], msg: &[u8]) {
let extern_cipher = Aes256Gcm::new(key.into());
let extern_ciphertext = extern_cipher
.encrypt(
&GenericArray::clone_from_slice(nonce),
Payload {
msg,
aad: associated_data,
},
)
.expect("encryption failure!");
let mut crate_cipher = AesGcm256::new(key, nonce, associated_data).unwrap();
let mut buf = msg.to_vec();
crate_cipher.encrypt(&mut buf);
let tag = crate_cipher.into_tag();
assert_eq!(tag.len(), TAG_LENGTH);
assert_eq!(
&extern_ciphertext[..extern_ciphertext.len().checked_sub(TAG_LENGTH).unwrap()],
buf.as_slice()
);
assert_eq!(
extern_ciphertext[extern_ciphertext.len().checked_sub(TAG_LENGTH).unwrap()..],
tag[..]
);
for size in &[
1, BLOCK_SIZE + 1, ] {
let mut crate_cipher = AesGcm256::new(key, nonce, associated_data).unwrap();
let mut buffer = msg.to_vec();
let mut chunks = buffer.as_mut_slice().chunks_mut(*size);
for chunk in &mut chunks {
crate_cipher.encrypt(chunk);
}
let tag = crate_cipher.into_tag();
assert_eq!(tag.len(), TAG_LENGTH);
assert_eq!(
&extern_ciphertext[..extern_ciphertext.len().checked_sub(TAG_LENGTH).unwrap()],
buffer.as_slice()
);
assert_eq!(
extern_ciphertext[extern_ciphertext.len().checked_sub(TAG_LENGTH).unwrap()..],
tag[..]
);
}
}
#[test]
fn test_against_aes_gcm() {
test_against_aesgcm(
b"\xe3\xc0\x8a\x8f\x06\xc6\xe3\xad\x95\xa7\x05\x57\xb2\x3f\x75\x48\x3c\xe3\x30\x21\xa9\xc7\x2b\x70\x25\x66\x62\x04\xc6\x9c\x0b\x72",
b"\x12\x15\x35\x24\xc0\x89\x5e\x81\xb2\xc2\x84\x65", b"\xd6\x09\xb1\xf0\x56\x63\x7a\x0d\x46\xdf\x99\x8d\x88\xe5\x2e\x00\xb2\xc2\x84\x65\x12\x15\x35\x24\xc0\x89\x5e\x81",
b"\x08\x00\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x20\x21\x22\x23\x24\x25\x26\x27\x28\x29\x2a\x2b\x2c\x2d\x2e\x2f\x30\x31\x32\x33\x34\x35\x36\x37\x38\x39\x3a\x00\x02",
);
test_against_aesgcm(
b"\xe3\xc0\x8a\x8f\x06\xc6\xe3\xad\x95\xa7\x05\x57\xb2\x3f\x75\x48\x3c\xe3\x30\x21\xa9\xc7\x2b\x70\x25\x66\x62\x04\xc6\x9c\x0b\x72",
b"\x12\x15\x35\x24\xc0\x89\x5e\x81\xb2\xc2\x84\x65", b"\xd6\x09\xb1\xf0\x56\x63\x7a\x0d\x46\xdf\x99\x8d\x88\xe5\x2e\x00\xb2\xc2\x84\x65\x12\x15\x35\x24\xc0\x89\x5e\x81",
b"\x08\x00\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x20\x21\x22\x23\x24\x25\x26\x27\x28\x29\x2a\x2b\x2c\x2d\x2e\x2f\x30\x31\x32\x33\x34\x35\x36\x37\x38\x39\x3a\x00",
);
}
#[test]
fn test_decryption() {
let key = b"\xe3\xc0\x8a\x8f\x06\xc6\xe3\xad\x95\xa7\x05\x57\xb2\x3f\x75\x48\x3c\xe3\x30\x21\xa9\xc7\x2b\x70\x25\x66\x62\x04\xc6\x9c\x0b\x72";
let nonce = b"\x12\x15\x35\x24\xc0\x89\x5e\x81\xb2\xc2\x84\x65";
let associated_data = b"\xd6\x09\xb1\xf0\x56\x63\x7a\x0d\x46\xdf\x99\x8d\x88\xe5\x2e\x00\xb2\xc2\x84\x65\x12\x15\x35\x24\xc0\x89\x5e\x81";
let msg = b"\x08\x00\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x20\x21\x22\x23\x24\x25\x26\x27\x28\x29\x2a\x2b\x2c\x2d\x2e\x2f\x30\x31\x32\x33\x34\x35\x36\x37\x38\x39\x3a\x00\x02";
let extern_cipher = Aes256Gcm::new(key.into());
let extern_ciphertext = extern_cipher
.encrypt(
&GenericArray::clone_from_slice(nonce),
Payload {
msg,
aad: associated_data,
},
)
.expect("encryption failure!");
let mut crate_cipher = AesGcm256::new(key, nonce, associated_data).unwrap();
let mut buf = msg.to_vec();
crate_cipher.encrypt(&mut buf);
let tag = crate_cipher.into_tag();
assert_eq!(tag.len(), TAG_LENGTH);
let mut crate_cipher = AesGcm256::new(key, nonce, b"").unwrap();
let mut buf = extern_ciphertext[..extern_ciphertext.len() - TAG_LENGTH].to_vec();
crate_cipher.decrypt_unauthenticated(&mut buf);
assert_eq!(buf.as_slice(), &msg[..]);
let mut crate_cipher = AesGcm256::new(key, nonce, associated_data).unwrap();
let mut buf = extern_ciphertext[..extern_ciphertext.len() - TAG_LENGTH].to_vec();
let expected_tag = extern_ciphertext[extern_ciphertext.len() - TAG_LENGTH..].to_vec();
let tag = crate_cipher.decrypt(&mut buf);
assert_eq!(buf.as_slice(), &msg[..]);
assert_eq!(&tag[..], expected_tag.as_slice());
}
}