use aes::cipher::{Array, BlockCipherEncrypt, KeyInit};
use aes::{Aes128, Aes256};
use oxicrypto_core::{Aead, CryptoError};
const NONCE_LEN: usize = 13;
const TAG_LEN: usize = 16;
const BLOCK_SIZE: usize = 16;
fn encode_b0(nonce: &[u8; NONCE_LEN], aad_len: usize, msg_len: usize) -> [u8; BLOCK_SIZE] {
let has_aad = aad_len > 0;
let flags: u8 = if has_aad { 0b0111_1001 } else { 0b0011_1001 };
let mut b0 = [0u8; BLOCK_SIZE];
b0[0] = flags;
b0[1..14].copy_from_slice(nonce);
b0[14] = ((msg_len >> 8) & 0xFF) as u8;
b0[15] = (msg_len & 0xFF) as u8;
b0
}
fn encrypt_state_block(
cipher: &impl BlockCipherEncrypt,
state: &mut [u8; BLOCK_SIZE],
) -> Result<(), CryptoError> {
let block = <&mut Array<u8, _>>::try_from(state.as_mut_slice())
.map_err(|_| CryptoError::Internal("ccm block invariant"))?;
cipher.encrypt_block(block);
Ok(())
}
fn cbc_mac_update(
cipher: &impl BlockCipherEncrypt,
state: &mut [u8; BLOCK_SIZE],
data: &[u8],
) -> Result<(), CryptoError> {
let mut offset = 0;
while offset + BLOCK_SIZE <= data.len() {
for i in 0..BLOCK_SIZE {
state[i] ^= data[offset + i];
}
encrypt_state_block(cipher, state)?;
offset += BLOCK_SIZE;
}
let remainder = data.len() - offset;
if remainder > 0 {
for i in 0..remainder {
state[i] ^= data[offset + i];
}
encrypt_state_block(cipher, state)?;
}
Ok(())
}
fn compute_tag<C: BlockCipherEncrypt + KeyInit>(
key: &[u8],
nonce: &[u8; NONCE_LEN],
aad: &[u8],
msg: &[u8],
) -> Result<[u8; TAG_LEN], CryptoError> {
let cipher = C::new_from_slice(key).map_err(|_| CryptoError::InvalidKey)?;
let b0 = encode_b0(nonce, aad.len(), msg.len());
let mut mac_state = [0u8; BLOCK_SIZE];
mac_state.copy_from_slice(&b0);
encrypt_state_block(&cipher, &mut mac_state)?;
if !aad.is_empty() {
if aad.len() < 0xFF00 {
let len_enc = [(aad.len() >> 8) as u8, (aad.len() & 0xFF) as u8];
let mut buf = [0u8; BLOCK_SIZE];
buf[0] = len_enc[0];
buf[1] = len_enc[1];
let first_chunk = aad.len().min(BLOCK_SIZE - 2);
buf[2..2 + first_chunk].copy_from_slice(&aad[..first_chunk]);
for i in 0..BLOCK_SIZE {
mac_state[i] ^= buf[i];
}
encrypt_state_block(&cipher, &mut mac_state)?;
if aad.len() > first_chunk {
cbc_mac_update(&cipher, &mut mac_state, &aad[first_chunk..])?;
}
} else {
let len_enc = [
0xFF_u8,
0xFE,
((aad.len() >> 24) & 0xFF) as u8,
((aad.len() >> 16) & 0xFF) as u8,
((aad.len() >> 8) & 0xFF) as u8,
(aad.len() & 0xFF) as u8,
];
let mut buf = [0u8; BLOCK_SIZE];
let first_chunk = aad.len().min(BLOCK_SIZE - 6);
buf[..6].copy_from_slice(&len_enc);
buf[6..6 + first_chunk].copy_from_slice(&aad[..first_chunk]);
for i in 0..BLOCK_SIZE {
mac_state[i] ^= buf[i];
}
encrypt_state_block(&cipher, &mut mac_state)?;
if aad.len() > first_chunk {
cbc_mac_update(&cipher, &mut mac_state, &aad[first_chunk..])?;
}
}
}
cbc_mac_update(&cipher, &mut mac_state, msg)?;
Ok(mac_state)
}
fn ctr_crypt<C: BlockCipherEncrypt + KeyInit>(
key: &[u8],
nonce: &[u8; NONCE_LEN],
data: &[u8],
out: &mut [u8],
start_counter: u16,
) -> Result<(), CryptoError> {
let cipher = C::new_from_slice(key).map_err(|_| CryptoError::InvalidKey)?;
const FLAGS: u8 = 0x01;
let mut counter = start_counter;
let mut offset = 0;
while offset < data.len() {
let mut a_block = [0u8; BLOCK_SIZE];
a_block[0] = FLAGS;
a_block[1..14].copy_from_slice(nonce);
a_block[14] = ((counter >> 8) & 0xFF) as u8;
a_block[15] = (counter & 0xFF) as u8;
encrypt_state_block(&cipher, &mut a_block)?;
let chunk_end = (offset + BLOCK_SIZE).min(data.len());
let chunk_len = chunk_end - offset;
for i in 0..chunk_len {
out[offset + i] = data[offset + i] ^ a_block[i];
}
offset = chunk_end;
counter = counter.wrapping_add(1);
}
Ok(())
}
fn ccm_seal<C: BlockCipherEncrypt + KeyInit>(
key: &[u8],
nonce: &[u8; NONCE_LEN],
aad: &[u8],
pt: &[u8],
ct_out: &mut [u8],
) -> Result<usize, CryptoError> {
let required = pt.len().checked_add(TAG_LEN).ok_or(CryptoError::BadInput)?;
if ct_out.len() < required {
return Err(CryptoError::BufferTooSmall);
}
let raw_tag = compute_tag::<C>(key, nonce, aad, pt)?;
ctr_crypt::<C>(key, nonce, pt, &mut ct_out[..pt.len()], 1)?;
let mut encrypted_tag = [0u8; TAG_LEN];
ctr_crypt::<C>(key, nonce, &raw_tag, &mut encrypted_tag, 0)?;
ct_out[pt.len()..required].copy_from_slice(&encrypted_tag);
Ok(required)
}
fn ccm_open<C: BlockCipherEncrypt + KeyInit>(
key: &[u8],
nonce: &[u8; NONCE_LEN],
aad: &[u8],
ct: &[u8],
pt_out: &mut [u8],
) -> Result<usize, CryptoError> {
if ct.len() < TAG_LEN {
return Err(CryptoError::BadInput);
}
let pt_len = ct.len() - TAG_LEN;
if pt_out.len() < pt_len {
return Err(CryptoError::BufferTooSmall);
}
let (ciphertext, tag_bytes) = ct.split_at(pt_len);
ctr_crypt::<C>(key, nonce, ciphertext, &mut pt_out[..pt_len], 1)?;
let mut raw_tag = [0u8; TAG_LEN];
ctr_crypt::<C>(key, nonce, tag_bytes, &mut raw_tag, 0)?;
let expected_tag = compute_tag::<C>(key, nonce, aad, &pt_out[..pt_len])?;
use subtle::ConstantTimeEq as _;
if raw_tag.ct_eq(&expected_tag).into() {
Ok(pt_len)
} else {
for b in &mut pt_out[..pt_len] {
*b = 0;
}
Err(CryptoError::InvalidTag)
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct Aes128Ccm;
impl Aead for Aes128Ccm {
fn name(&self) -> &'static str {
"AES-128-CCM"
}
fn key_len(&self) -> usize {
16
}
fn nonce_len(&self) -> usize {
NONCE_LEN
}
fn tag_len(&self) -> usize {
TAG_LEN
}
fn seal(
&self,
key: &[u8],
nonce: &[u8],
aad: &[u8],
pt: &[u8],
ct_out: &mut [u8],
) -> Result<usize, CryptoError> {
if key.len() != 16 {
return Err(CryptoError::InvalidKey);
}
if nonce.len() != NONCE_LEN {
return Err(CryptoError::InvalidNonce);
}
let nonce_arr: &[u8; NONCE_LEN] =
nonce.try_into().map_err(|_| CryptoError::InvalidNonce)?;
ccm_seal::<Aes128>(key, nonce_arr, aad, pt, ct_out)
}
fn open(
&self,
key: &[u8],
nonce: &[u8],
aad: &[u8],
ct: &[u8],
pt_out: &mut [u8],
) -> Result<usize, CryptoError> {
if key.len() != 16 {
return Err(CryptoError::InvalidKey);
}
if nonce.len() != NONCE_LEN {
return Err(CryptoError::InvalidNonce);
}
let nonce_arr: &[u8; NONCE_LEN] =
nonce.try_into().map_err(|_| CryptoError::InvalidNonce)?;
ccm_open::<Aes128>(key, nonce_arr, aad, ct, pt_out)
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct Aes256Ccm;
impl Aead for Aes256Ccm {
fn name(&self) -> &'static str {
"AES-256-CCM"
}
fn key_len(&self) -> usize {
32
}
fn nonce_len(&self) -> usize {
NONCE_LEN
}
fn tag_len(&self) -> usize {
TAG_LEN
}
fn seal(
&self,
key: &[u8],
nonce: &[u8],
aad: &[u8],
pt: &[u8],
ct_out: &mut [u8],
) -> Result<usize, CryptoError> {
if key.len() != 32 {
return Err(CryptoError::InvalidKey);
}
if nonce.len() != NONCE_LEN {
return Err(CryptoError::InvalidNonce);
}
let nonce_arr: &[u8; NONCE_LEN] =
nonce.try_into().map_err(|_| CryptoError::InvalidNonce)?;
ccm_seal::<Aes256>(key, nonce_arr, aad, pt, ct_out)
}
fn open(
&self,
key: &[u8],
nonce: &[u8],
aad: &[u8],
ct: &[u8],
pt_out: &mut [u8],
) -> Result<usize, CryptoError> {
if key.len() != 32 {
return Err(CryptoError::InvalidKey);
}
if nonce.len() != NONCE_LEN {
return Err(CryptoError::InvalidNonce);
}
let nonce_arr: &[u8; NONCE_LEN] =
nonce.try_into().map_err(|_| CryptoError::InvalidNonce)?;
ccm_open::<Aes256>(key, nonce_arr, aad, ct, pt_out)
}
}
#[cfg(test)]
mod tests {
use super::*;
const KEY_128: [u8; 16] = [0x42u8; 16];
const KEY_256: [u8; 32] = [0x42u8; 32];
const NONCE: [u8; NONCE_LEN] = [0x24u8; NONCE_LEN];
const AAD: &[u8] = b"additional authenticated data";
const PLAINTEXT: &[u8] = b"hello, oxicrypto ccm!";
#[test]
fn aes128ccm_round_trip() {
let aead = Aes128Ccm;
let mut ct = vec![0u8; PLAINTEXT.len() + aead.tag_len()];
let written = aead
.seal(&KEY_128, &NONCE, AAD, PLAINTEXT, &mut ct)
.expect("seal failed");
assert_eq!(written, PLAINTEXT.len() + aead.tag_len());
let mut pt = vec![0u8; PLAINTEXT.len()];
let recovered = aead
.open(&KEY_128, &NONCE, AAD, &ct[..written], &mut pt)
.expect("open failed");
assert_eq!(recovered, PLAINTEXT.len());
assert_eq!(&pt[..recovered], PLAINTEXT);
}
#[test]
fn aes256ccm_round_trip() {
let aead = Aes256Ccm;
let mut ct = vec![0u8; PLAINTEXT.len() + aead.tag_len()];
let written = aead
.seal(&KEY_256, &NONCE, AAD, PLAINTEXT, &mut ct)
.expect("seal failed");
assert_eq!(written, PLAINTEXT.len() + aead.tag_len());
let mut pt = vec![0u8; PLAINTEXT.len()];
let recovered = aead
.open(&KEY_256, &NONCE, AAD, &ct[..written], &mut pt)
.expect("open failed");
assert_eq!(recovered, PLAINTEXT.len());
assert_eq!(&pt[..recovered], PLAINTEXT);
}
#[test]
fn aes128ccm_tamper_ciphertext_fails() {
let aead = Aes128Ccm;
let mut ct = vec![0u8; PLAINTEXT.len() + aead.tag_len()];
let written = aead
.seal(&KEY_128, &NONCE, AAD, PLAINTEXT, &mut ct)
.unwrap();
ct[0] ^= 0xFF;
let mut pt = vec![0u8; PLAINTEXT.len()];
let result = aead.open(&KEY_128, &NONCE, AAD, &ct[..written], &mut pt);
assert_eq!(result, Err(CryptoError::InvalidTag));
}
#[test]
fn aes128ccm_tamper_tag_fails() {
let aead = Aes128Ccm;
let mut ct = vec![0u8; PLAINTEXT.len() + aead.tag_len()];
let written = aead
.seal(&KEY_128, &NONCE, AAD, PLAINTEXT, &mut ct)
.unwrap();
ct[written - 1] ^= 0x01;
let mut pt = vec![0u8; PLAINTEXT.len()];
let result = aead.open(&KEY_128, &NONCE, AAD, &ct[..written], &mut pt);
assert_eq!(result, Err(CryptoError::InvalidTag));
}
#[test]
fn aes128ccm_wrong_key_fails() {
let aead = Aes128Ccm;
let mut ct = vec![0u8; PLAINTEXT.len() + aead.tag_len()];
let written = aead
.seal(&KEY_128, &NONCE, AAD, PLAINTEXT, &mut ct)
.unwrap();
let mut pt = vec![0u8; PLAINTEXT.len()];
let result = aead.open(&[0x00u8; 16], &NONCE, AAD, &ct[..written], &mut pt);
assert_eq!(result, Err(CryptoError::InvalidTag));
}
#[test]
fn aes128ccm_empty_plaintext() {
let aead = Aes128Ccm;
let mut ct = vec![0u8; aead.tag_len()];
let written = aead.seal(&KEY_128, &NONCE, AAD, b"", &mut ct).unwrap();
assert_eq!(written, aead.tag_len());
let mut pt = vec![0u8; 0];
let recovered = aead
.open(&KEY_128, &NONCE, AAD, &ct[..written], &mut pt)
.unwrap();
assert_eq!(recovered, 0);
}
#[test]
fn aes128ccm_no_aad() {
let aead = Aes128Ccm;
let mut ct = vec![0u8; PLAINTEXT.len() + aead.tag_len()];
let written = aead
.seal(&KEY_128, &NONCE, b"", PLAINTEXT, &mut ct)
.unwrap();
let mut pt = vec![0u8; PLAINTEXT.len()];
let recovered = aead
.open(&KEY_128, &NONCE, b"", &ct[..written], &mut pt)
.unwrap();
assert_eq!(&pt[..recovered], PLAINTEXT);
}
#[test]
fn aes128ccm_deterministic_output() {
let aead = Aes128Ccm;
let key = [0u8; 16];
let nonce = [0u8; NONCE_LEN];
let pt = b"test";
let mut ct = vec![0u8; pt.len() + aead.tag_len()];
let written = aead.seal(&key, &nonce, b"", pt, &mut ct).unwrap();
let mut pt_out = vec![0u8; pt.len()];
let recovered = aead
.open(&key, &nonce, b"", &ct[..written], &mut pt_out)
.unwrap();
assert_eq!(&pt_out[..recovered], pt.as_ref());
let mut ct2 = vec![0u8; pt.len() + aead.tag_len()];
let written2 = aead.seal(&key, &nonce, b"", pt, &mut ct2).unwrap();
assert_eq!(ct[..written], ct2[..written2]);
}
}