use p3_field::{Algebra, Field};
#[inline]
pub(crate) fn dit_butterfly<F: Field, A: Algebra<F>, const N: usize>(
values: &mut [A; N],
idx_1: usize,
idx_2: usize,
twiddle: F,
) {
let val_1 = values[idx_1].clone();
let val_2 = values[idx_2].clone() * twiddle;
values[idx_1] = val_1.clone() + val_2.clone();
values[idx_2] = val_1 - val_2;
}
#[inline]
pub(crate) fn dif_butterfly<F: Field, A: Algebra<F>, const N: usize>(
values: &mut [A; N],
idx_1: usize,
idx_2: usize,
twiddle: F,
) {
let val_1 = values[idx_1].clone();
let val_2 = values[idx_2].clone();
values[idx_1] = val_1.clone() + val_2.clone();
values[idx_2] = (val_1 - val_2) * twiddle;
}
#[inline]
pub(crate) fn twiddle_free_butterfly<F: Field, A: Algebra<F>, const N: usize>(
values: &mut [A; N],
idx_1: usize,
idx_2: usize,
) {
let val_1 = values[idx_1].clone();
let val_2 = values[idx_2].clone();
values[idx_1] = val_1.clone() + val_2.clone();
values[idx_2] = val_1 - val_2;
}
#[inline]
pub(crate) fn bowers_g_layer<F: Field, A: Algebra<F>, const N: usize>(
values: &mut [A; N],
log_half_block_size: usize,
twiddles: &[F],
) {
let log_block_size = log_half_block_size + 1;
let half_block_size = 1 << log_half_block_size;
let num_blocks = N >> log_block_size;
for hi in 0..half_block_size {
let lo = hi + half_block_size;
twiddle_free_butterfly(values, hi, lo);
}
for (block, &twiddle) in (1..num_blocks).zip(&twiddles[1..]) {
let block_start = block << log_block_size;
for hi in block_start..block_start + half_block_size {
let lo = hi + half_block_size;
dif_butterfly(values, hi, lo, twiddle);
}
}
}
#[inline]
pub(crate) fn bowers_g_t_layer<F: Field, A: Algebra<F>, const N: usize>(
values: &mut [A; N],
log_half_block_size: usize,
twiddles: &[F],
) {
let log_block_size = log_half_block_size + 1;
let half_block_size = 1 << log_half_block_size;
let num_blocks = N >> log_block_size;
for hi in 0..half_block_size {
let lo = hi + half_block_size;
twiddle_free_butterfly(values, hi, lo);
}
for (block, &twiddle) in (1..num_blocks).zip(&twiddles[1..]) {
let block_start = block << log_block_size;
for hi in block_start..block_start + half_block_size {
let lo = hi + half_block_size;
dit_butterfly(values, hi, lo, twiddle);
}
}
}
#[inline]
pub(crate) fn bowers_g_t_layer_integrated<F: Field, A: Algebra<F>, const N: usize>(
values: &mut [A; N],
log_half_block_size: usize,
twiddles: &[F],
) {
let log_block_size = log_half_block_size + 1;
let half_block_size = 1 << log_half_block_size;
let num_blocks = N >> log_block_size;
for (block, &twiddle) in (0..num_blocks).zip(twiddles) {
let block_start = block << log_block_size;
for hi in block_start..block_start + half_block_size {
let lo = hi + half_block_size;
dit_butterfly(values, hi, lo, twiddle);
}
}
}
#[cfg(test)]
mod tests {
use p3_baby_bear::BabyBear;
use p3_field::{PrimeCharacteristicRing, TwoAdicField};
use proptest::prelude::*;
use super::*;
type F = BabyBear;
fn arb_f() -> impl Strategy<Value = F> {
prop::num::u32::ANY.prop_map(F::from_u32)
}
#[test]
fn dit_butterfly_manual() {
let a = F::from_u32(7);
let b = F::from_u32(11);
let t = F::from_u32(3);
let mut vals = [a, b];
dit_butterfly::<F, F, 2>(&mut vals, 0, 1, t);
assert_eq!(vals[0], a + b * t);
assert_eq!(vals[1], a - b * t);
}
#[test]
fn dif_butterfly_manual() {
let a = F::from_u32(7);
let b = F::from_u32(11);
let t = F::from_u32(3);
let mut vals = [a, b];
dif_butterfly::<F, F, 2>(&mut vals, 0, 1, t);
assert_eq!(vals[0], a + b);
assert_eq!(vals[1], (a - b) * t);
}
#[test]
fn twiddle_free_butterfly_manual() {
let a = F::from_u32(7);
let b = F::from_u32(11);
let mut vals = [a, b];
twiddle_free_butterfly::<F, F, 2>(&mut vals, 0, 1);
assert_eq!(vals[0], a + b);
assert_eq!(vals[1], a - b);
}
#[test]
fn dit_with_twiddle_one_equals_twiddle_free() {
let a = F::from_u32(42);
let b = F::from_u32(99);
let mut vals_dit = [a, b];
dit_butterfly::<F, F, 2>(&mut vals_dit, 0, 1, F::ONE);
let mut vals_free = [a, b];
twiddle_free_butterfly::<F, F, 2>(&mut vals_free, 0, 1);
assert_eq!(vals_dit, vals_free);
}
#[test]
fn dif_with_twiddle_one_equals_twiddle_free() {
let a = F::from_u32(42);
let b = F::from_u32(99);
let mut vals_dif = [a, b];
dif_butterfly::<F, F, 2>(&mut vals_dif, 0, 1, F::ONE);
let mut vals_free = [a, b];
twiddle_free_butterfly::<F, F, 2>(&mut vals_free, 0, 1);
assert_eq!(vals_dif, vals_free);
}
#[test]
fn dit_preserves_trace() {
let a = F::from_u32(123);
let b = F::from_u32(456);
let t = F::from_u32(789);
let mut vals = [a, b];
dit_butterfly::<F, F, 2>(&mut vals, 0, 1, t);
assert_eq!(vals[0] + vals[1], a.double());
}
#[test]
fn dif_sum_property() {
let a = F::from_u32(123);
let b = F::from_u32(456);
let t = F::from_u32(789);
let mut vals = [a, b];
dif_butterfly::<F, F, 2>(&mut vals, 0, 1, t);
assert_eq!(vals[0], a + b);
}
#[test]
fn butterfly_on_non_adjacent_indices() {
let mut vals: [F; 4] = [10, 20, 30, 40].map(F::from_u32);
let original = vals;
let t = F::from_u32(5);
dit_butterfly::<F, F, 4>(&mut vals, 0, 3, t);
assert_eq!(vals[1], original[1]);
assert_eq!(vals[2], original[2]);
assert_eq!(vals[0], original[0] + original[3] * t);
assert_eq!(vals[3], original[0] - original[3] * t);
}
#[test]
fn dit_zero_twiddle() {
let a = F::from_u32(7);
let b = F::from_u32(11);
let mut vals = [a, b];
dit_butterfly::<F, F, 2>(&mut vals, 0, 1, F::ZERO);
assert_eq!(vals[0], a);
assert_eq!(vals[1], a);
}
#[test]
fn dif_zero_twiddle() {
let a = F::from_u32(7);
let b = F::from_u32(11);
let mut vals = [a, b];
dif_butterfly::<F, F, 2>(&mut vals, 0, 1, F::ZERO);
assert_eq!(vals[0], a + b);
assert_eq!(vals[1], F::ZERO);
}
#[test]
fn bowers_g_then_g_t_roundtrip_n4() {
let omega = F::two_adic_generator(2);
let twiddles = [F::ONE, omega];
let original: [F; 4] = [3, 7, 11, 13].map(F::from_u32);
let mut vals = original;
bowers_g_layer::<F, F, 4>(&mut vals, 0, &twiddles);
bowers_g_t_layer::<F, F, 4>(&mut vals, 0, &twiddles);
let mut vals2 = original;
bowers_g_layer::<F, F, 4>(&mut vals2, 0, &twiddles);
bowers_g_t_layer::<F, F, 4>(&mut vals2, 0, &twiddles);
assert_eq!(vals, vals2);
}
#[test]
fn integrated_matches_regular_with_unit_twiddles() {
let twiddles = [F::ONE; 4];
let original: [F; 8] = [1, 2, 3, 4, 5, 6, 7, 8].map(F::from_u32);
let mut vals_regular = original;
bowers_g_t_layer::<F, F, 8>(&mut vals_regular, 0, &twiddles);
let mut vals_integrated = original;
bowers_g_t_layer_integrated::<F, F, 8>(&mut vals_integrated, 0, &twiddles);
assert_eq!(vals_regular, vals_integrated);
}
#[test]
fn all_zeros_through_layers() {
let twiddles = [F::ONE, F::two_adic_generator(2)];
let mut vals = [F::ZERO; 4];
bowers_g_layer::<F, F, 4>(&mut vals, 0, &twiddles);
assert_eq!(vals, [F::ZERO; 4]);
let mut vals = [F::ZERO; 4];
bowers_g_t_layer::<F, F, 4>(&mut vals, 0, &twiddles);
assert_eq!(vals, [F::ZERO; 4]);
let mut vals = [F::ZERO; 4];
bowers_g_t_layer_integrated::<F, F, 4>(&mut vals, 0, &twiddles);
assert_eq!(vals, [F::ZERO; 4]);
}
proptest! {
#[test]
fn dit_is_linear(
a1 in arb_f(), b1 in arb_f(),
a2 in arb_f(), b2 in arb_f(),
t in arb_f(),
) {
let mut sum_then_dit = [a1 + a2, b1 + b2];
dit_butterfly::<F, F, 2>(&mut sum_then_dit, 0, 1, t);
let mut dit1 = [a1, b1];
dit_butterfly::<F, F, 2>(&mut dit1, 0, 1, t);
let mut dit2 = [a2, b2];
dit_butterfly::<F, F, 2>(&mut dit2, 0, 1, t);
prop_assert_eq!(sum_then_dit[0], dit1[0] + dit2[0]);
prop_assert_eq!(sum_then_dit[1], dit1[1] + dit2[1]);
}
#[test]
fn dif_is_linear(
a1 in arb_f(), b1 in arb_f(),
a2 in arb_f(), b2 in arb_f(),
t in arb_f(),
) {
let mut sum_then_dif = [a1 + a2, b1 + b2];
dif_butterfly::<F, F, 2>(&mut sum_then_dif, 0, 1, t);
let mut dif1 = [a1, b1];
dif_butterfly::<F, F, 2>(&mut dif1, 0, 1, t);
let mut dif2 = [a2, b2];
dif_butterfly::<F, F, 2>(&mut dif2, 0, 1, t);
prop_assert_eq!(sum_then_dif[0], dif1[0] + dif2[0]);
prop_assert_eq!(sum_then_dif[1], dif1[1] + dif2[1]);
}
#[test]
fn dit_twiddle_one_squared_is_double(a in arb_f(), b in arb_f()) {
let mut vals = [a, b];
dit_butterfly::<F, F, 2>(&mut vals, 0, 1, F::ONE);
dit_butterfly::<F, F, 2>(&mut vals, 0, 1, F::ONE);
prop_assert_eq!(vals[0], a.double());
prop_assert_eq!(vals[1], b.double());
}
#[test]
fn dif_twiddle_one_squared_is_double(a in arb_f(), b in arb_f()) {
let mut vals = [a, b];
dif_butterfly::<F, F, 2>(&mut vals, 0, 1, F::ONE);
dif_butterfly::<F, F, 2>(&mut vals, 0, 1, F::ONE);
prop_assert_eq!(vals[0], a.double());
prop_assert_eq!(vals[1], b.double());
}
#[test]
fn twiddle_free_matches_dit_and_dif_unit(a in arb_f(), b in arb_f()) {
let mut free = [a, b];
twiddle_free_butterfly::<F, F, 2>(&mut free, 0, 1);
let mut dit = [a, b];
dit_butterfly::<F, F, 2>(&mut dit, 0, 1, F::ONE);
let mut dif = [a, b];
dif_butterfly::<F, F, 2>(&mut dif, 0, 1, F::ONE);
prop_assert_eq!(free, dit);
prop_assert_eq!(free, dif);
}
#[test]
fn bowers_g_layer_is_linear(
v1 in prop::array::uniform4(arb_f()),
v2 in prop::array::uniform4(arb_f()),
) {
let twiddles = [F::ONE, F::two_adic_generator(2)];
let mut sum = core::array::from_fn::<F, 4, _>(|i| v1[i] + v2[i]);
bowers_g_layer::<F, F, 4>(&mut sum, 0, &twiddles);
let mut r1 = v1;
bowers_g_layer::<F, F, 4>(&mut r1, 0, &twiddles);
let mut r2 = v2;
bowers_g_layer::<F, F, 4>(&mut r2, 0, &twiddles);
for i in 0..4 {
prop_assert_eq!(sum[i], r1[i] + r2[i]);
}
}
#[test]
fn bowers_g_t_layer_is_linear(
v1 in prop::array::uniform4(arb_f()),
v2 in prop::array::uniform4(arb_f()),
) {
let twiddles = [F::ONE, F::two_adic_generator(2)];
let mut sum = core::array::from_fn::<F, 4, _>(|i| v1[i] + v2[i]);
bowers_g_t_layer::<F, F, 4>(&mut sum, 0, &twiddles);
let mut r1 = v1;
bowers_g_t_layer::<F, F, 4>(&mut r1, 0, &twiddles);
let mut r2 = v2;
bowers_g_t_layer::<F, F, 4>(&mut r2, 0, &twiddles);
for i in 0..4 {
prop_assert_eq!(sum[i], r1[i] + r2[i]);
}
}
#[test]
fn bowers_g_t_layer_integrated_is_linear(
v1 in prop::array::uniform4(arb_f()),
v2 in prop::array::uniform4(arb_f()),
) {
let twiddles = [F::ONE, F::two_adic_generator(2)];
let mut sum = core::array::from_fn::<F, 4, _>(|i| v1[i] + v2[i]);
bowers_g_t_layer_integrated::<F, F, 4>(&mut sum, 0, &twiddles);
let mut r1 = v1;
bowers_g_t_layer_integrated::<F, F, 4>(&mut r1, 0, &twiddles);
let mut r2 = v2;
bowers_g_t_layer_integrated::<F, F, 4>(&mut r2, 0, &twiddles);
for i in 0..4 {
prop_assert_eq!(sum[i], r1[i] + r2[i]);
}
}
}
}