use ff::PrimeField;
use midnight_proofs::plonk::Expression;
pub const MASK_EVN_64: u64 = 0x5555_5555_5555_5555;
pub(crate) const MASK_ODD_64: u64 = 0xAAAA_AAAA_AAAA_AAAA;
pub(crate) fn get_even_and_odd_bits(value: u64) -> (u32, u32) {
(compact_even(value), compact_even(value >> 1))
}
fn compact_even(mut x: u64) -> u32 {
x &= 0x5555_5555_5555_5555;
x = (x | (x >> 1)) & 0x3333_3333_3333_3333;
x = (x | (x >> 2)) & 0x0f0f_0f0f_0f0f_0f0f;
x = (x | (x >> 4)) & 0x00ff_00ff_00ff_00ff;
x = (x | (x >> 8)) & 0x0000_ffff_0000_ffff;
x = (x | (x >> 16)) & 0x0000_0000_ffff_ffff;
x as u32
}
pub(crate) fn assert_in_valid_spreaded_form(x: u64) {
assert_eq!(MASK_ODD_64 & x, 0, "Input must be in valid spreaded form")
}
pub(crate) fn spread(x: u32) -> u64 {
(0..32).fold(0u64, |acc, i| acc | (((x as u64 >> i) & 1) << (2 * i)))
}
pub(crate) fn negate_spreaded(x: u64) -> u64 {
assert_in_valid_spreaded_form(x);
x ^ MASK_EVN_64
}
pub(crate) fn u32_in_be_limbs<const N: usize>(value: u32, limb_lengths: [usize; N]) -> [u32; N] {
assert_eq!(limb_lengths.iter().sum::<usize>(), 32);
let mut result = [0u32; N];
let mut shift = 32;
for (i, &len) in limb_lengths.iter().enumerate() {
assert!(len != 0);
shift -= len;
result[i] = (value >> shift) & ((1 << len) - 1);
}
result
}
pub(crate) fn spread_table_from_lengths<F: PrimeField>(
lookup_lengths: impl IntoIterator<Item = u32>,
) -> impl Iterator<Item = (F, F, F)> {
lookup_lengths.into_iter().flat_map(|len| {
let tag = F::from(len as u64);
(0..(1u64 << len)).map(move |i| (tag, F::from(i), F::from(spread(i as u32))))
})
}
pub(crate) fn expr_pow2_ip<F: PrimeField, const N: usize>(
exponents: [u8; N],
terms: [&Expression<F>; N],
) -> Expression<F> {
let mut expr = Expression::Constant(F::ZERO);
for (pow, term) in exponents.into_iter().zip(terms.into_iter()) {
expr = expr + Expression::Constant(F::from(1 << pow)) * term.clone();
}
expr
}
pub(crate) fn expr_pow4_ip<F: PrimeField, const N: usize>(
exponents: [u8; N],
terms: [&Expression<F>; N],
) -> Expression<F> {
expr_pow2_ip(exponents.map(|l| 2 * l), terms)
}
#[cfg(test)]
mod tests {
use rand::Rng;
use super::*;
#[test]
fn test_get_even_and_odd_bits() {
[
(0, 0, 0),
(1, 1, 0),
(2, 0, 1),
(1 << 3, 0, 2),
(u64::MAX, 0xFFFF_FFFF, 0xFFFF_FFFF),
(MASK_EVN_64, 0xFFFF_FFFF, 0),
(MASK_ODD_64, 0, 0xFFFF_FFFF),
(0b110101101u64, 19, 14),
]
.into_iter()
.for_each(|(n, expected_even, expected_odd)| {
let (even, odd) = get_even_and_odd_bits(n);
assert_eq!(even, expected_even);
assert_eq!(odd, expected_odd);
});
}
#[test]
fn test_spread() {
[(0, 0), (1, 1), (0b10, 0b0100), (0b11, 0b0101)]
.into_iter()
.for_each(|(plain, spreaded)| assert_eq!(spread(plain), spreaded));
}
#[test]
fn test_negate_spreaded() {
assert_eq!(negate_spreaded(MASK_EVN_64), 0);
assert_eq!(negate_spreaded(1), MASK_EVN_64 - 1);
assert_ne!(negate_spreaded(0), 0);
}
#[test]
fn test_u32_in_be_limbs() {
[
(0x12345678u32, [8, 8, 8, 8], [0x12, 0x34, 0x56, 0x78]),
(0x12345678u32, [4, 8, 12, 8], [0x1, 0x23, 0x456, 0x78]),
]
.into_iter()
.for_each(|(value, limb_lengths, expected)| {
assert_eq!(u32_in_be_limbs(value, limb_lengths), expected)
});
let mut rng = rand::thread_rng();
let value: u32 = rng.gen();
let limb_lengths = [1; 32];
let result = u32_in_be_limbs(value, limb_lengths);
let expected: [u32; 32] = core::array::from_fn(|i| (value >> (31 - i)) & 1);
assert_eq!(result, expected);
}
}