use ff::PrimeField;
pub(crate) use crate::hash::utils::{
assert_in_valid_spreaded_form, expr_pow2_ip, expr_pow4_ip, get_even_and_odd_bits,
negate_spreaded, spread, spread_table_from_lengths, u32_in_be_limbs, MASK_EVN_64,
};
const LOOKUP_LENGTHS: [u32; 10] = [2, 3, 4, 5, 6, 7, 9, 10, 11, 12];
pub(super) fn gen_spread_table<F: PrimeField>() -> impl Iterator<Item = (F, F, F)> {
spread_table_from_lengths(std::iter::once(0).chain(LOOKUP_LENGTHS))
}
pub(super) fn spreaded_maj(spreaded_forms: [u64; 3]) -> u64 {
spreaded_forms.into_iter().for_each(assert_in_valid_spreaded_form);
let [sA, sB, sC] = spreaded_forms;
sA + sB + sC
}
pub(super) fn spreaded_Sigma_0(spreaded_limbs: [u64; 4]) -> u64 {
spreaded_limbs.into_iter().for_each(assert_in_valid_spreaded_form);
let [sA_10, sA_09, sA_11, sA_02] = spreaded_limbs;
pow4_ip([30, 20, 11, 0], [sA_02, sA_10, sA_09, sA_11])
+ pow4_ip([21, 19, 9, 0], [sA_11, sA_02, sA_10, sA_09])
+ pow4_ip([23, 12, 10, 0], [sA_09, sA_11, sA_02, sA_10])
}
pub(super) fn spreaded_Sigma_1(spreaded_limbs: [u64; 5]) -> u64 {
spreaded_limbs.into_iter().for_each(assert_in_valid_spreaded_form);
let [sE_07, sE_12, sE_02, sE_05, sE_06] = spreaded_limbs;
pow4_ip([26, 19, 7, 5, 0], [sE_06, sE_07, sE_12, sE_02, sE_05])
+ pow4_ip([27, 21, 14, 2, 0], [sE_05, sE_06, sE_07, sE_12, sE_02])
+ pow4_ip([20, 18, 13, 7, 0], [sE_12, sE_02, sE_05, sE_06, sE_07])
}
pub(super) fn spreaded_sigma_0(spreaded_limbs: [u64; 8]) -> u64 {
spreaded_limbs.into_iter().for_each(assert_in_valid_spreaded_form);
let [sW_12, sW_1a, sW_1b, sW_1c, sW_07, sW_3a, sW_04, sW_3b] = spreaded_limbs;
pow4_ip(
[17, 16, 15, 14, 7, 4, 0],
[sW_12, sW_1a, sW_1b, sW_1c, sW_07, sW_3a, sW_04],
) + pow4_ip(
[28, 25, 13, 12, 11, 10, 3, 0],
[sW_04, sW_3b, sW_12, sW_1a, sW_1b, sW_1c, sW_07, sW_3a],
) + pow4_ip(
[31, 24, 21, 17, 14, 2, 1, 0],
[sW_1c, sW_07, sW_3a, sW_04, sW_3b, sW_12, sW_1a, sW_1b],
)
}
pub(super) fn spreaded_sigma_1(spreaded_limbs: [u64; 8]) -> u64 {
spreaded_limbs.into_iter().for_each(assert_in_valid_spreaded_form);
let [sW_12, sW_1a, sW_1b, sW_1c, sW_07, sW_3a, sW_04, sW_3b] = spreaded_limbs;
pow4_ip([10, 9, 8, 7, 0], [sW_12, sW_1a, sW_1b, sW_1c, sW_07])
+ pow4_ip(
[25, 22, 18, 15, 3, 2, 1, 0],
[sW_07, sW_3a, sW_04, sW_3b, sW_12, sW_1a, sW_1b, sW_1c],
)
+ pow4_ip(
[31, 30, 23, 20, 16, 13, 1, 0],
[sW_1b, sW_1c, sW_07, sW_3a, sW_04, sW_3b, sW_12, sW_1a],
)
}
fn pow4_ip<const N: usize>(exponents: [u8; N], terms: [u64; N]) -> u64 {
exponents.iter().zip(terms.iter()).map(|(e, t)| (1 << (2 * e)) * t).sum()
}
#[cfg(test)]
mod tests {
use rand::{seq::SliceRandom, Rng};
use super::*;
type F = midnight_curves::Fq;
#[test]
fn test_gen_spread_table() {
let table: Vec<_> = gen_spread_table::<F>().collect();
let mut rng = rand::thread_rng();
let to_fe = |(tag, plain, spreaded)| {
(
F::from(tag as u64),
F::from(plain as u64),
F::from(spreaded),
)
};
assert!(table.contains(&to_fe((0, 0, 0))));
for _ in 0..10 {
let tag = *LOOKUP_LENGTHS.choose(&mut rng).unwrap();
let plain = rng.gen_range(0..(1 << tag));
let spreaded = spread(plain);
let triple = to_fe((tag, plain, spreaded));
assert!(table.contains(&triple));
let random_triple = to_fe((rng.gen(), rng.gen(), rng.gen()));
assert!(!table.contains(&random_triple));
}
let tag = 16; let plain = rng.gen_range(0..(1 << tag));
let spreaded = spread(plain);
let triple = to_fe((tag, plain, spreaded));
assert!(!table.contains(&triple));
}
#[test]
fn test_spreaded_maj() {
fn assert_odd_of_spreaded_maj(vals: [u32; 3]) {
let [a, b, c] = vals;
let ret = (a & b) ^ (a & c) ^ (b & c);
let spreaded_forms: [u64; 3] = vals.map(spread);
let (_even, odd) = get_even_and_odd_bits(spreaded_maj(spreaded_forms));
assert_eq!(ret, odd);
}
let mut rng = rand::thread_rng();
for _ in 0..10 {
let vals: [u32; 3] = [rng.gen(), rng.gen(), rng.gen()];
assert_odd_of_spreaded_maj(vals);
}
}
#[test]
fn test_spreaded_Sigma_0() {
fn assert_even_of_spreaded_Sigma_0(val: u32) {
let rot_by_2 = val.rotate_right(2);
let rot_by_13 = val.rotate_right(13);
let rot_by_22 = val.rotate_right(22);
let ret = rot_by_2 ^ rot_by_13 ^ rot_by_22;
let plain_limbs: [u32; 4] = u32_in_be_limbs(val, [10, 9, 11, 2]);
let spreaded_limbs: [u64; 4] = plain_limbs.map(spread);
let (even, _) = get_even_and_odd_bits(spreaded_Sigma_0(spreaded_limbs));
assert_eq!(ret, even);
}
let mut rng = rand::thread_rng();
for _ in 0..10 {
assert_even_of_spreaded_Sigma_0(rng.gen());
}
}
#[test]
fn test_spreaded_Sigma_1() {
fn assert_even_of_spreaded_Sigma_1(val: u32) {
let rot_by_6 = val.rotate_right(6);
let rot_by_11 = val.rotate_right(11);
let rot_by_25 = val.rotate_right(25);
let ret = rot_by_6 ^ rot_by_11 ^ rot_by_25;
let plain_limbs: [u32; 5] = u32_in_be_limbs(val, [7, 12, 2, 5, 6]);
let spreaded_limbs: [u64; 5] = plain_limbs.map(spread);
let (even, _) = get_even_and_odd_bits(spreaded_Sigma_1(spreaded_limbs));
assert_eq!(ret, even);
}
let mut rng = rand::thread_rng();
for _ in 0..10 {
assert_even_of_spreaded_Sigma_1(rng.gen());
}
}
#[test]
fn test_spreaded_sigma_0() {
fn assert_even_of_spreaded_sigma_0(val: u32) {
let shifted_by_3 = val >> 3;
let rot_by_7 = val.rotate_right(7);
let rot_by_18 = val.rotate_right(18);
let ret = shifted_by_3 ^ rot_by_7 ^ rot_by_18;
let plain_limbs: [u32; 8] = u32_in_be_limbs(val, [12, 1, 1, 1, 7, 3, 4, 3]);
let spreaded_limbs: [u64; 8] = plain_limbs.map(spread);
let (even, _) = get_even_and_odd_bits(spreaded_sigma_0(spreaded_limbs));
assert_eq!(ret, even);
}
let mut rng = rand::thread_rng();
for _ in 0..10 {
assert_even_of_spreaded_sigma_0(rng.gen());
}
}
#[test]
fn test_spreaded_sigma_1() {
fn assert_even_of_spreaded_sigma_1(val: u32) {
let shifted_by_10 = val >> 10;
let rot_by_17 = val.rotate_right(17);
let rot_by_19 = val.rotate_right(19);
let ret = shifted_by_10 ^ rot_by_17 ^ rot_by_19;
let plain_limbs: [u32; 8] = u32_in_be_limbs(val, [12, 1, 1, 1, 7, 3, 4, 3]);
let spreaded_limbs: [u64; 8] = plain_limbs.map(spread);
let (even, _) = get_even_and_odd_bits(spreaded_sigma_1(spreaded_limbs));
assert_eq!(ret, even);
}
let mut rng = rand::thread_rng();
for _ in 0..10 {
assert_even_of_spreaded_sigma_1(rng.gen());
}
}
}