use aws_lc_rs::aead::AES_128_GCM;
use aws_lc_rs::aead::Aad;
use aws_lc_rs::aead::LessSafeKey;
use aws_lc_rs::aead::Nonce;
use aws_lc_rs::aead::UnboundKey;
use crate::crypto::cipher::Cipher;
use crate::types::Error;
use crate::types::Result;
use crate::wire::Packet;
const TAG_LEN: usize = 16;
const KEY_LEN: usize = 16;
const FIXED_IV_LEN: usize = 4;
const INVOCATION_COUNTER_LEN: usize = 8;
const NONCE_LEN: usize = FIXED_IV_LEN + INVOCATION_COUNTER_LEN;
pub struct Aes128Gcm {
key: LessSafeKey,
fixed_iv: [u8; FIXED_IV_LEN],
}
impl Aes128Gcm {
#[must_use]
pub fn new(key: [u8; KEY_LEN], fixed_iv: [u8; FIXED_IV_LEN]) -> Result<Self> {
let key = UnboundKey::new(&AES_128_GCM, &key).map_err(|_| Error::Crypto)?;
let key = LessSafeKey::new(key);
Ok(Self { key, fixed_iv })
}
fn make_nonce(&self, sequence_number: u32) -> Nonce {
let mut nonce_bytes = [0u8; NONCE_LEN];
nonce_bytes[..FIXED_IV_LEN].copy_from_slice(&self.fixed_iv);
nonce_bytes[FIXED_IV_LEN..].copy_from_slice(&(sequence_number as u64).to_be_bytes());
Nonce::assume_unique_for_key(nonce_bytes)
}
}
impl Cipher for Aes128Gcm {
const AEAD_LENGTH: Option<usize> = Some(TAG_LEN);
fn encrypt_packet<'buf, B>(
&self,
packet: &'buf mut Packet<&'buf mut B>,
sequence_number: u32,
) -> Result<()>
where
B: AsRef<[u8]> + AsMut<[u8]> + ?Sized,
{
let nonce = self.make_nonce(sequence_number);
let packet_len = packet.packet_length()?;
let packet_len_bytes = packet_len.to_be_bytes();
let full_packet = packet.packet_mut()?;
let (_, rest) = full_packet.split_at_mut(4); let (payload, tag_space) = rest.split_at_mut(packet_len as usize);
let aad = Aad::from(&packet_len_bytes);
let tag = self
.key
.seal_in_place_separate_tag(nonce, aad, payload)
.map_err(|_| Error::Crypto)?;
tag_space[..TAG_LEN].copy_from_slice(tag.as_ref());
Ok(())
}
fn decrypt_packet_length<B>(&self, packet: &Packet<B>, _sequence_number: u32) -> Result<u32>
where
B: AsRef<[u8]>,
{
packet.packet_length()
}
fn decrypt_packet<'buf, B>(
&self,
packet: &'buf mut Packet<&'buf mut B>,
sequence_number: u32,
) -> Result<()>
where
B: AsRef<[u8]> + AsMut<[u8]> + ?Sized,
{
let nonce = self.make_nonce(sequence_number);
let packet_len = packet.packet_length()?;
let packet_len_bytes = packet_len.to_be_bytes();
let aad = Aad::from(&packet_len_bytes);
let full_packet = packet.packet_mut()?;
let ciphertext_and_tag = &mut full_packet[4..];
self.key
.open_in_place(nonce, aad, ciphertext_and_tag)
.map_err(|_| Error::Crypto)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encrypt_decrypt_roundtrip() {
let key = [0x42; KEY_LEN];
let fixed_iv = [0x01, 0x02, 0x03, 0x04];
let cipher = Aes128Gcm::new(key, fixed_iv).unwrap();
let mut data = vec![
0x00, 0x00, 0x00, 0x09, 0x06, b'h', b'i', 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ];
data.extend_from_slice(&[0u8; TAG_LEN]);
let original_packet_len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]);
let original_payload = [data[4], data[5], data[6]];
{
let mut packet = Packet::new(&mut data, TAG_LEN as u8);
cipher.encrypt_packet(&mut packet, 0).unwrap();
}
let encrypted_packet_len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]);
assert_eq!(encrypted_packet_len, original_packet_len);
assert_ne!(&data[4..7], &original_payload);
{
let mut packet = Packet::new(&mut data, TAG_LEN as u8);
cipher.decrypt_packet(&mut packet, 0).unwrap();
}
assert_eq!(data[0..4], [0x00, 0x00, 0x00, 0x09]);
assert_eq!(data[4], 0x06); assert_eq!(&data[5..7], b"hi"); assert_eq!(&data[7..13], &[0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); }
#[test]
fn decrypt_with_wrong_sequence_fails() {
let key = [0x42; KEY_LEN];
let fixed_iv = [0x01, 0x02, 0x03, 0x04];
let cipher = Aes128Gcm::new(key, fixed_iv).unwrap();
let mut data = vec![
0x00, 0x00, 0x00, 0x09, 0x06, b'h', b'i', 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];
data.extend_from_slice(&[0u8; TAG_LEN]);
{
let mut packet = Packet::new(&mut data, TAG_LEN as u8);
cipher.encrypt_packet(&mut packet, 0).unwrap();
}
{
let mut packet = Packet::new(&mut data, TAG_LEN as u8);
let result = cipher.decrypt_packet(&mut packet, 1);
assert!(result.is_err());
}
}
#[test]
fn decrypt_with_tampered_aad_fails() {
let key = [0x42; KEY_LEN];
let fixed_iv = [0x01, 0x02, 0x03, 0x04];
let cipher = Aes128Gcm::new(key, fixed_iv).unwrap();
let mut data = vec![
0x00, 0x00, 0x00, 0x09, 0x06, b'h', b'i', 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];
data.extend_from_slice(&[0u8; TAG_LEN]);
{
let mut packet = Packet::new(&mut data, TAG_LEN as u8);
cipher.encrypt_packet(&mut packet, 0).unwrap();
}
data[3] = 0x0a;
{
let mut packet = Packet::new(&mut data, TAG_LEN as u8);
let result = cipher.decrypt_packet(&mut packet, 0);
assert!(result.is_err());
}
}
}