use core::{iter::zip, ops::AddAssign};
#[cfg(not(feature = "std"))]
use alloc::boxed::Box;
use generic_array::{
GenericArray,
typenum::{Quot, U2, U4, U8, Unsigned},
};
use super::{
ByteCommits, ByteCommitsRef, FieldCommitDegOne, FieldCommitDegTwo,
byte_commitments::ByteCommitment,
};
use crate::{
aes::*,
fields::{BigGaloisField, Square},
parameter::{OWFField, OWFParameters},
prover::aes::{StateBitsSquaredCommits, StateBytesSquaredCommits},
universal_hashing::ZKProofHasher,
};
pub(crate) fn enc_cstrnts<O, K, S>(
zk_hasher: &mut ZKProofHasher<OWFField<O>>,
input: impl for<'a> AddRoundKey<&'a K, Output = ByteCommits<OWFField<O>, O::NStBytes>>,
output: impl for<'a> AddRoundKey<&'a K, Output = ByteCommits<OWFField<O>, O::NStBytes>>,
w: ByteCommitsRef<OWFField<O>, O::LEncBytes>,
extended_key: &[K],
) where
O: OWFParameters,
K: StateToBytes<O, Output = Box<GenericArray<S, O::NStBytes>>>,
S: Square + Clone,
for<'a> FieldCommitDegTwo<OWFField<O>>: AddAssign<&'a S>,
for<'a> FieldCommitDegTwo<OWFField<O>>: AddAssign<&'a <S as Square>::Output>,
ByteCommits<OWFField<O>, O::NStBytes>: for<'a> AddRoundKeyAssign<&'a K>,
for<'a> ByteCommitsRef<'a, OWFField<O>, O::NStBytes>:
BytewiseMixColumns<O, Output = ByteCommits<OWFField<O>, O::NStBytes>>,
{
let mut state = input.add_round_key(&extended_key[0]);
for r in 0..O::R::USIZE / 2 {
let state_prime = enc_cstrnts_even::<O>(
zk_hasher,
&state,
w.get_commits_ref::<Quot<O::NStBytes, U2>>(3 * O::NStBytes::USIZE * r / 2),
);
let round_key = extended_key[2 * r + 1].state_to_bytes();
let round_key_sq = (&round_key).square();
let st_0 = aes_round::<O, _>(&state_prime, &round_key, false);
let st_1 = aes_round::<O, _>(&state_prime, &round_key_sq, true);
let round_key = &extended_key[2 * r + 2];
if r != O::R::USIZE / 2 - 1 {
let s_tilde = w.get_commits_ref::<O::NStBytes>(
O::NStBytes::USIZE / 2 + 3 * O::NStBytes::USIZE * r / 2,
);
enc_cstrnts_odd::<O>(zk_hasher, s_tilde, &st_0, &st_1);
state = s_tilde.bytewise_mix_columns();
state.add_round_key_assign(round_key);
} else {
let s_tilde = output.add_round_key(round_key);
enc_cstrnts_odd::<O>(zk_hasher, s_tilde.to_ref(), &st_0, &st_1);
}
}
}
fn enc_cstrnts_even<O>(
zk_hasher: &mut ZKProofHasher<OWFField<O>>,
state: &ByteCommits<OWFField<O>, O::NStBytes>,
w: ByteCommitsRef<OWFField<O>, Quot<O::NStBytes, U2>>,
) -> StateBitsSquaredCommits<O>
where
O: OWFParameters,
{
let mut state_prime = StateBitsSquaredCommits::<O>::default();
let mut state_conj = GenericArray::default();
for (i, (state_prime, mut commit_i)) in
zip(state_prime.chunks_exact_mut(8), state.iter()).enumerate()
{
let norm = (w.keys[i / 2] >> ((i % 2) * 4)) & 0xf;
let ys = invnorm_to_conjugates(norm, &w.tags[4 * i..4 * i + 4]);
f256_f2_conjugates(&mut state_conj, &mut commit_i);
state_prime[0] = &ys[0] * &state_conj[4];
let cnstr = &state_prime[0] * &state_conj[1] + &state_conj[0];
zk_hasher.update(&cnstr);
for j in 1..8 {
state_prime[j] = &state_conj[(j + 4) % 8] * &ys[j % 4];
}
}
state_prime
}
fn enc_cstrnts_odd<O>(
zk_hasher: &mut ZKProofHasher<OWFField<O>>,
s_tilde: ByteCommitsRef<OWFField<O>, O::NStBytes>,
st_0: &StateBytesSquaredCommits<O>,
st_1: &StateBytesSquaredCommits<O>,
) where
O: OWFParameters,
ByteCommits<OWFField<O>, O::NStBytes>: InverseAffine,
for<'a> ByteCommitsRef<'a, OWFField<O>, O::NStBytes>:
InverseShiftRows<O, Output = ByteCommits<OWFField<O>, O::NStBytes>>,
{
let mut s = s_tilde.inverse_shift_rows();
s.inverse_affine();
for (byte_i, (st0, st1)) in zip(st_0.as_ref(), st_1.as_ref()).enumerate() {
let si = s.get_field_commit(byte_i);
let si_sq = s.get_field_commit_sq(byte_i);
zk_hasher.odd_round_cstrnts(&si, &si_sq, st0, st1);
}
}
fn aes_round<O, T>(
state: &StateBitsSquaredCommits<O>,
key_bytes: &GenericArray<T, O::NStBytes>,
sq: bool,
) -> StateBytesSquaredCommits<O>
where
O: OWFParameters,
StateBitsSquaredCommits<O>: SBoxAffine<O, Output = StateBytesSquaredCommits<O>>,
FieldCommitDegTwo<OWFField<O>>: for<'a> AddAssign<&'a T>,
StateBytesSquaredCommits<O>: MixColumns<O>,
{
let mut st = state.s_box_affine(sq);
st.shift_rows();
st.mix_columns(sq);
st.add_round_key_bytes(key_bytes, sq);
st
}
fn invnorm_to_conjugates<F>(x_val: u8, x_tag: &[F]) -> GenericArray<FieldCommitDegOne<F>, U4>
where
F: BigGaloisField,
{
(0..4)
.map(|j| {
let key = F::from_bit(x_val & 1)
+ F::BETA_SQUARES[j] * ((x_val >> 1) & 1)
+ F::BETA_SQUARES[j + 1] * ((x_val >> 2) & 1)
+ F::BETA_CUBES[j] * ((x_val >> 3) & 1);
let tag = x_tag[0]
+ F::BETA_SQUARES[j] * x_tag[1]
+ F::BETA_SQUARES[j + 1] * x_tag[2]
+ F::BETA_CUBES[j] * x_tag[3];
FieldCommitDegOne::new(key, tag)
})
.collect()
}
pub(crate) fn f256_f2_conjugates<F>(
state_conj: &mut GenericArray<FieldCommitDegOne<F>, U8>,
commit: &mut ByteCommitment<F>,
) where
F: BigGaloisField,
{
for st_conj_i in state_conj.iter_mut() {
*st_conj_i = commit.combine();
commit.square_inplace();
}
}