1use alloc::vec;
2use alloc::vec::Vec;
3
4use itertools::{iterate, izip, Itertools};
5use p3_commit::PolynomialSpace;
6use p3_dft::{divide_by_height, Butterfly, DifButterfly, DitButterfly};
7use p3_field::extension::ComplexExtendable;
8use p3_field::{batch_multiplicative_inverse, ExtensionField, Field};
9use p3_matrix::dense::RowMajorMatrix;
10use p3_matrix::Matrix;
11use p3_maybe_rayon::prelude::*;
12use p3_util::{log2_ceil_usize, log2_strict_usize, reverse_slice_index_bits};
13use tracing::{debug_span, instrument};
14
15use crate::domain::CircleDomain;
16use crate::point::Point;
17use crate::{cfft_permute_index, cfft_permute_slice, CfftPermutable, CfftView};
18
19#[derive(Clone)]
20pub struct CircleEvaluations<F, M = RowMajorMatrix<F>> {
21 pub(crate) domain: CircleDomain<F>,
22 pub(crate) values: M,
23}
24
25impl<F: Copy + Send + Sync, M: Matrix<F>> CircleEvaluations<F, M> {
26 pub(crate) fn from_cfft_order(domain: CircleDomain<F>, values: M) -> Self {
27 assert_eq!(1 << domain.log_n, values.height());
28 Self { domain, values }
29 }
30 pub fn from_natural_order(
31 domain: CircleDomain<F>,
32 values: M,
33 ) -> CircleEvaluations<F, CfftView<M>> {
34 CircleEvaluations::from_cfft_order(domain, values.cfft_perm_rows())
35 }
36 pub fn to_cfft_order(self) -> M {
37 self.values
38 }
39 pub fn to_natural_order(self) -> CfftView<M> {
40 self.values.cfft_perm_rows()
41 }
42}
43
44impl<F: ComplexExtendable, M: Matrix<F>> CircleEvaluations<F, M> {
45 #[instrument(skip_all, fields(dims = %self.values.dimensions()))]
46 pub fn interpolate(self) -> RowMajorMatrix<F> {
47 let CircleEvaluations { domain, values } = self;
48 let mut values = debug_span!("to_rmm").in_scope(|| values.to_row_major_matrix());
49
50 let mut twiddles = debug_span!("twiddles").in_scope(|| {
51 compute_twiddles(domain)
52 .into_iter()
53 .map(|ts| {
54 batch_multiplicative_inverse(&ts)
55 .into_iter()
56 .map(|t| DifButterfly(t))
57 .collect_vec()
58 })
59 .peekable()
60 });
61
62 assert_eq!(twiddles.len(), domain.log_n);
63
64 let par_twiddles = twiddles
65 .peeking_take_while(|ts| ts.len() >= desired_num_jobs())
66 .collect_vec();
67 if let Some(min_blks) = par_twiddles.last().map(|ts| ts.len()) {
68 let max_blk_sz = values.height() / min_blks;
69 debug_span!("par_layers", log_min_blks = log2_strict_usize(min_blks)).in_scope(|| {
70 values
71 .par_row_chunks_exact_mut(max_blk_sz)
72 .enumerate()
73 .for_each(|(chunk_i, submat)| {
74 for ts in &par_twiddles {
75 let twiddle_chunk_sz = ts.len() / min_blks;
76 let twiddle_chunk = &ts
77 [(twiddle_chunk_sz * chunk_i)..(twiddle_chunk_sz * (chunk_i + 1))];
78 serial_layer(submat.values, twiddle_chunk);
79 }
80 });
81 });
82 }
83
84 for ts in twiddles {
85 par_within_blk_layer(&mut values.values, &ts);
86 }
87
88 divide_by_height(&mut values);
90 values
91 }
92
93 #[instrument(skip_all, fields(dims = %self.values.dimensions()))]
94 pub fn extrapolate(
95 self,
96 target_domain: CircleDomain<F>,
97 ) -> CircleEvaluations<F, RowMajorMatrix<F>> {
98 assert!(target_domain.log_n >= self.domain.log_n);
99 CircleEvaluations::<F>::evaluate(target_domain, self.interpolate())
100 }
101
102 pub fn evaluate_at_point<EF: ExtensionField<F>>(&self, point: Point<EF>) -> Vec<EF> {
103 let lagrange_num = self.domain.zeroifier(point);
104 let lagrange_den = cfft_permute_slice(&self.domain.points().collect_vec())
105 .into_iter()
106 .map(|p| p.v_tilde_p(point) * p.s_p_at_p(self.domain.log_n))
107 .collect_vec();
108 self.values
109 .columnwise_dot_product(&batch_multiplicative_inverse(&lagrange_den))
110 .into_iter()
111 .map(|x| x * lagrange_num)
112 .collect_vec()
113 }
114
115 #[cfg(test)]
116 pub(crate) fn dim(&self) -> usize
117 where
118 M: Clone,
119 {
120 let coeffs = self.clone().interpolate();
121 for (i, mut row) in coeffs.rows().enumerate() {
122 if row.all(|x| x.is_zero()) {
123 return i;
124 }
125 }
126 coeffs.height()
127 }
128}
129
130impl<F: ComplexExtendable> CircleEvaluations<F, RowMajorMatrix<F>> {
131 #[instrument(skip_all, fields(dims = %coeffs.dimensions()))]
132 pub fn evaluate(domain: CircleDomain<F>, mut coeffs: RowMajorMatrix<F>) -> Self {
133 let log_n = log2_strict_usize(coeffs.height());
134 assert!(log_n <= domain.log_n);
135
136 if log_n < domain.log_n {
137 debug_span!("extend coeffs").in_scope(|| {
144 coeffs.values.reserve(domain.size() * coeffs.width());
145 for _ in log_n..domain.log_n {
146 coeffs.values.extend_from_within(..);
147 }
148 });
149 }
150 assert_eq!(coeffs.height(), 1 << domain.log_n);
151
152 let mut twiddles = debug_span!("twiddles").in_scope(|| {
153 compute_twiddles(domain)
154 .into_iter()
155 .map(|ts| ts.into_iter().map(|t| DitButterfly(t)).collect_vec())
156 .rev()
157 .skip(domain.log_n - log_n)
158 .peekable()
159 });
160
161 for ts in twiddles.peeking_take_while(|ts| ts.len() < desired_num_jobs()) {
162 par_within_blk_layer(&mut coeffs.values, &ts);
163 }
164
165 let par_twiddles = twiddles.collect_vec();
166 if let Some(min_blks) = par_twiddles.first().map(|ts| ts.len()) {
167 let max_blk_sz = coeffs.height() / min_blks;
168 debug_span!("par_layers", log_min_blks = log2_strict_usize(min_blks)).in_scope(|| {
169 coeffs
170 .par_row_chunks_exact_mut(max_blk_sz)
171 .enumerate()
172 .for_each(|(chunk_i, submat)| {
173 for ts in &par_twiddles {
174 let twiddle_chunk_sz = ts.len() / min_blks;
175 let twiddle_chunk = &ts
176 [(twiddle_chunk_sz * chunk_i)..(twiddle_chunk_sz * (chunk_i + 1))];
177 serial_layer(submat.values, twiddle_chunk);
178 }
179 });
180 });
181 }
182
183 Self::from_cfft_order(domain, coeffs)
184 }
185}
186
187#[inline]
188fn serial_layer<F: Field, B: Butterfly<F>>(values: &mut [F], twiddles: &[B]) {
189 let blk_sz = values.len() / twiddles.len();
190 for (&t, blk) in izip!(twiddles, values.chunks_exact_mut(blk_sz)) {
191 let (lo, hi) = blk.split_at_mut(blk_sz / 2);
192 t.apply_to_rows(lo, hi);
193 }
194}
195
196#[inline]
197#[instrument(level = "debug", skip_all, fields(log_blks = log2_strict_usize(twiddles.len())))]
198fn par_within_blk_layer<F: Field, B: Butterfly<F>>(values: &mut [F], twiddles: &[B]) {
199 let blk_sz = values.len() / twiddles.len();
200 for (&t, blk) in izip!(twiddles, values.chunks_exact_mut(blk_sz)) {
201 let (lo, hi) = blk.split_at_mut(blk_sz / 2);
202 let job_sz = core::cmp::max(1, lo.len() >> log2_ceil_usize(desired_num_jobs()));
203 lo.par_chunks_mut(job_sz)
204 .zip(hi.par_chunks_mut(job_sz))
205 .for_each(|(lo_job, hi_job)| t.apply_to_rows(lo_job, hi_job));
206 }
207}
208
209#[inline]
210fn desired_num_jobs() -> usize {
211 16 * current_num_threads()
212}
213
214impl<F: ComplexExtendable> CircleDomain<F> {
215 pub(crate) fn y_twiddles(&self) -> Vec<F> {
216 let mut ys = self.coset0().map(|p| p.y).collect_vec();
217 reverse_slice_index_bits(&mut ys);
218 ys
219 }
220 pub(crate) fn nth_y_twiddle(&self, index: usize) -> F {
221 self.nth_point(cfft_permute_index(index << 1, self.log_n)).y
222 }
223 pub(crate) fn x_twiddles(&self, layer: usize) -> Vec<F> {
224 let gen = self.gen() * (1 << layer);
225 let shift = self.shift * (1 << layer);
226 let mut xs = iterate(shift, move |&p| p + gen)
227 .map(|p| p.x)
228 .take(1 << (self.log_n - layer - 2))
229 .collect_vec();
230 reverse_slice_index_bits(&mut xs);
231 xs
232 }
233 pub(crate) fn nth_x_twiddle(&self, index: usize) -> F {
234 (self.shift + self.gen() * index).x
235 }
236}
237
238fn compute_twiddles<F: ComplexExtendable>(domain: CircleDomain<F>) -> Vec<Vec<F>> {
239 assert!(domain.log_n >= 1);
240 let mut pts = domain.coset0().collect_vec();
241 reverse_slice_index_bits(&mut pts);
242 let mut twiddles = vec![pts.iter().map(|p| p.y).collect_vec()];
243 if domain.log_n >= 2 {
244 twiddles.push(pts.iter().step_by(2).map(|p| p.x).collect_vec());
245 for i in 0..(domain.log_n - 2) {
246 let prev = twiddles.last().unwrap();
247 assert_eq!(prev.len(), 1 << (domain.log_n - 2 - i));
248 let cur = prev
249 .iter()
250 .step_by(2)
251 .map(|x| x.square().double() - F::one())
252 .collect_vec();
253 twiddles.push(cur);
254 }
255 }
256 twiddles
257}
258
259pub fn circle_basis<F: Field>(p: Point<F>, log_n: usize) -> Vec<F> {
260 let mut b = vec![F::one(), p.y];
261 let mut x = p.x;
262 for _ in 0..(log_n - 1) {
263 for i in 0..b.len() {
264 b.push(b[i] * x);
265 }
266 x = x.square().double() - F::one();
267 }
268 assert_eq!(b.len(), 1 << log_n);
269 b
270}
271
272#[cfg(test)]
273mod tests {
274 use itertools::iproduct;
275 use p3_field::extension::BinomialExtensionField;
276 use p3_mersenne_31::Mersenne31;
277 use rand::{random, thread_rng};
278
279 use super::*;
280
281 type F = Mersenne31;
282 type EF = BinomialExtensionField<F, 3>;
283
284 #[test]
285 fn test_cfft_icfft() {
286 for (log_n, width) in iproduct!(2..5, [1, 4, 11]) {
287 let shift = Point::generator(F::CIRCLE_TWO_ADICITY) * random();
288 let domain = CircleDomain::<F>::new(log_n, shift);
289 let trace = RowMajorMatrix::<F>::rand(&mut thread_rng(), 1 << log_n, width);
290 let coeffs = CircleEvaluations::from_natural_order(domain, trace.clone()).interpolate();
291 assert_eq!(
292 CircleEvaluations::evaluate(domain, coeffs.clone())
293 .to_natural_order()
294 .to_row_major_matrix(),
295 trace,
296 "icfft(cfft(evals)) is identity",
297 );
298 for (i, pt) in domain.points().enumerate() {
299 assert_eq!(
300 &*trace.row_slice(i),
301 coeffs.columnwise_dot_product(&circle_basis(pt, log_n)),
302 "coeffs can be evaluated with circle_basis",
303 );
304 }
305 }
306 }
307
308 #[test]
309 fn test_extrapolation() {
310 for (log_n, log_blowup) in iproduct!(2..5, [1, 2, 3]) {
311 let evals = CircleEvaluations::<F>::from_natural_order(
312 CircleDomain::standard(log_n),
313 RowMajorMatrix::rand(&mut thread_rng(), 1 << log_n, 11),
314 );
315 let lde = evals
316 .clone()
317 .extrapolate(CircleDomain::standard(log_n + log_blowup));
318
319 let coeffs = evals.interpolate();
320 let lde_coeffs = lde.interpolate();
321
322 for r in 0..coeffs.height() {
323 assert_eq!(&*coeffs.row_slice(r), &*lde_coeffs.row_slice(r));
324 }
325 for r in coeffs.height()..lde_coeffs.height() {
326 assert!(lde_coeffs.row(r).all(|x| x.is_zero()));
327 }
328 }
329 }
330
331 #[test]
332 fn eval_at_point_matches_cfft() {
333 for (log_n, width) in iproduct!(2..5, [1, 4, 11]) {
334 let evals = CircleEvaluations::<F>::from_natural_order(
335 CircleDomain::standard(log_n),
336 RowMajorMatrix::rand(&mut thread_rng(), 1 << log_n, width),
337 );
338
339 let pt = Point::<EF>::from_projective_line(random());
340
341 assert_eq!(
342 evals.clone().evaluate_at_point(pt),
343 evals
344 .interpolate()
345 .columnwise_dot_product(&circle_basis(pt, log_n))
346 );
347 }
348 }
349
350 #[test]
351 fn eval_at_point_matches_lde() {
352 for (log_n, width, log_blowup) in iproduct!(2..8, [1, 4, 11], [1, 2]) {
353 let evals = CircleEvaluations::<F>::from_natural_order(
354 CircleDomain::standard(log_n),
355 RowMajorMatrix::rand(&mut thread_rng(), 1 << log_n, width),
356 );
357 let lde = evals
358 .clone()
359 .extrapolate(CircleDomain::standard(log_n + log_blowup));
360 let zeta = Point::<EF>::from_projective_line(random());
361 assert_eq!(evals.evaluate_at_point(zeta), lde.evaluate_at_point(zeta));
362 assert_eq!(
363 evals.evaluate_at_point(zeta),
364 evals
365 .interpolate()
366 .columnwise_dot_product(&circle_basis(zeta, log_n))
367 );
368 assert_eq!(
369 lde.evaluate_at_point(zeta),
370 lde.interpolate()
371 .columnwise_dot_product(&circle_basis(zeta, log_n + log_blowup))
372 );
373 }
374 }
375}