use super::aes::Aes;
use crate::BlockCipher;
fn check_params(m: usize, l: usize) -> Result<(), &'static str> {
if !matches!(m, 4 | 6 | 8 | 10 | 12 | 14 | 16) {
return Err("CCM: tag length M must be in {4,6,8,10,12,14,16}");
}
if !(2..=8).contains(&l) {
return Err("CCM: length-of-length L must be in {2..=8}");
}
Ok(())
}
pub fn ccm_encrypt(
aes: &Aes,
m: usize,
l: usize,
nonce: &[u8],
aad: &[u8],
plaintext: &[u8],
) -> Option<(Vec<u8>, Vec<u8>)> {
check_params(m, l).ok()?;
if nonce.len() != 15 - l {
return None;
}
if l < 8 {
let max_pt: u128 = 1u128 << (8 * l);
if (plaintext.len() as u128) >= max_pt {
return None;
}
}
if aad.len() >= (1usize << 16) - (1usize << 8) {
return None;
}
let t = cbc_mac(aes, m, l, nonce, aad, plaintext);
let mut a0 = ctr_block(l, nonce, 0);
aes.encrypt_block(&mut a0);
let mut tag = vec![0u8; m];
for i in 0..m {
tag[i] = t[i] ^ a0[i];
}
let mut ct = plaintext.to_vec();
let mut counter: u64 = 1;
let mut pos = 0;
while pos < ct.len() {
let mut block = ctr_block(l, nonce, counter);
aes.encrypt_block(&mut block);
let take = (16).min(ct.len() - pos);
for i in 0..take {
ct[pos + i] ^= block[i];
}
pos += 16;
counter += 1;
}
Some((ct, tag))
}
pub fn ccm_decrypt(
aes: &Aes,
m: usize,
l: usize,
nonce: &[u8],
aad: &[u8],
ciphertext: &[u8],
tag: &[u8],
) -> Option<Vec<u8>> {
check_params(m, l).ok()?;
if nonce.len() != 15 - l {
return None;
}
if tag.len() != m {
return None;
}
if l < 8 {
let max_pt: u128 = 1u128 << (8 * l);
if (ciphertext.len() as u128) >= max_pt {
return None;
}
}
if aad.len() >= (1usize << 16) - (1usize << 8) {
return None;
}
let mut pt = ciphertext.to_vec();
let mut counter: u64 = 1;
let mut pos = 0;
while pos < pt.len() {
let mut block = ctr_block(l, nonce, counter);
aes.encrypt_block(&mut block);
let take = (16).min(pt.len() - pos);
for i in 0..take {
pt[pos + i] ^= block[i];
}
pos += 16;
counter += 1;
}
let t = cbc_mac(aes, m, l, nonce, aad, &pt);
let mut a0 = ctr_block(l, nonce, 0);
aes.encrypt_block(&mut a0);
let mut expected = vec![0u8; m];
for i in 0..m {
expected[i] = t[i] ^ a0[i];
}
let mut diff = 0u8;
for i in 0..m {
diff |= expected[i] ^ tag[i];
}
if diff != 0 {
return None;
}
Some(pt)
}
fn ctr_block(l: usize, nonce: &[u8], counter: u64) -> [u8; 16] {
debug_assert_eq!(nonce.len(), 15 - l);
let mut block = [0u8; 16];
block[0] = (l - 1) as u8;
block[1..1 + nonce.len()].copy_from_slice(nonce);
let ctr_be = counter.to_be_bytes();
let l_used = l.min(8);
block[16 - l_used..].copy_from_slice(&ctr_be[8 - l_used..]);
block
}
fn b0_block(m: usize, l: usize, nonce: &[u8], aad_len: usize, payload_len: usize) -> [u8; 16] {
debug_assert_eq!(nonce.len(), 15 - l);
let adata: u8 = if aad_len > 0 { 1 } else { 0 };
let m_field: u8 = (((m as u8) - 2) / 2) << 3;
let l_field: u8 = (l as u8) - 1;
let flg: u8 = (adata << 6) | m_field | l_field;
let mut b0 = [0u8; 16];
b0[0] = flg;
b0[1..1 + nonce.len()].copy_from_slice(nonce);
let q_be = (payload_len as u64).to_be_bytes();
let l_used = l.min(8);
b0[16 - l_used..].copy_from_slice(&q_be[8 - l_used..]);
b0
}
fn cbc_mac(aes: &Aes, m: usize, l: usize, nonce: &[u8], aad: &[u8], payload: &[u8]) -> [u8; 16] {
let mut state = b0_block(m, l, nonce, aad.len(), payload.len());
aes.encrypt_block(&mut state);
if !aad.is_empty() {
let len_bytes = (aad.len() as u16).to_be_bytes();
let mut prefix = [0u8; 16];
prefix[0] = len_bytes[0];
prefix[1] = len_bytes[1];
let take = (14).min(aad.len());
prefix[2..2 + take].copy_from_slice(&aad[..take]);
for i in 0..16 {
state[i] ^= prefix[i];
}
aes.encrypt_block(&mut state);
let mut pos = take;
while pos < aad.len() {
let mut block = [0u8; 16];
let chunk = (16).min(aad.len() - pos);
block[..chunk].copy_from_slice(&aad[pos..pos + chunk]);
for i in 0..16 {
state[i] ^= block[i];
}
aes.encrypt_block(&mut state);
pos += chunk;
}
}
let mut pos = 0;
while pos < payload.len() {
let mut block = [0u8; 16];
let chunk = (16).min(payload.len() - pos);
block[..chunk].copy_from_slice(&payload[pos..pos + chunk]);
for i in 0..16 {
state[i] ^= block[i];
}
aes.encrypt_block(&mut state);
pos += chunk;
}
state
}
pub struct AesCcm {
aes: Aes,
m: usize,
}
impl AesCcm {
pub fn new(key: &[u8], m: usize) -> Option<Self> {
if !matches!(m, 4 | 6 | 8 | 10 | 12 | 14 | 16) {
return None;
}
if !matches!(key.len(), 16 | 24 | 32) {
return None;
}
Some(Self {
aes: <Aes as BlockCipher>::new(key),
m,
})
}
pub fn encrypt(&self, nonce: &[u8; 13], aad: &[u8], plaintext: &[u8]) -> Option<(Vec<u8>, Vec<u8>)> {
ccm_encrypt(&self.aes, self.m, 2, nonce, aad, plaintext)
}
pub fn decrypt(&self, nonce: &[u8; 13], aad: &[u8], ciphertext: &[u8], tag: &[u8]) -> Option<Vec<u8>> {
ccm_decrypt(&self.aes, self.m, 2, nonce, aad, ciphertext, tag)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn hex(s: &str) -> Vec<u8> {
let s: String = s.chars().filter(|c| !c.is_whitespace()).collect();
assert!(s.len() % 2 == 0);
(0..s.len())
.step_by(2)
.map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
.collect()
}
#[test]
fn rfc3610_packet_vector_1() {
let key = hex("c0c1c2c3c4c5c6c7c8c9cacbcccdcecf");
let nonce: [u8; 13] = {
let v = hex("00000003020100a0a1a2a3a4a5");
v.try_into().unwrap()
};
let aad = hex("0001020304050607");
let plaintext = hex("08090a0b0c0d0e0f101112131415161718191a1b1c1d1e");
let ccm = AesCcm::new(&key, 8).unwrap();
let (ct, tag) = ccm.encrypt(&nonce, &aad, &plaintext).unwrap();
let expected_ct = hex("588c979a61c663d2f066d0c2c0f989806d5f6b61dac384");
let expected_tag = hex("17e8d12cfdf926e0");
assert_eq!(ct, expected_ct);
assert_eq!(tag, expected_tag);
let pt = ccm.decrypt(&nonce, &aad, &ct, &tag).unwrap();
assert_eq!(pt, plaintext);
}
#[test]
fn rfc3610_packet_vector_2() {
let key = hex("c0c1c2c3c4c5c6c7c8c9cacbcccdcecf");
let nonce: [u8; 13] = hex("00000004030201a0a1a2a3a4a5").try_into().unwrap();
let aad = hex("0001020304050607");
let plaintext = hex("08090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f");
let ccm = AesCcm::new(&key, 8).unwrap();
let (ct, tag) = ccm.encrypt(&nonce, &aad, &plaintext).unwrap();
let expected_ct = hex("72c91a36e135f8cf291ca894085c87e3cc15c439c9e43a3b");
let expected_tag = hex("a091d56e10400916");
assert_eq!(ct, expected_ct);
assert_eq!(tag, expected_tag);
let pt = ccm.decrypt(&nonce, &aad, &ct, &tag).unwrap();
assert_eq!(pt, plaintext);
}
#[test]
fn ccm_aes128_m16_roundtrip() {
let key = [0x42u8; 16];
let nonce = [0xa5u8; 13];
let aad = b"some context";
let pt = b"hello world; this is a test of moderate length to span more than one AES block.";
let ccm = AesCcm::new(&key, 16).unwrap();
let (ct, tag) = ccm.encrypt(&nonce, aad, pt).unwrap();
assert_eq!(tag.len(), 16);
assert_ne!(ct.as_slice(), pt.as_slice());
let back = ccm.decrypt(&nonce, aad, &ct, &tag).unwrap();
assert_eq!(back.as_slice(), pt.as_slice());
}
#[test]
fn ccm_aes256_m16_roundtrip() {
let key = [0x77u8; 32];
let nonce = [0x11u8; 13];
let aad = b"";
let pt = b"AES-256-CCM message";
let ccm = AesCcm::new(&key, 16).unwrap();
let (ct, tag) = ccm.encrypt(&nonce, aad, pt).unwrap();
let back = ccm.decrypt(&nonce, aad, &ct, &tag).unwrap();
assert_eq!(back.as_slice(), pt.as_slice());
}
#[test]
fn ccm_rejects_tampered_ciphertext() {
let key = [0x01u8; 16];
let nonce = [0x02u8; 13];
let pt = b"do not modify";
let ccm = AesCcm::new(&key, 8).unwrap();
let (mut ct, tag) = ccm.encrypt(&nonce, b"", pt).unwrap();
ct[0] ^= 0x01;
assert!(ccm.decrypt(&nonce, b"", &ct, &tag).is_none());
}
#[test]
fn ccm_rejects_tampered_tag() {
let key = [0x01u8; 16];
let nonce = [0x02u8; 13];
let pt = b"do not modify";
let ccm = AesCcm::new(&key, 8).unwrap();
let (ct, mut tag) = ccm.encrypt(&nonce, b"", pt).unwrap();
tag[0] ^= 0x01;
assert!(ccm.decrypt(&nonce, b"", &ct, &tag).is_none());
}
#[test]
fn ccm_rejects_modified_aad() {
let key = [0xffu8; 16];
let nonce = [0x10u8; 13];
let aad = b"context-A";
let pt = b"shared payload";
let ccm = AesCcm::new(&key, 8).unwrap();
let (ct, tag) = ccm.encrypt(&nonce, aad, pt).unwrap();
assert!(ccm.decrypt(&nonce, b"context-B", &ct, &tag).is_none());
}
#[test]
fn ccm_rejects_wrong_key() {
let mut key1 = [0x33u8; 16];
let mut key2 = key1;
key2[0] ^= 0x01;
let nonce = [0x44u8; 13];
let pt = b"sensitive";
let ccm1 = AesCcm::new(&key1, 8).unwrap();
let (ct, tag) = ccm1.encrypt(&nonce, b"", pt).unwrap();
let ccm2 = AesCcm::new(&key2, 8).unwrap();
assert!(ccm2.decrypt(&nonce, b"", &ct, &tag).is_none());
key1[0] = 0;
}
#[test]
fn ccm_empty_plaintext() {
let key = [0x55u8; 16];
let nonce = [0x66u8; 13];
let aad = b"only-context";
let ccm = AesCcm::new(&key, 8).unwrap();
let (ct, tag) = ccm.encrypt(&nonce, aad, b"").unwrap();
assert!(ct.is_empty());
let back = ccm.decrypt(&nonce, aad, &ct, &tag).unwrap();
assert!(back.is_empty());
assert!(ccm.decrypt(&nonce, b"other", &ct, &tag).is_none());
}
#[test]
fn ccm_rejects_invalid_m() {
let key = [0u8; 16];
for bad_m in [0, 1, 2, 3, 5, 7, 9, 11, 13, 15, 17, 18, 32] {
assert!(AesCcm::new(&key, bad_m).is_none(), "M={} should be rejected", bad_m);
}
for good_m in [4, 6, 8, 10, 12, 14, 16] {
assert!(AesCcm::new(&key, good_m).is_some(), "M={} should be accepted", good_m);
}
}
}