use alloc::rc::Rc;
use alloc::vec::Vec;
use core::cell::RefCell;
use p3_commit::PolynomialSpace;
use p3_dft::divide_by_height;
use p3_field::extension::{Complex, ComplexExtendable};
use p3_field::{AbstractField, Field, PackedValue};
use p3_matrix::dense::{DenseMatrix, RowMajorMatrix};
use p3_matrix::Matrix;
use p3_maybe_rayon::prelude::*;
use p3_util::log2_strict_usize;
use tracing::instrument;
use crate::domain::CircleDomain;
use crate::twiddles::TwiddleCache;
#[derive(Default, Clone, Debug)]
pub struct Cfft<F: Field>(Rc<RefCell<TwiddleCache<F>>>);
impl<F: ComplexExtendable> Cfft<F> {
pub fn cfft(&self, vec: Vec<F>) -> Vec<F> {
self.cfft_batch(DenseMatrix::new_col(vec)).values
}
pub fn cfft_batch(&self, mat: DenseMatrix<F>) -> DenseMatrix<F> {
let log_n = log2_strict_usize(mat.height());
self.coset_cfft_batch(mat, F::circle_two_adic_generator(log_n + 1))
}
#[instrument(skip_all, fields(dims = %mat.dimensions()))]
pub fn coset_cfft_batch(&self, mut mat: DenseMatrix<F>, shift: Complex<F>) -> DenseMatrix<F> {
let n = mat.height();
let log_n = log2_strict_usize(n);
let mut cache = self.0.borrow_mut();
let twiddles = cache.get_twiddles(log_n, shift, true);
for (i, twiddle) in twiddles.iter().enumerate() {
let block_size = 1 << (log_n - i);
let half_block_size = block_size >> 1;
assert_eq!(twiddle.len(), half_block_size);
mat.par_row_chunks_exact_mut(block_size)
.for_each(|mut chunk| {
for (i, &t) in twiddle.iter().enumerate() {
let (lo, hi) = chunk.row_pair_mut(i, block_size - i - 1);
let (lo_packed, lo_suffix) = F::Packing::pack_slice_with_suffix_mut(lo);
let (hi_packed, hi_suffix) = F::Packing::pack_slice_with_suffix_mut(hi);
dif_butterfly(lo_packed, hi_packed, t.into());
dif_butterfly(lo_suffix, hi_suffix, t);
}
});
}
divide_by_height(&mut mat);
mat
}
pub fn icfft(&self, vec: Vec<F>) -> Vec<F> {
self.icfft_batch(RowMajorMatrix::new_col(vec)).values
}
pub fn icfft_batch(&self, mat: DenseMatrix<F>) -> DenseMatrix<F> {
let log_n = log2_strict_usize(mat.height());
self.coset_icfft_batch(mat, F::circle_two_adic_generator(log_n + 1))
}
#[instrument(skip_all, fields(dims = %mat.dimensions()))]
pub fn coset_icfft_batch(&self, mat: DenseMatrix<F>, shift: Complex<F>) -> DenseMatrix<F> {
self.coset_icfft_batch_skipping_first_layers(mat, shift, 0)
}
#[instrument(skip_all, fields(dims = %mat.dimensions()))]
fn coset_icfft_batch_skipping_first_layers(
&self,
mut mat: DenseMatrix<F>,
shift: Complex<F>,
num_skipped_layers: usize,
) -> DenseMatrix<F> {
let n = mat.height();
let log_n = log2_strict_usize(n);
let mut cache = self.0.borrow_mut();
let twiddles = cache.get_twiddles(log_n, shift, false);
for (i, twiddle) in twiddles.iter().rev().enumerate().skip(num_skipped_layers) {
let block_size = 1 << (i + 1);
let half_block_size = block_size >> 1;
assert_eq!(twiddle.len(), half_block_size);
mat.par_row_chunks_exact_mut(block_size)
.for_each(|mut chunk| {
for (i, &t) in twiddle.iter().enumerate() {
let (lo, hi) = chunk.row_pair_mut(i, block_size - i - 1);
let (lo_packed, lo_suffix) = F::Packing::pack_slice_with_suffix_mut(lo);
let (hi_packed, hi_suffix) = F::Packing::pack_slice_with_suffix_mut(hi);
dit_butterfly(lo_packed, hi_packed, t.into());
dit_butterfly(lo_suffix, hi_suffix, t);
}
});
}
mat
}
#[instrument(skip_all, fields(dims = %mat.dimensions()))]
pub fn lde(
&self,
mut mat: DenseMatrix<F>,
src_domain: CircleDomain<F>,
target_domain: CircleDomain<F>,
) -> DenseMatrix<F> {
assert_eq!(mat.height(), src_domain.size());
assert!(target_domain.size() >= src_domain.size());
let added_bits = target_domain.log_n - src_domain.log_n;
mat = self.coset_cfft_batch(mat, src_domain.shift);
let tiled_mat = tile_rows(mat, 1 << added_bits);
debug_assert_eq!(tiled_mat.height(), target_domain.size());
self.coset_icfft_batch_skipping_first_layers(tiled_mat, target_domain.shift, added_bits)
}
}
#[inline(always)]
fn dif_butterfly<F: AbstractField + Copy>(lo_chunk: &mut [F], hi_chunk: &mut [F], twiddle: F) {
for (lo, hi) in lo_chunk.iter_mut().zip(hi_chunk) {
let sum = *lo + *hi;
let diff = (*lo - *hi) * twiddle;
*lo = sum;
*hi = diff;
}
}
#[inline(always)]
fn dit_butterfly<F: AbstractField + Copy>(lo_chunk: &mut [F], hi_chunk: &mut [F], twiddle: F) {
for (lo, hi) in lo_chunk.iter_mut().zip(hi_chunk) {
let hi_twiddle = *hi * twiddle;
let sum = *lo + hi_twiddle;
let diff = *lo - hi_twiddle;
*lo = sum;
*hi = diff;
}
}
fn tile_rows<F: Field>(mat: impl Matrix<F>, repetitions: usize) -> RowMajorMatrix<F> {
let mut values = Vec::with_capacity(mat.width() * mat.height() * repetitions);
for r in 0..mat.height() {
let s = mat.row_slice(r);
for _ in 0..repetitions {
values.extend_from_slice(s.as_ref());
}
}
RowMajorMatrix::new(values, mat.width())
}
#[cfg(test)]
mod tests {
use itertools::Itertools;
use p3_dft::bit_reversed_zero_pad;
use p3_mersenne_31::Mersenne31;
use rand::{random, thread_rng};
use super::*;
use crate::util::{eval_circle_polys, univariate_to_point};
type F = Mersenne31;
fn do_test_cfft(log_n: usize) {
let n = 1 << log_n;
let cfft = Cfft::default();
let shift: Complex<F> = univariate_to_point(random()).unwrap();
let evals = RowMajorMatrix::<F>::rand(&mut thread_rng(), n, 1 << 5);
let coeffs = cfft.coset_cfft_batch(evals.clone(), shift);
assert_eq!(evals.clone(), cfft.coset_icfft_batch(coeffs.clone(), shift));
let d = CircleDomain { shift, log_n };
for (pt, ys) in d.points().zip(evals.rows()) {
assert_eq!(ys.collect_vec(), eval_circle_polys(&coeffs, pt));
}
}
#[test]
fn test_cfft() {
do_test_cfft(5);
do_test_cfft(8);
}
fn do_test_lde(log_n: usize, added_bits: usize) {
let n = 1 << log_n;
let cfft = Cfft::<F>::default();
let shift: Complex<F> = univariate_to_point(random()).unwrap();
let evals = RowMajorMatrix::<F>::rand(&mut thread_rng(), n, 1);
let src_domain = CircleDomain { log_n, shift };
let target_domain = CircleDomain::standard(log_n + added_bits);
let mut coeffs = cfft.coset_cfft_batch(evals.clone(), src_domain.shift);
bit_reversed_zero_pad(&mut coeffs, added_bits);
let expected = cfft.coset_icfft_batch(coeffs, target_domain.shift);
let actual = cfft.lde(evals, src_domain, target_domain);
assert_eq!(actual, expected);
}
#[test]
fn test_lde() {
do_test_lde(3, 1);
}
}