use alloc::vec::Vec;
use core::cmp::{max, min};
use plonky2_util::{log2_strict, reverse_index_bits_in_place};
use unroll::unroll_for_loops;
use crate::packable::Packable;
use crate::packed::PackedField;
use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::types::Field;
pub type FftRootTable<F> = Vec<Vec<F>>;
pub fn fft_root_table<F: Field>(n: usize) -> FftRootTable<F> {
let lg_n = log2_strict(n);
let mut bases = Vec::with_capacity(lg_n);
let mut base = F::primitive_root_of_unity(lg_n);
bases.push(base);
for _ in 1..lg_n {
base = base.square(); bases.push(base);
}
let mut root_table = Vec::with_capacity(lg_n);
for lg_m in 1..=lg_n {
let half_m = 1 << (lg_m - 1);
let base = bases[lg_n - lg_m];
let root_row = base.powers().take(half_m.max(2)).collect();
root_table.push(root_row);
}
root_table
}
#[inline]
fn fft_dispatch<F: Field>(
input: &mut [F],
zero_factor: Option<usize>,
root_table: Option<&FftRootTable<F>>,
) {
let computed_root_table = if root_table.is_some() {
None
} else {
Some(fft_root_table(input.len()))
};
let used_root_table = root_table.or(computed_root_table.as_ref()).unwrap();
fft_classic(input, zero_factor.unwrap_or(0), used_root_table);
}
#[inline]
pub fn fft<F: Field>(poly: PolynomialCoeffs<F>) -> PolynomialValues<F> {
fft_with_options(poly, None, None)
}
#[inline]
pub fn fft_with_options<F: Field>(
poly: PolynomialCoeffs<F>,
zero_factor: Option<usize>,
root_table: Option<&FftRootTable<F>>,
) -> PolynomialValues<F> {
let PolynomialCoeffs { coeffs: mut buffer } = poly;
fft_dispatch(&mut buffer, zero_factor, root_table);
PolynomialValues::new(buffer)
}
#[inline]
pub fn ifft<F: Field>(poly: PolynomialValues<F>) -> PolynomialCoeffs<F> {
ifft_with_options(poly, None, None)
}
pub fn ifft_with_options<F: Field>(
poly: PolynomialValues<F>,
zero_factor: Option<usize>,
root_table: Option<&FftRootTable<F>>,
) -> PolynomialCoeffs<F> {
let n = poly.len();
let lg_n = log2_strict(n);
let n_inv = F::inverse_2exp(lg_n);
let PolynomialValues { values: mut buffer } = poly;
fft_dispatch(&mut buffer, zero_factor, root_table);
buffer[0] *= n_inv;
buffer[n / 2] *= n_inv;
for i in 1..(n / 2) {
let j = n - i;
let coeffs_i = buffer[j] * n_inv;
let coeffs_j = buffer[i] * n_inv;
buffer[i] = coeffs_i;
buffer[j] = coeffs_j;
}
PolynomialCoeffs { coeffs: buffer }
}
#[unroll_for_loops]
fn fft_classic_simd<P: PackedField>(
values: &mut [P::Scalar],
r: usize,
lg_n: usize,
root_table: &FftRootTable<P::Scalar>,
) {
let lg_packed_width = log2_strict(P::WIDTH); let packed_values = P::pack_slice_mut(values);
let packed_n = packed_values.len();
debug_assert!(packed_n == 1 << (lg_n - lg_packed_width));
assert!(lg_packed_width <= 4);
for lg_half_m in 0..4 {
if (r..min(lg_n, lg_packed_width)).contains(&lg_half_m) {
let half_m = 1 << lg_half_m;
let mut omega = P::default();
for (j, omega_j) in omega.as_slice_mut().iter_mut().enumerate() {
*omega_j = root_table[lg_half_m][j % half_m];
}
for k in (0..packed_n).step_by(2) {
let (u, v) = packed_values[k].interleave(packed_values[k + 1], half_m);
let t = omega * v;
(packed_values[k], packed_values[k + 1]) = (u + t).interleave(u - t, half_m);
}
}
}
let s = max(r, lg_packed_width);
for lg_half_m in s..lg_n {
let lg_m = lg_half_m + 1;
let m = 1 << lg_m; let packed_m = m >> lg_packed_width; let half_packed_m = packed_m / 2;
debug_assert!(half_packed_m != 0);
let omega_table = P::pack_slice(&root_table[lg_half_m][..]);
for k in (0..packed_n).step_by(packed_m) {
for j in 0..half_packed_m {
let omega = omega_table[j];
let t = omega * packed_values[k + half_packed_m + j];
let u = packed_values[k + j];
packed_values[k + j] = u + t;
packed_values[k + half_packed_m + j] = u - t;
}
}
}
}
pub(crate) fn fft_classic<F: Field>(values: &mut [F], r: usize, root_table: &FftRootTable<F>) {
reverse_index_bits_in_place(values);
let n = values.len();
let lg_n = log2_strict(n);
if root_table.len() != lg_n {
panic!(
"Expected root table of length {}, but it was {}.",
lg_n,
root_table.len()
);
}
if r > 0 {
let mask = !((1 << r) - 1);
for i in 0..n {
values[i] = values[i & mask];
}
}
let lg_packed_width = log2_strict(<F as Packable>::Packing::WIDTH);
if lg_n <= lg_packed_width {
fft_classic_simd::<F>(values, r, lg_n, root_table);
} else {
fft_classic_simd::<<F as Packable>::Packing>(values, r, lg_n, root_table);
}
}
#[cfg(test)]
mod tests {
use alloc::vec::Vec;
use plonky2_util::{log2_ceil, log2_strict};
use crate::fft::{fft, fft_with_options, ifft};
use crate::goldilocks_field::GoldilocksField;
use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::types::Field;
#[test]
fn fft_and_ifft() {
type F = GoldilocksField;
let degree = 200usize;
let degree_padded = degree.next_power_of_two();
let coeffs = (0..degree)
.map(|i| F::from_canonical_usize(i * 1337 % 100))
.chain(core::iter::repeat(F::ZERO).take(degree_padded - degree))
.collect::<Vec<_>>();
assert_eq!(coeffs.len(), degree_padded);
let coefficients = PolynomialCoeffs { coeffs };
let points = fft(coefficients.clone());
assert_eq!(points, evaluate_naive(&coefficients));
let interpolated_coefficients = ifft(points);
for i in 0..degree {
assert_eq!(interpolated_coefficients.coeffs[i], coefficients.coeffs[i]);
}
for i in degree..degree_padded {
assert_eq!(interpolated_coefficients.coeffs[i], F::ZERO);
}
for r in 0..4 {
let zero_tail = coefficients.lde(r);
assert_eq!(
fft(zero_tail.clone()),
fft_with_options(zero_tail, Some(r), None)
);
}
}
fn evaluate_naive<F: Field>(coefficients: &PolynomialCoeffs<F>) -> PolynomialValues<F> {
let degree = coefficients.len();
let degree_padded = 1 << log2_ceil(degree);
let coefficients_padded = coefficients.padded(degree_padded);
evaluate_naive_power_of_2(&coefficients_padded)
}
fn evaluate_naive_power_of_2<F: Field>(
coefficients: &PolynomialCoeffs<F>,
) -> PolynomialValues<F> {
let degree = coefficients.len();
let degree_log = log2_strict(degree);
let subgroup = F::two_adic_subgroup(degree_log);
let values = subgroup
.into_iter()
.map(|x| evaluate_at_naive(coefficients, x))
.collect();
PolynomialValues::new(values)
}
fn evaluate_at_naive<F: Field>(coefficients: &PolynomialCoeffs<F>, point: F) -> F {
let mut sum = F::ZERO;
let mut point_power = F::ONE;
for &c in &coefficients.coeffs {
sum += c * point_power;
point_power *= point;
}
sum
}
}