1use alloc::vec;
2use alloc::vec::Vec;
3
4use itertools::{iterate, Itertools};
5use p3_commit::{LagrangeSelectors, PolynomialSpace};
6use p3_field::extension::ComplexExtendable;
7use p3_field::ExtensionField;
8use p3_matrix::dense::RowMajorMatrix;
9use p3_matrix::Matrix;
10use p3_util::{log2_ceil_usize, log2_strict_usize};
11use tracing::instrument;
12
13use crate::point::Point;
14
15#[derive(Copy, Clone, PartialEq, Eq, Debug)]
41pub struct CircleDomain<F> {
42 pub(crate) log_n: usize,
44 pub(crate) shift: Point<F>,
45}
46
47impl<F: ComplexExtendable> CircleDomain<F> {
48 pub const fn new(log_n: usize, shift: Point<F>) -> Self {
49 Self { log_n, shift }
50 }
51 pub fn standard(log_n: usize) -> Self {
52 Self {
53 log_n,
54 shift: Point::generator(log_n + 1),
55 }
56 }
57 fn is_standard(&self) -> bool {
58 self.shift == Point::generator(self.log_n + 1)
59 }
60 pub(crate) fn gen(&self) -> Point<F> {
61 Point::generator(self.log_n - 1)
62 }
63 pub(crate) fn coset0(&self) -> impl Iterator<Item = Point<F>> {
64 let g = self.gen();
65 iterate(self.shift, move |&p| p + g).take(1 << (self.log_n - 1))
66 }
67 fn coset1(&self) -> impl Iterator<Item = Point<F>> {
68 let g = self.gen();
69 iterate(g - self.shift, move |&p| p + g).take(1 << (self.log_n - 1))
70 }
71 pub(crate) fn points(&self) -> impl Iterator<Item = Point<F>> {
72 self.coset0().interleave(self.coset1())
73 }
74 pub(crate) fn nth_point(&self, idx: usize) -> Point<F> {
75 let (idx, lsb) = (idx >> 1, idx & 1);
76 if lsb == 0 {
77 self.shift + self.gen() * idx
78 } else {
79 -self.shift + self.gen() * (idx + 1)
80 }
81 }
82
83 pub(crate) fn zeroifier<EF: ExtensionField<F>>(&self, at: Point<EF>) -> EF {
84 at.v_n(self.log_n) - self.shift.v_n(self.log_n)
85 }
86
87 pub(crate) fn s_p<EF: ExtensionField<F>>(&self, p: Point<F>, at: Point<EF>) -> EF {
88 self.zeroifier(at) / p.v_tilde_p(at)
89 }
90
91 pub(crate) fn s_p_normalized<EF: ExtensionField<F>>(&self, p: Point<F>, at: Point<EF>) -> EF {
92 self.zeroifier(at) / (p.v_tilde_p(at) * p.s_p_at_p(self.log_n))
93 }
94}
95
96impl<F: ComplexExtendable> PolynomialSpace for CircleDomain<F> {
97 type Val = F;
98
99 fn size(&self) -> usize {
100 1 << self.log_n
101 }
102
103 fn first_point(&self) -> Self::Val {
104 self.shift.to_projective_line().unwrap()
105 }
106
107 fn next_point<Ext: ExtensionField<Self::Val>>(&self, x: Ext) -> Option<Ext> {
108 if self.is_standard() {
110 Some(
111 (Point::from_projective_line(x) + Point::generator(self.log_n))
112 .to_projective_line()
113 .unwrap(),
114 )
115 } else {
116 None
117 }
118 }
119
120 fn create_disjoint_domain(&self, min_size: usize) -> Self {
121 assert!(
126 self.is_standard(),
127 "create_disjoint_domain not currently supported for nonstandard twin cosets"
128 );
129 let log_n = log2_ceil_usize(min_size);
130 Self::standard(if log_n == self.log_n {
132 log_n + 1
133 } else {
134 log_n
135 })
136 }
137
138 fn zp_at_point<Ext: ExtensionField<Self::Val>>(&self, point: Ext) -> Ext {
139 self.zeroifier(Point::from_projective_line(point))
140 }
141
142 fn selectors_at_point<Ext: ExtensionField<Self::Val>>(
143 &self,
144 point: Ext,
145 ) -> LagrangeSelectors<Ext> {
146 let point = Point::from_projective_line(point);
147 LagrangeSelectors {
148 is_first_row: self.s_p(self.shift, point),
149 is_last_row: self.s_p(-self.shift, point),
150 is_transition: Ext::one() - self.s_p_normalized(-self.shift, point),
151 inv_zeroifier: self.zeroifier(point).inverse(),
152 }
153 }
154
155 #[instrument(skip_all, fields(log_n = %coset.log_n))]
158 fn selectors_on_coset(&self, coset: Self) -> LagrangeSelectors<Vec<Self::Val>> {
159 let sels = coset
160 .points()
161 .map(|p| self.selectors_at_point(p.to_projective_line().unwrap()))
162 .collect_vec();
163 LagrangeSelectors {
164 is_first_row: sels.iter().map(|s| s.is_first_row).collect(),
165 is_last_row: sels.iter().map(|s| s.is_last_row).collect(),
166 is_transition: sels.iter().map(|s| s.is_transition).collect(),
167 inv_zeroifier: sels.iter().map(|s| s.inv_zeroifier).collect(),
168 }
169 }
170
171 fn split_domains(&self, num_chunks: usize) -> Vec<Self> {
173 assert!(self.is_standard());
174 let log_chunks = log2_strict_usize(num_chunks);
175 self.points()
176 .take(num_chunks)
177 .map(|shift| CircleDomain {
178 log_n: self.log_n - log_chunks,
179 shift,
180 })
181 .collect()
182 }
183
184 fn split_evals(
201 &self,
202 num_chunks: usize,
203 evals: RowMajorMatrix<Self::Val>,
204 ) -> Vec<RowMajorMatrix<Self::Val>> {
205 let log_chunks = log2_strict_usize(num_chunks);
206 assert!(evals.height() >> (log_chunks + 1) >= 1);
207 let width = evals.width();
208 let mut values: Vec<Vec<Self::Val>> = vec![vec![]; num_chunks];
209 evals
210 .rows()
211 .enumerate()
212 .for_each(|(i, row)| values[forward_backward_index(i, num_chunks)].extend(row));
213 values
214 .into_iter()
215 .map(|v| RowMajorMatrix::new(v, width))
216 .collect()
217 }
218}
219
220fn forward_backward_index(mut i: usize, len: usize) -> usize {
222 i %= 2 * len;
223 if i < len {
224 i
225 } else {
226 2 * len - 1 - i
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use core::iter;
233
234 use hashbrown::HashSet;
235 use itertools::izip;
236 use p3_field::{batch_multiplicative_inverse, AbstractField};
237 use p3_mersenne_31::Mersenne31;
238 use rand::thread_rng;
239
240 use super::*;
241 use crate::CircleEvaluations;
242
243 fn assert_is_twin_coset<F: ComplexExtendable>(d: CircleDomain<F>) {
244 let pts = d.points().collect_vec();
245 let half_n = pts.len() >> 1;
246 for (&l, &r) in izip!(&pts[..half_n], pts[half_n..].iter().rev()) {
247 assert_eq!(l, -r);
248 }
249 }
250
251 fn do_test_circle_domain(log_n: usize, width: usize) {
252 let n = 1 << log_n;
253
254 type F = Mersenne31;
255 let d = CircleDomain::<F>::standard(log_n);
256
257 let p0 = d.first_point();
259 let mut p1 = p0;
260 for i in 0..(n - 1) {
261 assert_eq!(Point::from_projective_line(p1), d.nth_point(i));
263 p1 = d.next_point(p1).unwrap();
264 assert_ne!(p1, p0);
265 }
266 assert_eq!(d.next_point(p1).unwrap(), p0);
267
268 let mut uni_point = d.first_point();
270 for p in d.points() {
271 assert_eq!(Point::from_projective_line(uni_point), p);
272 uni_point = d.next_point(uni_point).unwrap();
273 }
274
275 let seen: HashSet<Point<F>> = d.points().collect();
277 for disjoint_size in [10, 100, n - 5, n + 15] {
278 let dd = d.create_disjoint_domain(disjoint_size);
279 assert!(dd.size() >= disjoint_size);
280 for pt in dd.points() {
281 assert!(!seen.contains(&pt));
282 }
283 }
284
285 for p in d.points() {
287 assert_eq!(d.zp_at_point(p.to_projective_line().unwrap()), F::zero());
288 }
289
290 let evals = RowMajorMatrix::rand(&mut thread_rng(), n, width);
292 let orig: Vec<(Point<F>, Vec<F>)> = d
293 .points()
294 .zip(evals.rows().map(|r| r.collect_vec()))
295 .collect();
296 for num_chunks in [1, 2, 4, 8] {
297 let mut combined = vec![];
298
299 let sds = d.split_domains(num_chunks);
300 assert_eq!(sds.len(), num_chunks);
301 let ses = d.split_evals(num_chunks, evals.clone());
302 assert_eq!(ses.len(), num_chunks);
303 for (sd, se) in izip!(sds, ses) {
304 assert_is_twin_coset(sd);
306 assert_eq!(sd.size() * num_chunks, d.size());
308 assert_eq!(se.width(), evals.width());
309 assert_eq!(se.height() * num_chunks, d.size());
310 combined.extend(sd.points().zip(se.rows().map(|r| r.collect_vec())));
311 }
312 assert_eq!(
314 orig.iter().map(|x| x.0).collect::<HashSet<_>>(),
315 combined.iter().map(|x| x.0).collect::<HashSet<_>>(),
316 "union of split domains is orig domain"
317 );
318 assert_eq!(
319 orig.iter().map(|x| &x.1).collect::<HashSet<_>>(),
320 combined.iter().map(|x| &x.1).collect::<HashSet<_>>(),
321 "union of split evals is orig evals"
322 );
323 assert_eq!(
324 orig.iter().collect::<HashSet<_>>(),
325 combined.iter().collect::<HashSet<_>>(),
326 "split domains and evals correspond to orig domains and evals"
327 );
328 }
329 }
330
331 #[test]
332 fn selectors() {
333 type F = Mersenne31;
334 let log_n = 8;
335 let n = 1 << log_n;
336
337 let d = CircleDomain::<F>::standard(log_n);
338 let coset = d.create_disjoint_domain(n);
339 let sels = d.selectors_on_coset(coset);
340
341 let mut pt = coset.first_point();
343 for i in 0..coset.size() {
344 let pt_sels = d.selectors_at_point(pt);
345 assert_eq!(sels.is_first_row[i], pt_sels.is_first_row);
346 assert_eq!(sels.is_last_row[i], pt_sels.is_last_row);
347 assert_eq!(sels.is_transition[i], pt_sels.is_transition);
348 assert_eq!(sels.inv_zeroifier[i], pt_sels.inv_zeroifier);
349 pt = coset.next_point(pt).unwrap();
350 }
351
352 let coset_to_d = |evals: &[F]| {
353 let evals = CircleEvaluations::from_natural_order(
354 coset,
355 RowMajorMatrix::new_col(evals.to_vec()),
356 );
357 let coeffs = evals.interpolate().to_row_major_matrix();
358 let (lo, hi) = coeffs.split_rows(n);
359 assert_eq!(hi.values, vec![F::zero(); n]);
360 CircleEvaluations::evaluate(d, lo.to_row_major_matrix())
361 .to_natural_order()
362 .to_row_major_matrix()
363 .values
364 };
365
366 let is_first_row = coset_to_d(&sels.is_first_row);
368 assert_ne!(is_first_row[0], F::zero());
369 assert_eq!(&is_first_row[1..], &vec![F::zero(); n - 1]);
370
371 let is_last_row = coset_to_d(&sels.is_last_row);
373 assert_eq!(&is_last_row[..n - 1], &vec![F::zero(); n - 1]);
374 assert_ne!(is_last_row[n - 1], F::zero());
375
376 let is_transition = coset_to_d(&sels.is_transition);
378 assert_ne!(&is_transition[..n - 1], &vec![F::zero(); n - 1]);
379 assert_eq!(is_transition[n - 1], F::zero());
380
381 let z_coeffs = CircleEvaluations::from_natural_order(
383 coset,
384 RowMajorMatrix::new_col(batch_multiplicative_inverse(&sels.inv_zeroifier)),
385 )
386 .interpolate()
387 .to_row_major_matrix()
388 .values;
389 assert_eq!(
390 z_coeffs,
391 iter::empty()
392 .chain(iter::repeat(F::zero()).take(n))
393 .chain(iter::once(F::one()))
394 .chain(iter::repeat(F::zero()).take(n - 1))
395 .collect_vec()
396 );
397 }
398
399 #[test]
400 fn test_circle_domain() {
401 do_test_circle_domain(4, 8);
402 do_test_circle_domain(10, 32);
403 }
404}