use core::mem::size_of;
#[cfg(not(feature = "std"))]
use alloc::{borrow::ToOwned, boxed::Box};
use generic_array::{
GenericArray,
typenum::{U3, Unsigned},
};
use crate::{
fields::{GF8, Square},
parameter::OWFParameters,
rijndael_32::{
State, bitslice, convert_from_batchblocks, inv_bitslice, mix_columns_0,
rijndael_add_round_key, rijndael_key_schedule, rijndael_shift_rows_1, rijndael_sub_bytes,
},
};
pub(crate) fn aes_extendedwitness<O>(
owf_secret: &GenericArray<u8, O::LambdaBytes>,
owf_input: &GenericArray<u8, O::InputSize>,
) -> Box<GenericArray<u8, O::LBytes>>
where
O: OWFParameters,
{
let mut input: GenericArray<u8, O::InputSize> = owf_input.to_owned();
let mut witness = GenericArray::default_boxed();
let kb = rijndael_key_schedule::<O::NSt, O::NK, O::R>(owf_secret, O::SKe::USIZE);
let mut index = 0;
if !O::IS_EM {
save_key_bits::<O>(&mut witness, owf_secret, &mut index);
save_non_lin_bits::<O>(&mut witness, &kb, &mut index);
} else {
save_key_bits::<O>(&mut witness, owf_input, &mut index);
}
for _ in 0..O::Beta::USIZE {
round_with_save::<O>(&input, &kb, &mut witness, &mut index);
input[0] ^= 1;
}
witness
}
fn save_key_bits<O>(witness: &mut [u8], key: &[u8], index: &mut usize)
where
O: OWFParameters,
{
witness[..O::LambdaBytes::USIZE].copy_from_slice(key);
*index += O::LambdaBytes::USIZE;
}
fn save_non_lin_bits<O>(witness: &mut [u8], kb: &[u32], index: &mut usize)
where
O: OWFParameters,
{
let start_off = 1 + (O::NK::USIZE / 8);
let non_lin_blocks = if O::NK::USIZE % 4 == 0 {
O::SKe::USIZE / 4
} else {
O::SKe::USIZE * 3 / 8
};
for j in start_off..start_off + non_lin_blocks {
let inside = GenericArray::<_, U3>::from_iter(
convert_from_batchblocks(inv_bitslice(&kb[8 * j..8 * (j + 1)])).take(3),
);
if O::NK::USIZE != 6 || j % 3 == 0 {
witness[*index..*index + size_of::<u32>()].copy_from_slice(&inside[0]);
*index += size_of::<u32>();
} else if j % 3 == 1 {
witness[*index..*index + size_of::<u32>()].copy_from_slice(&inside[2]);
*index += size_of::<u32>();
}
}
}
fn gf8_exp_238(x: GF8) -> GF8 {
let y = x.square(); let x = y.square(); let y = x * y;
let x = x.square(); let y = x * y;
let x = x.square(); let x = x.square(); let y = x * y;
let x = x.square(); let y = x * y;
let x = x.square(); x * y
}
fn invnorm(x: u8) -> u8 {
let x = u8::from(gf8_exp_238(x.into()));
(x & 1) ^ ((x >> 5) & 6) ^ ((x << 1) & 8)
}
#[inline]
fn store_invnorm_state(lo_idx: u8, hi_idx: u8) -> u8 {
invnorm(lo_idx) | (invnorm(hi_idx) << 4)
}
fn round_with_save<O>(
input1: &[u8], kb: &[u32], witness: &mut [u8],
index: &mut usize,
) where
O: OWFParameters,
{
let mut state = State::default();
let (input0, input1) = input1.split_at(16);
bitslice(&mut state, input0, input1);
rijndael_add_round_key(&mut state, &kb[..8]);
for j in 0..O::R::USIZE - 1 {
let even_round = (j % 2) == 0;
if even_round {
let to_take = if !O::IS_EM { 4 } else { O::NK::USIZE };
for i in convert_from_batchblocks(inv_bitslice(&state)).take(to_take) {
witness[*index] = store_invnorm_state(i[0], i[1]);
*index += 1;
witness[*index] = store_invnorm_state(i[2], i[3]);
*index += 1;
}
}
rijndael_sub_bytes(&mut state);
rijndael_shift_rows_1::<O::NSt>(&mut state);
if !even_round {
for i in convert_from_batchblocks(inv_bitslice(&state)).take(O::NSt::USIZE) {
witness[*index..*index + size_of::<u32>()].copy_from_slice(&i);
*index += size_of::<u32>();
}
}
mix_columns_0(&mut state);
rijndael_add_round_key(&mut state, &kb[8 * (j + 1)..8 * (j + 2)]);
}
}
#[cfg(test)]
mod test {
#![allow(clippy::needless_range_loop)]
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use crate::{
fields::{GF128, GF192, GF256},
parameter::{OWF128, OWF128EM, OWF192, OWF192EM, OWF256, OWF256EM, OWFParameters},
utils::test::read_test_data,
};
use generic_array::GenericArray;
use serde::Deserialize;
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct AesExtendedWitness {
lambda: u16,
em: bool,
key: Vec<u8>,
input: Vec<u8>,
w: Vec<u8>,
}
impl AesExtendedWitness {
fn test(&self) -> bool {
match self.em {
false => self.extend_witness_test(),
true => self.extend_witness_test_em(),
}
}
fn extend_witness_test(&self) -> bool {
match self.lambda {
128 => {
let wit = OWF128::<GF128>::extendwitness(
GenericArray::from_slice(&self.key),
GenericArray::from_slice(&self.input),
);
(*wit).as_slice() == self.w.as_slice()
}
192 => {
let wit = OWF192::<GF192>::extendwitness(
GenericArray::from_slice(&self.key),
GenericArray::from_slice(&self.input),
);
(*wit).as_slice() == self.w.as_slice()
}
_ => {
let wit = OWF256::<GF256>::extendwitness(
GenericArray::from_slice(&self.key),
GenericArray::from_slice(&self.input),
);
(*wit).as_slice() == self.w.as_slice()
}
}
}
fn extend_witness_test_em(&self) -> bool {
match self.lambda {
128 => {
let wit = OWF128EM::<GF128>::extendwitness(
GenericArray::from_slice(&self.key),
GenericArray::from_slice(&self.input),
);
(*wit).as_slice() == self.w.as_slice()
}
192 => {
let wit = OWF192EM::<GF192>::extendwitness(
GenericArray::from_slice(&self.key),
GenericArray::from_slice(&self.input),
);
(*wit).as_slice() == self.w.as_slice()
}
_ => {
let wit = OWF256EM::<GF256>::extendwitness(
GenericArray::from_slice(&self.key),
GenericArray::from_slice(&self.input),
);
(*wit).as_slice() == self.w.as_slice()
}
}
}
}
#[test]
fn aes_extended_witness_test() {
let database: Vec<AesExtendedWitness> = read_test_data("AesExtendedWitness.json");
for data in database {
assert!(data.test())
}
}
}