use core::{
iter::zip,
mem,
ops::{AddAssign, Mul},
};
#[cfg(not(feature = "std"))]
use alloc::{borrow::ToOwned, boxed::Box};
use generic_array::{
ArrayLength, GenericArray,
typenum::{U4, U8, Unsigned},
};
use super::{ByteCommits, ByteCommitsRef, FieldCommitDegOne, FieldCommitDegTwo};
use crate::{
aes::*,
fields::{
BigGaloisField, ByteCombine, ByteCombineConstants, Sigmas,
large_fields::ByteCombineSquaredConstants,
},
parameter::{OWFField, OWFParameters, SecurityParameter},
utils::xor_arrays_inplace,
};
pub(crate) type StateBitsSquaredCommits<O> =
Box<GenericArray<FieldCommitDegTwo<OWFField<O>>, <O as OWFParameters>::NStBits>>;
pub(crate) type StateBytesCommits<O> =
Box<GenericArray<FieldCommitDegOne<OWFField<O>>, <O as OWFParameters>::NStBytes>>;
pub(crate) type StateBytesSquaredCommits<O> =
Box<GenericArray<FieldCommitDegTwo<OWFField<O>>, <O as OWFParameters>::NStBytes>>;
impl<O, L> StateToBytes<O> for ByteCommitsRef<'_, OWFField<O>, L>
where
L: ArrayLength + Mul<U8, Output: ArrayLength>,
O: OWFParameters,
{
type Output = StateBytesCommits<O>;
fn state_to_bytes(&self) -> Self::Output
where {
(0..L::USIZE).map(|i| self.get_field_commit(i)).collect()
}
}
impl<O, L> StateToBytes<O> for GenericArray<u8, L>
where
L: ArrayLength + Mul<U8, Output: ArrayLength>,
O: OWFParameters,
{
type Output = Box<GenericArray<OWFField<O>, O::NStBytes>>;
fn state_to_bytes(&self) -> Self::Output {
self.iter()
.map(|&k| OWFField::<O>::byte_combine_bits(k))
.collect()
}
}
impl<F, L> AddRoundKey<&GenericArray<u8, L>> for ByteCommits<F, L>
where
L: ArrayLength + Mul<U8, Output: ArrayLength>,
F: BigGaloisField,
{
type Output = Self;
fn add_round_key(&self, rhs: &GenericArray<u8, L>) -> Self::Output {
Self::new(
zip(self.keys.iter(), rhs).map(|(a, b)| a ^ b).collect(),
self.tags.to_owned(),
)
}
}
impl<F, L> AddRoundKey<&GenericArray<u8, L>> for ByteCommitsRef<'_, F, L>
where
L: ArrayLength + Mul<U8, Output: ArrayLength>,
F: BigGaloisField,
{
type Output = ByteCommits<F, L>;
fn add_round_key(&self, rhs: &GenericArray<u8, L>) -> Self::Output {
ByteCommits {
keys: zip(self.keys, rhs).map(|(a, b)| a ^ b).collect(),
tags: Box::new(self.tags.to_owned()),
}
}
}
impl<F, L> AddRoundKey<&ByteCommitsRef<'_, F, L>> for &GenericArray<u8, L>
where
L: ArrayLength + Mul<U8, Output: ArrayLength>,
F: BigGaloisField,
{
type Output = ByteCommits<F, L>;
fn add_round_key(&self, rhs: &ByteCommitsRef<'_, F, L>) -> Self::Output {
ByteCommits {
keys: zip(self.iter(), rhs.keys).map(|(a, b)| a ^ b).collect(),
tags: Box::new(GenericArray::from_slice(&rhs.tags[..L::USIZE * 8]).to_owned()),
}
}
}
impl<F, L> AddRoundKey<&ByteCommitsRef<'_, F, L>> for ByteCommits<F, L>
where
L: ArrayLength + Mul<U8, Output: ArrayLength>,
F: BigGaloisField,
{
type Output = Self;
fn add_round_key(&self, rhs: &ByteCommitsRef<'_, F, L>) -> Self::Output {
Self {
keys: zip(self.keys.iter(), rhs.keys)
.map(|(a, b)| a ^ b)
.collect(),
tags: zip(self.tags.iter(), rhs.tags)
.map(|(&a, b)| a + b)
.collect(),
}
}
}
impl<F, L> AddRoundKeyAssign<&GenericArray<u8, L>> for ByteCommits<F, L>
where
L: ArrayLength + Mul<U8, Output: ArrayLength>,
F: BigGaloisField,
{
fn add_round_key_assign(&mut self, rhs: &GenericArray<u8, L>) {
xor_arrays_inplace(self.keys.as_mut_slice(), rhs.as_slice());
}
}
impl<F, L> AddRoundKeyAssign<&ByteCommitsRef<'_, F, L>> for ByteCommits<F, L>
where
L: ArrayLength + Mul<U8, Output: ArrayLength>,
F: BigGaloisField,
{
fn add_round_key_assign(&mut self, rhs: &ByteCommitsRef<'_, F, L>) {
xor_arrays_inplace(self.keys.as_mut_slice(), rhs.keys.as_slice());
for (a, b) in zip(self.tags.iter_mut(), rhs.tags.iter()) {
*a += b;
}
}
}
impl<T, L> ShiftRows for GenericArray<T, L>
where
T: Clone,
L: ArrayLength,
{
fn shift_rows(&mut self) {
let mut tmp = self.clone();
let nst = L::USIZE / 4;
for r in 0..4 {
for c in 0..nst {
let off = if (nst != 8) || (r <= 1) { 0 } else { 1 };
mem::swap(
&mut self[4 * c + r],
&mut tmp[4 * ((c + r + off) % nst) + r],
);
}
}
}
}
impl<F, L, T> AddRoundKeyBytes<&GenericArray<T, L>> for Box<GenericArray<FieldCommitDegTwo<F>, L>>
where
F: BigGaloisField,
L: ArrayLength + Mul<U8, Output: ArrayLength>,
for<'a> FieldCommitDegTwo<F>: AddAssign<&'a T>,
{
fn add_round_key_bytes(&mut self, key: &GenericArray<T, L>, _sq: bool) {
for (st, k) in zip(self.iter_mut(), key) {
*st += k;
}
}
}
impl<O> InverseShiftRows<O> for ByteCommitsRef<'_, OWFField<O>, O::NStBytes>
where
O: OWFParameters,
{
type Output = ByteCommits<OWFField<O>, O::NStBytes>;
fn inverse_shift_rows(&self) -> Self::Output {
let mut state_prime = ByteCommits::<OWFField<O>, O::NStBytes>::default();
for r in 0..4 {
for c in 0..O::NSt::USIZE {
let i = if (O::NSt::USIZE != 8) || (r <= 1) {
4 * ((O::NSt::USIZE + c - r) % O::NSt::USIZE) + r
} else {
4 * ((O::NSt::USIZE + c - r - 1) % O::NSt::USIZE) + r
};
state_prime.keys[4 * c + r] = self.keys[i];
state_prime.tags[8 * (4 * c + r)..8 * (4 * c + r) + 8]
.copy_from_slice(&self.tags[8 * i..8 * i + 8]);
}
}
state_prime
}
}
impl<O> MixColumns<O> for StateBytesSquaredCommits<O>
where
O: OWFParameters,
{
fn mix_columns(&mut self, sq: bool) {
let v2 = if sq {
&OWFField::<O>::BYTE_COMBINE_SQ_2
} else {
&OWFField::<O>::BYTE_COMBINE_2
};
let v3 = if sq {
&OWFField::<O>::BYTE_COMBINE_SQ_3
} else {
&OWFField::<O>::BYTE_COMBINE_3
};
for c in 0..O::NSt::USIZE {
let tmp = GenericArray::<_, U4>::from_slice(&self[4 * c..4 * c + 4]).to_owned();
let i0 = 4 * c;
let i1 = i0 + 1;
let i2 = i0 + 2;
let i3 = i0 + 3;
self[i0] = &tmp[0] * v2 + &tmp[1] * v3 + &tmp[2] + &tmp[3];
self[i1] = &tmp[1] * v2 + &tmp[2] * v3 + &tmp[0] + &tmp[3];
self[i2] = &tmp[2] * v2 + &tmp[3] * v3 + &tmp[0] + &tmp[1];
let (tmp0, tmp1, tmp2, tmp3) = tmp.into();
self[i3] = tmp0 * v3 + tmp3 * v2 + &tmp1 + &tmp2;
}
}
}
impl<O> BytewiseMixColumns<O> for ByteCommitsRef<'_, OWFField<O>, O::NStBytes>
where
O: OWFParameters,
{
type Output = ByteCommits<OWFField<O>, O::NStBytes>;
fn bytewise_mix_columns(&self) -> Self::Output {
let mut o = ByteCommits::default();
for c in 0..O::NSt::USIZE {
for r in 0..4 {
let a_key = self.keys[4 * c + r];
let a_tags = &self.tags[32 * c + 8 * r..32 * c + 8 * r + 8];
let a_key_7 = a_key & 0x80;
let b_key = a_key.rotate_left(1) ^ (a_key_7 >> 6) ^ (a_key_7 >> 4) ^ (a_key_7 >> 3);
let b_tags = [
a_tags[7],
a_tags[0] + a_tags[7],
a_tags[1],
a_tags[2] + a_tags[7],
a_tags[3] + a_tags[7],
a_tags[4],
a_tags[5],
a_tags[6],
];
for j in 0..2 {
let off = (4 + r - j) % 4;
o.keys[4 * c + off] ^= b_key;
for (o, b) in zip(
o.tags[32 * c + 8 * off..32 * c + 8 * off + 8].iter_mut(),
b_tags,
) {
*o += b;
}
}
for j in 1..4 {
let off = (r + j) % 4;
o.keys[4 * c + off] ^= a_key;
for (o, a) in zip(
o.tags[32 * c + 8 * off..32 * c + 8 * off + 8].iter_mut(),
a_tags,
) {
*o += a;
}
}
}
}
o
}
}
impl<O> SBoxAffine<O> for StateBitsSquaredCommits<O>
where
O: OWFParameters,
{
type Output = StateBytesSquaredCommits<O>;
fn s_box_affine(&self, sq: bool) -> Self::Output {
let sigmas = if sq {
&OWFField::<O>::SIGMA_SQUARES
} else {
&OWFField::<O>::SIGMA
};
let t = sq as usize;
(0..O::NStBytes::USIZE)
.map(|i| {
let mut y_i = &self[i * 8 + t % 8] * sigmas[0];
for sigma_idx in 1..8 {
y_i += &self[i * 8 + (sigma_idx + t) % 8] * sigmas[sigma_idx];
}
y_i += sigmas[8];
y_i
})
.collect()
}
}
impl<F, L> InverseAffine for ByteCommits<F, L>
where
F: BigGaloisField,
L: SecurityParameter,
{
fn inverse_affine(&mut self) {
for i in 0..L::USIZE {
self.keys[i] = self.keys[i].rotate_right(7)
^ self.keys[i].rotate_right(5)
^ self.keys[i].rotate_right(2)
^ 0x5;
let xi_tags =
GenericArray::<_, U8>::from_slice(&self.tags[8 * i..8 * i + 8]).to_owned();
for bit_i in 0..8 {
self.tags[8 * i + bit_i] = xi_tags[(bit_i + 8 - 1) % 8]
+ xi_tags[(bit_i + 8 - 3) % 8]
+ xi_tags[(bit_i + 8 - 6) % 8];
}
}
}
}