use core::array;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use generic_array::{GenericArray, typenum::Unsigned};
use super::{
encryption,
key_expansion::key_exp_cstrnts,
vole_commitments::{VoleCommits, VoleCommitsRef},
};
use crate::{
aes::AddRoundKey,
fields::BigGaloisField,
internal_keys::PublicKey,
parameter::{OWFField, OWFParameters},
rijndael_32::rijndael_key_schedule_unbitsliced,
universal_hashing::ZKVerifyHasher,
};
pub(crate) fn owf_constraints<O>(
zk_hasher: &mut ZKVerifyHasher<OWFField<O>>,
w: VoleCommitsRef<'_, OWFField<O>, O::L>,
delta: &OWFField<O>,
pk: &PublicKey<O>,
) where
O: OWFParameters,
{
let PublicKey {
owf_input: x,
owf_output: y,
} = pk;
zk_hasher.mul_and_update(&w.scalars[0], &w.scalars[1]);
if O::IS_EM {
let extended_key = key_schedule_bytes::<O>(x, delta);
let owf_input = w.get_commits_ref::<O::NStBits>(0);
let owf_output = owf_input.add_round_key(GenericArray::<u8, O::NStBytes>::from_slice(y));
let w_tilde = w.get_commits_ref::<O::LEnc>(O::LKe::USIZE);
encryption::enc_cstrnts::<O, _>(
zk_hasher,
owf_input,
owf_output.to_ref(),
w_tilde,
extended_key.as_slice(),
);
} else {
let mut owf_input: VoleCommits<_, O::NStBits> = VoleCommits::from_constant(
GenericArray::<u8, O::NStBytes>::from_slice(x.as_slice()),
delta,
);
let k = key_exp_cstrnts::<O>(zk_hasher, w.get_commits_ref::<O::LKe>(0));
let extended_key: Vec<_> = (0..=O::R::USIZE)
.map(|i| k.get_commits_ref::<O::NStBits>(i * O::NStBits::USIZE))
.collect();
for b in 0..O::Beta::USIZE {
let w_tilde = w.get_commits_ref::<O::LEnc>(O::LKe::USIZE + b * O::LEnc::USIZE);
let owf_output = GenericArray::<u8, O::NStBytes>::from_slice(
&y[O::InputSize::USIZE * b..O::InputSize::USIZE * (b + 1)],
);
let owf_output = VoleCommits::from_constant(owf_output, delta);
encryption::enc_cstrnts::<O, _>(
zk_hasher,
owf_input.to_ref(),
owf_output.to_ref(),
w_tilde,
extended_key.as_slice(),
);
owf_input.scalars[0] += delta;
}
}
}
#[inline]
fn byte_to_field<F>(byte: u8, delta: F) -> [F; 8]
where
F: BigGaloisField,
{
array::from_fn(|i| delta * ((byte >> i) & 1))
}
#[inline]
fn key_schedule_bytes<'a, O>(
key: &GenericArray<u8, O::InputSize>,
delta: &'a OWFField<O>,
) -> Vec<VoleCommits<'a, OWFField<O>, O::NStBits>>
where
O: OWFParameters,
{
rijndael_key_schedule_unbitsliced::<O::NSt, O::NK, O::R>(key, O::SKe::USIZE)
.chunks_exact(32)
.take(O::R::USIZE + 1)
.map(|chunk| {
let scalars = chunk[..O::NStBytes::USIZE]
.iter()
.flat_map(|&byte| byte_to_field(byte, *delta))
.collect();
VoleCommits { scalars, delta }
})
.collect()
}