p3_circle/
cfft.rs

1use alloc::vec;
2use alloc::vec::Vec;
3
4use itertools::{iterate, izip, Itertools};
5use p3_commit::PolynomialSpace;
6use p3_dft::{divide_by_height, Butterfly, DifButterfly, DitButterfly};
7use p3_field::extension::ComplexExtendable;
8use p3_field::{batch_multiplicative_inverse, ExtensionField, Field};
9use p3_matrix::dense::RowMajorMatrix;
10use p3_matrix::Matrix;
11use p3_maybe_rayon::prelude::*;
12use p3_util::{log2_ceil_usize, log2_strict_usize, reverse_slice_index_bits};
13use tracing::{debug_span, instrument};
14
15use crate::domain::CircleDomain;
16use crate::point::Point;
17use crate::{cfft_permute_index, cfft_permute_slice, CfftPermutable, CfftView};
18
19#[derive(Clone)]
20pub struct CircleEvaluations<F, M = RowMajorMatrix<F>> {
21    pub(crate) domain: CircleDomain<F>,
22    pub(crate) values: M,
23}
24
25impl<F: Copy + Send + Sync, M: Matrix<F>> CircleEvaluations<F, M> {
26    pub(crate) fn from_cfft_order(domain: CircleDomain<F>, values: M) -> Self {
27        assert_eq!(1 << domain.log_n, values.height());
28        Self { domain, values }
29    }
30    pub fn from_natural_order(
31        domain: CircleDomain<F>,
32        values: M,
33    ) -> CircleEvaluations<F, CfftView<M>> {
34        CircleEvaluations::from_cfft_order(domain, values.cfft_perm_rows())
35    }
36    pub fn to_cfft_order(self) -> M {
37        self.values
38    }
39    pub fn to_natural_order(self) -> CfftView<M> {
40        self.values.cfft_perm_rows()
41    }
42}
43
44impl<F: ComplexExtendable, M: Matrix<F>> CircleEvaluations<F, M> {
45    #[instrument(skip_all, fields(dims = %self.values.dimensions()))]
46    pub fn interpolate(self) -> RowMajorMatrix<F> {
47        let CircleEvaluations { domain, values } = self;
48        let mut values = debug_span!("to_rmm").in_scope(|| values.to_row_major_matrix());
49
50        let mut twiddles = debug_span!("twiddles").in_scope(|| {
51            compute_twiddles(domain)
52                .into_iter()
53                .map(|ts| {
54                    batch_multiplicative_inverse(&ts)
55                        .into_iter()
56                        .map(|t| DifButterfly(t))
57                        .collect_vec()
58                })
59                .peekable()
60        });
61
62        assert_eq!(twiddles.len(), domain.log_n);
63
64        let par_twiddles = twiddles
65            .peeking_take_while(|ts| ts.len() >= desired_num_jobs())
66            .collect_vec();
67        if let Some(min_blks) = par_twiddles.last().map(|ts| ts.len()) {
68            let max_blk_sz = values.height() / min_blks;
69            debug_span!("par_layers", log_min_blks = log2_strict_usize(min_blks)).in_scope(|| {
70                values
71                    .par_row_chunks_exact_mut(max_blk_sz)
72                    .enumerate()
73                    .for_each(|(chunk_i, submat)| {
74                        for ts in &par_twiddles {
75                            let twiddle_chunk_sz = ts.len() / min_blks;
76                            let twiddle_chunk = &ts
77                                [(twiddle_chunk_sz * chunk_i)..(twiddle_chunk_sz * (chunk_i + 1))];
78                            serial_layer(submat.values, twiddle_chunk);
79                        }
80                    });
81            });
82        }
83
84        for ts in twiddles {
85            par_within_blk_layer(&mut values.values, &ts);
86        }
87
88        // TODO: omit this?
89        divide_by_height(&mut values);
90        values
91    }
92
93    #[instrument(skip_all, fields(dims = %self.values.dimensions()))]
94    pub fn extrapolate(
95        self,
96        target_domain: CircleDomain<F>,
97    ) -> CircleEvaluations<F, RowMajorMatrix<F>> {
98        assert!(target_domain.log_n >= self.domain.log_n);
99        CircleEvaluations::<F>::evaluate(target_domain, self.interpolate())
100    }
101
102    pub fn evaluate_at_point<EF: ExtensionField<F>>(&self, point: Point<EF>) -> Vec<EF> {
103        let lagrange_num = self.domain.zeroifier(point);
104        let lagrange_den = cfft_permute_slice(&self.domain.points().collect_vec())
105            .into_iter()
106            .map(|p| p.v_tilde_p(point) * p.s_p_at_p(self.domain.log_n))
107            .collect_vec();
108        self.values
109            .columnwise_dot_product(&batch_multiplicative_inverse(&lagrange_den))
110            .into_iter()
111            .map(|x| x * lagrange_num)
112            .collect_vec()
113    }
114
115    #[cfg(test)]
116    pub(crate) fn dim(&self) -> usize
117    where
118        M: Clone,
119    {
120        let coeffs = self.clone().interpolate();
121        for (i, mut row) in coeffs.rows().enumerate() {
122            if row.all(|x| x.is_zero()) {
123                return i;
124            }
125        }
126        coeffs.height()
127    }
128}
129
130impl<F: ComplexExtendable> CircleEvaluations<F, RowMajorMatrix<F>> {
131    #[instrument(skip_all, fields(dims = %coeffs.dimensions()))]
132    pub fn evaluate(domain: CircleDomain<F>, mut coeffs: RowMajorMatrix<F>) -> Self {
133        let log_n = log2_strict_usize(coeffs.height());
134        assert!(log_n <= domain.log_n);
135
136        if log_n < domain.log_n {
137            // We could simply pad coeffs like this:
138            // coeffs.pad_to_height(target_domain.size(), F::zero());
139            // But the first `added_bits` layers will simply fill out the zeros
140            // with the lower order values. (In `DitButterfly`, `x_2` is 0, so
141            // both `x_1` and `x_2` are set to `x_1`).
142            // So instead we directly repeat the coeffs and skip the initial layers.
143            debug_span!("extend coeffs").in_scope(|| {
144                coeffs.values.reserve(domain.size() * coeffs.width());
145                for _ in log_n..domain.log_n {
146                    coeffs.values.extend_from_within(..);
147                }
148            });
149        }
150        assert_eq!(coeffs.height(), 1 << domain.log_n);
151
152        let mut twiddles = debug_span!("twiddles").in_scope(|| {
153            compute_twiddles(domain)
154                .into_iter()
155                .map(|ts| ts.into_iter().map(|t| DitButterfly(t)).collect_vec())
156                .rev()
157                .skip(domain.log_n - log_n)
158                .peekable()
159        });
160
161        for ts in twiddles.peeking_take_while(|ts| ts.len() < desired_num_jobs()) {
162            par_within_blk_layer(&mut coeffs.values, &ts);
163        }
164
165        let par_twiddles = twiddles.collect_vec();
166        if let Some(min_blks) = par_twiddles.first().map(|ts| ts.len()) {
167            let max_blk_sz = coeffs.height() / min_blks;
168            debug_span!("par_layers", log_min_blks = log2_strict_usize(min_blks)).in_scope(|| {
169                coeffs
170                    .par_row_chunks_exact_mut(max_blk_sz)
171                    .enumerate()
172                    .for_each(|(chunk_i, submat)| {
173                        for ts in &par_twiddles {
174                            let twiddle_chunk_sz = ts.len() / min_blks;
175                            let twiddle_chunk = &ts
176                                [(twiddle_chunk_sz * chunk_i)..(twiddle_chunk_sz * (chunk_i + 1))];
177                            serial_layer(submat.values, twiddle_chunk);
178                        }
179                    });
180            });
181        }
182
183        Self::from_cfft_order(domain, coeffs)
184    }
185}
186
187#[inline]
188fn serial_layer<F: Field, B: Butterfly<F>>(values: &mut [F], twiddles: &[B]) {
189    let blk_sz = values.len() / twiddles.len();
190    for (&t, blk) in izip!(twiddles, values.chunks_exact_mut(blk_sz)) {
191        let (lo, hi) = blk.split_at_mut(blk_sz / 2);
192        t.apply_to_rows(lo, hi);
193    }
194}
195
196#[inline]
197#[instrument(level = "debug", skip_all, fields(log_blks = log2_strict_usize(twiddles.len())))]
198fn par_within_blk_layer<F: Field, B: Butterfly<F>>(values: &mut [F], twiddles: &[B]) {
199    let blk_sz = values.len() / twiddles.len();
200    for (&t, blk) in izip!(twiddles, values.chunks_exact_mut(blk_sz)) {
201        let (lo, hi) = blk.split_at_mut(blk_sz / 2);
202        let job_sz = core::cmp::max(1, lo.len() >> log2_ceil_usize(desired_num_jobs()));
203        lo.par_chunks_mut(job_sz)
204            .zip(hi.par_chunks_mut(job_sz))
205            .for_each(|(lo_job, hi_job)| t.apply_to_rows(lo_job, hi_job));
206    }
207}
208
209#[inline]
210fn desired_num_jobs() -> usize {
211    16 * current_num_threads()
212}
213
214impl<F: ComplexExtendable> CircleDomain<F> {
215    pub(crate) fn y_twiddles(&self) -> Vec<F> {
216        let mut ys = self.coset0().map(|p| p.y).collect_vec();
217        reverse_slice_index_bits(&mut ys);
218        ys
219    }
220    pub(crate) fn nth_y_twiddle(&self, index: usize) -> F {
221        self.nth_point(cfft_permute_index(index << 1, self.log_n)).y
222    }
223    pub(crate) fn x_twiddles(&self, layer: usize) -> Vec<F> {
224        let gen = self.gen() * (1 << layer);
225        let shift = self.shift * (1 << layer);
226        let mut xs = iterate(shift, move |&p| p + gen)
227            .map(|p| p.x)
228            .take(1 << (self.log_n - layer - 2))
229            .collect_vec();
230        reverse_slice_index_bits(&mut xs);
231        xs
232    }
233    pub(crate) fn nth_x_twiddle(&self, index: usize) -> F {
234        (self.shift + self.gen() * index).x
235    }
236}
237
238fn compute_twiddles<F: ComplexExtendable>(domain: CircleDomain<F>) -> Vec<Vec<F>> {
239    assert!(domain.log_n >= 1);
240    let mut pts = domain.coset0().collect_vec();
241    reverse_slice_index_bits(&mut pts);
242    let mut twiddles = vec![pts.iter().map(|p| p.y).collect_vec()];
243    if domain.log_n >= 2 {
244        twiddles.push(pts.iter().step_by(2).map(|p| p.x).collect_vec());
245        for i in 0..(domain.log_n - 2) {
246            let prev = twiddles.last().unwrap();
247            assert_eq!(prev.len(), 1 << (domain.log_n - 2 - i));
248            let cur = prev
249                .iter()
250                .step_by(2)
251                .map(|x| x.square().double() - F::one())
252                .collect_vec();
253            twiddles.push(cur);
254        }
255    }
256    twiddles
257}
258
259pub fn circle_basis<F: Field>(p: Point<F>, log_n: usize) -> Vec<F> {
260    let mut b = vec![F::one(), p.y];
261    let mut x = p.x;
262    for _ in 0..(log_n - 1) {
263        for i in 0..b.len() {
264            b.push(b[i] * x);
265        }
266        x = x.square().double() - F::one();
267    }
268    assert_eq!(b.len(), 1 << log_n);
269    b
270}
271
272#[cfg(test)]
273mod tests {
274    use itertools::iproduct;
275    use p3_field::extension::BinomialExtensionField;
276    use p3_mersenne_31::Mersenne31;
277    use rand::{random, thread_rng};
278
279    use super::*;
280
281    type F = Mersenne31;
282    type EF = BinomialExtensionField<F, 3>;
283
284    #[test]
285    fn test_cfft_icfft() {
286        for (log_n, width) in iproduct!(2..5, [1, 4, 11]) {
287            let shift = Point::generator(F::CIRCLE_TWO_ADICITY) * random();
288            let domain = CircleDomain::<F>::new(log_n, shift);
289            let trace = RowMajorMatrix::<F>::rand(&mut thread_rng(), 1 << log_n, width);
290            let coeffs = CircleEvaluations::from_natural_order(domain, trace.clone()).interpolate();
291            assert_eq!(
292                CircleEvaluations::evaluate(domain, coeffs.clone())
293                    .to_natural_order()
294                    .to_row_major_matrix(),
295                trace,
296                "icfft(cfft(evals)) is identity",
297            );
298            for (i, pt) in domain.points().enumerate() {
299                assert_eq!(
300                    &*trace.row_slice(i),
301                    coeffs.columnwise_dot_product(&circle_basis(pt, log_n)),
302                    "coeffs can be evaluated with circle_basis",
303                );
304            }
305        }
306    }
307
308    #[test]
309    fn test_extrapolation() {
310        for (log_n, log_blowup) in iproduct!(2..5, [1, 2, 3]) {
311            let evals = CircleEvaluations::<F>::from_natural_order(
312                CircleDomain::standard(log_n),
313                RowMajorMatrix::rand(&mut thread_rng(), 1 << log_n, 11),
314            );
315            let lde = evals
316                .clone()
317                .extrapolate(CircleDomain::standard(log_n + log_blowup));
318
319            let coeffs = evals.interpolate();
320            let lde_coeffs = lde.interpolate();
321
322            for r in 0..coeffs.height() {
323                assert_eq!(&*coeffs.row_slice(r), &*lde_coeffs.row_slice(r));
324            }
325            for r in coeffs.height()..lde_coeffs.height() {
326                assert!(lde_coeffs.row(r).all(|x| x.is_zero()));
327            }
328        }
329    }
330
331    #[test]
332    fn eval_at_point_matches_cfft() {
333        for (log_n, width) in iproduct!(2..5, [1, 4, 11]) {
334            let evals = CircleEvaluations::<F>::from_natural_order(
335                CircleDomain::standard(log_n),
336                RowMajorMatrix::rand(&mut thread_rng(), 1 << log_n, width),
337            );
338
339            let pt = Point::<EF>::from_projective_line(random());
340
341            assert_eq!(
342                evals.clone().evaluate_at_point(pt),
343                evals
344                    .interpolate()
345                    .columnwise_dot_product(&circle_basis(pt, log_n))
346            );
347        }
348    }
349
350    #[test]
351    fn eval_at_point_matches_lde() {
352        for (log_n, width, log_blowup) in iproduct!(2..8, [1, 4, 11], [1, 2]) {
353            let evals = CircleEvaluations::<F>::from_natural_order(
354                CircleDomain::standard(log_n),
355                RowMajorMatrix::rand(&mut thread_rng(), 1 << log_n, width),
356            );
357            let lde = evals
358                .clone()
359                .extrapolate(CircleDomain::standard(log_n + log_blowup));
360            let zeta = Point::<EF>::from_projective_line(random());
361            assert_eq!(evals.evaluate_at_point(zeta), lde.evaluate_at_point(zeta));
362            assert_eq!(
363                evals.evaluate_at_point(zeta),
364                evals
365                    .interpolate()
366                    .columnwise_dot_product(&circle_basis(zeta, log_n))
367            );
368            assert_eq!(
369                lde.evaluate_at_point(zeta),
370                lde.interpolate()
371                    .columnwise_dot_product(&circle_basis(zeta, log_n + log_blowup))
372            );
373        }
374    }
375}