use std::ops::MulAssign;
use num_traits::Zero;
use rand_distr::num_traits::One;
use crate::shared_math::traits::{FiniteField, ModPowU32};
use super::{
b_field_element::BFieldElement,
traits::{Inverse, New},
};
#[allow(clippy::many_single_char_names)]
pub fn ntt<FF: FiniteField + MulAssign<BFieldElement>>(
x: &mut [FF],
omega: BFieldElement,
log_2_of_n: u32,
) {
let n = x.len() as u32;
debug_assert_eq!(n, 1 << log_2_of_n, "2^log2(n) == n");
debug_assert!(
omega.mod_pow_u32(n).is_one(),
"Got {omega} which is not a {n}th root of 1"
);
debug_assert!(!omega.mod_pow_u32(n / 2).is_one());
for k in 0..n {
let rk = bitreverse(k, log_2_of_n);
if k < rk {
x.swap(rk as usize, k as usize);
}
}
let mut m = 1;
for _ in 0..log_2_of_n {
let w_m = omega.mod_pow_u32(n / (2 * m));
let mut k = 0;
while k < n {
let mut w = BFieldElement::one();
for j in 0..m {
let u = x[(k + j) as usize];
let mut v = x[(k + j + m) as usize];
v *= w;
x[(k + j) as usize] = u + v;
x[(k + j + m) as usize] = u - v;
w *= w_m;
}
k += 2 * m;
}
m *= 2;
}
}
pub fn intt<FF: FiniteField + MulAssign<BFieldElement>>(
x: &mut [FF],
omega: BFieldElement,
log_2_of_n: u32,
) {
let n: BFieldElement = omega.new_from_usize(x.len());
let n_inv: BFieldElement = BFieldElement::one() / n;
ntt::<FF>(x, omega.inverse(), log_2_of_n);
for elem in x.iter_mut() {
*elem *= n_inv
}
}
#[inline]
fn bitreverse_usize(mut n: usize, l: usize) -> usize {
let mut r = 0;
for _ in 0..l {
r = (r << 1) | (n & 1);
n >>= 1;
}
r
}
pub fn bitreverse_order<FF>(array: &mut [FF]) {
let mut logn = 0;
while (1 << logn) < array.len() {
logn += 1;
}
for k in 0..array.len() {
let rk = bitreverse_usize(k, logn);
if k < rk {
array.swap(rk, k);
}
}
}
pub fn ntt_noswap<FF: FiniteField + MulAssign<BFieldElement>>(x: &mut [FF], omega: BFieldElement) {
let n: usize = x.len();
debug_assert_eq!(n & (n - 1), 0);
debug_assert!(
omega.mod_pow_u32(n as u32).is_one(),
"Got {omega} which is not a {n}th root of 1"
);
debug_assert!(!omega.mod_pow_u32((n / 2).try_into().unwrap()).is_one());
let mut logn: usize = 0;
while (1 << logn) < x.len() {
logn += 1;
}
let mut powers_of_omega_bitreversed = vec![BFieldElement::zero(); n];
let mut omegai = BFieldElement::one();
for i in 0..n / 2 {
powers_of_omega_bitreversed[bitreverse_usize(i, logn - 1)] = omegai;
omegai *= omega;
}
let mut m: usize = 1;
let mut t: usize = n;
while m < n {
t >>= 1;
for (i, zeta) in powers_of_omega_bitreversed.iter().enumerate().take(m) {
let s = i * t * 2;
for j in s..(s + t) {
let u = x[j];
let mut v = x[j + t];
v *= *zeta;
x[j] = u + v;
x[j + t] = u - v;
}
}
m *= 2;
}
}
pub fn intt_noswap<FF: FiniteField + MulAssign<BFieldElement>>(x: &mut [FF], omega: BFieldElement) {
let n = x.len();
let omega_inverse = omega.inverse();
debug_assert_eq!(n & (n - 1), 0, "array length must be power of 2");
debug_assert!(
omega_inverse.mod_pow_u32(n.try_into().unwrap()).is_one(),
"Got {omega_inverse} which is not a {n}th root of 1"
);
debug_assert!(!omega_inverse
.mod_pow_u32((n / 2).try_into().unwrap())
.is_one());
let mut logn: usize = 0;
while (1 << logn) < x.len() {
logn += 1;
}
let mut m = 1;
for _ in 0..logn {
let w_m = omega_inverse.mod_pow_u32((n / (2 * m)).try_into().unwrap());
let mut k = 0;
while k < n {
let mut w = BFieldElement::one();
for j in 0..m {
let u = x[k + j];
let mut v = x[k + j + m];
v *= w;
x[k + j] = u + v;
x[k + j + m] = u - v;
w *= w_m;
}
k += 2 * m;
}
m *= 2;
}
}
#[inline]
fn bitreverse(mut n: u32, l: u32) -> u32 {
let mut r = 0;
for _ in 0..l {
r = (r << 1) | (n & 1);
n >>= 1;
}
r
}
#[cfg(test)]
mod fast_ntt_attempt_tests {
use itertools::Itertools;
use num_traits::{One, Zero};
use crate::shared_math::b_field_element::BFieldElement;
use crate::shared_math::other::random_elements;
use crate::shared_math::polynomial::Polynomial;
use crate::shared_math::traits::PrimitiveRootOfUnity;
use crate::shared_math::x_field_element::XFieldElement;
use super::*;
#[test]
fn chu_ntt_b_field_prop_test() {
for log_2_n in 1..10 {
let n = 1 << log_2_n;
for _ in 0..10 {
let mut values = random_elements(n);
let original_values = values.clone();
let omega = BFieldElement::primitive_root_of_unity(n as u64).unwrap();
ntt::<BFieldElement>(&mut values, omega, log_2_n);
assert_ne!(original_values, values);
intt::<BFieldElement>(&mut values, omega, log_2_n);
assert_eq!(original_values, values);
values[0] = BFieldElement::new(BFieldElement::MAX);
let original_values_with_max_element = values.clone();
ntt::<BFieldElement>(&mut values, omega, log_2_n);
assert_ne!(original_values, values);
intt::<BFieldElement>(&mut values, omega, log_2_n);
assert_eq!(original_values_with_max_element, values);
}
}
}
#[test]
fn chu_ntt_x_field_prop_test() {
for log_2_n in 1..10 {
let n = 1 << log_2_n;
for _ in 0..10 {
let mut values = random_elements(n);
let original_values = values.clone();
let omega = XFieldElement::primitive_root_of_unity(n as u64).unwrap();
ntt::<XFieldElement>(&mut values, omega.unlift().unwrap(), log_2_n);
assert_ne!(original_values, values);
intt::<XFieldElement>(&mut values, omega.unlift().unwrap(), log_2_n);
assert_eq!(original_values, values);
assert!(
!original_values[1].coefficients[1].is_zero()
|| !original_values[1].coefficients[2].is_zero()
);
values[0] = XFieldElement::new([
BFieldElement::new(BFieldElement::MAX),
BFieldElement::new(BFieldElement::MAX),
BFieldElement::new(BFieldElement::MAX),
]);
let original_values_with_max_element = values.clone();
ntt::<XFieldElement>(&mut values, omega.unlift().unwrap(), log_2_n);
assert_ne!(original_values, values);
intt::<XFieldElement>(&mut values, omega.unlift().unwrap(), log_2_n);
assert_eq!(original_values_with_max_element, values);
}
}
}
#[test]
fn xfield_basic_test_of_chu_ntt() {
let mut input_output = vec![
XFieldElement::new_const(BFieldElement::one()),
XFieldElement::new_const(BFieldElement::zero()),
XFieldElement::new_const(BFieldElement::zero()),
XFieldElement::new_const(BFieldElement::zero()),
];
let original_input = input_output.clone();
let expected = vec![
XFieldElement::new_const(BFieldElement::one()),
XFieldElement::new_const(BFieldElement::one()),
XFieldElement::new_const(BFieldElement::one()),
XFieldElement::new_const(BFieldElement::one()),
];
let omega = XFieldElement::primitive_root_of_unity(4).unwrap();
println!("input_output = {input_output:?}");
ntt::<XFieldElement>(&mut input_output, omega.unlift().unwrap(), 2);
assert_eq!(expected, input_output);
println!("input_output = {input_output:?}");
intt::<XFieldElement>(&mut input_output, omega.unlift().unwrap(), 2);
assert_eq!(original_input, input_output);
}
#[test]
fn bfield_basic_test_of_chu_ntt() {
let mut input_output = vec![
BFieldElement::new(1),
BFieldElement::new(4),
BFieldElement::new(0),
BFieldElement::new(0),
];
let original_input = input_output.clone();
let expected = vec![
BFieldElement::new(5),
BFieldElement::new(1125899906842625),
BFieldElement::new(18446744069414584318),
BFieldElement::new(18445618169507741698),
];
let omega = BFieldElement::primitive_root_of_unity(4).unwrap();
ntt::<BFieldElement>(&mut input_output, omega, 2);
assert_eq!(expected, input_output);
intt::<BFieldElement>(&mut input_output, omega, 2);
assert_eq!(original_input, input_output);
}
#[test]
fn bfield_max_value_test_of_chu_ntt() {
let mut input_output = vec![
BFieldElement::new(BFieldElement::MAX),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
];
let original_input = input_output.clone();
let expected = vec![
BFieldElement::new(BFieldElement::MAX),
BFieldElement::new(BFieldElement::MAX),
BFieldElement::new(BFieldElement::MAX),
BFieldElement::new(BFieldElement::MAX),
];
let omega = BFieldElement::primitive_root_of_unity(4).unwrap();
ntt::<BFieldElement>(&mut input_output, omega, 2);
assert_eq!(expected, input_output);
intt::<BFieldElement>(&mut input_output, omega, 2);
assert_eq!(original_input, input_output);
}
#[test]
fn b_field_ntt_with_length_32() {
let mut input_output = vec![
BFieldElement::new(1),
BFieldElement::new(4),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(1),
BFieldElement::new(4),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(1),
BFieldElement::new(4),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(1),
BFieldElement::new(4),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
];
let original_input = input_output.clone();
let omega = BFieldElement::primitive_root_of_unity(32).unwrap();
ntt::<BFieldElement>(&mut input_output, omega, 5);
println!("actual_output = {input_output:?}");
let expected = vec![
BFieldElement::new(20),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(18446744069146148869),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(4503599627370500),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(18446726477228544005),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(18446744069414584309),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(268435460),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(18442240469787213829),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(17592186040324),
BFieldElement::new(0),
BFieldElement::new(0),
BFieldElement::new(0),
];
assert_eq!(expected, input_output);
intt::<BFieldElement>(&mut input_output, omega, 5);
assert_eq!(original_input, input_output);
}
#[test]
pub fn test_compare_ntt_to_eval() {
for log_size in 1..10 {
let size = 1 << log_size;
let mut array: Vec<BFieldElement> = random_elements(size);
let polynomial = Polynomial::new(array.to_vec());
let omega = BFieldElement::primitive_root_of_unity(size.try_into().unwrap()).unwrap();
ntt(&mut array, omega, log_size.try_into().unwrap());
let evals = (0..size)
.map(|i| omega.mod_pow(i.try_into().unwrap()))
.map(|p| polynomial.evaluate(&p))
.collect_vec();
assert_eq!(evals, array);
}
}
#[test]
fn test_ntt_noswap() {
for log_size in 1..8 {
let size = 1 << log_size;
println!("size: {size}");
let a: Vec<BFieldElement> = random_elements(size);
let omega = BFieldElement::primitive_root_of_unity(size.try_into().unwrap()).unwrap();
let mut a1 = a.clone();
ntt(&mut a1, omega, log_size);
let mut a2 = a.clone();
ntt_noswap(&mut a2, omega);
bitreverse_order(&mut a2);
assert_eq!(a1, a2);
intt(&mut a1, omega, log_size);
bitreverse_order(&mut a2);
intt_noswap(&mut a2, omega);
for a2e in a2.iter_mut() {
*a2e *= BFieldElement::new(size.try_into().unwrap()).inverse();
}
assert_eq!(a1, a2);
}
}
}