use std::num::NonZeroUsize;
use std::ops::MulAssign;
use std::sync::OnceLock;
use num_traits::ConstOne;
use super::b_field_element::BFieldElement;
use super::traits::FiniteField;
use super::traits::Inverse;
use super::traits::ModPowU32;
use super::traits::PrimitiveRootOfUnity;
const NUM_DOMAINS: usize = {
#[cfg(target_pointer_width = "16")]
compile_error!("pointer width 16 is not supported");
#[cfg(target_pointer_width = "32")]
{
29 }
#[cfg(target_pointer_width = "64")]
{
32 }
};
pub fn ntt<FF>(x: &mut [FF])
where
FF: FiniteField + MulAssign<BFieldElement>,
{
static ALL_TWIDDLE_FACTORS: [OnceLock<Vec<Vec<BFieldElement>>>; NUM_DOMAINS] =
[const { OnceLock::new() }; NUM_DOMAINS];
let slice_len = slice_len(x);
let twiddle_factors = ALL_TWIDDLE_FACTORS[slice_len.checked_ilog2().unwrap_or(0) as usize]
.get_or_init(|| {
let omega = BFieldElement::primitive_root_of_unity(u64::from(slice_len)).unwrap();
twiddle_factors(slice_len, omega)
});
ntt_unchecked(x, twiddle_factors);
}
pub fn intt<FF>(x: &mut [FF])
where
FF: FiniteField + MulAssign<BFieldElement>,
{
static ALL_TWIDDLE_FACTORS: [OnceLock<Vec<Vec<BFieldElement>>>; NUM_DOMAINS] =
[const { OnceLock::new() }; NUM_DOMAINS];
let slice_len = slice_len(x);
let twiddle_factors = ALL_TWIDDLE_FACTORS[slice_len.checked_ilog2().unwrap_or(0) as usize]
.get_or_init(|| {
let omega = BFieldElement::primitive_root_of_unity(u64::from(slice_len)).unwrap();
twiddle_factors(slice_len, omega.inverse())
});
ntt_unchecked(x, twiddle_factors);
unscale(x);
}
fn slice_len<FF>(x: &[FF]) -> u32 {
let slice_len = u32::try_from(x.len()).expect("slice should be no longer than u32::MAX");
assert!(slice_len == 0 || slice_len.is_power_of_two());
slice_len
}
#[expect(clippy::many_single_char_names)]
#[inline]
fn ntt_unchecked<FF>(x: &mut [FF], twiddle_factors: &[Vec<BFieldElement>])
where
FF: FiniteField + MulAssign<BFieldElement>,
{
static ALL_SWAP_INDICES: [OnceLock<Vec<Option<NonZeroUsize>>>; NUM_DOMAINS] =
[const { OnceLock::new() }; NUM_DOMAINS];
let slice_len = x.len();
let Some(log2_slice_len) = slice_len.checked_ilog2() else {
return;
};
let swap_indices =
ALL_SWAP_INDICES[log2_slice_len as usize].get_or_init(|| swap_indices(slice_len));
debug_assert_eq!(swap_indices.len(), slice_len);
for (k, maybe_rev_k) in swap_indices.iter().enumerate() {
if let Some(rev_k) = maybe_rev_k {
x.swap(k, rev_k.get());
}
}
let slice_len = slice_len as u32;
let mut m = 1;
for twiddles in twiddle_factors {
let mut k = 0;
while k < slice_len {
for j in 0..m {
let idx1 = (k + j) as usize;
let idx2 = (k + j + m) as usize;
let u = x[idx1];
let mut v = x[idx2];
v *= twiddles[j as usize];
x[idx1] = u + v;
x[idx2] = u - v;
}
k += 2 * m;
}
m *= 2;
}
}
#[inline]
fn unscale<FF>(array: &mut [FF])
where
FF: FiniteField + MulAssign<BFieldElement>,
{
let n_inv = BFieldElement::from(array.len()).inverse_or_zero();
for elem in array {
*elem *= n_inv;
}
}
#[doc(hidden)]
pub fn swap_indices(len: usize) -> Vec<Option<NonZeroUsize>> {
#[inline(always)]
const fn bitreverse(mut k: u32, log2_n: u32) -> u32 {
k = ((k & 0x55555555) << 1) | ((k & 0xaaaaaaaa) >> 1);
k = ((k & 0x33333333) << 2) | ((k & 0xcccccccc) >> 2);
k = ((k & 0x0f0f0f0f) << 4) | ((k & 0xf0f0f0f0) >> 4);
k = ((k & 0x00ff00ff) << 8) | ((k & 0xff00ff00) >> 8);
k = k.rotate_right(16);
k >> ((32 - log2_n) & 0x1f)
}
let log_2_len = len.checked_ilog2().unwrap_or(0);
(0..len)
.map(|k| {
let rev_k = bitreverse(k as u32, log_2_len);
((k as u32) < rev_k).then(|| NonZeroUsize::new(rev_k as usize).unwrap())
})
.collect()
}
#[doc(hidden)]
pub fn twiddle_factors(slice_len: u32, root_of_unity: BFieldElement) -> Vec<Vec<BFieldElement>> {
(0..slice_len.checked_ilog2().unwrap_or(0))
.map(|i| {
let m = 1 << i;
let exponent = slice_len / (2 * m);
let w_m = root_of_unity.mod_pow_u32(exponent);
let mut w_powers = vec![BFieldElement::ONE; m as usize];
for j in 1..m as usize {
w_powers[j] = w_powers[j - 1] * w_m;
}
w_powers
})
.collect()
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use itertools::Itertools;
use num_traits::ConstZero;
use num_traits::Zero;
use proptest::collection::vec;
use proptest::prelude::*;
use proptest_arbitrary_adapter::arb;
use super::*;
use crate::math::other::random_elements;
use crate::math::traits::PrimitiveRootOfUnity;
use crate::math::x_field_element::EXTENSION_DEGREE;
use crate::prelude::*;
use crate::tests::proptest;
use crate::tests::test;
use crate::xfe;
#[macro_rules_attr::apply(test)]
fn chu_ntt_b_field_prop_test() {
for log_2_n in 1..10 {
let n = 1 << log_2_n;
for _ in 0..10 {
let mut values = random_elements(n);
let original_values = values.clone();
ntt::<BFieldElement>(&mut values);
assert_ne!(original_values, values);
intt::<BFieldElement>(&mut values);
assert_eq!(original_values, values);
values[0] = bfe!(BFieldElement::MAX);
let original_values_with_max_element = values.clone();
ntt::<BFieldElement>(&mut values);
assert_ne!(original_values, values);
intt::<BFieldElement>(&mut values);
assert_eq!(original_values_with_max_element, values);
}
}
}
#[macro_rules_attr::apply(test)]
fn chu_ntt_x_field_prop_test() {
for log_2_n in 1..10 {
let n = 1 << log_2_n;
for _ in 0..10 {
let mut values = random_elements(n);
let original_values = values.clone();
ntt::<XFieldElement>(&mut values);
assert_ne!(original_values, values);
intt::<XFieldElement>(&mut values);
assert_eq!(original_values, values);
assert!(
!original_values[1].coefficients[1].is_zero()
|| !original_values[1].coefficients[2].is_zero()
);
values[0] = xfe!([BFieldElement::MAX; EXTENSION_DEGREE]);
let original_values_with_max_element = values.clone();
ntt::<XFieldElement>(&mut values);
assert_ne!(original_values, values);
intt::<XFieldElement>(&mut values);
assert_eq!(original_values_with_max_element, values);
}
}
}
#[macro_rules_attr::apply(test)]
fn xfield_basic_test_of_chu_ntt() {
let mut input_output = vec![
XFieldElement::new_const(BFieldElement::ONE),
XFieldElement::new_const(BFieldElement::ZERO),
XFieldElement::new_const(BFieldElement::ZERO),
XFieldElement::new_const(BFieldElement::ZERO),
];
let original_input = input_output.clone();
let expected = vec![
XFieldElement::new_const(BFieldElement::ONE),
XFieldElement::new_const(BFieldElement::ONE),
XFieldElement::new_const(BFieldElement::ONE),
XFieldElement::new_const(BFieldElement::ONE),
];
println!("input_output = {input_output:?}");
ntt::<XFieldElement>(&mut input_output);
assert_eq!(expected, input_output);
println!("input_output = {input_output:?}");
intt::<XFieldElement>(&mut input_output);
assert_eq!(original_input, input_output);
}
#[macro_rules_attr::apply(test)]
fn bfield_basic_test_of_chu_ntt() {
let mut input_output = vec![
BFieldElement::new(1),
BFieldElement::new(4),
BFieldElement::new(0),
BFieldElement::new(0),
];
let original_input = input_output.clone();
let expected = vec![
BFieldElement::new(5),
BFieldElement::new(1125899906842625),
BFieldElement::new(18446744069414584318),
BFieldElement::new(18445618169507741698),
];
ntt::<BFieldElement>(&mut input_output);
assert_eq!(expected, input_output);
intt::<BFieldElement>(&mut input_output);
assert_eq!(original_input, input_output);
}
#[macro_rules_attr::apply(test)]
fn bfield_max_value_test_of_chu_ntt() {
let mut input_output = vec![
BFieldElement::new(BFieldElement::MAX),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
];
let original_input = input_output.clone();
let expected = vec![
BFieldElement::new(BFieldElement::MAX),
BFieldElement::new(BFieldElement::MAX),
BFieldElement::new(BFieldElement::MAX),
BFieldElement::new(BFieldElement::MAX),
];
ntt::<BFieldElement>(&mut input_output);
assert_eq!(expected, input_output);
intt::<BFieldElement>(&mut input_output);
assert_eq!(original_input, input_output);
}
#[macro_rules_attr::apply(test)]
fn ntt_on_empty_input() {
let mut input_output = vec![];
let original_input = input_output.clone();
ntt::<BFieldElement>(&mut input_output);
assert_eq!(0, input_output.len());
intt::<BFieldElement>(&mut input_output);
assert_eq!(original_input, input_output);
}
#[macro_rules_attr::apply(proptest(cases = 10))]
fn ntt_on_input_of_length_one(bfe: BFieldElement) {
let mut test_vector = vec![bfe];
ntt(&mut test_vector);
assert_eq!(vec![bfe], test_vector);
}
#[macro_rules_attr::apply(test)]
fn ntt_on_input_of_length_0_then_1_then_0() {
let mut empty = Vec::<BFieldElement>::new();
ntt(&mut empty);
ntt(&mut [BFieldElement::new(0)]);
ntt(&mut empty);
}
#[macro_rules_attr::apply(proptest(cases = 10))]
fn ntt_then_intt_is_identity_operation(
#[strategy((0_usize..18).prop_map(|l| 1 << l))] _vector_length: usize,
#[strategy(vec(arb(), #_vector_length))] mut input: Vec<BFieldElement>,
) {
let original_input = input.clone();
ntt::<BFieldElement>(&mut input);
intt::<BFieldElement>(&mut input);
assert_eq!(original_input, input);
}
#[macro_rules_attr::apply(test)]
fn b_field_ntt_with_length_32() {
let mut input_output = bfe_vec![
1, 4, 0, 0, 0, 0, 0, 0, 1, 4, 0, 0, 0, 0, 0, 0, 1, 4, 0, 0, 0, 0, 0, 0, 1, 4, 0, 0, 0,
0, 0, 0,
];
let original_input = input_output.clone();
ntt::<BFieldElement>(&mut input_output);
println!("actual_output = {input_output:?}");
let expected = bfe_vec![
20,
0,
0,
0,
18446744069146148869_u64,
0,
0,
0,
4503599627370500_u64,
0,
0,
0,
18446726477228544005_u64,
0,
0,
0,
18446744069414584309_u64,
0,
0,
0,
268435460,
0,
0,
0,
18442240469787213829_u64,
0,
0,
0,
17592186040324_u64,
0,
0,
0,
];
assert_eq!(expected, input_output);
intt::<BFieldElement>(&mut input_output);
assert_eq!(original_input, input_output);
}
#[macro_rules_attr::apply(test)]
fn test_compare_ntt_to_eval() {
for log_size in 1..10 {
let size = 1 << log_size;
let mut coefficients = random_elements(size);
let polynomial = Polynomial::new(coefficients.clone());
let omega = BFieldElement::primitive_root_of_unity(size.try_into().unwrap()).unwrap();
ntt(&mut coefficients);
let evals = (0..size)
.map(|i| omega.mod_pow(i.try_into().unwrap()))
.map(|p| polynomial.evaluate_in_same_field(p))
.collect_vec();
assert_eq!(evals, coefficients);
}
}
#[macro_rules_attr::apply(test)]
fn swap_indices_can_be_computed() {
for log_size in 0..NUM_DOMAINS - 2 {
swap_indices(1 << log_size);
}
}
#[macro_rules_attr::apply(test)]
fn twiddle_factors_can_be_computed() {
for log_size in 0..NUM_DOMAINS - 5 {
let size = 1 << log_size;
let root = BFieldElement::primitive_root_of_unity(size.into()).unwrap();
twiddle_factors(size, root);
}
}
}