p3_circle/
domain.rs

1use alloc::vec;
2use alloc::vec::Vec;
3
4use itertools::{iterate, Itertools};
5use p3_commit::{LagrangeSelectors, PolynomialSpace};
6use p3_field::extension::ComplexExtendable;
7use p3_field::ExtensionField;
8use p3_matrix::dense::RowMajorMatrix;
9use p3_matrix::Matrix;
10use p3_util::{log2_ceil_usize, log2_strict_usize};
11use tracing::instrument;
12
13use crate::point::Point;
14
15/// A twin-coset of the circle group on F. It has a power-of-two size and an arbitrary shift.
16///
17/// X is generator, O is the first coset, goes counterclockwise
18/// ```text
19///   O X .
20///  .     .
21/// .       O <- start = shift
22/// .   .   - (1,0)
23/// O       .
24///  .     .
25///   . . O
26/// ```
27///
28/// For ordering reasons, the other half will start at gen / shift:
29/// ```text
30///   . X O  <- start = gen/shift
31///  .     .
32/// O       .
33/// .   .   - (1,0)
34/// .       O
35///  .     .
36///   O . .
37/// ```
38///
39/// The full domain is the interleaving of these two cosets
40#[derive(Copy, Clone, PartialEq, Eq, Debug)]
41pub struct CircleDomain<F> {
42    // log_n corresponds to the log size of the WHOLE domain
43    pub(crate) log_n: usize,
44    pub(crate) shift: Point<F>,
45}
46
47impl<F: ComplexExtendable> CircleDomain<F> {
48    pub const fn new(log_n: usize, shift: Point<F>) -> Self {
49        Self { log_n, shift }
50    }
51    pub fn standard(log_n: usize) -> Self {
52        Self {
53            log_n,
54            shift: Point::generator(log_n + 1),
55        }
56    }
57    fn is_standard(&self) -> bool {
58        self.shift == Point::generator(self.log_n + 1)
59    }
60    pub(crate) fn gen(&self) -> Point<F> {
61        Point::generator(self.log_n - 1)
62    }
63    pub(crate) fn coset0(&self) -> impl Iterator<Item = Point<F>> {
64        let g = self.gen();
65        iterate(self.shift, move |&p| p + g).take(1 << (self.log_n - 1))
66    }
67    fn coset1(&self) -> impl Iterator<Item = Point<F>> {
68        let g = self.gen();
69        iterate(g - self.shift, move |&p| p + g).take(1 << (self.log_n - 1))
70    }
71    pub(crate) fn points(&self) -> impl Iterator<Item = Point<F>> {
72        self.coset0().interleave(self.coset1())
73    }
74    pub(crate) fn nth_point(&self, idx: usize) -> Point<F> {
75        let (idx, lsb) = (idx >> 1, idx & 1);
76        if lsb == 0 {
77            self.shift + self.gen() * idx
78        } else {
79            -self.shift + self.gen() * (idx + 1)
80        }
81    }
82
83    pub(crate) fn zeroifier<EF: ExtensionField<F>>(&self, at: Point<EF>) -> EF {
84        at.v_n(self.log_n) - self.shift.v_n(self.log_n)
85    }
86
87    pub(crate) fn s_p<EF: ExtensionField<F>>(&self, p: Point<F>, at: Point<EF>) -> EF {
88        self.zeroifier(at) / p.v_tilde_p(at)
89    }
90
91    pub(crate) fn s_p_normalized<EF: ExtensionField<F>>(&self, p: Point<F>, at: Point<EF>) -> EF {
92        self.zeroifier(at) / (p.v_tilde_p(at) * p.s_p_at_p(self.log_n))
93    }
94}
95
96impl<F: ComplexExtendable> PolynomialSpace for CircleDomain<F> {
97    type Val = F;
98
99    fn size(&self) -> usize {
100        1 << self.log_n
101    }
102
103    fn first_point(&self) -> Self::Val {
104        self.shift.to_projective_line().unwrap()
105    }
106
107    fn next_point<Ext: ExtensionField<Self::Val>>(&self, x: Ext) -> Option<Ext> {
108        // Only in standard position do we have an algebraic expression to access the next point.
109        if self.is_standard() {
110            Some(
111                (Point::from_projective_line(x) + Point::generator(self.log_n))
112                    .to_projective_line()
113                    .unwrap(),
114            )
115        } else {
116            None
117        }
118    }
119
120    fn create_disjoint_domain(&self, min_size: usize) -> Self {
121        // Right now we simply guarantee the domain is disjoint by returning a
122        // larger standard position coset, which is fine because we always ask for a larger
123        // domain. If we wanted good performance for a disjoint domain of the same size,
124        // we could change the shift. Also we could support nonstandard twin cosets.
125        assert!(
126            self.is_standard(),
127            "create_disjoint_domain not currently supported for nonstandard twin cosets"
128        );
129        let log_n = log2_ceil_usize(min_size);
130        // Any standard position coset that is not the same size as us will be disjoint.
131        Self::standard(if log_n == self.log_n {
132            log_n + 1
133        } else {
134            log_n
135        })
136    }
137
138    fn zp_at_point<Ext: ExtensionField<Self::Val>>(&self, point: Ext) -> Ext {
139        self.zeroifier(Point::from_projective_line(point))
140    }
141
142    fn selectors_at_point<Ext: ExtensionField<Self::Val>>(
143        &self,
144        point: Ext,
145    ) -> LagrangeSelectors<Ext> {
146        let point = Point::from_projective_line(point);
147        LagrangeSelectors {
148            is_first_row: self.s_p(self.shift, point),
149            is_last_row: self.s_p(-self.shift, point),
150            is_transition: Ext::one() - self.s_p_normalized(-self.shift, point),
151            inv_zeroifier: self.zeroifier(point).inverse(),
152        }
153    }
154
155    // wow, really slow!
156    // todo: batch inverses
157    #[instrument(skip_all, fields(log_n = %coset.log_n))]
158    fn selectors_on_coset(&self, coset: Self) -> LagrangeSelectors<Vec<Self::Val>> {
159        let sels = coset
160            .points()
161            .map(|p| self.selectors_at_point(p.to_projective_line().unwrap()))
162            .collect_vec();
163        LagrangeSelectors {
164            is_first_row: sels.iter().map(|s| s.is_first_row).collect(),
165            is_last_row: sels.iter().map(|s| s.is_last_row).collect(),
166            is_transition: sels.iter().map(|s| s.is_transition).collect(),
167            inv_zeroifier: sels.iter().map(|s| s.inv_zeroifier).collect(),
168        }
169    }
170
171    /// Decompose a domain into disjoint twin-cosets.
172    fn split_domains(&self, num_chunks: usize) -> Vec<Self> {
173        assert!(self.is_standard());
174        let log_chunks = log2_strict_usize(num_chunks);
175        self.points()
176            .take(num_chunks)
177            .map(|shift| CircleDomain {
178                log_n: self.log_n - log_chunks,
179                shift,
180            })
181            .collect()
182    }
183
184    /*
185    chunks=2:
186
187          1 . 1
188         .     .
189        0       0 <-- start
190        .   .   - (1,0)
191        0       0
192         .     .
193          1 . 1
194
195
196    idx -> which chunk to put it in:
197    chunks=2: 0 1 1 0 0 1 1 0 0 1 1 0 0 1 1 0
198    chunks=4: 0 1 2 3 3 2 1 0 0 1 2 3 3 2 1 0
199    */
200    fn split_evals(
201        &self,
202        num_chunks: usize,
203        evals: RowMajorMatrix<Self::Val>,
204    ) -> Vec<RowMajorMatrix<Self::Val>> {
205        let log_chunks = log2_strict_usize(num_chunks);
206        assert!(evals.height() >> (log_chunks + 1) >= 1);
207        let width = evals.width();
208        let mut values: Vec<Vec<Self::Val>> = vec![vec![]; num_chunks];
209        evals
210            .rows()
211            .enumerate()
212            .for_each(|(i, row)| values[forward_backward_index(i, num_chunks)].extend(row));
213        values
214            .into_iter()
215            .map(|v| RowMajorMatrix::new(v, width))
216            .collect()
217    }
218}
219
220// 0 1 2 .. len-1 len len len-1 .. 1 0 0 1 ..
221fn forward_backward_index(mut i: usize, len: usize) -> usize {
222    i %= 2 * len;
223    if i < len {
224        i
225    } else {
226        2 * len - 1 - i
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use core::iter;
233
234    use hashbrown::HashSet;
235    use itertools::izip;
236    use p3_field::{batch_multiplicative_inverse, AbstractField};
237    use p3_mersenne_31::Mersenne31;
238    use rand::thread_rng;
239
240    use super::*;
241    use crate::CircleEvaluations;
242
243    fn assert_is_twin_coset<F: ComplexExtendable>(d: CircleDomain<F>) {
244        let pts = d.points().collect_vec();
245        let half_n = pts.len() >> 1;
246        for (&l, &r) in izip!(&pts[..half_n], pts[half_n..].iter().rev()) {
247            assert_eq!(l, -r);
248        }
249    }
250
251    fn do_test_circle_domain(log_n: usize, width: usize) {
252        let n = 1 << log_n;
253
254        type F = Mersenne31;
255        let d = CircleDomain::<F>::standard(log_n);
256
257        // we can move around the circle and end up where we started
258        let p0 = d.first_point();
259        let mut p1 = p0;
260        for i in 0..(n - 1) {
261            // nth_point is correct
262            assert_eq!(Point::from_projective_line(p1), d.nth_point(i));
263            p1 = d.next_point(p1).unwrap();
264            assert_ne!(p1, p0);
265        }
266        assert_eq!(d.next_point(p1).unwrap(), p0);
267
268        // .points() is the same as first_point -> next_point
269        let mut uni_point = d.first_point();
270        for p in d.points() {
271            assert_eq!(Point::from_projective_line(uni_point), p);
272            uni_point = d.next_point(uni_point).unwrap();
273        }
274
275        // disjoint domain is actually disjoint, and large enough
276        let seen: HashSet<Point<F>> = d.points().collect();
277        for disjoint_size in [10, 100, n - 5, n + 15] {
278            let dd = d.create_disjoint_domain(disjoint_size);
279            assert!(dd.size() >= disjoint_size);
280            for pt in dd.points() {
281                assert!(!seen.contains(&pt));
282            }
283        }
284
285        // zp is zero
286        for p in d.points() {
287            assert_eq!(d.zp_at_point(p.to_projective_line().unwrap()), F::zero());
288        }
289
290        // split domains
291        let evals = RowMajorMatrix::rand(&mut thread_rng(), n, width);
292        let orig: Vec<(Point<F>, Vec<F>)> = d
293            .points()
294            .zip(evals.rows().map(|r| r.collect_vec()))
295            .collect();
296        for num_chunks in [1, 2, 4, 8] {
297            let mut combined = vec![];
298
299            let sds = d.split_domains(num_chunks);
300            assert_eq!(sds.len(), num_chunks);
301            let ses = d.split_evals(num_chunks, evals.clone());
302            assert_eq!(ses.len(), num_chunks);
303            for (sd, se) in izip!(sds, ses) {
304                // Split domains are twin cosets
305                assert_is_twin_coset(sd);
306                // Split domains have correct size wrt original domain
307                assert_eq!(sd.size() * num_chunks, d.size());
308                assert_eq!(se.width(), evals.width());
309                assert_eq!(se.height() * num_chunks, d.size());
310                combined.extend(sd.points().zip(se.rows().map(|r| r.collect_vec())));
311            }
312            // Union of split domains and evals is the original domain and evals
313            assert_eq!(
314                orig.iter().map(|x| x.0).collect::<HashSet<_>>(),
315                combined.iter().map(|x| x.0).collect::<HashSet<_>>(),
316                "union of split domains is orig domain"
317            );
318            assert_eq!(
319                orig.iter().map(|x| &x.1).collect::<HashSet<_>>(),
320                combined.iter().map(|x| &x.1).collect::<HashSet<_>>(),
321                "union of split evals is orig evals"
322            );
323            assert_eq!(
324                orig.iter().collect::<HashSet<_>>(),
325                combined.iter().collect::<HashSet<_>>(),
326                "split domains and evals correspond to orig domains and evals"
327            );
328        }
329    }
330
331    #[test]
332    fn selectors() {
333        type F = Mersenne31;
334        let log_n = 8;
335        let n = 1 << log_n;
336
337        let d = CircleDomain::<F>::standard(log_n);
338        let coset = d.create_disjoint_domain(n);
339        let sels = d.selectors_on_coset(coset);
340
341        // selectors_on_coset matches selectors_at_point
342        let mut pt = coset.first_point();
343        for i in 0..coset.size() {
344            let pt_sels = d.selectors_at_point(pt);
345            assert_eq!(sels.is_first_row[i], pt_sels.is_first_row);
346            assert_eq!(sels.is_last_row[i], pt_sels.is_last_row);
347            assert_eq!(sels.is_transition[i], pt_sels.is_transition);
348            assert_eq!(sels.inv_zeroifier[i], pt_sels.inv_zeroifier);
349            pt = coset.next_point(pt).unwrap();
350        }
351
352        let coset_to_d = |evals: &[F]| {
353            let evals = CircleEvaluations::from_natural_order(
354                coset,
355                RowMajorMatrix::new_col(evals.to_vec()),
356            );
357            let coeffs = evals.interpolate().to_row_major_matrix();
358            let (lo, hi) = coeffs.split_rows(n);
359            assert_eq!(hi.values, vec![F::zero(); n]);
360            CircleEvaluations::evaluate(d, lo.to_row_major_matrix())
361                .to_natural_order()
362                .to_row_major_matrix()
363                .values
364        };
365
366        // Nonzero at first point, zero everywhere else on domain
367        let is_first_row = coset_to_d(&sels.is_first_row);
368        assert_ne!(is_first_row[0], F::zero());
369        assert_eq!(&is_first_row[1..], &vec![F::zero(); n - 1]);
370
371        // Nonzero at last point, zero everywhere else on domain
372        let is_last_row = coset_to_d(&sels.is_last_row);
373        assert_eq!(&is_last_row[..n - 1], &vec![F::zero(); n - 1]);
374        assert_ne!(is_last_row[n - 1], F::zero());
375
376        // Nonzero everywhere on domain but last point
377        let is_transition = coset_to_d(&sels.is_transition);
378        assert_ne!(&is_transition[..n - 1], &vec![F::zero(); n - 1]);
379        assert_eq!(is_transition[n - 1], F::zero());
380
381        // Zeroifier coefficients look like [0.. (n times), 1, 0.. (n-1 times)]
382        let z_coeffs = CircleEvaluations::from_natural_order(
383            coset,
384            RowMajorMatrix::new_col(batch_multiplicative_inverse(&sels.inv_zeroifier)),
385        )
386        .interpolate()
387        .to_row_major_matrix()
388        .values;
389        assert_eq!(
390            z_coeffs,
391            iter::empty()
392                .chain(iter::repeat(F::zero()).take(n))
393                .chain(iter::once(F::one()))
394                .chain(iter::repeat(F::zero()).take(n - 1))
395                .collect_vec()
396        );
397    }
398
399    #[test]
400    fn test_circle_domain() {
401        do_test_circle_domain(4, 8);
402        do_test_circle_domain(10, 32);
403    }
404}