use core::{iter::zip, mem, ops::Mul};
#[cfg(not(feature = "std"))]
use alloc::boxed::Box;
use hybrid_array::{
Array, ArraySize,
typenum::{U8, marker_traits::Unsigned},
};
use crate::{
aes::{
AddRoundKey, AddRoundKeyAssign, AddRoundKeyBytes, BytewiseMixColumns, InverseAffine,
InverseShiftRows, MixColumns, SBoxAffine, ShiftRows, StateToBytes,
},
fields::{
BigGaloisField, ByteCombine, ByteCombineConstants, ByteCombineSquaredConstants, Sigmas,
Square,
},
parameter::{OWFField, OWFParameters},
verifier::{VoleCommits, VoleCommitsRef},
};
impl<O> StateToBytes<O> for VoleCommits<'_, OWFField<O>, O::NStBits>
where
O: OWFParameters,
{
type Output = Array<OWFField<O>, O::NStBytes>;
fn state_to_bytes(&self) -> Self::Output {
self.scalars
.chunks_exact(8)
.map(OWFField::<O>::byte_combine_slice)
.collect()
}
}
impl<O> StateToBytes<O> for VoleCommitsRef<'_, OWFField<O>, O::NStBits>
where
O: OWFParameters,
{
type Output = Array<OWFField<O>, O::NStBytes>;
fn state_to_bytes(&self) -> Self::Output {
self.scalars
.chunks_exact(8)
.map(OWFField::<O>::byte_combine_slice)
.collect()
}
}
impl<'a, F, L> AddRoundKey<Self> for &VoleCommits<'a, F, L>
where
F: BigGaloisField,
L: ArraySize,
{
type Output = VoleCommits<'a, F, L>;
fn add_round_key(&self, rhs: Self) -> Self::Output {
Self::Output {
scalars: Box::new(
zip(self.scalars.iter(), rhs.scalars.iter())
.map(|(x, y)| *x + y)
.collect(),
),
delta: self.delta,
}
}
}
impl<'a, F, L> AddRoundKey<&VoleCommits<'a, F, L>> for VoleCommitsRef<'a, F, L>
where
F: BigGaloisField,
L: ArraySize,
{
type Output = VoleCommits<'a, F, L>;
fn add_round_key(&self, rhs: &VoleCommits<'a, F, L>) -> Self::Output {
Self::Output {
scalars: Box::new(
zip(self.scalars, rhs.scalars.iter())
.map(|(x, y)| *x + y)
.collect(),
),
delta: self.delta,
}
}
}
impl<'a, F, L> AddRoundKey<&Self> for VoleCommitsRef<'a, F, L>
where
F: BigGaloisField,
L: ArraySize,
{
type Output = VoleCommits<'a, F, L>;
fn add_round_key(&self, rhs: &Self) -> Self::Output {
Self::Output {
scalars: Box::new(
zip(self.scalars, rhs.scalars)
.map(|(x, y)| *x + y)
.collect(),
),
delta: self.delta,
}
}
}
impl<'a, F, L> AddRoundKey<&Array<F, L>> for VoleCommitsRef<'a, F, L>
where
F: BigGaloisField,
L: ArraySize,
{
type Output = VoleCommits<'a, F, L>;
fn add_round_key(&self, rhs: &Array<F, L>) -> Self::Output {
Self::Output {
scalars: Box::new(zip(self.scalars, rhs).map(|(x, y)| *x + y).collect()),
delta: self.delta,
}
}
}
impl<'a, F, L, L2> AddRoundKey<&Array<u8, L>> for VoleCommitsRef<'a, F, L2>
where
F: BigGaloisField,
L: ArraySize + Mul<U8, Output = L2>,
L2: ArraySize,
{
type Output = VoleCommits<'a, F, L2>;
fn add_round_key(&self, rhs: &Array<u8, L>) -> Self::Output {
let scalars = Box::new(
self.scalars
.iter()
.enumerate()
.map(|(i, comm_i)| *comm_i + *self.delta * ((rhs[i / 8] >> (i % 8)) & 1))
.collect(),
);
VoleCommits {
scalars,
delta: self.delta,
}
}
}
impl<F, L> AddRoundKeyAssign<&Self> for VoleCommits<'_, F, L>
where
F: BigGaloisField,
L: ArraySize,
{
fn add_round_key_assign(&mut self, rhs: &Self) {
for (x, y) in zip(self.scalars.iter_mut(), rhs.scalars.iter()) {
*x += y;
}
}
}
impl<F, L> AddRoundKeyAssign<&VoleCommitsRef<'_, F, L>> for VoleCommits<'_, F, L>
where
F: BigGaloisField,
L: ArraySize,
{
fn add_round_key_assign(&mut self, rhs: &VoleCommitsRef<'_, F, L>) {
for (x, y) in zip(self.scalars.iter_mut(), rhs.scalars.iter()) {
*x += y;
}
}
}
impl<F, L> AddRoundKeyAssign<&Array<F, L>> for VoleCommits<'_, F, L>
where
F: BigGaloisField,
L: ArraySize,
{
fn add_round_key_assign(&mut self, rhs: &Array<F, L>) {
for (x, y) in zip(self.scalars.iter_mut(), rhs.iter()) {
*x += y;
}
}
}
impl<'a, O> InverseShiftRows<O> for VoleCommitsRef<'a, OWFField<O>, O::NStBits>
where
O: OWFParameters,
{
type Output = VoleCommits<'a, OWFField<O>, O::NStBits>;
fn inverse_shift_rows(&self) -> Self::Output {
let mut state_prime: Box<Array<_, O::NStBits>> = Box::default();
for r in 0..4 {
for c in 0..O::NSt::USIZE {
let i = if (O::NSt::USIZE != 8) || (r <= 1) {
4 * ((O::NSt::USIZE + c - r) % O::NSt::USIZE) + r
} else {
4 * ((O::NSt::USIZE + c - r - 1) % O::NSt::USIZE) + r
};
state_prime[8 * (4 * c + r)..8 * (4 * c + r) + 8]
.copy_from_slice(&self.scalars[8 * i..8 * i + 8]);
}
}
Self::Output {
scalars: state_prime,
delta: self.delta,
}
}
}
impl<'a, O> InverseShiftRows<O> for &VoleCommits<'a, OWFField<O>, O::NStBits>
where
O: OWFParameters,
{
type Output = VoleCommits<'a, OWFField<O>, O::NStBits>;
fn inverse_shift_rows(&self) -> Self::Output {
let mut state_prime: Box<Array<_, O::NStBits>> = Box::default();
for r in 0..4 {
for c in 0..O::NSt::USIZE {
let i = if (O::NSt::USIZE != 8) || (r <= 1) {
4 * ((O::NSt::USIZE + c - r) % O::NSt::USIZE) + r
} else {
4 * ((O::NSt::USIZE + c - r - 1) % O::NSt::USIZE) + r
};
state_prime[8 * (4 * c + r)..8 * (4 * c + r) + 8]
.copy_from_slice(&self.scalars[8 * i..8 * i + 8]);
}
}
Self::Output {
scalars: state_prime,
delta: self.delta,
}
}
}
impl<'a, O> BytewiseMixColumns<O> for VoleCommitsRef<'a, OWFField<O>, O::NStBits>
where
O: OWFParameters,
{
type Output = VoleCommits<'a, OWFField<O>, O::NStBits>;
fn bytewise_mix_columns(&self) -> Self::Output {
let mut o: Box<Array<_, O::NStBits>> = Box::default();
for c in 0..O::NSt::USIZE {
for r in 0..4 {
let a_key = &self.scalars[32 * c + 8 * r..32 * c + 8 * r + 8];
let b_key = [
a_key[7],
a_key[0] + a_key[7],
a_key[1],
a_key[2] + a_key[7],
a_key[3] + a_key[7],
a_key[4],
a_key[5],
a_key[6],
];
for j in 0..2 {
let off = 32 * c + 8 * ((4 + r - j) % 4);
for (o, b) in zip(o[off..off + 8].iter_mut(), b_key.iter()) {
*o += b;
}
}
for j in 1..4 {
let off = 32 * c + 8 * ((r + j) % 4);
for (o, a) in zip(o[off..off + 8].iter_mut(), a_key.iter()) {
*o += a;
}
}
}
}
VoleCommits {
scalars: o,
delta: self.delta,
}
}
}
impl<'a, O> SBoxAffine<O> for VoleCommits<'a, OWFField<O>, O::NStBits>
where
O: OWFParameters,
{
type Output = VoleCommits<'a, OWFField<O>, O::NStBytes>;
fn s_box_affine(&self, sq: bool) -> Self::Output {
let sigmas = if sq {
&OWFField::<O>::SIGMA_SQUARES
} else {
&OWFField::<O>::SIGMA
};
let t = sq as usize;
let scalars = Box::new(
(0..O::NStBytes::USIZE)
.map(|i| {
let mut y_i = sigmas[8] * self.delta.square();
for (sigma_idx, sigma) in sigmas.iter().enumerate().take(8) {
y_i += self.scalars[i * 8 + (sigma_idx + t) % 8] * sigma;
}
y_i
})
.collect(),
);
Self::Output {
scalars,
delta: self.delta,
}
}
}
impl<F, L> ShiftRows for VoleCommits<'_, F, L>
where
F: BigGaloisField,
L: ArraySize,
{
fn shift_rows(&mut self) {
let mut tmp = self.scalars.clone();
let nst = L::USIZE / 4;
for r in 0..4 {
for c in 0..nst {
let off = if (nst != 8) || (r <= 1) { 0 } else { 1 };
mem::swap(
&mut self.scalars[4 * c + r],
&mut tmp[4 * ((c + r + off) % nst) + r],
);
}
}
}
}
impl<F, L> InverseAffine for VoleCommits<'_, F, L>
where
F: BigGaloisField,
L: ArraySize,
{
fn inverse_affine(&mut self) {
for scalars in self.scalars.as_chunks_mut::<8>().0 {
let xi_tags = *scalars;
for bit_i in 0..8 {
scalars[bit_i] = xi_tags[(bit_i + 8 - 1) % 8]
+ xi_tags[(bit_i + 8 - 3) % 8]
+ xi_tags[(bit_i + 8 - 6) % 8];
if bit_i == 0 || bit_i == 2 {
scalars[bit_i] += self.delta;
}
}
}
}
}
impl<O> MixColumns<O> for VoleCommits<'_, OWFField<O>, O::NStBytes>
where
O: OWFParameters,
{
fn mix_columns(&mut self, sq: bool) {
let (v2, v3) = if sq {
(
&OWFField::<O>::BYTE_COMBINE_SQ_2,
&OWFField::<O>::BYTE_COMBINE_SQ_3,
)
} else {
(
&OWFField::<O>::BYTE_COMBINE_2,
&OWFField::<O>::BYTE_COMBINE_3,
)
};
for scalars in self.scalars.as_chunks_mut::<4>().0 {
let tmp = *scalars;
scalars[0] = tmp[0] * v2 + tmp[1] * v3 + tmp[2] + tmp[3];
scalars[1] = tmp[1] * v2 + tmp[2] * v3 + tmp[0] + tmp[3];
scalars[2] = tmp[2] * v2 + tmp[3] * v3 + tmp[0] + tmp[1];
scalars[3] = tmp[0] * v3 + tmp[3] * v2 + tmp[1] + tmp[2];
}
}
}
impl<F, L> AddRoundKeyBytes<&Array<F, L>> for VoleCommits<'_, F, L>
where
F: BigGaloisField,
L: ArraySize,
{
fn add_round_key_bytes(&mut self, rhs: &Array<F, L>, sq: bool) {
if sq {
for (st, k) in zip(self.scalars.iter_mut(), rhs) {
*st += k;
}
} else {
for (st, k) in zip(self.scalars.iter_mut(), rhs) {
*st += *k * self.delta;
}
}
}
}