1use alloc::vec;
2use alloc::vec::Vec;
3
4use itertools::{Itertools, iterate, izip};
5use p3_commit::PolynomialSpace;
6use p3_dft::{Butterfly, DifButterfly, DitButterfly, divide_by_height};
7use p3_field::extension::ComplexExtendable;
8use p3_field::{ExtensionField, Field, batch_multiplicative_inverse};
9use p3_matrix::Matrix;
10use p3_matrix::dense::RowMajorMatrix;
11use p3_maybe_rayon::prelude::*;
12use p3_util::{log2_ceil_usize, log2_strict_usize, reverse_slice_index_bits};
13use tracing::{debug_span, instrument};
14
15use crate::domain::CircleDomain;
16use crate::point::{Point, compute_lagrange_den_batched};
17use crate::{CfftPermutable, CfftView, cfft_permute_index, cfft_permute_slice};
18
19#[derive(Clone)]
20pub struct CircleEvaluations<F, M = RowMajorMatrix<F>> {
21 pub(crate) domain: CircleDomain<F>,
22 pub(crate) values: M,
23}
24
25impl<F: Copy + Send + Sync, M: Matrix<F>> CircleEvaluations<F, M> {
26 pub(crate) fn from_cfft_order(domain: CircleDomain<F>, values: M) -> Self {
27 assert_eq!(1 << domain.log_n, values.height());
28 Self { domain, values }
29 }
30 pub fn from_natural_order(
31 domain: CircleDomain<F>,
32 values: M,
33 ) -> CircleEvaluations<F, CfftView<M>> {
34 CircleEvaluations::from_cfft_order(domain, values.cfft_perm_rows())
35 }
36 pub fn to_cfft_order(self) -> M {
37 self.values
38 }
39 pub fn to_natural_order(self) -> CfftView<M> {
40 self.values.cfft_perm_rows()
41 }
42}
43
44impl<F: ComplexExtendable, M: Matrix<F>> CircleEvaluations<F, M> {
45 #[instrument(skip_all, fields(dims = %self.values.dimensions()))]
46 pub fn interpolate(self) -> RowMajorMatrix<F> {
47 let Self { domain, values } = self;
48 let mut values = debug_span!("to_rmm").in_scope(|| values.to_row_major_matrix());
49
50 let mut twiddles = debug_span!("twiddles").in_scope(|| {
51 compute_twiddles(domain)
52 .into_iter()
53 .map(|ts| {
54 batch_multiplicative_inverse(&ts)
55 .into_iter()
56 .map(|t| DifButterfly(t))
57 .collect_vec()
58 })
59 .peekable()
60 });
61
62 assert_eq!(twiddles.len(), domain.log_n);
63
64 let par_twiddles = twiddles
65 .peeking_take_while(|ts| ts.len() >= desired_num_jobs())
66 .collect_vec();
67 if let Some(min_blks) = par_twiddles.last().map(|ts| ts.len()) {
68 let max_blk_sz = values.height() / min_blks;
69 debug_span!("par_layers", log_min_blks = log2_strict_usize(min_blks)).in_scope(|| {
70 values
71 .par_row_chunks_exact_mut(max_blk_sz)
72 .enumerate()
73 .for_each(|(chunk_i, submat)| {
74 for ts in &par_twiddles {
75 let twiddle_chunk_sz = ts.len() / min_blks;
76 let twiddle_chunk = &ts
77 [(twiddle_chunk_sz * chunk_i)..(twiddle_chunk_sz * (chunk_i + 1))];
78 serial_layer(submat.values, twiddle_chunk);
79 }
80 });
81 });
82 }
83
84 for ts in twiddles {
85 par_within_blk_layer(&mut values.values, &ts);
86 }
87
88 divide_by_height(&mut values);
90 values
91 }
92
93 #[instrument(skip_all, fields(dims = %self.values.dimensions()))]
94 pub fn extrapolate(
95 self,
96 target_domain: CircleDomain<F>,
97 ) -> CircleEvaluations<F, RowMajorMatrix<F>> {
98 assert!(target_domain.log_n >= self.domain.log_n);
99 CircleEvaluations::evaluate(target_domain, self.interpolate())
100 }
101
102 pub fn evaluate_at_point<EF: ExtensionField<F>>(&self, point: Point<EF>) -> Vec<EF> {
103 let lagrange_num = self.domain.vanishing_poly(point);
105
106 let permuted_points = cfft_permute_slice(&self.domain.points().collect_vec());
108
109 let lagrange_den = compute_lagrange_den_batched(&permuted_points, point, self.domain.log_n);
111
112 self.values
115 .columnwise_dot_product(&lagrange_den)
116 .into_iter()
117 .map(|x| x * lagrange_num)
118 .collect_vec()
119 }
120
121 #[cfg(test)]
122 pub(crate) fn dim(&self) -> usize
123 where
124 M: Clone,
125 {
126 let coeffs = self.clone().interpolate();
127 for (i, mut row) in coeffs.rows().enumerate() {
128 if row.all(|x| x.is_zero()) {
129 return i;
130 }
131 }
132 coeffs.height()
133 }
134}
135
136impl<F: ComplexExtendable> CircleEvaluations<F, RowMajorMatrix<F>> {
137 #[instrument(skip_all, fields(dims = %coeffs.dimensions()))]
138 pub fn evaluate(domain: CircleDomain<F>, mut coeffs: RowMajorMatrix<F>) -> Self {
139 let log_n = log2_strict_usize(coeffs.height());
140 assert!(log_n <= domain.log_n);
141
142 if log_n < domain.log_n {
143 debug_span!("extend coeffs").in_scope(|| {
150 coeffs.values.reserve(domain.size() * coeffs.width());
151 for _ in log_n..domain.log_n {
152 coeffs.values.extend_from_within(..);
153 }
154 });
155 }
156 assert_eq!(coeffs.height(), 1 << domain.log_n);
157
158 let mut twiddles = debug_span!("twiddles").in_scope(|| {
159 compute_twiddles(domain)
160 .into_iter()
161 .map(|ts| ts.into_iter().map(|t| DitButterfly(t)).collect_vec())
162 .rev()
163 .skip(domain.log_n - log_n)
164 .peekable()
165 });
166
167 for ts in twiddles.peeking_take_while(|ts| ts.len() < desired_num_jobs()) {
168 par_within_blk_layer(&mut coeffs.values, &ts);
169 }
170
171 let par_twiddles = twiddles.collect_vec();
172 if let Some(min_blks) = par_twiddles.first().map(|ts| ts.len()) {
173 let max_blk_sz = coeffs.height() / min_blks;
174 debug_span!("par_layers", log_min_blks = log2_strict_usize(min_blks)).in_scope(|| {
175 coeffs
176 .par_row_chunks_exact_mut(max_blk_sz)
177 .enumerate()
178 .for_each(|(chunk_i, submat)| {
179 for ts in &par_twiddles {
180 let twiddle_chunk_sz = ts.len() / min_blks;
181 let twiddle_chunk = &ts
182 [(twiddle_chunk_sz * chunk_i)..(twiddle_chunk_sz * (chunk_i + 1))];
183 serial_layer(submat.values, twiddle_chunk);
184 }
185 });
186 });
187 }
188
189 Self::from_cfft_order(domain, coeffs)
190 }
191}
192
193#[inline]
194fn serial_layer<F: Field, B: Butterfly<F>>(values: &mut [F], twiddles: &[B]) {
195 let blk_sz = values.len() / twiddles.len();
196 for (&t, blk) in izip!(twiddles, values.chunks_exact_mut(blk_sz)) {
197 let (lo, hi) = blk.split_at_mut(blk_sz / 2);
198 t.apply_to_rows(lo, hi);
199 }
200}
201
202#[inline]
203#[instrument(level = "debug", skip_all, fields(log_blks = log2_strict_usize(twiddles.len())))]
204fn par_within_blk_layer<F: Field, B: Butterfly<F>>(values: &mut [F], twiddles: &[B]) {
205 let blk_sz = values.len() / twiddles.len();
206 for (&t, blk) in izip!(twiddles, values.chunks_exact_mut(blk_sz)) {
207 let (lo, hi) = blk.split_at_mut(blk_sz / 2);
208 let job_sz = core::cmp::max(1, lo.len() >> log2_ceil_usize(desired_num_jobs()));
209 lo.par_chunks_mut(job_sz)
210 .zip(hi.par_chunks_mut(job_sz))
211 .for_each(|(lo_job, hi_job)| t.apply_to_rows(lo_job, hi_job));
212 }
213}
214
215#[inline]
216#[allow(clippy::missing_const_for_fn)]
217fn desired_num_jobs() -> usize {
218 16 * current_num_threads()
219}
220
221impl<F: ComplexExtendable> CircleDomain<F> {
222 pub(crate) fn y_twiddles(&self) -> Vec<F> {
223 let mut ys = self.coset0().map(|p| p.y).collect_vec();
224 reverse_slice_index_bits(&mut ys);
225 ys
226 }
227 pub(crate) fn nth_y_twiddle(&self, index: usize) -> F {
228 self.nth_point(cfft_permute_index(index << 1, self.log_n)).y
229 }
230 pub(crate) fn x_twiddles(&self, layer: usize) -> Vec<F> {
231 let generator = self.subgroup_generator() * (1 << layer);
232 let shift = self.shift * (1 << layer);
233 let mut xs = iterate(shift, move |&p| p + generator)
234 .map(|p| p.x)
235 .take(1 << (self.log_n - layer - 2))
236 .collect_vec();
237 reverse_slice_index_bits(&mut xs);
238 xs
239 }
240 pub(crate) fn nth_x_twiddle(&self, index: usize) -> F {
241 (self.shift + self.subgroup_generator() * index).x
242 }
243}
244
245fn compute_twiddles<F: ComplexExtendable>(domain: CircleDomain<F>) -> Vec<Vec<F>> {
246 assert!(domain.log_n >= 1);
247 let mut pts = domain.coset0().collect_vec();
248 reverse_slice_index_bits(&mut pts);
249 let mut twiddles = vec![pts.iter().map(|p| p.y).collect_vec()];
250 if domain.log_n >= 2 {
251 twiddles.push(pts.iter().step_by(2).map(|p| p.x).collect_vec());
252 for i in 0..(domain.log_n - 2) {
253 let prev = twiddles.last().unwrap();
254 assert_eq!(prev.len(), 1 << (domain.log_n - 2 - i));
255 let cur = prev
256 .iter()
257 .step_by(2)
258 .map(|x| x.square().double() - F::ONE)
259 .collect_vec();
260 twiddles.push(cur);
261 }
262 }
263 twiddles
264}
265
266pub fn circle_basis<F: Field>(p: Point<F>, log_n: usize) -> Vec<F> {
267 let mut b = vec![F::ONE, p.y];
268 let mut x = p.x;
269 for _ in 0..(log_n - 1) {
270 for i in 0..b.len() {
271 b.push(b[i] * x);
272 }
273 x = x.square().double() - F::ONE;
274 }
275 assert_eq!(b.len(), 1 << log_n);
276 b
277}
278
279#[cfg(test)]
280mod tests {
281 use itertools::iproduct;
282 use p3_field::extension::BinomialExtensionField;
283 use p3_mersenne_31::Mersenne31;
284 use rand::rngs::SmallRng;
285 use rand::{Rng, SeedableRng};
286
287 use super::*;
288
289 type F = Mersenne31;
290 type EF = BinomialExtensionField<F, 3>;
291
292 #[test]
293 fn test_cfft_icfft() {
294 let mut rng = SmallRng::seed_from_u64(1);
295 for (log_n, width) in iproduct!(2..5, [1, 4, 11]) {
296 let shift = Point::generator(F::CIRCLE_TWO_ADICITY) * (rng.random::<u16>() as usize);
297 let domain = CircleDomain::new(log_n, shift);
298 let trace = RowMajorMatrix::<F>::rand(&mut rng, 1 << log_n, width);
299 let coeffs = CircleEvaluations::from_natural_order(domain, trace.clone()).interpolate();
300 assert_eq!(
301 CircleEvaluations::evaluate(domain, coeffs.clone())
302 .to_natural_order()
303 .to_row_major_matrix(),
304 trace,
305 "icfft(cfft(evals)) is identity",
306 );
307 for (i, pt) in domain.points().enumerate() {
308 assert_eq!(
309 &*trace.row_slice(i).unwrap(),
310 coeffs.columnwise_dot_product(&circle_basis(pt, log_n)),
311 "coeffs can be evaluated with circle_basis",
312 );
313 }
314 }
315 }
316
317 #[test]
318 fn test_extrapolation() {
319 let mut rng = SmallRng::seed_from_u64(1);
320 for (log_n, log_blowup) in iproduct!(2..5, [1, 2, 3]) {
321 let evals = CircleEvaluations::<F>::from_natural_order(
322 CircleDomain::standard(log_n),
323 RowMajorMatrix::rand(&mut rng, 1 << log_n, 11),
324 );
325 let lde = evals
326 .clone()
327 .extrapolate(CircleDomain::standard(log_n + log_blowup));
328
329 let coeffs = evals.interpolate();
330 let lde_coeffs = lde.interpolate();
331
332 for r in 0..coeffs.height() {
333 assert_eq!(
334 &*coeffs.row_slice(r).unwrap(),
335 &*lde_coeffs.row_slice(r).unwrap()
336 );
337 }
338 for r in coeffs.height()..lde_coeffs.height() {
339 assert!(lde_coeffs.row(r).unwrap().into_iter().all(|x| x.is_zero()));
340 }
341 }
342 }
343
344 #[test]
345 fn eval_at_point_matches_cfft() {
346 let mut rng = SmallRng::seed_from_u64(1);
347 for (log_n, width) in iproduct!(2..5, [1, 4, 11]) {
348 let evals = CircleEvaluations::<F>::from_natural_order(
349 CircleDomain::standard(log_n),
350 RowMajorMatrix::rand(&mut rng, 1 << log_n, width),
351 );
352
353 let pt = Point::<EF>::from_projective_line(rng.random());
354
355 assert_eq!(
356 evals.clone().evaluate_at_point(pt),
357 evals
358 .interpolate()
359 .columnwise_dot_product(&circle_basis(pt, log_n))
360 );
361 }
362 }
363
364 #[test]
365 fn eval_at_point_matches_lde() {
366 let mut rng = SmallRng::seed_from_u64(1);
367 for (log_n, width, log_blowup) in iproduct!(2..8, [1, 4, 11], [1, 2]) {
368 let evals = CircleEvaluations::<F>::from_natural_order(
369 CircleDomain::standard(log_n),
370 RowMajorMatrix::rand(&mut rng, 1 << log_n, width),
371 );
372 let lde = evals
373 .clone()
374 .extrapolate(CircleDomain::standard(log_n + log_blowup));
375 let zeta = Point::<EF>::from_projective_line(rng.random());
376 assert_eq!(evals.evaluate_at_point(zeta), lde.evaluate_at_point(zeta));
377 assert_eq!(
378 evals.evaluate_at_point(zeta),
379 evals
380 .interpolate()
381 .columnwise_dot_product(&circle_basis(zeta, log_n))
382 );
383 assert_eq!(
384 lde.evaluate_at_point(zeta),
385 lde.interpolate()
386 .columnwise_dot_product(&circle_basis(zeta, log_n + log_blowup))
387 );
388 }
389 }
390}