use generic_array::{GenericArray, typenum::Unsigned};
#[cfg(not(feature = "std"))]
use alloc::{borrow::ToOwned, vec::Vec};
use super::{
ByteCommitsRef, FieldCommitDegThree, encryption::enc_cstrnts, key_expansion::key_exp_cstrnts,
};
use crate::{
fields::{Field, FromBit},
internal_keys::PublicKey,
parameter::{OWFField, OWFParameters},
rijndael_32::rijndael_key_schedule_unbitsliced,
universal_hashing::ZKProofHasher,
utils::xor_arrays,
};
pub(crate) fn owf_constraints<O>(
zk_hasher: &mut ZKProofHasher<OWFField<O>>,
w: ByteCommitsRef<OWFField<O>, O::LBytes>,
pk: &PublicKey<O>,
) where
O: OWFParameters,
{
let PublicKey {
owf_input: x,
owf_output: y,
} = pk;
{
let (k_0, k_1) = ((w.keys[0] & 0b1), (w.keys[0] & 0b10) >> 1);
zk_hasher.update(&FieldCommitDegThree {
key: OWFField::<O>::from_bit(k_0 & k_1),
tag: [
OWFField::<O>::ZERO,
w.tags[0] * w.tags[1],
w.tags[0] * k_1 + w.tags[1] * k_0,
],
});
}
if O::IS_EM {
let ext_key = key_schedule_bytes::<O>(x);
let owf_input = w.get_commits_ref::<O::NStBytes>(0);
let owf_output_keys: GenericArray<u8, O::NStBytes> =
xor_arrays(owf_input.keys.as_slice(), y.as_slice()).collect();
let owf_output = ByteCommitsRef::from_slices(&owf_output_keys, owf_input.tags);
let w_tilde = w.get_commits_ref::<O::LEncBytes>(O::LKeBytes::USIZE);
enc_cstrnts::<O, _, _>(
zk_hasher,
owf_input,
owf_output,
w_tilde,
ext_key.as_slice(),
);
} else {
let mut owf_input = GenericArray::from_slice(x).to_owned();
let k = key_exp_cstrnts::<O>(zk_hasher, w.get_commits_ref::<O::LKeBytes>(0));
let extended_key: Vec<_> = (0..=O::R::USIZE)
.map(|i| k.get_commits_ref::<O::NStBytes>(i * O::NStBytes::USIZE))
.collect();
for b in 0..O::Beta::USIZE {
let w_tilde =
w.get_commits_ref::<O::LEncBytes>(O::LKeBytes::USIZE + b * O::LEncBytes::USIZE);
let owf_output = GenericArray::from_slice(
&y[O::InputSize::USIZE * b..O::InputSize::USIZE * (b + 1)],
);
enc_cstrnts::<O, _, _>(
zk_hasher,
&owf_input,
owf_output,
w_tilde,
extended_key.as_slice(),
);
owf_input[0] ^= 1;
}
}
}
#[inline]
fn key_schedule_bytes<O>(key: &GenericArray<u8, O::InputSize>) -> Vec<GenericArray<u8, O::NStBytes>>
where
O: OWFParameters,
{
rijndael_key_schedule_unbitsliced::<O::NSt, O::NK, O::R>(key, O::SKe::USIZE)
.chunks_exact(32)
.take(O::R::USIZE + 1)
.map(|chunk| {
chunk
.iter()
.copied()
.take(O::NStBytes::USIZE)
.collect::<GenericArray<u8, O::NStBytes>>()
})
.collect()
}