p3_circle/
cfft.rs

1use alloc::vec;
2use alloc::vec::Vec;
3
4use itertools::{Itertools, iterate, izip};
5use p3_commit::PolynomialSpace;
6use p3_dft::{Butterfly, DifButterfly, DitButterfly, divide_by_height};
7use p3_field::extension::ComplexExtendable;
8use p3_field::{ExtensionField, Field, batch_multiplicative_inverse};
9use p3_matrix::Matrix;
10use p3_matrix::dense::RowMajorMatrix;
11use p3_maybe_rayon::prelude::*;
12use p3_util::{log2_ceil_usize, log2_strict_usize, reverse_slice_index_bits};
13use tracing::{debug_span, instrument};
14
15use crate::domain::CircleDomain;
16use crate::point::{Point, compute_lagrange_den_batched};
17use crate::{CfftPermutable, CfftView, cfft_permute_index, cfft_permute_slice};
18
19#[derive(Clone)]
20pub struct CircleEvaluations<F, M = RowMajorMatrix<F>> {
21    pub(crate) domain: CircleDomain<F>,
22    pub(crate) values: M,
23}
24
25impl<F: Copy + Send + Sync, M: Matrix<F>> CircleEvaluations<F, M> {
26    pub(crate) fn from_cfft_order(domain: CircleDomain<F>, values: M) -> Self {
27        assert_eq!(1 << domain.log_n, values.height());
28        Self { domain, values }
29    }
30    pub fn from_natural_order(
31        domain: CircleDomain<F>,
32        values: M,
33    ) -> CircleEvaluations<F, CfftView<M>> {
34        CircleEvaluations::from_cfft_order(domain, values.cfft_perm_rows())
35    }
36    pub fn to_cfft_order(self) -> M {
37        self.values
38    }
39    pub fn to_natural_order(self) -> CfftView<M> {
40        self.values.cfft_perm_rows()
41    }
42}
43
44impl<F: ComplexExtendable, M: Matrix<F>> CircleEvaluations<F, M> {
45    #[instrument(skip_all, fields(dims = %self.values.dimensions()))]
46    pub fn interpolate(self) -> RowMajorMatrix<F> {
47        let Self { domain, values } = self;
48        let mut values = debug_span!("to_rmm").in_scope(|| values.to_row_major_matrix());
49
50        let mut twiddles = debug_span!("twiddles").in_scope(|| {
51            compute_twiddles(domain)
52                .into_iter()
53                .map(|ts| {
54                    batch_multiplicative_inverse(&ts)
55                        .into_iter()
56                        .map(|t| DifButterfly(t))
57                        .collect_vec()
58                })
59                .peekable()
60        });
61
62        assert_eq!(twiddles.len(), domain.log_n);
63
64        let par_twiddles = twiddles
65            .peeking_take_while(|ts| ts.len() >= desired_num_jobs())
66            .collect_vec();
67        if let Some(min_blks) = par_twiddles.last().map(|ts| ts.len()) {
68            let max_blk_sz = values.height() / min_blks;
69            debug_span!("par_layers", log_min_blks = log2_strict_usize(min_blks)).in_scope(|| {
70                values
71                    .par_row_chunks_exact_mut(max_blk_sz)
72                    .enumerate()
73                    .for_each(|(chunk_i, submat)| {
74                        for ts in &par_twiddles {
75                            let twiddle_chunk_sz = ts.len() / min_blks;
76                            let twiddle_chunk = &ts
77                                [(twiddle_chunk_sz * chunk_i)..(twiddle_chunk_sz * (chunk_i + 1))];
78                            serial_layer(submat.values, twiddle_chunk);
79                        }
80                    });
81            });
82        }
83
84        for ts in twiddles {
85            par_within_blk_layer(&mut values.values, &ts);
86        }
87
88        // TODO: omit this?
89        divide_by_height(&mut values);
90        values
91    }
92
93    #[instrument(skip_all, fields(dims = %self.values.dimensions()))]
94    pub fn extrapolate(
95        self,
96        target_domain: CircleDomain<F>,
97    ) -> CircleEvaluations<F, RowMajorMatrix<F>> {
98        assert!(target_domain.log_n >= self.domain.log_n);
99        CircleEvaluations::evaluate(target_domain, self.interpolate())
100    }
101
102    pub fn evaluate_at_point<EF: ExtensionField<F>>(&self, point: Point<EF>) -> Vec<EF> {
103        // Compute z_H
104        let lagrange_num = self.domain.vanishing_poly(point);
105
106        // Permute the domain to get it into the right format.
107        let permuted_points = cfft_permute_slice(&self.domain.points().collect_vec());
108
109        // Compute the lagrange denominators. This is batched as it lets us make use of batched_multiplicative_inverse.
110        let lagrange_den = compute_lagrange_den_batched(&permuted_points, point, self.domain.log_n);
111
112        // The columnwise_dot_product here consumes about 5% of the runtime for example prove_poseidon2_m31_keccak.
113        // Definitely something worth optimising further.
114        self.values
115            .columnwise_dot_product(&lagrange_den)
116            .into_iter()
117            .map(|x| x * lagrange_num)
118            .collect_vec()
119    }
120
121    #[cfg(test)]
122    pub(crate) fn dim(&self) -> usize
123    where
124        M: Clone,
125    {
126        let coeffs = self.clone().interpolate();
127        for (i, mut row) in coeffs.rows().enumerate() {
128            if row.all(|x| x.is_zero()) {
129                return i;
130            }
131        }
132        coeffs.height()
133    }
134}
135
136impl<F: ComplexExtendable> CircleEvaluations<F, RowMajorMatrix<F>> {
137    #[instrument(skip_all, fields(dims = %coeffs.dimensions()))]
138    pub fn evaluate(domain: CircleDomain<F>, mut coeffs: RowMajorMatrix<F>) -> Self {
139        let log_n = log2_strict_usize(coeffs.height());
140        assert!(log_n <= domain.log_n);
141
142        if log_n < domain.log_n {
143            // We could simply pad coeffs like this:
144            // coeffs.pad_to_height(target_domain.size(), F::ZERO);
145            // But the first `added_bits` layers will simply fill out the zeros
146            // with the lower order values. (In `DitButterfly`, `x_2` is 0, so
147            // both `x_1` and `x_2` are set to `x_1`).
148            // So instead we directly repeat the coeffs and skip the initial layers.
149            debug_span!("extend coeffs").in_scope(|| {
150                coeffs.values.reserve(domain.size() * coeffs.width());
151                for _ in log_n..domain.log_n {
152                    coeffs.values.extend_from_within(..);
153                }
154            });
155        }
156        assert_eq!(coeffs.height(), 1 << domain.log_n);
157
158        let mut twiddles = debug_span!("twiddles").in_scope(|| {
159            compute_twiddles(domain)
160                .into_iter()
161                .map(|ts| ts.into_iter().map(|t| DitButterfly(t)).collect_vec())
162                .rev()
163                .skip(domain.log_n - log_n)
164                .peekable()
165        });
166
167        for ts in twiddles.peeking_take_while(|ts| ts.len() < desired_num_jobs()) {
168            par_within_blk_layer(&mut coeffs.values, &ts);
169        }
170
171        let par_twiddles = twiddles.collect_vec();
172        if let Some(min_blks) = par_twiddles.first().map(|ts| ts.len()) {
173            let max_blk_sz = coeffs.height() / min_blks;
174            debug_span!("par_layers", log_min_blks = log2_strict_usize(min_blks)).in_scope(|| {
175                coeffs
176                    .par_row_chunks_exact_mut(max_blk_sz)
177                    .enumerate()
178                    .for_each(|(chunk_i, submat)| {
179                        for ts in &par_twiddles {
180                            let twiddle_chunk_sz = ts.len() / min_blks;
181                            let twiddle_chunk = &ts
182                                [(twiddle_chunk_sz * chunk_i)..(twiddle_chunk_sz * (chunk_i + 1))];
183                            serial_layer(submat.values, twiddle_chunk);
184                        }
185                    });
186            });
187        }
188
189        Self::from_cfft_order(domain, coeffs)
190    }
191}
192
193#[inline]
194fn serial_layer<F: Field, B: Butterfly<F>>(values: &mut [F], twiddles: &[B]) {
195    let blk_sz = values.len() / twiddles.len();
196    for (&t, blk) in izip!(twiddles, values.chunks_exact_mut(blk_sz)) {
197        let (lo, hi) = blk.split_at_mut(blk_sz / 2);
198        t.apply_to_rows(lo, hi);
199    }
200}
201
202#[inline]
203#[instrument(level = "debug", skip_all, fields(log_blks = log2_strict_usize(twiddles.len())))]
204fn par_within_blk_layer<F: Field, B: Butterfly<F>>(values: &mut [F], twiddles: &[B]) {
205    let blk_sz = values.len() / twiddles.len();
206    for (&t, blk) in izip!(twiddles, values.chunks_exact_mut(blk_sz)) {
207        let (lo, hi) = blk.split_at_mut(blk_sz / 2);
208        let job_sz = core::cmp::max(1, lo.len() >> log2_ceil_usize(desired_num_jobs()));
209        lo.par_chunks_mut(job_sz)
210            .zip(hi.par_chunks_mut(job_sz))
211            .for_each(|(lo_job, hi_job)| t.apply_to_rows(lo_job, hi_job));
212    }
213}
214
215#[inline]
216#[allow(clippy::missing_const_for_fn)]
217fn desired_num_jobs() -> usize {
218    16 * current_num_threads()
219}
220
221impl<F: ComplexExtendable> CircleDomain<F> {
222    pub(crate) fn y_twiddles(&self) -> Vec<F> {
223        let mut ys = self.coset0().map(|p| p.y).collect_vec();
224        reverse_slice_index_bits(&mut ys);
225        ys
226    }
227    pub(crate) fn nth_y_twiddle(&self, index: usize) -> F {
228        self.nth_point(cfft_permute_index(index << 1, self.log_n)).y
229    }
230    pub(crate) fn x_twiddles(&self, layer: usize) -> Vec<F> {
231        let generator = self.subgroup_generator() * (1 << layer);
232        let shift = self.shift * (1 << layer);
233        let mut xs = iterate(shift, move |&p| p + generator)
234            .map(|p| p.x)
235            .take(1 << (self.log_n - layer - 2))
236            .collect_vec();
237        reverse_slice_index_bits(&mut xs);
238        xs
239    }
240    pub(crate) fn nth_x_twiddle(&self, index: usize) -> F {
241        (self.shift + self.subgroup_generator() * index).x
242    }
243}
244
245fn compute_twiddles<F: ComplexExtendable>(domain: CircleDomain<F>) -> Vec<Vec<F>> {
246    assert!(domain.log_n >= 1);
247    let mut pts = domain.coset0().collect_vec();
248    reverse_slice_index_bits(&mut pts);
249    let mut twiddles = vec![pts.iter().map(|p| p.y).collect_vec()];
250    if domain.log_n >= 2 {
251        twiddles.push(pts.iter().step_by(2).map(|p| p.x).collect_vec());
252        for i in 0..(domain.log_n - 2) {
253            let prev = twiddles.last().unwrap();
254            assert_eq!(prev.len(), 1 << (domain.log_n - 2 - i));
255            let cur = prev
256                .iter()
257                .step_by(2)
258                .map(|x| x.square().double() - F::ONE)
259                .collect_vec();
260            twiddles.push(cur);
261        }
262    }
263    twiddles
264}
265
266pub fn circle_basis<F: Field>(p: Point<F>, log_n: usize) -> Vec<F> {
267    let mut b = vec![F::ONE, p.y];
268    let mut x = p.x;
269    for _ in 0..(log_n - 1) {
270        for i in 0..b.len() {
271            b.push(b[i] * x);
272        }
273        x = x.square().double() - F::ONE;
274    }
275    assert_eq!(b.len(), 1 << log_n);
276    b
277}
278
279#[cfg(test)]
280mod tests {
281    use itertools::iproduct;
282    use p3_field::extension::BinomialExtensionField;
283    use p3_mersenne_31::Mersenne31;
284    use rand::rngs::SmallRng;
285    use rand::{Rng, SeedableRng};
286
287    use super::*;
288
289    type F = Mersenne31;
290    type EF = BinomialExtensionField<F, 3>;
291
292    #[test]
293    fn test_cfft_icfft() {
294        let mut rng = SmallRng::seed_from_u64(1);
295        for (log_n, width) in iproduct!(2..5, [1, 4, 11]) {
296            let shift = Point::generator(F::CIRCLE_TWO_ADICITY) * (rng.random::<u16>() as usize);
297            let domain = CircleDomain::new(log_n, shift);
298            let trace = RowMajorMatrix::<F>::rand(&mut rng, 1 << log_n, width);
299            let coeffs = CircleEvaluations::from_natural_order(domain, trace.clone()).interpolate();
300            assert_eq!(
301                CircleEvaluations::evaluate(domain, coeffs.clone())
302                    .to_natural_order()
303                    .to_row_major_matrix(),
304                trace,
305                "icfft(cfft(evals)) is identity",
306            );
307            for (i, pt) in domain.points().enumerate() {
308                assert_eq!(
309                    &*trace.row_slice(i).unwrap(),
310                    coeffs.columnwise_dot_product(&circle_basis(pt, log_n)),
311                    "coeffs can be evaluated with circle_basis",
312                );
313            }
314        }
315    }
316
317    #[test]
318    fn test_extrapolation() {
319        let mut rng = SmallRng::seed_from_u64(1);
320        for (log_n, log_blowup) in iproduct!(2..5, [1, 2, 3]) {
321            let evals = CircleEvaluations::<F>::from_natural_order(
322                CircleDomain::standard(log_n),
323                RowMajorMatrix::rand(&mut rng, 1 << log_n, 11),
324            );
325            let lde = evals
326                .clone()
327                .extrapolate(CircleDomain::standard(log_n + log_blowup));
328
329            let coeffs = evals.interpolate();
330            let lde_coeffs = lde.interpolate();
331
332            for r in 0..coeffs.height() {
333                assert_eq!(
334                    &*coeffs.row_slice(r).unwrap(),
335                    &*lde_coeffs.row_slice(r).unwrap()
336                );
337            }
338            for r in coeffs.height()..lde_coeffs.height() {
339                assert!(lde_coeffs.row(r).unwrap().into_iter().all(|x| x.is_zero()));
340            }
341        }
342    }
343
344    #[test]
345    fn eval_at_point_matches_cfft() {
346        let mut rng = SmallRng::seed_from_u64(1);
347        for (log_n, width) in iproduct!(2..5, [1, 4, 11]) {
348            let evals = CircleEvaluations::<F>::from_natural_order(
349                CircleDomain::standard(log_n),
350                RowMajorMatrix::rand(&mut rng, 1 << log_n, width),
351            );
352
353            let pt = Point::<EF>::from_projective_line(rng.random());
354
355            assert_eq!(
356                evals.clone().evaluate_at_point(pt),
357                evals
358                    .interpolate()
359                    .columnwise_dot_product(&circle_basis(pt, log_n))
360            );
361        }
362    }
363
364    #[test]
365    fn eval_at_point_matches_lde() {
366        let mut rng = SmallRng::seed_from_u64(1);
367        for (log_n, width, log_blowup) in iproduct!(2..8, [1, 4, 11], [1, 2]) {
368            let evals = CircleEvaluations::<F>::from_natural_order(
369                CircleDomain::standard(log_n),
370                RowMajorMatrix::rand(&mut rng, 1 << log_n, width),
371            );
372            let lde = evals
373                .clone()
374                .extrapolate(CircleDomain::standard(log_n + log_blowup));
375            let zeta = Point::<EF>::from_projective_line(rng.random());
376            assert_eq!(evals.evaluate_at_point(zeta), lde.evaluate_at_point(zeta));
377            assert_eq!(
378                evals.evaluate_at_point(zeta),
379                evals
380                    .interpolate()
381                    .columnwise_dot_product(&circle_basis(zeta, log_n))
382            );
383            assert_eq!(
384                lde.evaluate_at_point(zeta),
385                lde.interpolate()
386                    .columnwise_dot_product(&circle_basis(zeta, log_n + log_blowup))
387            );
388        }
389    }
390}