use std::cmp::min;
use super::{AeadParam, AtcaAesCcmCtx, AtcaStatus, AteccDevice, KeyType, NonceTarget};
use super::{
ATCA_AES_DATA_SIZE, ATCA_ATECC_SLOTS_COUNT, ATCA_ATECC_TEMPKEY_KEYID, ATCA_NONCE_SIZE,
};
impl AteccDevice {
pub(crate) fn encrypt_aes_ccm(
&self,
aead_param: AeadParam,
slot_id: u8,
data: &mut Vec<u8>,
) -> Result<Vec<u8>, AtcaStatus> {
let mut ctx: AtcaAesCcmCtx = self.common_aes_ccm(aead_param, slot_id, data)?;
ctx = self.aes_ccm_update(ctx, data, true)?;
let result = self.aes_ccm_finish(ctx)?;
Ok(result)
}
pub(crate) fn decrypt_aes_ccm(
&self,
aead_param: AeadParam,
slot_id: u8,
data: &mut Vec<u8>,
) -> Result<bool, AtcaStatus> {
let tag_to_check: Vec<u8>;
if let Some(val) = aead_param.tag.clone() {
tag_to_check = val;
} else {
return Err(AtcaStatus::AtcaBadParam);
}
let mut ctx: AtcaAesCcmCtx = self.common_aes_ccm(aead_param, slot_id, data)?;
ctx = self.aes_ccm_update(ctx, data, false)?;
let result = self.aes_ccm_decrypt_finish(ctx, &tag_to_check)?;
Ok(result)
}
fn common_aes_ccm(
&self,
aead_param: AeadParam,
slot_id: u8,
data: &mut Vec<u8>,
) -> Result<AtcaAesCcmCtx, AtcaStatus> {
const MAX_IV_SIZE: usize = 13;
const MIN_IV_SIZE: usize = 7;
const MAX_TAG_SIZE: usize = ATCA_AES_DATA_SIZE;
const MIN_TAG_SIZE: usize = 4;
const MAX_AAD_SIZE: usize = 0xFEFF;
if (slot_id > ATCA_ATECC_SLOTS_COUNT)
|| ((slot_id < ATCA_ATECC_SLOTS_COUNT)
&& (self.slots[slot_id as usize].config.key_type != KeyType::Aes))
{
return Err(AtcaStatus::AtcaInvalidId);
}
if (ATCA_ATECC_SLOTS_COUNT == slot_id) && aead_param.key.is_none()
|| (aead_param.tag_length.is_some() && aead_param.tag.is_some())
{
return Err(AtcaStatus::AtcaBadParam);
}
if (data.is_empty() && aead_param.additional_data.is_none())
|| (aead_param.nonce.len() < MIN_IV_SIZE || aead_param.nonce.len() > MAX_IV_SIZE)
|| (aead_param.tag_length.is_some()
&& ((aead_param.tag_length < Some(MIN_TAG_SIZE as u8))
|| (aead_param.tag_length > Some(MAX_TAG_SIZE as u8))
|| (aead_param.tag_length.unwrap() % 2 != 0)))
|| (aead_param.tag.is_some()
&& ((aead_param.tag.as_ref().unwrap().len() < MIN_TAG_SIZE)
|| (aead_param.tag.as_ref().unwrap().len() > MAX_TAG_SIZE)
|| (aead_param.tag.as_ref().unwrap().len() % 2 != 0)))
{
return Err(AtcaStatus::AtcaInvalidSize);
}
let mut tag_length: usize = ATCA_AES_DATA_SIZE;
if let Some(val) = &aead_param.tag_length {
tag_length = *val as usize
} else if let Some(val) = &aead_param.tag {
tag_length = val.len();
}
if let Some(val) = &aead_param.key {
let mut key: Vec<u8> = val.to_vec();
key.resize_with(ATCA_NONCE_SIZE, || 0x00);
let result = self.nonce(NonceTarget::TempKey, &key);
if AtcaStatus::AtcaSuccess != result {
return Err(result);
}
}
let iv: Vec<u8> = aead_param.nonce;
let mut additional_data_size: usize = 0;
if let Some(val) = &aead_param.additional_data {
additional_data_size = val.len();
if additional_data_size > MAX_AAD_SIZE {
return Err(AtcaStatus::AtcaInvalidSize);
}
};
let data_size = data.len();
let mut ctx: AtcaAesCcmCtx =
self.aes_ccm_init(slot_id, &iv, additional_data_size, data_size, tag_length)?;
if let Some(data_to_sign) = &aead_param.additional_data {
ctx = self.aes_ccm_aad_update(ctx, data_to_sign)?;
}
Ok(ctx)
}
fn aes_ccm_init(
&self,
slot_id: u8,
iv: &[u8],
aad_size: usize,
text_size: usize,
tag_size: usize,
) -> Result<AtcaAesCcmCtx, AtcaStatus> {
if iv.is_empty() || iv.len() < 7 || iv.len() > 13 {
return Err(AtcaStatus::AtcaBadParam);
}
if !(3..=ATCA_AES_DATA_SIZE).contains(&tag_size) || (tag_size % 2 != 0) {
return Err(AtcaStatus::AtcaBadParam);
}
let mut b: [u8; ATCA_AES_DATA_SIZE] = [0x00; ATCA_AES_DATA_SIZE];
let mut counter: [u8; ATCA_AES_DATA_SIZE] = [0x00; ATCA_AES_DATA_SIZE];
let mut ctx: AtcaAesCcmCtx = AtcaAesCcmCtx {
iv_size: iv.len() as u8,
..Default::default()
};
let m = ((tag_size - 2) / 2) as u8;
let l = (ATCA_AES_DATA_SIZE - iv.len() - 1 - 1) as u8;
ctx.m = m;
b[0] = l | (m << 3) | (((aad_size > 0) as u8) << 6);
b[1..=iv.len()].clone_from_slice(iv);
let mut size_left: usize = text_size;
for i in 0..=l {
b[(15 - i) as usize] = (size_left & 0xFF) as u8;
size_left >>= 8;
}
ctx.cbc_mac_ctx = self.aes_cbcmac_init(slot_id);
ctx.cbc_mac_ctx = self.aes_cbcmac_update(ctx.cbc_mac_ctx, &b)?;
if aad_size > 0 {
ctx.partial_aad[0] = ((aad_size >> 8) & 0xFF) as u8;
ctx.partial_aad[1] = (aad_size & 0xFF) as u8;
ctx.partial_aad_size = 2;
}
ctx.text_size = text_size;
counter[0] = l;
counter[1..=iv.len()].clone_from_slice(iv);
ctx.counter[..].copy_from_slice(&counter);
let counter_size: u8 = (ATCA_AES_DATA_SIZE - iv.len() - 1) as u8;
ctx.ctr_ctx = self.aes_ctr_init(slot_id, counter_size, &counter)?;
ctx.ctr_ctx = self.aes_ctr_increment(ctx.ctr_ctx)?;
Ok(ctx)
}
fn aes_ccm_update(
&self,
ctx: AtcaAesCcmCtx,
data: &mut Vec<u8>,
is_encrypt: bool,
) -> Result<AtcaAesCcmCtx, AtcaStatus> {
let mut temp_ctx = ctx;
temp_ctx = self.aes_ccm_aad_finish(temp_ctx)?;
if data.is_empty() {
return Ok(temp_ctx);
}
let mut data_idx: usize = 0;
let input_size: usize = data.len();
while data_idx < input_size {
if 0 == (temp_ctx.data_size % (ATCA_AES_DATA_SIZE as u32)) {
temp_ctx.enc_cb = self.aes_encrypt_block(
temp_ctx.ctr_ctx.key_id,
temp_ctx.ctr_ctx.key_block,
&temp_ctx.ctr_ctx.cb,
)?;
temp_ctx.ctr_ctx = self.aes_ctr_increment(temp_ctx.ctr_ctx)?;
}
let end_idx = min(ATCA_AES_DATA_SIZE, data.len() - data_idx);
for idx in ((temp_ctx.data_size as usize) % ATCA_AES_DATA_SIZE)..end_idx {
if is_encrypt {
temp_ctx.ciphertext_block[idx] = data[data_idx]
}
data[data_idx] ^= temp_ctx.enc_cb[idx];
if !is_encrypt {
temp_ctx.ciphertext_block[idx] = data[data_idx]
}
temp_ctx.data_size += 1;
data_idx += 1;
}
if 0 == (temp_ctx.data_size % (ATCA_AES_DATA_SIZE as u32)) {
temp_ctx.cbc_mac_ctx =
self.aes_cbcmac_update(temp_ctx.cbc_mac_ctx, &temp_ctx.ciphertext_block[..])?;
}
}
Ok(temp_ctx)
}
#[inline]
fn aes_ccm_decrypt_finish(&self, ctx: AtcaAesCcmCtx, tag: &[u8]) -> Result<bool, AtcaStatus> {
let val = self.aes_ccm_finish(ctx)?;
let matching = tag
.iter()
.zip(val.iter())
.filter(|&(tag, val)| tag == val)
.count();
match matching == tag.len() && matching == val.len() {
true => Ok(true),
false => Ok(false),
}
}
fn aes_ccm_finish(&self, ctx: AtcaAesCcmCtx) -> Result<Vec<u8>, AtcaStatus> {
let mut tag: Vec<u8> = vec![0x00; ATCA_AES_DATA_SIZE];
let mut t: [u8; ATCA_AES_DATA_SIZE] = [0x00; ATCA_AES_DATA_SIZE];
let mut u: [u8; ATCA_AES_DATA_SIZE] = [0x00; ATCA_AES_DATA_SIZE];
let mut buffer: [u8; ATCA_AES_DATA_SIZE] = [0x00; ATCA_AES_DATA_SIZE];
let mut temp_ctx = ctx;
let end_idx = (temp_ctx.data_size as usize) % ATCA_AES_DATA_SIZE;
if end_idx != 0 {
buffer[..end_idx].copy_from_slice(&temp_ctx.ciphertext_block[..end_idx]);
temp_ctx.cbc_mac_ctx = self.aes_cbcmac_update(temp_ctx.cbc_mac_ctx, &buffer)?;
}
let tag_size = ((temp_ctx.m * 2) + 2) as usize;
let val = self.aes_cbcmac_finish(temp_ctx.cbc_mac_ctx, tag_size)?;
t[..val.len()].copy_from_slice(&val[..val.len()]);
let mut slot = temp_ctx.ctr_ctx.key_id as u8;
if temp_ctx.ctr_ctx.key_id == ATCA_ATECC_TEMPKEY_KEYID {
slot = ATCA_ATECC_SLOTS_COUNT;
}
temp_ctx.ctr_ctx =
self.aes_ctr_init(slot, temp_ctx.ctr_ctx.key_block, &temp_ctx.counter)?;
temp_ctx.ctr_ctx = self.aes_ctr_block(temp_ctx.ctr_ctx, &t, &mut u)?;
tag.copy_from_slice(&u);
tag.resize(tag_size, 0x00);
tag.shrink_to_fit();
Ok(tag)
}
fn aes_ccm_aad_update(
&self,
ctx: AtcaAesCcmCtx,
data: &[u8],
) -> Result<AtcaAesCcmCtx, AtcaStatus> {
if data.is_empty() {
return Ok(ctx);
};
let mut temp_ctx: AtcaAesCcmCtx = ctx;
let mut aad_size: usize = data.len();
let copy_size: usize = min(aad_size, ATCA_AES_DATA_SIZE - temp_ctx.partial_aad_size);
let start_pos = temp_ctx.partial_aad_size;
let end_pos = min(ATCA_AES_DATA_SIZE, start_pos + copy_size);
temp_ctx.partial_aad[start_pos..end_pos].clone_from_slice(&data[..copy_size]);
if temp_ctx.partial_aad_size + aad_size < ATCA_AES_DATA_SIZE {
temp_ctx.partial_aad_size += aad_size;
return Ok(temp_ctx);
}
temp_ctx.cbc_mac_ctx =
self.aes_cbcmac_update(temp_ctx.cbc_mac_ctx, &temp_ctx.partial_aad)?;
aad_size -= copy_size; let block_count = aad_size / ATCA_AES_DATA_SIZE;
if block_count > 0 {
temp_ctx.cbc_mac_ctx = self.aes_cbcmac_update(
temp_ctx.cbc_mac_ctx,
&data[copy_size..((block_count * ATCA_AES_DATA_SIZE) + copy_size)],
)?;
}
temp_ctx.partial_aad_size = aad_size % ATCA_AES_DATA_SIZE;
let start_pos = copy_size + (block_count * ATCA_AES_DATA_SIZE);
temp_ctx.partial_aad[..temp_ctx.partial_aad_size]
.clone_from_slice(&data[start_pos..(start_pos + temp_ctx.partial_aad_size)]);
Ok(temp_ctx)
}
fn aes_ccm_aad_finish(&self, ctx: AtcaAesCcmCtx) -> Result<AtcaAesCcmCtx, AtcaStatus> {
let mut temp_ctx = ctx;
if temp_ctx.partial_aad_size > 0 {
let mut buffer: [u8; ATCA_AES_DATA_SIZE] = [0x00; ATCA_AES_DATA_SIZE];
buffer[..temp_ctx.partial_aad_size]
.copy_from_slice(&temp_ctx.partial_aad[..temp_ctx.partial_aad_size]);
temp_ctx.cbc_mac_ctx = self.aes_cbcmac_update(temp_ctx.cbc_mac_ctx, &buffer)?;
temp_ctx.partial_aad_size = 0
}
Ok(temp_ctx)
} }