p3_circle/
domain.rs

1use alloc::vec;
2use alloc::vec::Vec;
3
4use itertools::{Itertools, iterate};
5use p3_commit::{LagrangeSelectors, PolynomialSpace};
6use p3_field::extension::ComplexExtendable;
7use p3_field::{ExtensionField, batch_multiplicative_inverse};
8use p3_matrix::Matrix;
9use p3_matrix::dense::RowMajorMatrix;
10use p3_util::{log2_ceil_usize, log2_strict_usize};
11use tracing::instrument;
12
13use crate::point::Point;
14
15/// A twin-coset of the circle group on F. It has a power-of-two size and an arbitrary shift.
16///
17/// X is generator, O is the first coset, goes counterclockwise
18/// ```text
19///   O X .
20///  .     .
21/// .       O <- start = shift
22/// .   .   - (1,0)
23/// O       .
24///  .     .
25///   . . O
26/// ```
27///
28/// For ordering reasons, the other half will start at gen / shift:
29/// ```text
30///   . X O  <- start = gen/shift
31///  .     .
32/// O       .
33/// .   .   - (1,0)
34/// .       O
35///  .     .
36///   O . .
37/// ```
38///
39/// The full domain is the interleaving of these two cosets
40#[derive(Copy, Clone, PartialEq, Eq, Debug)]
41pub struct CircleDomain<F> {
42    // log_n corresponds to the log size of the WHOLE domain
43    pub(crate) log_n: usize,
44    pub(crate) shift: Point<F>,
45}
46
47impl<F: ComplexExtendable> CircleDomain<F> {
48    pub const fn new(log_n: usize, shift: Point<F>) -> Self {
49        Self { log_n, shift }
50    }
51    pub fn standard(log_n: usize) -> Self {
52        Self {
53            log_n,
54            shift: Point::generator(log_n + 1),
55        }
56    }
57    fn is_standard(&self) -> bool {
58        self.shift == Point::generator(self.log_n + 1)
59    }
60    pub(crate) fn subgroup_generator(&self) -> Point<F> {
61        Point::generator(self.log_n - 1)
62    }
63    pub(crate) fn coset0(&self) -> impl Iterator<Item = Point<F>> {
64        let g = self.subgroup_generator();
65        iterate(self.shift, move |&p| p + g).take(1 << (self.log_n - 1))
66    }
67    fn coset1(&self) -> impl Iterator<Item = Point<F>> {
68        let g = self.subgroup_generator();
69        iterate(g - self.shift, move |&p| p + g).take(1 << (self.log_n - 1))
70    }
71    pub(crate) fn points(&self) -> impl Iterator<Item = Point<F>> {
72        self.coset0().interleave(self.coset1())
73    }
74    pub(crate) fn nth_point(&self, idx: usize) -> Point<F> {
75        let (idx, lsb) = (idx >> 1, idx & 1);
76        if lsb == 0 {
77            self.shift + self.subgroup_generator() * idx
78        } else {
79            -self.shift + self.subgroup_generator() * (idx + 1)
80        }
81    }
82
83    pub(crate) fn vanishing_poly<EF: ExtensionField<F>>(&self, at: Point<EF>) -> EF {
84        at.v_n(self.log_n) - self.shift.v_n(self.log_n)
85    }
86
87    pub(crate) fn s_p<EF: ExtensionField<F>>(&self, p: Point<F>, at: Point<EF>) -> EF {
88        self.vanishing_poly(at) / p.v_tilde_p(at)
89    }
90
91    pub(crate) fn s_p_normalized<EF: ExtensionField<F>>(&self, p: Point<F>, at: Point<EF>) -> EF {
92        self.vanishing_poly(at) / (p.v_tilde_p(at) * p.s_p_at_p(self.log_n))
93    }
94}
95
96impl<F: ComplexExtendable> PolynomialSpace for CircleDomain<F> {
97    type Val = F;
98
99    fn size(&self) -> usize {
100        1 << self.log_n
101    }
102
103    fn first_point(&self) -> Self::Val {
104        self.shift.to_projective_line().unwrap()
105    }
106
107    fn next_point<Ext: ExtensionField<Self::Val>>(&self, x: Ext) -> Option<Ext> {
108        // Only in standard position do we have an algebraic expression to access the next point.
109        if self.is_standard() {
110            (Point::from_projective_line(x) + Point::generator(self.log_n)).to_projective_line()
111        } else {
112            None
113        }
114    }
115
116    fn create_disjoint_domain(&self, min_size: usize) -> Self {
117        // Right now we simply guarantee the domain is disjoint by returning a
118        // larger standard position coset, which is fine because we always ask for a larger
119        // domain. If we wanted good performance for a disjoint domain of the same size,
120        // we could change the shift. Also we could support nonstandard twin cosets.
121        assert!(
122            self.is_standard(),
123            "create_disjoint_domain not currently supported for nonstandard twin cosets"
124        );
125        let log_n = log2_ceil_usize(min_size);
126        // Any standard position coset that is not the same size as us will be disjoint.
127        Self::standard(if log_n == self.log_n {
128            log_n + 1
129        } else {
130            log_n
131        })
132    }
133
134    /// Decompose a domain into disjoint twin-cosets.
135    fn split_domains(&self, num_chunks: usize) -> Vec<Self> {
136        assert!(self.is_standard());
137        let log_chunks = log2_strict_usize(num_chunks);
138        assert!(log_chunks <= self.log_n);
139        self.points()
140            .take(num_chunks)
141            .map(|shift| Self {
142                log_n: self.log_n - log_chunks,
143                shift,
144            })
145            .collect()
146    }
147
148    fn split_evals(
149        &self,
150        num_chunks: usize,
151        evals: RowMajorMatrix<Self::Val>,
152    ) -> Vec<RowMajorMatrix<Self::Val>> {
153        let log_chunks = log2_strict_usize(num_chunks);
154        assert!(evals.height() >> (log_chunks + 1) >= 1);
155        let width = evals.width();
156        let mut values: Vec<Vec<Self::Val>> = vec![vec![]; num_chunks];
157        evals
158            .rows()
159            .enumerate()
160            .for_each(|(i, row)| values[forward_backward_index(i, num_chunks)].extend(row));
161        values
162            .into_iter()
163            .map(|v| RowMajorMatrix::new(v, width))
164            .collect()
165    }
166
167    fn vanishing_poly_at_point<Ext: ExtensionField<Self::Val>>(&self, point: Ext) -> Ext {
168        self.vanishing_poly(Point::from_projective_line(point))
169    }
170
171    fn selectors_at_point<Ext: ExtensionField<Self::Val>>(
172        &self,
173        point: Ext,
174    ) -> LagrangeSelectors<Ext> {
175        let point = Point::from_projective_line(point);
176        LagrangeSelectors {
177            is_first_row: self.s_p(self.shift, point),
178            is_last_row: self.s_p(-self.shift, point),
179            is_transition: Ext::ONE - self.s_p_normalized(-self.shift, point),
180            inv_vanishing: self.vanishing_poly(point).inverse(),
181        }
182    }
183
184    /*
185    chunks=2:
186
187          1 . 1
188         .     .
189        0       0 <-- start
190        .   .   - (1,0)
191        0       0
192         .     .
193          1 . 1
194
195
196    idx -> which chunk to put it in:
197    chunks=2: 0 1 1 0 0 1 1 0 0 1 1 0 0 1 1 0
198    chunks=4: 0 1 2 3 3 2 1 0 0 1 2 3 3 2 1 0
199    */
200    #[instrument(skip_all, fields(log_n = %coset.log_n))]
201    fn selectors_on_coset(&self, coset: Self) -> LagrangeSelectors<Vec<Self::Val>> {
202        let pts = coset.points().collect_vec();
203
204        // Precompute constants and per-point numerators/denominators
205        let neg_shift = -self.shift;
206        let k = neg_shift.s_p_at_p(self.log_n);
207
208        let z_vals: Vec<Self::Val> = pts.iter().map(|&at| self.vanishing_poly(at)).collect();
209        let den_shift: Vec<Self::Val> = pts.iter().map(|&at| self.shift.v_tilde_p(at)).collect();
210        let den_negshift_k: Vec<Self::Val> =
211            pts.iter().map(|&at| neg_shift.v_tilde_p(at) * k).collect();
212
213        // Batch inverses
214        let inv_vanishing = batch_multiplicative_inverse(&z_vals);
215        let inv_den_shift = batch_multiplicative_inverse(&den_shift);
216        let inv_den_negshift_k = batch_multiplicative_inverse(&den_negshift_k);
217
218        // Build selectors
219        // TODO: If we need to make this faster we could look into using packed fields.
220        let is_first_row = z_vals
221            .iter()
222            .zip(inv_den_shift.iter())
223            .map(|(&z, &inv_d)| z * inv_d)
224            .collect();
225        let is_last_row = z_vals
226            .iter()
227            .zip(inv_den_negshift_k.iter())
228            .map(|(&z, &inv_dk)| z * inv_dk * k)
229            .collect();
230        let is_transition = z_vals
231            .iter()
232            .zip(inv_den_negshift_k.iter())
233            .map(|(&z, &inv_dk)| Self::Val::ONE - z * inv_dk)
234            .collect();
235
236        LagrangeSelectors {
237            is_first_row,
238            is_last_row,
239            is_transition,
240            inv_vanishing,
241        }
242    }
243}
244
245// 0 1 2 .. len-1 len len len-1 .. 1 0 0 1 ..
246const fn forward_backward_index(mut i: usize, len: usize) -> usize {
247    i %= 2 * len;
248    if i < len { i } else { 2 * len - 1 - i }
249}
250
251#[cfg(test)]
252mod tests {
253    use core::iter;
254
255    use hashbrown::HashSet;
256    use itertools::izip;
257    use p3_field::{PrimeCharacteristicRing, batch_multiplicative_inverse};
258    use p3_mersenne_31::Mersenne31;
259    use rand::SeedableRng;
260    use rand::rngs::SmallRng;
261
262    use super::*;
263    use crate::CircleEvaluations;
264
265    fn assert_is_twin_coset<F: ComplexExtendable>(d: CircleDomain<F>) {
266        let pts = d.points().collect_vec();
267        let half_n = pts.len() >> 1;
268        for (&l, &r) in izip!(&pts[..half_n], pts[half_n..].iter().rev()) {
269            assert_eq!(l, -r);
270        }
271    }
272
273    fn do_test_circle_domain(log_n: usize, width: usize) {
274        let n = 1 << log_n;
275
276        type F = Mersenne31;
277        let d = CircleDomain::<F>::standard(log_n);
278
279        // we can move around the circle and end up where we started
280        let p0 = d.first_point();
281        let mut p1 = p0;
282        for i in 0..(n - 1) {
283            // nth_point is correct
284            assert_eq!(Point::from_projective_line(p1), d.nth_point(i));
285            p1 = d.next_point(p1).unwrap();
286            assert_ne!(p1, p0);
287        }
288        assert_eq!(d.next_point(p1).unwrap(), p0);
289
290        // .points() is the same as first_point -> next_point
291        let mut uni_point = d.first_point();
292        for p in d.points() {
293            assert_eq!(Point::from_projective_line(uni_point), p);
294            uni_point = d.next_point(uni_point).unwrap();
295        }
296
297        // disjoint domain is actually disjoint, and large enough
298        let seen: HashSet<Point<F>> = d.points().collect();
299        for disjoint_size in [10, 100, n - 5, n + 15] {
300            let dd = d.create_disjoint_domain(disjoint_size);
301            assert!(dd.size() >= disjoint_size);
302            for pt in dd.points() {
303                assert!(!seen.contains(&pt));
304            }
305        }
306
307        // zp is zero
308        for p in d.points() {
309            assert_eq!(
310                d.vanishing_poly_at_point(p.to_projective_line().unwrap()),
311                F::ZERO
312            );
313        }
314
315        let mut rng = SmallRng::seed_from_u64(1);
316
317        // split domains
318        let evals = RowMajorMatrix::rand(&mut rng, n, width);
319        let orig: Vec<(Point<F>, Vec<F>)> = d
320            .points()
321            .zip(evals.rows().map(|r| r.collect_vec()))
322            .collect();
323        for num_chunks in [1, 2, 4, 8] {
324            let mut combined = vec![];
325
326            let sds = d.split_domains(num_chunks);
327            assert_eq!(sds.len(), num_chunks);
328            let ses = d.split_evals(num_chunks, evals.clone());
329            assert_eq!(ses.len(), num_chunks);
330            for (sd, se) in izip!(sds, ses) {
331                // Split domains are twin cosets
332                assert_is_twin_coset(sd);
333                // Split domains have correct size wrt original domain
334                assert_eq!(sd.size() * num_chunks, d.size());
335                assert_eq!(se.width(), evals.width());
336                assert_eq!(se.height() * num_chunks, d.size());
337                combined.extend(sd.points().zip(se.rows().map(|r| r.collect_vec())));
338            }
339            // Union of split domains and evals is the original domain and evals
340            assert_eq!(
341                orig.iter().map(|x| x.0).collect::<HashSet<_>>(),
342                combined.iter().map(|x| x.0).collect::<HashSet<_>>(),
343                "union of split domains is orig domain"
344            );
345            assert_eq!(
346                orig.iter().map(|x| &x.1).collect::<HashSet<_>>(),
347                combined.iter().map(|x| &x.1).collect::<HashSet<_>>(),
348                "union of split evals is orig evals"
349            );
350            assert_eq!(
351                orig.iter().collect::<HashSet<_>>(),
352                combined.iter().collect::<HashSet<_>>(),
353                "split domains and evals correspond to orig domains and evals"
354            );
355        }
356    }
357
358    #[test]
359    fn selectors() {
360        type F = Mersenne31;
361        let log_n = 8;
362        let n = 1 << log_n;
363
364        let d = CircleDomain::<F>::standard(log_n);
365        let coset = d.create_disjoint_domain(n);
366        let sels = d.selectors_on_coset(coset);
367
368        // selectors_on_coset matches selectors_at_point
369        let mut pt = coset.first_point();
370        for i in 0..coset.size() {
371            let pt_sels = d.selectors_at_point(pt);
372            assert_eq!(sels.is_first_row[i], pt_sels.is_first_row);
373            assert_eq!(sels.is_last_row[i], pt_sels.is_last_row);
374            assert_eq!(sels.is_transition[i], pt_sels.is_transition);
375            assert_eq!(sels.inv_vanishing[i], pt_sels.inv_vanishing);
376            pt = coset.next_point(pt).unwrap();
377        }
378
379        let coset_to_d = |evals: &[F]| {
380            let evals = CircleEvaluations::from_natural_order(
381                coset,
382                RowMajorMatrix::new_col(evals.to_vec()),
383            );
384            let coeffs = evals.interpolate().to_row_major_matrix();
385            let (lo, hi) = coeffs.split_rows(n);
386            assert_eq!(hi.values, vec![F::ZERO; n]);
387            CircleEvaluations::evaluate(d, lo.to_row_major_matrix())
388                .to_natural_order()
389                .to_row_major_matrix()
390                .values
391        };
392
393        // Nonzero at first point, zero everywhere else on domain
394        let is_first_row = coset_to_d(&sels.is_first_row);
395        assert_ne!(is_first_row[0], F::ZERO);
396        assert_eq!(&is_first_row[1..], &vec![F::ZERO; n - 1]);
397
398        // Nonzero at last point, zero everywhere else on domain
399        let is_last_row = coset_to_d(&sels.is_last_row);
400        assert_eq!(&is_last_row[..n - 1], &vec![F::ZERO; n - 1]);
401        assert_ne!(is_last_row[n - 1], F::ZERO);
402
403        // Nonzero everywhere on domain but last point
404        let is_transition = coset_to_d(&sels.is_transition);
405        assert_ne!(&is_transition[..n - 1], &vec![F::ZERO; n - 1]);
406        assert_eq!(is_transition[n - 1], F::ZERO);
407
408        // Vanishing polynomial coefficients look like [0.. (n times), 1, 0.. (n-1 times)]
409        let z_coeffs = CircleEvaluations::from_natural_order(
410            coset,
411            RowMajorMatrix::new_col(batch_multiplicative_inverse(&sels.inv_vanishing)),
412        )
413        .interpolate()
414        .to_row_major_matrix()
415        .values;
416        assert_eq!(
417            z_coeffs,
418            iter::empty()
419                .chain(iter::repeat_n(F::ZERO, n))
420                .chain(iter::once(F::ONE))
421                .chain(iter::repeat_n(F::ZERO, n - 1))
422                .collect_vec()
423        );
424    }
425
426    #[test]
427    fn test_circle_domain() {
428        do_test_circle_domain(4, 8);
429        do_test_circle_domain(10, 32);
430    }
431}