1use alloc::vec;
2use alloc::vec::Vec;
3
4use itertools::{Itertools, iterate};
5use p3_commit::{LagrangeSelectors, PolynomialSpace};
6use p3_field::extension::ComplexExtendable;
7use p3_field::{ExtensionField, batch_multiplicative_inverse};
8use p3_matrix::Matrix;
9use p3_matrix::dense::RowMajorMatrix;
10use p3_util::{log2_ceil_usize, log2_strict_usize};
11use tracing::instrument;
12
13use crate::point::Point;
14
15#[derive(Copy, Clone, PartialEq, Eq, Debug)]
41pub struct CircleDomain<F> {
42 pub(crate) log_n: usize,
44 pub(crate) shift: Point<F>,
45}
46
47impl<F: ComplexExtendable> CircleDomain<F> {
48 pub const fn new(log_n: usize, shift: Point<F>) -> Self {
49 Self { log_n, shift }
50 }
51 pub fn standard(log_n: usize) -> Self {
52 Self {
53 log_n,
54 shift: Point::generator(log_n + 1),
55 }
56 }
57 fn is_standard(&self) -> bool {
58 self.shift == Point::generator(self.log_n + 1)
59 }
60 pub(crate) fn subgroup_generator(&self) -> Point<F> {
61 Point::generator(self.log_n - 1)
62 }
63 pub(crate) fn coset0(&self) -> impl Iterator<Item = Point<F>> {
64 let g = self.subgroup_generator();
65 iterate(self.shift, move |&p| p + g).take(1 << (self.log_n - 1))
66 }
67 fn coset1(&self) -> impl Iterator<Item = Point<F>> {
68 let g = self.subgroup_generator();
69 iterate(g - self.shift, move |&p| p + g).take(1 << (self.log_n - 1))
70 }
71 pub(crate) fn points(&self) -> impl Iterator<Item = Point<F>> {
72 self.coset0().interleave(self.coset1())
73 }
74 pub(crate) fn nth_point(&self, idx: usize) -> Point<F> {
75 let (idx, lsb) = (idx >> 1, idx & 1);
76 if lsb == 0 {
77 self.shift + self.subgroup_generator() * idx
78 } else {
79 -self.shift + self.subgroup_generator() * (idx + 1)
80 }
81 }
82
83 pub(crate) fn vanishing_poly<EF: ExtensionField<F>>(&self, at: Point<EF>) -> EF {
84 at.v_n(self.log_n) - self.shift.v_n(self.log_n)
85 }
86
87 pub(crate) fn s_p<EF: ExtensionField<F>>(&self, p: Point<F>, at: Point<EF>) -> EF {
88 self.vanishing_poly(at) / p.v_tilde_p(at)
89 }
90
91 pub(crate) fn s_p_normalized<EF: ExtensionField<F>>(&self, p: Point<F>, at: Point<EF>) -> EF {
92 self.vanishing_poly(at) / (p.v_tilde_p(at) * p.s_p_at_p(self.log_n))
93 }
94}
95
96impl<F: ComplexExtendable> PolynomialSpace for CircleDomain<F> {
97 type Val = F;
98
99 fn size(&self) -> usize {
100 1 << self.log_n
101 }
102
103 fn first_point(&self) -> Self::Val {
104 self.shift.to_projective_line().unwrap()
105 }
106
107 fn next_point<Ext: ExtensionField<Self::Val>>(&self, x: Ext) -> Option<Ext> {
108 if self.is_standard() {
110 (Point::from_projective_line(x) + Point::generator(self.log_n)).to_projective_line()
111 } else {
112 None
113 }
114 }
115
116 fn create_disjoint_domain(&self, min_size: usize) -> Self {
117 assert!(
122 self.is_standard(),
123 "create_disjoint_domain not currently supported for nonstandard twin cosets"
124 );
125 let log_n = log2_ceil_usize(min_size);
126 Self::standard(if log_n == self.log_n {
128 log_n + 1
129 } else {
130 log_n
131 })
132 }
133
134 fn split_domains(&self, num_chunks: usize) -> Vec<Self> {
136 assert!(self.is_standard());
137 let log_chunks = log2_strict_usize(num_chunks);
138 assert!(log_chunks <= self.log_n);
139 self.points()
140 .take(num_chunks)
141 .map(|shift| Self {
142 log_n: self.log_n - log_chunks,
143 shift,
144 })
145 .collect()
146 }
147
148 fn split_evals(
149 &self,
150 num_chunks: usize,
151 evals: RowMajorMatrix<Self::Val>,
152 ) -> Vec<RowMajorMatrix<Self::Val>> {
153 let log_chunks = log2_strict_usize(num_chunks);
154 assert!(evals.height() >> (log_chunks + 1) >= 1);
155 let width = evals.width();
156 let mut values: Vec<Vec<Self::Val>> = vec![vec![]; num_chunks];
157 evals
158 .rows()
159 .enumerate()
160 .for_each(|(i, row)| values[forward_backward_index(i, num_chunks)].extend(row));
161 values
162 .into_iter()
163 .map(|v| RowMajorMatrix::new(v, width))
164 .collect()
165 }
166
167 fn vanishing_poly_at_point<Ext: ExtensionField<Self::Val>>(&self, point: Ext) -> Ext {
168 self.vanishing_poly(Point::from_projective_line(point))
169 }
170
171 fn selectors_at_point<Ext: ExtensionField<Self::Val>>(
172 &self,
173 point: Ext,
174 ) -> LagrangeSelectors<Ext> {
175 let point = Point::from_projective_line(point);
176 LagrangeSelectors {
177 is_first_row: self.s_p(self.shift, point),
178 is_last_row: self.s_p(-self.shift, point),
179 is_transition: Ext::ONE - self.s_p_normalized(-self.shift, point),
180 inv_vanishing: self.vanishing_poly(point).inverse(),
181 }
182 }
183
184 #[instrument(skip_all, fields(log_n = %coset.log_n))]
201 fn selectors_on_coset(&self, coset: Self) -> LagrangeSelectors<Vec<Self::Val>> {
202 let pts = coset.points().collect_vec();
203
204 let neg_shift = -self.shift;
206 let k = neg_shift.s_p_at_p(self.log_n);
207
208 let z_vals: Vec<Self::Val> = pts.iter().map(|&at| self.vanishing_poly(at)).collect();
209 let den_shift: Vec<Self::Val> = pts.iter().map(|&at| self.shift.v_tilde_p(at)).collect();
210 let den_negshift_k: Vec<Self::Val> =
211 pts.iter().map(|&at| neg_shift.v_tilde_p(at) * k).collect();
212
213 let inv_vanishing = batch_multiplicative_inverse(&z_vals);
215 let inv_den_shift = batch_multiplicative_inverse(&den_shift);
216 let inv_den_negshift_k = batch_multiplicative_inverse(&den_negshift_k);
217
218 let is_first_row = z_vals
221 .iter()
222 .zip(inv_den_shift.iter())
223 .map(|(&z, &inv_d)| z * inv_d)
224 .collect();
225 let is_last_row = z_vals
226 .iter()
227 .zip(inv_den_negshift_k.iter())
228 .map(|(&z, &inv_dk)| z * inv_dk * k)
229 .collect();
230 let is_transition = z_vals
231 .iter()
232 .zip(inv_den_negshift_k.iter())
233 .map(|(&z, &inv_dk)| Self::Val::ONE - z * inv_dk)
234 .collect();
235
236 LagrangeSelectors {
237 is_first_row,
238 is_last_row,
239 is_transition,
240 inv_vanishing,
241 }
242 }
243}
244
245const fn forward_backward_index(mut i: usize, len: usize) -> usize {
247 i %= 2 * len;
248 if i < len { i } else { 2 * len - 1 - i }
249}
250
251#[cfg(test)]
252mod tests {
253 use core::iter;
254
255 use hashbrown::HashSet;
256 use itertools::izip;
257 use p3_field::{PrimeCharacteristicRing, batch_multiplicative_inverse};
258 use p3_mersenne_31::Mersenne31;
259 use rand::SeedableRng;
260 use rand::rngs::SmallRng;
261
262 use super::*;
263 use crate::CircleEvaluations;
264
265 fn assert_is_twin_coset<F: ComplexExtendable>(d: CircleDomain<F>) {
266 let pts = d.points().collect_vec();
267 let half_n = pts.len() >> 1;
268 for (&l, &r) in izip!(&pts[..half_n], pts[half_n..].iter().rev()) {
269 assert_eq!(l, -r);
270 }
271 }
272
273 fn do_test_circle_domain(log_n: usize, width: usize) {
274 let n = 1 << log_n;
275
276 type F = Mersenne31;
277 let d = CircleDomain::<F>::standard(log_n);
278
279 let p0 = d.first_point();
281 let mut p1 = p0;
282 for i in 0..(n - 1) {
283 assert_eq!(Point::from_projective_line(p1), d.nth_point(i));
285 p1 = d.next_point(p1).unwrap();
286 assert_ne!(p1, p0);
287 }
288 assert_eq!(d.next_point(p1).unwrap(), p0);
289
290 let mut uni_point = d.first_point();
292 for p in d.points() {
293 assert_eq!(Point::from_projective_line(uni_point), p);
294 uni_point = d.next_point(uni_point).unwrap();
295 }
296
297 let seen: HashSet<Point<F>> = d.points().collect();
299 for disjoint_size in [10, 100, n - 5, n + 15] {
300 let dd = d.create_disjoint_domain(disjoint_size);
301 assert!(dd.size() >= disjoint_size);
302 for pt in dd.points() {
303 assert!(!seen.contains(&pt));
304 }
305 }
306
307 for p in d.points() {
309 assert_eq!(
310 d.vanishing_poly_at_point(p.to_projective_line().unwrap()),
311 F::ZERO
312 );
313 }
314
315 let mut rng = SmallRng::seed_from_u64(1);
316
317 let evals = RowMajorMatrix::rand(&mut rng, n, width);
319 let orig: Vec<(Point<F>, Vec<F>)> = d
320 .points()
321 .zip(evals.rows().map(|r| r.collect_vec()))
322 .collect();
323 for num_chunks in [1, 2, 4, 8] {
324 let mut combined = vec![];
325
326 let sds = d.split_domains(num_chunks);
327 assert_eq!(sds.len(), num_chunks);
328 let ses = d.split_evals(num_chunks, evals.clone());
329 assert_eq!(ses.len(), num_chunks);
330 for (sd, se) in izip!(sds, ses) {
331 assert_is_twin_coset(sd);
333 assert_eq!(sd.size() * num_chunks, d.size());
335 assert_eq!(se.width(), evals.width());
336 assert_eq!(se.height() * num_chunks, d.size());
337 combined.extend(sd.points().zip(se.rows().map(|r| r.collect_vec())));
338 }
339 assert_eq!(
341 orig.iter().map(|x| x.0).collect::<HashSet<_>>(),
342 combined.iter().map(|x| x.0).collect::<HashSet<_>>(),
343 "union of split domains is orig domain"
344 );
345 assert_eq!(
346 orig.iter().map(|x| &x.1).collect::<HashSet<_>>(),
347 combined.iter().map(|x| &x.1).collect::<HashSet<_>>(),
348 "union of split evals is orig evals"
349 );
350 assert_eq!(
351 orig.iter().collect::<HashSet<_>>(),
352 combined.iter().collect::<HashSet<_>>(),
353 "split domains and evals correspond to orig domains and evals"
354 );
355 }
356 }
357
358 #[test]
359 fn selectors() {
360 type F = Mersenne31;
361 let log_n = 8;
362 let n = 1 << log_n;
363
364 let d = CircleDomain::<F>::standard(log_n);
365 let coset = d.create_disjoint_domain(n);
366 let sels = d.selectors_on_coset(coset);
367
368 let mut pt = coset.first_point();
370 for i in 0..coset.size() {
371 let pt_sels = d.selectors_at_point(pt);
372 assert_eq!(sels.is_first_row[i], pt_sels.is_first_row);
373 assert_eq!(sels.is_last_row[i], pt_sels.is_last_row);
374 assert_eq!(sels.is_transition[i], pt_sels.is_transition);
375 assert_eq!(sels.inv_vanishing[i], pt_sels.inv_vanishing);
376 pt = coset.next_point(pt).unwrap();
377 }
378
379 let coset_to_d = |evals: &[F]| {
380 let evals = CircleEvaluations::from_natural_order(
381 coset,
382 RowMajorMatrix::new_col(evals.to_vec()),
383 );
384 let coeffs = evals.interpolate().to_row_major_matrix();
385 let (lo, hi) = coeffs.split_rows(n);
386 assert_eq!(hi.values, vec![F::ZERO; n]);
387 CircleEvaluations::evaluate(d, lo.to_row_major_matrix())
388 .to_natural_order()
389 .to_row_major_matrix()
390 .values
391 };
392
393 let is_first_row = coset_to_d(&sels.is_first_row);
395 assert_ne!(is_first_row[0], F::ZERO);
396 assert_eq!(&is_first_row[1..], &vec![F::ZERO; n - 1]);
397
398 let is_last_row = coset_to_d(&sels.is_last_row);
400 assert_eq!(&is_last_row[..n - 1], &vec![F::ZERO; n - 1]);
401 assert_ne!(is_last_row[n - 1], F::ZERO);
402
403 let is_transition = coset_to_d(&sels.is_transition);
405 assert_ne!(&is_transition[..n - 1], &vec![F::ZERO; n - 1]);
406 assert_eq!(is_transition[n - 1], F::ZERO);
407
408 let z_coeffs = CircleEvaluations::from_natural_order(
410 coset,
411 RowMajorMatrix::new_col(batch_multiplicative_inverse(&sels.inv_vanishing)),
412 )
413 .interpolate()
414 .to_row_major_matrix()
415 .values;
416 assert_eq!(
417 z_coeffs,
418 iter::empty()
419 .chain(iter::repeat_n(F::ZERO, n))
420 .chain(iter::once(F::ONE))
421 .chain(iter::repeat_n(F::ZERO, n - 1))
422 .collect_vec()
423 );
424 }
425
426 #[test]
427 fn test_circle_domain() {
428 do_test_circle_domain(4, 8);
429 do_test_circle_domain(10, 32);
430 }
431}