use core::{cmp::min, marker::PhantomData};
use aead::{
Error,
array::{
Array, ArraySize,
typenum::{IsGreaterOrEqual, IsLessOrEqual, Unsigned},
},
consts::{True, U4, U16},
inout::InOutBuf,
};
use ascon_core::State;
use subtle::ConstantTimeEq;
#[inline(always)]
const fn pad(n: usize) -> u64 {
0x01_u64 << (8 * n)
}
#[inline(always)]
const fn clear(word: u64, n: usize) -> u64 {
word & (0x00ffffffffffffff << (n * 8))
}
#[inline]
fn u64_from_bytes(input: &[u8]) -> u64 {
u64::from_le_bytes(input.try_into().unwrap())
}
#[inline]
fn u64_from_bytes_partial(input: &[u8]) -> u64 {
let mut tmp = [0u8; 8];
tmp[0..input.len()].copy_from_slice(input);
u64::from_le_bytes(tmp)
}
pub(crate) trait InternalKey<KS: ArraySize>:
Sized + Clone + for<'a> From<&'a Array<u8, KS>>
{
fn get_k1(&self) -> u64;
fn get_k2(&self) -> u64;
}
#[derive(Clone)]
#[cfg_attr(feature = "zeroize", derive(zeroize::Zeroize, zeroize::ZeroizeOnDrop))]
pub(crate) struct InternalKey16(u64, u64);
impl InternalKey<U16> for InternalKey16 {
#[inline(always)]
fn get_k1(&self) -> u64 {
self.0
}
#[inline(always)]
fn get_k2(&self) -> u64 {
self.1
}
}
impl From<&Array<u8, U16>> for InternalKey16 {
fn from(key: &Array<u8, U16>) -> Self {
Self(u64_from_bytes(&key[..8]), u64_from_bytes(&key[8..]))
}
}
pub(crate) trait Parameters {
type KeySize: ArraySize;
type TagSize: ArraySize
+ IsLessOrEqual<U16, Output = True>
+ IsGreaterOrEqual<U4, Output = True>;
type InternalKey: InternalKey<Self::KeySize>;
const IV: u64;
}
pub(crate) struct Parameters128<TagSize>(PhantomData<TagSize>)
where
TagSize: ArraySize + IsLessOrEqual<U16, Output = True> + IsGreaterOrEqual<U4, Output = True>;
impl<TagSize> Parameters for Parameters128<TagSize>
where
TagSize: ArraySize + IsLessOrEqual<U16, Output = True> + IsGreaterOrEqual<U4, Output = True>,
{
type KeySize = U16;
type TagSize = TagSize;
type InternalKey = InternalKey16;
const IV: u64 = 0x00001000808c0001;
}
pub(crate) struct AsconCore<'a, P: Parameters> {
state: State,
key: &'a P::InternalKey,
}
impl<'a, P: Parameters> AsconCore<'a, P> {
pub(crate) fn new(internal_key: &'a P::InternalKey, nonce: &Array<u8, U16>) -> Self {
let mut state = State::new(
P::IV,
internal_key.get_k1(),
internal_key.get_k2(),
u64_from_bytes(&nonce[..8]),
u64_from_bytes(&nonce[8..]),
);
state.permute_12();
state[3] ^= internal_key.get_k1();
state[4] ^= internal_key.get_k2();
Self {
state,
key: internal_key,
}
}
fn permute_12_and_apply_key(&mut self) {
self.state.permute_12();
self.state[3] ^= self.key.get_k1();
self.state[4] ^= self.key.get_k2();
}
#[inline(always)]
fn permute_state(&mut self) {
self.state.permute_8();
}
fn process_associated_data(&mut self, associated_data: &[u8]) {
if !associated_data.is_empty() {
let mut blocks = associated_data.chunks_exact(16);
for block in blocks.by_ref() {
self.state[0] ^= u64_from_bytes(&block[..8]);
self.state[1] ^= u64_from_bytes(&block[8..16]);
self.permute_state();
}
let mut last_block = blocks.remainder();
let sidx = if last_block.len() >= 8 {
self.state[0] ^= u64_from_bytes(&last_block[..8]);
last_block = &last_block[8..];
1
} else {
0
};
self.state[sidx] ^= pad(last_block.len());
if !last_block.is_empty() {
self.state[sidx] ^= u64_from_bytes_partial(last_block);
}
self.permute_state();
}
self.state[4] ^= 0x8000000000000000;
}
fn process_encrypt_inout(&mut self, message: InOutBuf<'_, '_, u8>) {
let (blocks, mut last_block) = message.into_chunks::<U16>();
for mut block in blocks {
self.state[0] ^= u64_from_bytes(&block.get_in()[..8]);
block.get_out()[..8].copy_from_slice(&u64::to_le_bytes(self.state[0]));
self.state[1] ^= u64_from_bytes(&block.get_in()[8..16]);
block.get_out()[8..16].copy_from_slice(&u64::to_le_bytes(self.state[1]));
self.permute_state();
}
let sidx = if last_block.len() >= 8 {
self.state[0] ^= u64_from_bytes(&last_block.get_in()[..8]);
last_block.get_out()[..8].copy_from_slice(&u64::to_le_bytes(self.state[0]));
(_, last_block) = last_block.split_at(8);
1
} else {
0
};
self.state[sidx] ^= pad(last_block.len());
if !last_block.is_empty() {
self.state[sidx] ^= u64_from_bytes_partial(last_block.get_in());
let last_block_len = last_block.len();
last_block
.get_out()
.copy_from_slice(&u64::to_le_bytes(self.state[sidx])[0..last_block_len]);
}
}
fn process_decrypt_inout(&mut self, ciphertext: InOutBuf<'_, '_, u8>) {
let (blocks, mut last_block) = ciphertext.into_chunks::<U16>();
for mut block in blocks {
let cx = u64_from_bytes(&block.get_in()[..8]);
block.get_out()[..8].copy_from_slice(&u64::to_le_bytes(self.state[0] ^ cx));
self.state[0] = cx;
let cx = u64_from_bytes(&block.get_in()[8..16]);
block.get_out()[8..16].copy_from_slice(&u64::to_le_bytes(self.state[1] ^ cx));
self.state[1] = cx;
self.permute_state();
}
let sidx = if last_block.len() >= 8 {
let cx = u64_from_bytes(&last_block.get_in()[..8]);
last_block.get_out()[..8].copy_from_slice(&u64::to_le_bytes(self.state[0] ^ cx));
self.state[0] = cx;
(_, last_block) = last_block.split_at(8);
1
} else {
0
};
self.state[sidx] ^= pad(last_block.len());
if !last_block.is_empty() {
let cx = u64_from_bytes_partial(last_block.get_in());
self.state[sidx] ^= cx;
let last_block_len = last_block.len();
last_block
.get_out()
.copy_from_slice(&u64::to_le_bytes(self.state[sidx])[0..last_block_len]);
self.state[sidx] = clear(self.state[sidx], last_block.len()) ^ cx;
}
}
fn process_final(&mut self) -> Array<u8, P::TagSize> {
self.state[2] ^= self.key.get_k1();
self.state[3] ^= self.key.get_k2();
self.permute_12_and_apply_key();
let mut tag = Array::default();
tag[..min(8, P::TagSize::USIZE)]
.copy_from_slice(&u64::to_le_bytes(self.state[3])[..min(8, P::TagSize::USIZE)]);
if P::TagSize::USIZE > 8 {
tag[8..min(16, P::TagSize::USIZE)]
.copy_from_slice(&u64::to_le_bytes(self.state[4])[..min(8, P::TagSize::USIZE - 8)]);
}
tag
}
pub(crate) fn encrypt_inout(
&mut self,
message: InOutBuf<'_, '_, u8>,
associated_data: &[u8],
) -> Array<u8, P::TagSize> {
self.process_associated_data(associated_data);
self.process_encrypt_inout(message);
self.process_final()
}
pub(crate) fn decrypt_inout(
&mut self,
mut ciphertext: InOutBuf<'_, '_, u8>,
associated_data: &[u8],
expected_tag: &Array<u8, P::TagSize>,
) -> Result<(), Error> {
self.process_associated_data(associated_data);
self.process_decrypt_inout(ciphertext.reborrow());
let tag = self.process_final();
if bool::from(tag.ct_eq(expected_tag)) {
Ok(())
} else {
ciphertext.get_out().fill(0);
Err(Error)
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn clear_0to7() {
assert_eq!(clear(0x0123456789abcdef, 1), 0x0123456789abcd00);
assert_eq!(clear(0x0123456789abcdef, 2), 0x0123456789ab0000);
assert_eq!(clear(0x0123456789abcdef, 3), 0x0123456789000000);
assert_eq!(clear(0x0123456789abcdef, 4), 0x0123456700000000);
assert_eq!(clear(0x0123456789abcdef, 5), 0x0123450000000000);
assert_eq!(clear(0x0123456789abcdef, 6), 0x0123000000000000);
assert_eq!(clear(0x0123456789abcdef, 7), 0x0100000000000000);
}
}