use core::array;
use p3_field::integers::QuotientMap;
use p3_field::{Field, PrimeCharacteristicRing};
use crate::AirBuilder;
#[inline]
pub fn pack_bits_le<R, Var, I>(iter: I) -> R
where
R: PrimeCharacteristicRing,
Var: Into<R> + Clone,
I: DoubleEndedIterator<Item = Var>,
{
iter.rev()
.map(Into::into)
.reduce(|acc, elem| acc.double() + elem)
.unwrap_or(R::ZERO)
}
#[inline(always)]
pub fn checked_xor<F: Field, const N: usize>(xs: &[F]) -> F {
xs.iter().fold(F::ZERO, |acc, x| {
debug_assert!(x.is_zero() || x.is_one());
acc.xor(x)
})
}
#[inline(always)]
pub fn checked_andn<F: Field>(x: F, y: F) -> F {
debug_assert!(x.is_zero() || x.is_one());
debug_assert!(y.is_zero() || y.is_one());
x.andn(&y)
}
#[inline]
pub fn u32_to_bits_le<R: PrimeCharacteristicRing>(val: u32) -> [R; 32] {
array::from_fn(|i| R::from_bool(val & (1 << i) != 0))
}
#[inline]
pub fn u64_to_bits_le<R: PrimeCharacteristicRing>(val: u64) -> [R; 64] {
array::from_fn(|i| R::from_bool(val & (1 << i) != 0))
}
#[inline]
pub fn u64_to_16_bit_limbs<R: PrimeCharacteristicRing>(val: u64) -> [R; 4] {
array::from_fn(|i| R::from_u16((val >> (16 * i)) as u16))
}
#[inline]
pub fn add3<AB: AirBuilder>(
builder: &mut AB,
a: &[AB::Var; 2],
b: &[AB::Var; 2],
c: &[AB::Expr; 2],
d: &[AB::Expr; 2],
) {
let two_16 =
<AB::Expr as PrimeCharacteristicRing>::PrimeSubfield::from_canonical_checked(1 << 16)
.unwrap();
let two_32 = two_16.square();
let acc_16 = a[0].clone() - b[0].clone() - c[0].clone() - d[0].clone();
let acc_32 = a[1].clone() - b[1].clone() - c[1].clone() - d[1].clone();
let acc = acc_16.clone() + acc_32.mul_2exp_u64(16);
builder.assert_zeros([
acc.clone()
* (acc.clone() + AB::Expr::from_prime_subfield(two_32))
* (acc + AB::Expr::from_prime_subfield(two_32.double())),
acc_16.clone()
* (acc_16.clone() + AB::Expr::from_prime_subfield(two_16))
* (acc_16 + AB::Expr::from_prime_subfield(two_16.double())),
]);
}
#[inline]
pub fn add2<AB: AirBuilder>(
builder: &mut AB,
a: &[AB::Var; 2],
b: &[AB::Var; 2],
c: &[AB::Expr; 2],
) {
let two_16 =
<AB::Expr as PrimeCharacteristicRing>::PrimeSubfield::from_canonical_checked(1 << 16)
.unwrap();
let two_32 = two_16.square();
let acc_16 = a[0].clone() - b[0].clone() - c[0].clone();
let acc_32 = a[1].clone() - b[1].clone() - c[1].clone();
let acc = acc_16.clone() + acc_32.mul_2exp_u64(16);
builder.assert_zeros([
acc.clone() * (acc + AB::Expr::from_prime_subfield(two_32)),
acc_16.clone() * (acc_16 + AB::Expr::from_prime_subfield(two_16)),
]);
}
#[inline]
pub fn xor_32_shift<AB: AirBuilder>(
builder: &mut AB,
a: &[AB::Var; 2],
b: &[AB::Var; 32],
c: &[AB::Var; 32],
shift: usize,
) {
builder.assert_bools(c.clone());
let xor_shift_c_0_16 = b[..16].iter().enumerate().map(|(i, elem)| {
(elem.clone())
.into()
.xor(&c[(32 + i - shift) % 32].clone().into())
});
let sum_0_16: AB::Expr = pack_bits_le(xor_shift_c_0_16);
let xor_shift_c_16_32 = b[16..].iter().enumerate().map(|(i, elem)| {
(elem.clone())
.into()
.xor(&c[(32 + (i + 16) - shift) % 32].clone().into())
});
let sum_16_32: AB::Expr = pack_bits_le(xor_shift_c_16_32);
builder.assert_zeros([a[0].clone() - sum_0_16, a[1].clone() - sum_16_32]);
}
#[cfg(test)]
mod tests {
use alloc::vec;
use p3_baby_bear::BabyBear;
use super::*;
type F = BabyBear;
#[test]
fn test_pack_bits_le_various_patterns() {
let bits = [F::ONE, F::ZERO, F::ONE];
let packed = pack_bits_le::<F, _, _>(bits.iter().cloned());
assert_eq!(packed, F::from_u8(5));
let bits = [F::ONE, F::ONE, F::ZERO, F::ONE];
let packed = pack_bits_le::<F, _, _>(bits.iter().cloned());
assert_eq!(packed, F::from_u8(11));
let bits = [F::ZERO; 5];
let packed = pack_bits_le::<F, _, _>(bits.iter().cloned());
assert_eq!(packed, F::ZERO);
let bits = [F::ZERO, F::ZERO, F::ZERO, F::ZERO, F::ONE];
let packed = pack_bits_le::<F, _, _>(bits.iter().cloned());
assert_eq!(packed, F::from_u8(16));
}
#[test]
fn test_checked_xor_multiple_cases() {
let bits = vec![F::ONE, F::ZERO, F::ONE];
let result = checked_xor::<F, 3>(&bits);
assert_eq!(result, F::ZERO);
let bits = vec![F::ONE, F::ONE, F::ONE];
let result = checked_xor::<F, 3>(&bits);
assert_eq!(result, F::ONE);
let bits = vec![F::ZERO, F::ZERO, F::ZERO];
let result = checked_xor::<F, 3>(&bits);
assert_eq!(result, F::ZERO);
let bits = vec![F::ONE, F::ZERO, F::ONE, F::ZERO];
let result = checked_xor::<F, 4>(&bits);
assert_eq!(result, F::ZERO);
}
#[test]
fn test_checked_andn() {
let result = checked_andn(F::ONE, F::ZERO);
assert_eq!(result, F::ZERO);
let result = checked_andn(F::ZERO, F::ONE);
assert_eq!(result, F::ONE);
let result = checked_andn(F::ZERO, F::ZERO);
assert_eq!(result, F::ZERO);
let result = checked_andn(F::ONE, F::ONE);
assert_eq!(result, F::ZERO);
}
#[test]
fn test_u32_to_bits_le() {
let bits = u32_to_bits_le::<F>(10);
assert_eq!(bits[0], F::ZERO); assert_eq!(bits[1], F::ONE);
assert_eq!(bits[2], F::ZERO);
assert_eq!(bits[3], F::ONE);
for &bit in &bits[4..] {
assert_eq!(bit, F::ZERO);
}
let bits = u32_to_bits_le::<F>(0);
assert!(bits.iter().all(|b| *b == F::ZERO));
let bits = u32_to_bits_le::<F>(u32::MAX);
assert!(bits.iter().all(|b| *b == F::ONE));
}
#[test]
fn test_u64_to_bits_le() {
let bits = u64_to_bits_le::<F>(3);
assert_eq!(bits[0], F::ONE);
assert_eq!(bits[1], F::ONE);
assert_eq!(bits[2], F::ZERO);
for &bit in &bits[3..] {
assert_eq!(bit, F::ZERO);
}
let bits = u64_to_bits_le::<F>(0);
assert!(bits.iter().all(|b| *b == F::ZERO));
let bits = u64_to_bits_le::<F>(u64::MAX);
assert!(bits.iter().all(|b| *b == F::ONE));
}
#[test]
fn test_u64_to_16_bit_limbs() {
let val: u64 = 0x123456789ABCDEF0;
let limbs = u64_to_16_bit_limbs::<F>(val);
assert_eq!(limbs[0], F::from_u16(0xDEF0));
assert_eq!(limbs[1], F::from_u16(0x9ABC));
assert_eq!(limbs[2], F::from_u16(0x5678));
assert_eq!(limbs[3], F::from_u16(0x1234));
assert_eq!(
limbs[0]
+ limbs[1].mul_2exp_u64(16)
+ limbs[2].mul_2exp_u64(32)
+ limbs[3].mul_2exp_u64(48),
F::from_u64(val)
);
let limbs = u64_to_16_bit_limbs::<F>(0);
assert!(limbs.iter().all(|l| *l == F::ZERO));
let limbs = u64_to_16_bit_limbs::<F>(u64::MAX);
for l in limbs {
assert_eq!(l, F::from_u64(0xFFFF));
}
let val: u64 = 0x1234;
let limbs = u64_to_16_bit_limbs::<F>(val);
assert_eq!(limbs[0], F::from_u64(0x1234));
assert!(limbs[1..].iter().all(|l| *l == F::ZERO));
}
}