pub(crate) use crate::hash::utils::{
expr_pow2_ip, expr_pow4_ip, get_even_and_odd_bits, negate_spreaded, spread,
spread_table_from_lengths, u32_in_be_limbs, MASK_EVN_64,
};
use crate::CircuitField;
const WORD: u8 = 32;
const MAX_LIMB: u8 = 11;
const LAST_LIMB: u8 = WORD % MAX_LIMB; pub(super) const NUM_LIMBS: usize = ((WORD - 1) / MAX_LIMB + 2) as usize;
pub(super) fn limb_lengths(rot: u8) -> ([u8; NUM_LIMBS], usize) {
assert!(rot > 0 && rot < 16);
let mut lengths = [MAX_LIMB; NUM_LIMBS];
lengths[NUM_LIMBS - 1] = LAST_LIMB;
let a = rot % MAX_LIMB;
let b = MAX_LIMB - a;
let k = (rot / MAX_LIMB + 1) as usize;
lengths[k - 1] = a;
lengths[k] = b;
(lengths, k)
}
pub(super) fn limb_coeffs(rot: u8) -> ([u32; NUM_LIMBS], [u32; NUM_LIMBS]) {
let compute_coeffs = |lengths: &[u8; NUM_LIMBS]| {
let mut acc = 1u32;
let mut res = [0u32; NUM_LIMBS];
for (i, &len) in lengths.iter().rev().enumerate() {
res[i] = acc;
acc = acc.wrapping_shl(len as u32);
}
res.reverse();
res
};
let (mut limb_lengths, k) = limb_lengths(rot);
let coeffs = compute_coeffs(&limb_lengths);
limb_lengths.rotate_left(k);
let mut coeffs_rot = compute_coeffs(&limb_lengths);
coeffs_rot.rotate_right(k);
(coeffs, coeffs_rot)
}
pub(super) fn limb_values(value: u32, rot: u8) -> [u32; NUM_LIMBS] {
let (limb_lengths, _) = limb_lengths(rot);
let mut result = [0u32; NUM_LIMBS];
let mut shift = WORD;
for (i, &len) in limb_lengths.iter().enumerate() {
if len == 0 {
result[i] = 0;
} else {
shift -= len;
result[i] = (value >> shift) & ((1 << len) - 1);
}
}
result
}
pub(super) fn gen_spread_table<F: CircuitField>() -> impl Iterator<Item = (F, F, F)> {
spread_table_from_lengths(0..=11)
}
#[cfg(test)]
mod tests {
use rand::Rng;
use super::*;
type F = midnight_curves::Fq;
#[test]
fn test_limb_lengths() {
for rot in 1..16 {
let (lengths, k) = limb_lengths(rot);
let sum: u8 = lengths.iter().sum();
assert_eq!(
sum, WORD,
"Sum of lengths does not equal WORD={} for rot={}",
WORD, rot
);
let expected_rot = lengths.iter().take(k).sum::<u8>();
assert_eq!(
expected_rot, rot,
"Sum of the first k = {} lengths does not equal rot={}",
k, rot
);
}
}
#[test]
fn test_decomposition_and_rotation() {
for rot in 1..16 {
let mut rng = rand::thread_rng();
let val: u32 = rng.gen();
let (coeffs, coeffs_rot) = limb_coeffs(rot);
let limbs = limb_values(val, rot);
let res = limbs.iter().zip(coeffs.iter()).fold(0u32, |acc, (&limb, &coeff)| {
acc.wrapping_add(limb.wrapping_mul(coeff))
});
assert_eq!(val, res, "Failed reconstruction for rot={}", rot);
let rot_val = val.rotate_left(rot as u32);
let rot_res = limbs.iter().zip(coeffs_rot.iter()).fold(0u32, |acc, (&limb, &coeff)| {
acc.wrapping_add(limb.wrapping_mul(coeff))
});
assert_eq!(
rot_val, rot_res,
"Failed rotation reconstruction for rot={}",
rot
);
}
}
#[test]
fn test_type_one() {
fn assert_even_of_spreaded_type_one(vals: [u32; 3]) {
let [a, b, c] = vals;
let ret = a ^ b ^ c;
let [a_sprdd, b_sprdd, c_sprdd]: [u64; 3] = vals.map(spread);
let (even, _odd) = get_even_and_odd_bits(a_sprdd + b_sprdd + c_sprdd);
assert_eq!(ret, even);
}
let mut rng = rand::thread_rng();
for _ in 0..10 {
let vals: [u32; 3] = [rng.gen(), rng.gen(), rng.gen()];
assert_even_of_spreaded_type_one(vals);
}
}
#[test]
fn test_type_two() {
fn assert_type_two(vals: [u32; 3]) {
let [a, b, c] = vals;
let ret = (a & b) | ((!a) & c);
let expected_ret = (a & b) ^ ((!a) & c);
assert_eq!(ret, expected_ret);
}
let mut rng = rand::thread_rng();
for _ in 0..10 {
let vals: [u32; 3] = [rng.gen(), rng.gen(), rng.gen()];
assert_type_two(vals);
}
}
#[test]
fn test_type_three() {
fn assert_type_three(vals: [u32; 3]) {
let [a, b, c] = vals;
let ret = (a | (!b)) ^ c;
let expected_ret = (a ^ (!b) ^ c) ^ (a & (!b));
assert_eq!(ret, expected_ret);
}
let mut rng = rand::thread_rng();
for _ in 0..10 {
let vals: [u32; 3] = [rng.gen(), rng.gen(), rng.gen()];
assert_type_three(vals);
}
}
#[test]
fn test_gen_spread_table() {
let table: Vec<_> = gen_spread_table::<F>().collect();
let mut rng = rand::thread_rng();
let to_fe = |(tag, plain, spreaded)| {
(
F::from(tag as u64),
F::from(plain as u64),
F::from(spreaded),
)
};
assert!(table.contains(&to_fe((0, 0, 0))));
for _ in 0..10 {
let tag = rng.gen_range(0..=11);
let plain = rng.gen_range(0..(1 << tag));
let spreaded = spread(plain);
let triple = to_fe((tag, plain, spreaded));
assert!(table.contains(&triple));
let random_triple = to_fe((rng.gen(), rng.gen(), rng.gen()));
assert!(!table.contains(&random_triple));
}
let tag = 12; let plain = rng.gen_range(0..(1 << tag));
let spreaded = spread(plain);
let triple = to_fe((tag, plain, spreaded));
assert!(!table.contains(&triple));
}
}