commonware_math/ntt.rs
1use crate::algebra::{Additive, FieldNTT, Ring};
2#[cfg(not(feature = "std"))]
3use alloc::{vec, vec::Vec};
4use commonware_codec::{EncodeSize, RangeCfg, Read, Write};
5use commonware_utils::bitmap::BitMap;
6use core::{
7 num::NonZeroU32,
8 ops::{Index, IndexMut},
9};
10use rand_core::CryptoRngCore;
11#[cfg(feature = "std")]
12use std::vec::Vec;
13
14/// Determines the size of polynomials we compute naively in [`EvaluationColumn::vanishing`].
15///
16/// Benchmarked to be optimal, based on BLS12381 threshold recovery time.
17const LG_VANISHING_BASE: u32 = 8;
18
19/// Reverse the first `bit_width` bits of `i`.
20///
21/// Any bits beyond that width will be erased.
22fn reverse_bits(bit_width: u32, i: u64) -> u64 {
23 assert!(bit_width <= 64, "bit_width must be <= 64");
24 i.wrapping_shl(64 - bit_width).reverse_bits()
25}
26
27/// Turn a slice into reversed bit order in place.
28///
29/// `out` MUST have length `2^bit_width`.
30fn reverse_slice<T>(bit_width: u32, out: &mut [T]) {
31 assert_eq!(out.len(), 1 << bit_width);
32 for i in 0..out.len() {
33 let j = reverse_bits(bit_width, i as u64) as usize;
34 // Only swap once, and don't swap if the location is the same.
35 if i < j {
36 out.swap(i, j);
37 }
38 }
39}
40
41/// Calculate an NTT, or an inverse NTT (with FORWARD=false), in place.
42///
43/// We implement this generically over anything we can index into, which allows
44/// performing NTTs in place.
45fn ntt<const FORWARD: bool, F: FieldNTT, M: IndexMut<(usize, usize), Output = F>>(
46 rows: usize,
47 cols: usize,
48 matrix: &mut M,
49) {
50 let lg_rows = rows.ilog2() as usize;
51 assert_eq!(1 << lg_rows, rows, "rows should be a power of 2");
52 // A number w such that w^(2^lg_rows) = 1.
53 // (Or, in the inverse case, the inverse of that number, to undo the NTT).
54 let w = {
55 let w = F::root_of_unity(lg_rows as u8).expect("too many rows to perform NTT");
56 if FORWARD {
57 w
58 } else {
59 // since w^(2^lg_rows) = 1, w^(2^lg_rows - 1) * w = 1,
60 // making that left-hand term the inverse of w.
61 w.exp(&[(1 << lg_rows) - 1])
62 }
63 };
64 // The inverse algorithm consists of carefully undoing the work of the
65 // standard algorithm, so we describe that in detail.
66 //
67 // To understand the NTT algorithm, first consider the case of a single
68 // column. We have a polynomial f(X), and we want to turn that into:
69 //
70 // [f(w^0), f(w^1), ..., f(w^(2^lg_rows - 1))]
71 //
72 // Our polynomial can be written as:
73 //
74 // f+(X^2) + X f-(X^2)
75 //
76 // where f+ and f- are polynomials with half the degree.
77 // f+ is obtained by taking the coefficients at even indices,
78 // f- is obtained by taking the coefficients at odd indices.
79 //
80 // w^2 is also conveniently a 2^(lg_rows - 1) root of unity. Thus,
81 // we can recursively compute an NTT on f+, using w^2 as the root,
82 // and an NTT on f-, using w^2 as the root, each of which is a problem
83 // of half the size.
84 //
85 // We can then compute:
86 // f+((w^i)^2) + (w^i) f-((w^i)^2)
87 // f+((w^i)^2) - (w^i) f-((w^i)^2)
88 // for each i.
89 // (Note that (-w^i)^2 = ((-w)^2)^i = (w^i)^2))
90 //
91 // Our coefficients are conveniently laid out as [f+ f-], already
92 // in a neat order. When we recurse, the coefficients of f+ are, in
93 // turn, already laid out as [f++ f+-], and so on.
94 //
95 // We just need to transform this recursive algorithm, in top down form,
96 // into an iterative one, in bottom up form. For that, note that the NTT
97 // for the case of 1 row is trivial: do nothing.
98
99 // Will contain, in bottom up order, the power of w we need at that stage.
100 // At the last stage, we need w itself.
101 // At the stage before last, we need w^2.
102 // And so on.
103 // How many stages do we need? If we have 1 row, we need 0 stages.
104 // In general, with 2^n rows, we need n stages.
105 let stages = {
106 let mut out = vec![(0usize, F::zero()); lg_rows];
107 let mut w_i = w;
108 for i in (0..lg_rows).rev() {
109 out[i] = (i, w_i.clone());
110 w_i = w_i.clone() * &w_i;
111 }
112 // In the case of the reverse algorithm, we undo each stage of the
113 // forward algorithm, starting with the last stage.
114 if !FORWARD {
115 out.reverse();
116 }
117 out
118 };
119 for (stage, w) in stages.into_iter() {
120 // At stage i, we have polynomials with 2^i coefficients,
121 // which have already been evaluated to create 2^i entries.
122 // We need to combine these evaluations to create 2^(i + 1) entries,
123 // representing the evaluation of a polynomial with 2^(i + 1) coefficients.
124 // If we have two of these evaluations, laid out one after the other:
125 //
126 // [x_0, x_1, ...] [y_0, y_1, ...]
127 //
128 // Then the number of elements we need to skip to get the corresponding
129 // element in the other half is simply the number of elements in each half,
130 // i.e. 2^i.
131 let skip = 1 << stage;
132 let mut i = 0;
133 while i < rows {
134 // In the case of a backwards NTT, skew should be the inverse of the skew
135 // in the forwards direction.
136 let mut w_j = F::one();
137 for j in 0..skip {
138 let index_a = i + j;
139 let index_b = index_a + skip;
140 for k in 0..cols {
141 let (a, b) = (matrix[(index_a, k)].clone(), matrix[(index_b, k)].clone());
142 if FORWARD {
143 let w_j_b = w_j.clone() * &b;
144 matrix[(index_a, k)] = a.clone() + &w_j_b;
145 matrix[(index_b, k)] = a - &w_j_b;
146 } else {
147 // To check the math, convince yourself that applying the forward
148 // transformation, and then this transformation, with w_j being the
149 // inverse of the value above, that you get (a, b).
150 // (a + w_j * b) + (a - w_j * b) = 2 * a
151 matrix[(index_a, k)] = (a.clone() + &b).div_2();
152 // (a + w_j * b) - (a - w_j * b) = 2 * w_j * b.
153 // w_j in this branch is the inverse of w_j in the other branch.
154 matrix[(index_b, k)] = ((a - &b) * &w_j).div_2();
155 }
156 }
157 w_j *= &w;
158 }
159 i += 2 * skip;
160 }
161 }
162}
163
164/// Columns of some larger piece of data.
165///
166/// This allows us to easily do NTTs over partial segments of some bigger matrix.
167struct Columns<'a, const N: usize, F> {
168 data: [&'a mut [F]; N],
169}
170
171impl<'a, const N: usize, F> Index<(usize, usize)> for Columns<'a, N, F> {
172 type Output = F;
173
174 fn index(&self, (i, j): (usize, usize)) -> &Self::Output {
175 &self.data[j][i]
176 }
177}
178
179impl<'a, const N: usize, F> IndexMut<(usize, usize)> for Columns<'a, N, F> {
180 fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
181 &mut self.data[j][i]
182 }
183}
184
185/// Used to keep track of the points at which a polynomial needs to vanish.
186///
187/// This takes care of subtle details like padding and bit ordering.
188///
189/// This struct is associated with a particular size, which is a power of two,
190/// and thus a particular root of unity.
191#[derive(Debug, PartialEq)]
192pub struct VanishingPoints {
193 lg_size: u32,
194 bits: BitMap,
195}
196
197impl VanishingPoints {
198 /// This will have size `2^lg_size`, and vanish everywhere.
199 ///
200 /// Be aware that this means all points are initially marked as vanishing.
201 pub fn new(lg_size: u32) -> Self {
202 Self {
203 lg_size,
204 bits: BitMap::zeroes(1 << lg_size),
205 }
206 }
207
208 /// This will have size `2^lg_size`, and vanish nowhere.
209 pub fn all_non_vanishing(lg_size: u32) -> Self {
210 Self {
211 lg_size,
212 bits: BitMap::ones(1 << lg_size),
213 }
214 }
215
216 pub const fn lg_size(&self) -> u32 {
217 self.lg_size
218 }
219
220 /// Set the root `w^index` to vanish, `value = false`, or not, `value = true`.
221 fn set(&mut self, index: u64, value: bool) {
222 self.bits.set(reverse_bits(self.lg_size, index), value);
223 }
224
225 /// Set the root `w^index` to not vanish.
226 ///
227 /// cf. `set`;
228 pub fn set_non_vanishing(&mut self, index: u64) {
229 self.set(index, true);
230 }
231
232 pub fn get(&self, index: u64) -> bool {
233 self.bits.get(reverse_bits(self.lg_size, index))
234 }
235
236 pub fn count_non_vanishing(&self) -> u64 {
237 self.bits.count_ones()
238 }
239
240 /// Check that a particular chunk of this set vanishes.
241 ///
242 /// `lg_chunk_size` determines the size of the chunk, which must be a power of two.
243 ///
244 /// `index` determines which chunk to use. After chunk 0, you have chunk 1, and so on.
245 ///
246 /// The chunk is taken from the set in reverse bit order. This is what methods
247 /// that create a vanishing polynomial recursively want. Take care when using
248 /// this naively.
249 fn chunk_vanishes_everywhere(&self, lg_chunk_size: u32, index: u64) -> bool {
250 assert!(lg_chunk_size <= self.lg_size);
251 let start = index << lg_chunk_size;
252 self.bits.is_unset(start..start + (1 << lg_chunk_size))
253 }
254
255 /// Yield the bits of a chunk, in reverse bit order.
256 ///
257 /// cf. `chunk_vanishes_everywhere`, which uses the same chunk indexing scheme.
258 fn get_chunk(&self, lg_chunk_size: u32, index: u64) -> impl Iterator<Item = bool> + '_ {
259 (index << lg_chunk_size..(index + 1) << lg_chunk_size).map(|i| self.bits.get(i))
260 }
261
262 #[cfg(any(test, feature = "fuzz"))]
263 fn iter_bits_in_order(&self) -> impl Iterator<Item = bool> + '_ {
264 (0..(1u64 << self.lg_size)).map(|i| self.get(i))
265 }
266}
267
268/// Represents the evaluation of a single polynomial over a full domain.
269#[derive(Debug)]
270struct EvaluationColumn<F> {
271 evaluations: Vec<F>,
272}
273
274impl<F: FieldNTT> EvaluationColumn<F> {
275 /// Evaluate the vanishing polynomial over `points` on the domain.
276 ///
277 /// This returns the evaluation of the polynomial at `0`, and then the evaluation
278 /// of the polynomial over the whole domain.
279 ///
280 /// This assumes that `points` has at least one non-vanishing point.
281 pub fn vanishing(points: &VanishingPoints) -> (F, Self) {
282 // The goal of this function is to produce a polynomial v such that
283 // v(w^j) = 0 for each index j where points.get(j) = false.
284 //
285 // The core idea is to split this up recursively. We split the possible
286 // roots into two groups, and figure out the vanishing polynomials
287 // v_L and v_R for the first and second groups, respectively. Then,
288 // multiplying v_L and v_R yields a polynomial with the appropriate roots.
289 //
290 // We can multiply the polynomials in O(N lg N) time, by performing an
291 // NTT on both of them, multiplying the evaluations point wise, and then
292 // using a reverse NTT to get a polynomial back.
293 //
294 // Naturally, we can extend this to construct each sub-polynomial recursively
295 // as well, giving an O(N lg^2 N) algorithm in total.
296 //
297 // This function doesn't return the polynomial directly, but rather an
298 // evaluation of the polynomial. This is because many consumers often
299 // need this anyways, and by providing them with this result, we avoid
300 // performing a reverse NTT that they then proceed to undo. However,
301 // they can also need the evaluation at 0, so we provide and calculate that
302 // as well. That can also be calculated recursively, and merged with the
303 // above calculation.
304 //
305 // One point we haven't clarified yet is how to split up the roots.
306 // Let's use an example. With size 8, the roots are:
307 //
308 // w^0 w^1 w^2 w^3 w^4 w^5 w^6 w^7
309 //
310 // or, writing down just the exponent
311 //
312 // 0 1 2 3 4 5 6 7
313 //
314 // We could build up our final polynomial by merging polynomials of size
315 // two, with roots chosen among the following possibilities:
316 //
317 // 0 1 2 3 4 5 6 7
318 //
319 // However, this requires using different roots for each polynomial.
320 //
321 // If we instead use reverse bit order, we can have things be:
322 //
323 // 0 4 2 6 1 5 3 7
324 //
325 // which is equal to:
326 //
327 // 0 4 2 + (0 4) 1 + (0 4 2 + (0 4))
328 //
329 // So, we can start by having polynomials with the same possible roots
330 // at the lowest level, and then merge by multiplying the roots by
331 // the right power, for the polynomial on the right.
332 //
333 // The roots of a polynomial can easily be multiplied by some factor
334 // by dividing its coefficients by powers of a factor.
335 // cf [`PolynomialColumn::divide_roots`].
336 //
337 // Another optimization we can do for the merges is to keep track
338 // of polynomials that vanish everywhere and nowhere. A polynomial
339 // vanishing nowhere has no effect when merging, so we can skip a multiplication.
340 // Similarly, a polynomial vanishing everywhere is of the form X^N - 1,
341 // with which multiplication is simple.
342
343 /// Used to keep track of special polynomial values.
344 #[derive(Clone, Copy)]
345 enum Where {
346 /// Vanishes at none of the roots; i.e. is f(X) = 1.
347 Nowhere,
348 /// Vanishes at at least one of the roots.
349 Somewhere,
350 /// Vanishes at every single one of the roots.
351 Everywhere,
352 }
353
354 use Where::*;
355
356 let lg_len = points.lg_size();
357 let len = 1usize << lg_len;
358 // This will store our in progress polynomials, and eventually,
359 // the final evaluations.
360 let mut out = vec![F::zero(); len];
361 // For small inputs, one chunk might more than cover it all, so we
362 // need to make the chunk size be too big.
363 let lg_chunk_size = LG_VANISHING_BASE.min(lg_len);
364 // We use this to keep track of the polynomial evaluated at 0.
365 let mut at_zero = F::one();
366
367 // Populate out with polynomials up to a low degree.
368 // We also get a vector with the status of each polymomial, letting
369 // us accelerate the merging step.
370 let mut vanishes = {
371 let chunk_size = 1usize << lg_chunk_size;
372 // The negation of each possible root vanishing polynomials can have.
373 // We have the roots in reverse bit order.
374 let minus_roots = {
375 // We can panic without worry here, because we require a smaller
376 // root of unity to exist elsewhere.
377 let w = u8::try_from(lg_chunk_size)
378 .ok()
379 .and_then(|s| F::root_of_unity(s))
380 .expect("sub-root of unity should exist");
381 // The powers of w we'll use as roots, pre-negated.
382 let mut out: Vec<_> = (0..)
383 .scan(F::one(), |state, _| {
384 let out = -state.clone();
385 *state *= &w;
386 Some(out)
387 })
388 .take(chunk_size)
389 .collect();
390 // Make sure the order is what the rest of this routine expects.
391 reverse_slice(lg_chunk_size, out.as_mut_slice());
392 out
393 };
394 // Instead of actually negating `at_zero` inside of the loop below,
395 // we instead keep track of whether or not it needs to be negated
396 // after the loop, to just perform that operation once.
397 let mut negate_at_zero = false;
398 // Populate each chunk with the initial polynomial,
399 let vanishing = out
400 .chunks_exact_mut(chunk_size)
401 .enumerate()
402 .map(|(i, poly)| {
403 let i_u64 = i as u64;
404 if points.chunk_vanishes_everywhere(lg_chunk_size, i_u64) {
405 // Implicitly, there's a 1 past the end of the polynomial,
406 // which we handle when merging.
407 poly[0] = -F::one();
408 negate_at_zero ^= true;
409 return Where::Everywhere;
410 }
411 poly[0] = F::one();
412 let mut coeffs = 1;
413 for (b_j, minus_root) in points
414 .get_chunk(lg_chunk_size, i_u64)
415 .zip(minus_roots.iter())
416 {
417 if b_j {
418 continue;
419 }
420 // Multiply the polynomial by (X - w^j).
421 poly[coeffs] = F::one();
422 for k in (1..coeffs).rev() {
423 let (chunk_head, chunk_tail) = poly.split_at_mut(k);
424 chunk_tail[0] *= minus_root;
425 chunk_tail[0] += &chunk_head[k - 1];
426 }
427 poly[0] *= minus_root;
428 coeffs += 1;
429 }
430 if coeffs > 1 {
431 reverse_slice(lg_chunk_size, poly);
432 at_zero *= &poly[0];
433 Where::Somewhere
434 } else {
435 Where::Nowhere
436 }
437 })
438 .collect::<Vec<_>>();
439 if negate_at_zero {
440 at_zero = -at_zero.clone();
441 }
442 vanishing
443 };
444 // Avoid doing any of the subsequent work if we've already covered this case.
445 if lg_chunk_size >= lg_len {
446 // We do, however, need to turn the coefficients into evaluations.
447 return (at_zero, PolynomialColumn { coefficients: out }.evaluate());
448 }
449 let w_invs = {
450 // since w^(2^lg_rows) = 1, w^(2^lg_rows - 1) * w = 1,
451 // making that left-hand term the inverse of w.
452 let mut w_inv = F::root_of_unity(lg_len as u8)
453 .expect("too many rows to create vanishing polynomial")
454 .exp(&[(1 << lg_len) - 1]);
455 let mut out = Vec::with_capacity((lg_len - lg_chunk_size) as usize);
456 for _ in lg_chunk_size..lg_len {
457 out.push(w_inv.clone());
458 w_inv = w_inv.clone() * &w_inv;
459 }
460 out.reverse();
461 out
462 };
463 let mut lg_chunk_size = lg_chunk_size;
464 let mut scratch = Vec::<F>::with_capacity(len);
465 let mut coeff_shifts = Vec::with_capacity(1 << lg_chunk_size);
466 for w_inv in w_invs.into_iter() {
467 let chunk_size = 1 << lg_chunk_size;
468 // Closure to shift coefficients by the current power.
469 // This lets us reuse the computation of the powers.
470 let mut shift_coeffs = |coeffs: &mut [F]| {
471 if coeff_shifts.len() != chunk_size {
472 coeff_shifts.clear();
473 let mut acc = F::one();
474 for _ in 0..chunk_size {
475 coeff_shifts.push(acc.clone());
476 acc *= &w_inv;
477 }
478 }
479 for (i, coeff_i) in coeffs.iter_mut().enumerate() {
480 *coeff_i *= &coeff_shifts[reverse_bits(lg_chunk_size, i as u64) as usize];
481 }
482 };
483 let next_lg_chunk_size = lg_chunk_size + 1;
484 let next_chunk_size = 1 << next_lg_chunk_size;
485 for (i, chunk) in out.chunks_exact_mut(1 << next_lg_chunk_size).enumerate() {
486 let (left, right) = chunk.split_at_mut(1 << lg_chunk_size);
487 let (vanishes_l, vanishes_r) = (vanishes[2 * i], vanishes[2 * i + 1]);
488 // We keep track of whether or not the polynomial resulting from
489 // the merge is evaluated or not.
490 let mut evaluated = false;
491 vanishes[i] = match (vanishes_l, vanishes_r) {
492 (Nowhere, Nowhere) => {
493 // Both polynomials consist of 1 0 0 0 ..., and we
494 // want the final result to be that, just with more zeroes,
495 // so we need to clear the 1 value on the right side.
496 right[0] = F::zero();
497 Nowhere
498 }
499 (Nowhere, Somewhere) => {
500 // Clear the one value on the left.
501 left[0] = F::zero();
502 // Adjust the roots on the right.
503 shift_coeffs(right);
504 // Make it take all of the left space.
505 for i in 0..chunk_size {
506 chunk.swap(chunk_size + i, 2 * i);
507 }
508 Somewhere
509 }
510 (Nowhere, Everywhere) => {
511 // (X^(N/2) - 1) is on the right.
512 // First, we multiply its roots by w_N, yielding:
513 //
514 // -X^(N/2) - 1
515 //
516 // in reverse bit order we get the following:
517 left[0] = -F::one();
518 left[1] = -F::one();
519 // And we remove the -1 on the right side.
520 right[0] = F::zero();
521 Somewhere
522 }
523 // These two cases mirror the two above.
524 (Somewhere, Nowhere) => {
525 // Clear the one on the right side.
526 right[0] = F::zero();
527 // Make it take all of the right space.
528 // We can skip moving index 0.
529 for i in (1..chunk_size).rev() {
530 chunk.swap(i, 2 * i);
531 }
532 Somewhere
533 }
534 (Everywhere, Nowhere) => {
535 // Like above, but with the polynomial on the left,
536 // there's no need to adjust the roots.
537 left[0] = -F::one();
538 left[1] = F::one();
539 right[0] = F::zero();
540 Somewhere
541 }
542 (Somewhere, Everywhere) => {
543 // We need to make the left side occupy the whole space.
544 // Shifting by one index has the effect of multiplying
545 // the polynomial by X^(chunk_size), which is what we want.
546 for i in (0..chunk_size).rev() {
547 chunk.swap(i, 2 * i + 1);
548 // We copy the value in i, negate it, and make it occupy
549 // both 2 * i + 1 and 2 * i, thus multiplying by -(X^chunk_size + 1).
550 chunk[2 * i + 1] = -chunk[2 * i + 1].clone();
551 chunk[2 * i] = chunk[2 * i + 1].clone();
552 }
553 Somewhere
554 }
555 (Everywhere, Somewhere) => {
556 // Adjust the roots on the right.
557 shift_coeffs(right);
558 // Like above, but moving the right side, and multiplying by
559 // (X^chunk_size - 1).
560 for i in 0..chunk_size {
561 chunk.swap(chunk_size + i, 2 * i + 1);
562 chunk[2 * i] = -chunk[2 * i + 1].clone();
563 }
564 Somewhere
565 }
566 (Everywhere, Everywhere) => {
567 // Make sure to clear the -1 on the right side.
568 right[0] = F::zero();
569 // By choosing to do things this way, we effectively
570 // negate the final polynomial, so we need to correct
571 // for this with the zero value.
572 at_zero = -at_zero.clone();
573 Everywhere
574 }
575 // In this case, we can assume nothing, and have to do
576 // the full logic for actually multiplying the polynomials.
577 (Somewhere, Somewhere) => {
578 // Adjust the roots on the right.
579 shift_coeffs(right);
580 // Populate the scratch buffer with the right side.
581 scratch.clear();
582 scratch.resize(next_chunk_size, F::zero());
583 for i in 0..chunk_size {
584 core::mem::swap(&mut right[i], &mut scratch[2 * i]);
585 }
586 // We can skip moving index 0.
587 for i in (1..chunk_size).rev() {
588 chunk.swap(i, 2 * i);
589 }
590 // Turn the polynomials into evaluations.
591 ntt::<true, _, _>(
592 next_chunk_size,
593 2,
594 &mut Columns {
595 data: [chunk, scratch.as_mut_slice()],
596 },
597 );
598 // Multiply them, into the chunk.
599 for (l, r) in chunk.iter_mut().zip(scratch.iter_mut()) {
600 *l *= r;
601 }
602 evaluated = true;
603 Somewhere
604 }
605 };
606 // If this isn't the last iteration, make sure to turn back into coefficients.
607 let should_be_evaluated = next_chunk_size >= len;
608 if should_be_evaluated != evaluated {
609 if evaluated {
610 ntt::<false, _, _>(next_chunk_size, 1, &mut Columns { data: [chunk] });
611 } else {
612 ntt::<true, _, _>(next_chunk_size, 1, &mut Columns { data: [chunk] });
613 }
614 }
615 }
616 lg_chunk_size = next_lg_chunk_size;
617 }
618 // We do, however, need to turn the coefficients into evaluations.
619 (at_zero, Self { evaluations: out })
620 }
621
622 pub fn interpolate(self) -> PolynomialColumn<F> {
623 let mut data = self.evaluations;
624 ntt::<false, _, _>(
625 data.len(),
626 1,
627 &mut Columns {
628 data: [data.as_mut_slice()],
629 },
630 );
631 PolynomialColumn { coefficients: data }
632 }
633}
634
635/// A column containing a single polynomial.
636#[derive(Debug)]
637struct PolynomialColumn<F> {
638 coefficients: Vec<F>,
639}
640
641impl<F: FieldNTT> PolynomialColumn<F> {
642 /// Evaluate this polynomial over the domain, returning
643 pub fn evaluate(self) -> EvaluationColumn<F> {
644 let mut data = self.coefficients;
645 ntt::<true, _, _>(
646 data.len(),
647 1,
648 &mut Columns {
649 data: [data.as_mut_slice()],
650 },
651 );
652 EvaluationColumn { evaluations: data }
653 }
654
655 #[cfg(any(test, feature = "fuzz"))]
656 fn evaluate_one(&self, point: F) -> F {
657 let mut out = F::zero();
658 let rows = self.coefficients.len();
659 let lg_rows = rows.ilog2();
660 for i in (0..rows).rev() {
661 out = out * &point + &self.coefficients[reverse_bits(lg_rows, i as u64) as usize];
662 }
663 out
664 }
665
666 #[cfg(any(test, feature = "fuzz"))]
667 fn degree(&self) -> usize {
668 let rows = self.coefficients.len();
669 let lg_rows = rows.ilog2();
670 for i in (0..rows).rev() {
671 if self.coefficients[reverse_bits(lg_rows, i as u64) as usize] != F::zero() {
672 return i;
673 }
674 }
675 0
676 }
677
678 /// Divide the roots of each polynomial by some factor.
679 ///
680 /// If f(x) = 0, then after this transformation, f(x / z) = 0 instead.
681 ///
682 /// The number of roots does not change.
683 ///
684 /// c.f. [`EvaluationColumn::vanishing`] for how this is used.
685 fn divide_roots(&mut self, factor: F) {
686 let mut factor_i = F::one();
687 let lg_rows = self.coefficients.len().ilog2();
688 for i in 0..self.coefficients.len() {
689 let index = reverse_bits(lg_rows, i as u64) as usize;
690 self.coefficients[index] *= &factor_i;
691 factor_i *= &factor;
692 }
693 }
694}
695
696/// Represents a matrix of field elements, of arbitrary dimensions
697///
698/// This is in row major order, so consider processing elements in the same
699/// row first, for locality.
700#[derive(Clone, PartialEq)]
701pub struct Matrix<F> {
702 rows: usize,
703 cols: usize,
704 data: Vec<F>,
705}
706
707impl<F: EncodeSize> EncodeSize for Matrix<F> {
708 fn encode_size(&self) -> usize {
709 self.rows.encode_size() + self.cols.encode_size() + self.data.encode_size()
710 }
711}
712
713impl<F: Write> Write for Matrix<F> {
714 fn write(&self, buf: &mut impl bytes::BufMut) {
715 self.rows.write(buf);
716 self.cols.write(buf);
717 self.data.write(buf);
718 }
719}
720
721impl<F: Read> Read for Matrix<F> {
722 type Cfg = (usize, <F as Read>::Cfg);
723
724 fn read_cfg(
725 buf: &mut impl bytes::Buf,
726 (max_els, f_cfg): &Self::Cfg,
727 ) -> Result<Self, commonware_codec::Error> {
728 let cfg = RangeCfg::from(..=*max_els);
729 let rows = usize::read_cfg(buf, &cfg)?;
730 let cols = usize::read_cfg(buf, &cfg)?;
731 let data = Vec::<F>::read_cfg(buf, &(cfg, f_cfg.clone()))?;
732 let expected_len = rows
733 .checked_mul(cols)
734 .ok_or(commonware_codec::Error::Invalid(
735 "Matrix",
736 "matrix dimensions overflow",
737 ))?;
738 if data.len() != expected_len {
739 return Err(commonware_codec::Error::Invalid(
740 "Matrix",
741 "matrix element count does not match dimensions",
742 ));
743 }
744 Ok(Self { rows, cols, data })
745 }
746}
747
748impl<F: core::fmt::Debug> core::fmt::Debug for Matrix<F> {
749 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
750 for i in 0..self.rows {
751 let row_i = &self[i];
752 for row_i_j in row_i {
753 write!(f, "{row_i_j:?} ")?;
754 }
755 writeln!(f)?;
756 }
757 Ok(())
758 }
759}
760
761impl<F: Additive> Matrix<F> {
762 /// Create a zero matrix, with a certain number of rows and columns
763 fn zero(rows: usize, cols: usize) -> Self {
764 Self {
765 rows,
766 cols,
767 data: vec![F::zero(); rows * cols],
768 }
769 }
770
771 /// Initialize a matrix, with dimensions, and data to pull from.
772 ///
773 /// Any extra data is ignored, any data not supplied is treated as 0.
774 pub fn init(rows: usize, cols: usize, mut data: impl Iterator<Item = F>) -> Self {
775 let mut out = Self::zero(rows, cols);
776 'outer: for i in 0..rows {
777 for row_i in &mut out[i] {
778 let Some(x) = data.next() else {
779 break 'outer;
780 };
781 *row_i = x;
782 }
783 }
784 out
785 }
786
787 /// Interpret the columns of this matrix as polynomials, with at least `min_coefficients`.
788 ///
789 /// This will, in fact, produce a matrix padded to the next power of 2 of that number.
790 ///
791 /// This will return `None` if `min_coefficients < self.rows`, which would mean
792 /// discarding data, instead of padding it.
793 pub fn as_polynomials(&self, min_coefficients: usize) -> Option<PolynomialVector<F>>
794 where
795 F: Clone,
796 {
797 if min_coefficients < self.rows {
798 return None;
799 }
800 Some(PolynomialVector::new(
801 min_coefficients,
802 self.cols,
803 (0..self.rows).flat_map(|i| self[i].iter().cloned()),
804 ))
805 }
806
807 /// Multiply this matrix by another.
808 ///
809 /// This assumes that the number of columns in this matrix match the number
810 /// of rows in the other matrix.
811 pub fn mul(&self, other: &Self) -> Self
812 where
813 F: Clone + Ring,
814 {
815 assert_eq!(self.cols, other.rows);
816 let mut out = Self::zero(self.rows, other.cols);
817 for i in 0..self.rows {
818 for j in 0..self.cols {
819 let c = self[(i, j)].clone();
820 let other_j = &other[j];
821 for k in 0..other.cols {
822 out[(i, k)] += &(c.clone() * &other_j[k])
823 }
824 }
825 }
826 out
827 }
828}
829
830impl<F: FieldNTT> Matrix<F> {
831 fn ntt<const FORWARD: bool>(&mut self) {
832 ntt::<FORWARD, F, Self>(self.rows, self.cols, self)
833 }
834}
835
836impl<F> Matrix<F> {
837 pub const fn rows(&self) -> usize {
838 self.rows
839 }
840
841 pub const fn cols(&self) -> usize {
842 self.cols
843 }
844
845 /// Iterate over the rows of this matrix.
846 pub fn iter(&self) -> impl Iterator<Item = &[F]> {
847 (0..self.rows).map(|i| &self[i])
848 }
849}
850
851impl<F: crate::algebra::Random> Matrix<F> {
852 /// Create a random matrix with certain dimensions.
853 pub fn rand(mut rng: impl CryptoRngCore, rows: usize, cols: usize) -> Self
854 where
855 F: Additive,
856 {
857 Self::init(rows, cols, (0..rows * cols).map(|_| F::random(&mut rng)))
858 }
859}
860
861impl<F> Index<usize> for Matrix<F> {
862 type Output = [F];
863
864 fn index(&self, index: usize) -> &Self::Output {
865 &self.data[self.cols * index..self.cols * (index + 1)]
866 }
867}
868
869impl<F> IndexMut<usize> for Matrix<F> {
870 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
871 &mut self.data[self.cols * index..self.cols * (index + 1)]
872 }
873}
874
875impl<F> Index<(usize, usize)> for Matrix<F> {
876 type Output = F;
877
878 fn index(&self, (i, j): (usize, usize)) -> &Self::Output {
879 &self.data[self.cols * i + j]
880 }
881}
882
883impl<F> IndexMut<(usize, usize)> for Matrix<F> {
884 fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
885 &mut self.data[self.cols * i + j]
886 }
887}
888
889#[cfg(any(test, feature = "arbitrary"))]
890impl<'a, F: arbitrary::Arbitrary<'a>> arbitrary::Arbitrary<'a> for Matrix<F> {
891 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
892 let rows = u.int_in_range(1..=16)?;
893 let cols = u.int_in_range(1..=16)?;
894 let data = (0..rows * cols)
895 .map(|_| F::arbitrary(u))
896 .collect::<arbitrary::Result<Vec<F>>>()?;
897 Ok(Self { rows, cols, data })
898 }
899}
900
901#[derive(Clone, Debug, PartialEq)]
902pub struct PolynomialVector<F> {
903 // Each column of this matrix contains the coefficients of a polynomial,
904 // in reverse bit order. So, the ith coefficient appears at index i.reverse_bits().
905 //
906 // For example, a polynomial a0 + a1 X + a2 X^2 + a3 X^3 is stored as:
907 //
908 // a0 a2 a1 a3
909 //
910 // This is convenient because the even coefficients and the odd coefficients
911 // split nicely into halves. The first half of the rows have the property
912 // that the first bit of their coefficient index is 0, then in that subset
913 // the first half has the second bit set to 0, and the second half set to 1,
914 // and so on, recursively.
915 data: Matrix<F>,
916}
917
918impl<F: Additive> PolynomialVector<F> {
919 /// Construct a new vector of polynomials, from dimensions, and coefficients.
920 ///
921 /// The coefficients should be supplied in order of increasing index,
922 /// and then for each polynomial.
923 ///
924 /// In other words, if you have 3 polynomials:
925 ///
926 /// a0 + a1 X + ...
927 /// b0 + b1 X + ...
928 /// c0 + c1 X + ...
929 ///
930 /// The iterator should yield:
931 ///
932 /// a0 b0 c0
933 /// a1 b1 c1
934 /// ...
935 ///
936 /// Any coefficients not supplied are treated as being equal to 0.
937 fn new(rows: usize, cols: usize, mut coefficients: impl Iterator<Item = F>) -> Self {
938 assert!(rows > 0);
939 let rows = rows.next_power_of_two();
940 let lg_rows = rows.ilog2();
941 let mut data = Matrix::zero(rows, cols);
942 'outer: for i in 0..rows {
943 let row_i = &mut data[reverse_bits(lg_rows, i as u64) as usize];
944 for row_i_j in row_i {
945 let Some(c) = coefficients.next() else {
946 break 'outer;
947 };
948 *row_i_j = c;
949 }
950 }
951 Self { data }
952 }
953}
954
955impl<F: FieldNTT> PolynomialVector<F> {
956 /// Evaluate each polynomial in this vector over all points in an interpolation domain.
957 pub fn evaluate(mut self) -> EvaluationVector<F> {
958 self.data.ntt::<true>();
959 let active_rows = VanishingPoints::all_non_vanishing(self.data.rows().ilog2());
960 EvaluationVector {
961 data: self.data,
962 active_rows,
963 }
964 }
965
966 /// Like [Self::evaluate], but with a simpler algorithm that's much less efficient.
967 ///
968 /// Exists as a useful tool for testing
969 #[cfg(any(test, feature = "fuzz"))]
970 fn evaluate_naive(self) -> EvaluationVector<F> {
971 let rows = self.data.rows;
972 let lg_rows = rows.ilog2();
973 let w = F::root_of_unity(lg_rows as u8).expect("too much data to calculate NTT");
974 // entry (i, j) of this matrix will contain w^ij. Thus, multiplying it
975 // with the coefficients of a polynomial, in column order, will evaluate it.
976 // We also need to re-arrange the columns of the matrix to match the same
977 // order we have for polynomial coefficients.
978 let mut vandermonde_matrix = Matrix::zero(rows, rows);
979 let mut w_i = F::one();
980 for i in 0..rows {
981 let row_i = &mut vandermonde_matrix[i];
982 let mut w_ij = F::one();
983 for j in 0..rows {
984 // Remember, the coeffients of the polynomial are in reverse bit order!
985 row_i[reverse_bits(lg_rows, j as u64) as usize] = w_ij.clone();
986 w_ij *= &w_i;
987 }
988 w_i *= &w;
989 }
990
991 EvaluationVector {
992 data: vandermonde_matrix.mul(&self.data),
993 active_rows: VanishingPoints::all_non_vanishing(lg_rows),
994 }
995 }
996
997 /// Divide the roots of each polynomial by some factor.
998 ///
999 /// c.f. [`PolynomialColumn::divide_roots`]. This performs the same operation on
1000 /// each polynomial in this vector.
1001 fn divide_roots(&mut self, factor: F) {
1002 let mut factor_i = F::one();
1003 let lg_rows = self.data.rows.ilog2();
1004 for i in 0..self.data.rows {
1005 for p_i in &mut self.data[reverse_bits(lg_rows, i as u64) as usize] {
1006 *p_i *= &factor_i;
1007 }
1008 factor_i *= &factor;
1009 }
1010 }
1011
1012 /// For each polynomial P_i in this vector compute the evaluation of P_i / Q.
1013 ///
1014 /// Naturally, you can call [EvaluationVector::interpolate]. The reason we don't
1015 /// do this is that the algorithm naturally yields an [EvaluationVector], and
1016 /// some use-cases may want access to that data as well.
1017 ///
1018 /// This assumes that the number of coefficients in the polynomials of this vector
1019 /// matches that of `q` (the coefficients can be 0, but need to be padded to the right size).
1020 ///
1021 /// This assumes that `q` has no zeroes over `coset_shift() * root_of_unity()^i`,
1022 /// for any i. This will be the case for a vanishing polynomial produced by
1023 /// [EvaluationColumn::vanishing] and then interpolated.
1024 /// If this isn't the case, the result may be junk.
1025 ///
1026 /// If `q` doesn't divide a partiular polynomial in this vector, the result
1027 /// for that polynomial is not guaranteed to be anything meaningful.
1028 fn divide(&mut self, mut q: PolynomialColumn<F>) {
1029 // The algorithm operates column wise.
1030 //
1031 // You can compute P(X) / Q(X) by evaluating each polynomial, then computing
1032 //
1033 // P(w^i) / Q(w^i)
1034 //
1035 // for each evaluation point. Then, you can interpolate back.
1036 //
1037 // But wait! What if Q(w^i) = 0? In particular, for the case of recovering
1038 // a polynomial from data with missing rows, we *expect* P(w^i) = 0 = Q(w^i)
1039 // for the indicies we're missing, so this doesn't work.
1040 //
1041 // What we can do is to instead multiply each of the roots by some factor z,
1042 // such that z w^i != w^j, for any i, j. In other words, we change the roots
1043 // such that they're not in the evaluation domain anymore, allowing us to
1044 // divide. We can then interpolate the result back into a polynomial,
1045 // and divide back the roots to where they should be.
1046 //
1047 // c.f. [PolynomialVector::divide_roots]
1048 assert_eq!(
1049 self.data.rows,
1050 q.coefficients.len(),
1051 "cannot divide by polynomial of the wrong size"
1052 );
1053 let skew = F::coset_shift();
1054 let skew_inv = F::coset_shift_inv();
1055 self.divide_roots(skew.clone());
1056 q.divide_roots(skew);
1057 ntt::<true, F, _>(self.data.rows, self.data.cols, &mut self.data);
1058 ntt::<true, F, _>(
1059 q.coefficients.len(),
1060 1,
1061 &mut Columns {
1062 data: [&mut q.coefficients],
1063 },
1064 );
1065 // Do a point wise division.
1066 for i in 0..self.data.rows {
1067 let q_i = q.coefficients[i].clone();
1068 // If `q_i = 0`, then we will get 0 in the output.
1069 // We don't expect any of the q_i to be 0, but being 0 is only one
1070 // of the many possibilities for the coefficient to be incorrect,
1071 // so doing a runtime assertion here doesn't make sense.
1072 let q_i_inv = q_i.inv();
1073 for d_i_j in &mut self.data[i] {
1074 *d_i_j *= &q_i_inv;
1075 }
1076 }
1077 // Interpolate back, using the inverse skew
1078 ntt::<false, F, _>(self.data.rows, self.data.cols, &mut self.data);
1079 self.divide_roots(skew_inv);
1080 }
1081}
1082
1083impl<F> PolynomialVector<F> {
1084 /// Iterate over up to n rows of this vector.
1085 ///
1086 /// For example, given polynomials:
1087 ///
1088 /// a0 + a1 X + a2 X^2 + ...
1089 /// b0 + b1 X + b2 X^2 + ...
1090 ///
1091 /// This will return:
1092 ///
1093 /// a0 b0
1094 /// a1 b1
1095 /// ...
1096 ///
1097 /// up to n times.
1098 pub fn coefficients_up_to(&self, n: usize) -> impl Iterator<Item = &[F]> {
1099 let n = n.min(self.data.rows);
1100 let lg_rows = self.data.rows().ilog2();
1101 (0..n).map(move |i| &self.data[reverse_bits(lg_rows, i as u64) as usize])
1102 }
1103}
1104
1105/// The result of evaluating a vector of polynomials over all points in an interpolation domain.
1106///
1107/// This struct also remembers which rows have ever been filled with [Self::fill_row].
1108/// This is used in [Self::recover], which can use the rows that are present to fill in the missing
1109/// rows.
1110#[derive(Debug, PartialEq)]
1111pub struct EvaluationVector<F> {
1112 data: Matrix<F>,
1113 active_rows: VanishingPoints,
1114}
1115
1116impl<F: FieldNTT> EvaluationVector<F> {
1117 /// Figure out the polynomial which evaluates to this vector.
1118 ///
1119 /// i.e. the inverse of [PolynomialVector::evaluate].
1120 ///
1121 /// (This makes all the rows count as filled).
1122 fn interpolate(mut self) -> PolynomialVector<F> {
1123 self.data.ntt::<false>();
1124 PolynomialVector { data: self.data }
1125 }
1126
1127 /// Erase a particular row.
1128 ///
1129 /// Useful for testing the recovery procedure.
1130 #[cfg(any(test, feature = "fuzz"))]
1131 fn remove_row(&mut self, row: usize) {
1132 self.data[row].fill(F::zero());
1133 self.active_rows.set(row as u64, false);
1134 }
1135
1136 fn multiply(&mut self, evaluation: &EvaluationColumn<F>) {
1137 for (i, e_i) in evaluation.evaluations.iter().enumerate() {
1138 for self_j in &mut self.data[i] {
1139 *self_j = self_j.clone() * e_i;
1140 }
1141 }
1142 }
1143
1144 /// Attempt to recover the missing rows in this data.
1145 pub fn recover(mut self) -> PolynomialVector<F> {
1146 let non_vanishing = self.active_rows.count_non_vanishing();
1147 if non_vanishing == 0 || non_vanishing == self.data.rows as u64 {
1148 return self.interpolate();
1149 }
1150
1151 // If we had all of the rows, we could simply call [Self::interpolate],
1152 // in order to recover the original polynomial. If we do this while missing some
1153 // rows, what we get is D(X) * V(X) where D is the original polynomial,
1154 // and V(X) is a polynomial which vanishes at all the rows we're missing.
1155 //
1156 // As long as the degree of D is low enough, compared to the number of evaluations
1157 // we *do* have, then we can recover it by performing:
1158 //
1159 // (D(X) * V(X)) / V(X)
1160 //
1161 // If we have multiple columns, then this procedure can be done column by column,
1162 // with the same vanishing polynomial.
1163 let (_, vanishing) = EvaluationColumn::vanishing(&self.active_rows);
1164 self.multiply(&vanishing);
1165 let mut out = self.interpolate();
1166 out.divide(vanishing.interpolate());
1167 out
1168 }
1169}
1170
1171impl<F: Additive> EvaluationVector<F> {
1172 /// Create an empty element of this struct, with no filled rows.
1173 ///
1174 /// `2^lg_rows` must be a valid `usize`.
1175 pub fn empty(lg_rows: usize, cols: usize) -> Self {
1176 assert!(
1177 lg_rows < usize::BITS as usize,
1178 "2^lg_rows must be a valid usize"
1179 );
1180 let data = Matrix::zero(1 << lg_rows, cols);
1181 let active = VanishingPoints::new(lg_rows as u32);
1182 Self {
1183 data,
1184 active_rows: active,
1185 }
1186 }
1187
1188 /// Fill a specific row.
1189 pub fn fill_row(&mut self, row: usize, data: &[F])
1190 where
1191 F: Clone,
1192 {
1193 assert!(data.len() <= self.data.cols);
1194 self.data[row][..data.len()].clone_from_slice(data);
1195 self.active_rows.set(row as u64, true);
1196 }
1197}
1198
1199impl<F> EvaluationVector<F> {
1200 /// Get the underlying data, as a Matrix.
1201 pub fn data(self) -> Matrix<F> {
1202 self.data
1203 }
1204
1205 /// Return how many distinct rows have been filled.
1206 pub fn filled_rows(&self) -> usize {
1207 self.active_rows.count_non_vanishing() as usize
1208 }
1209}
1210
1211/// Compute Lagrange coefficients for interpolating a polynomial at 0 from evaluations
1212/// at roots of unity.
1213///
1214/// Given a subset S of indices where we have evaluations, this computes the Lagrange
1215/// coefficients needed to interpolate to 0. For each index `j` in S, the coefficient
1216/// is `L_j(0)` where `L_j` is the Lagrange basis polynomial.
1217///
1218/// The key formula is: `L_j(0) = P_Sbar(w^j) / (N * P_Sbar(0))`
1219///
1220/// where `P_Sbar` is the (possibly scaled) vanishing polynomial over the complement
1221/// (missing points), and N is the domain size. This follows from
1222/// `V_S(X) * V_Sbar(X) = X^N - 1`, which gives `V_S(0) = -1/V_Sbar(0)`.
1223/// The scaling factor of `P_Sbar` cancels in the ratio.
1224///
1225/// Building `P_Sbar` as the vanishing polynomial over missing points is cheaper than building `V_S`
1226/// when most points are present (the typical erasure-coding case), since `|Sbar| << |S|`.
1227///
1228/// # Arguments
1229/// * `total` - The total number of points in the domain (rounded up to power of 2)
1230/// * `iter` - Iterator of indices where we have evaluations (duplicates ignored, indices >= total ignored)
1231///
1232/// # Returns
1233/// A vector of `(index, coefficient)` pairs for each unique index in the input set.
1234pub fn lagrange_coefficients<F: FieldNTT>(
1235 total: NonZeroU32,
1236 iter: impl IntoIterator<Item = u32>,
1237) -> Vec<(u32, F)> {
1238 let total_u64 = u64::from(total.get());
1239 let size = total_u64.next_power_of_two();
1240 let lg_size = size.ilog2();
1241
1242 let mut present = VanishingPoints::new(lg_size);
1243 for i in iter {
1244 let i_u64 = u64::from(i);
1245 if i_u64 < total_u64 {
1246 present.set_non_vanishing(i_u64);
1247 }
1248 }
1249
1250 let num_present = present.count_non_vanishing();
1251
1252 if num_present == 0 {
1253 return Vec::new();
1254 }
1255
1256 let n_f = F::one().scale(&[size]);
1257 if num_present == size {
1258 let n_inv = n_f.inv();
1259 return (0..size).map(|i| (i as u32, n_inv.clone())).collect();
1260 }
1261
1262 // Build P_Sbar (vanishes at indices NOT in present) and evaluate at all
1263 // roots of unity via NTT. Note: vanishing() may produce a scaled polynomial
1264 // P_Sbar = c * V_Sbar, but the scaling cancels in the ratio below.
1265 let (p_sbar_at_zero, complement_evals) = EvaluationColumn::vanishing(&present);
1266
1267 // From V_S(0) * V_Sbar(0) = -1 (since V_S * V_Sbar = X^N - 1), we get:
1268 // L_j(0) = -V_S(0) * V_Sbar(w^j) / N = V_Sbar(w^j) / (N * V_Sbar(0))
1269 // Since P_Sbar = c * V_Sbar, the scaling c cancels:
1270 // L_j(0) = P_Sbar(w^j) / (N * P_Sbar(0))
1271 let factor = (n_f * &p_sbar_at_zero).inv();
1272
1273 let mut out = Vec::with_capacity(num_present as usize);
1274 for j in 0..size {
1275 if present.get(j) {
1276 let coeff = factor.clone() * &complement_evals.evaluations[j as usize];
1277 out.push((j as u32, coeff));
1278 }
1279 }
1280 out
1281}
1282
1283#[cfg(any(test, feature = "fuzz"))]
1284pub mod fuzz {
1285 use super::*;
1286 use crate::{algebra::Ring, fields::goldilocks::F};
1287 use arbitrary::{Arbitrary, Unstructured};
1288
1289 fn arb_polynomial_vector(
1290 u: &mut Unstructured<'_>,
1291 max_log_rows: u32,
1292 max_cols: usize,
1293 ) -> arbitrary::Result<PolynomialVector<F>> {
1294 let lg_rows = u.int_in_range(0..=max_log_rows)?;
1295 let cols = u.int_in_range(1..=max_cols)?;
1296 let rows = 1usize << lg_rows;
1297 let coefficients: Vec<F> = (0..rows * cols)
1298 .map(|_| Ok(F::from(u.arbitrary::<u64>()?)))
1299 .collect::<arbitrary::Result<_>>()?;
1300 Ok(PolynomialVector::new(rows, cols, coefficients.into_iter()))
1301 }
1302
1303 fn arb_bit_vec_not_all_0(
1304 u: &mut Unstructured<'_>,
1305 max_log_rows: u32,
1306 ) -> arbitrary::Result<VanishingPoints> {
1307 let lg_rows = u.int_in_range(0..=max_log_rows)?;
1308 let rows = 1usize << lg_rows;
1309 let set_row = u.int_in_range(0..=rows - 1)?;
1310 let mut bools: Vec<bool> = (0..rows)
1311 .map(|_| u.arbitrary())
1312 .collect::<arbitrary::Result<_>>()?;
1313 bools[set_row] = true;
1314 let mut out = VanishingPoints::new(lg_rows);
1315 for (i, b) in bools.into_iter().enumerate() {
1316 out.set(i as u64, b);
1317 }
1318 Ok(out)
1319 }
1320
1321 fn arb_recovery_setup(
1322 u: &mut Unstructured<'_>,
1323 max_n: usize,
1324 max_k: usize,
1325 max_cols: usize,
1326 ) -> arbitrary::Result<RecoverySetup> {
1327 let n = u.int_in_range(1..=max_n)?;
1328 let k = u.int_in_range(0..=max_k)?;
1329 let cols = u.int_in_range(1..=max_cols)?;
1330 let data: Vec<F> = (0..n * cols)
1331 .map(|_| Ok(F::from(u.arbitrary::<u64>()?)))
1332 .collect::<arbitrary::Result<_>>()?;
1333 let padded_rows = (n + k).next_power_of_two();
1334 let num_present = u.int_in_range(n..=padded_rows)?;
1335 let mut indices: Vec<usize> = (0..padded_rows).collect();
1336 for i in 0..num_present {
1337 let j = u.int_in_range(i..=padded_rows - 1)?;
1338 indices.swap(i, j);
1339 }
1340 let mut present = VanishingPoints::new(padded_rows.ilog2());
1341 for &i in &indices[..num_present] {
1342 present.set(i as u64, true);
1343 }
1344 Ok(RecoverySetup {
1345 n,
1346 k,
1347 cols,
1348 data,
1349 present,
1350 })
1351 }
1352
1353 #[derive(Debug)]
1354 pub struct RecoverySetup {
1355 n: usize,
1356 k: usize,
1357 cols: usize,
1358 data: Vec<F>,
1359 present: VanishingPoints,
1360 }
1361
1362 impl RecoverySetup {
1363 #[cfg(test)]
1364 pub(crate) const fn new(
1365 n: usize,
1366 k: usize,
1367 cols: usize,
1368 data: Vec<F>,
1369 present: VanishingPoints,
1370 ) -> Self {
1371 Self {
1372 n,
1373 k,
1374 cols,
1375 data,
1376 present,
1377 }
1378 }
1379
1380 pub fn test(self) {
1381 let data = PolynomialVector::new(self.n + self.k, self.cols, self.data.into_iter());
1382 let mut encoded = data.clone().evaluate();
1383 for (i, b_i) in self.present.iter_bits_in_order().enumerate() {
1384 if !b_i {
1385 encoded.remove_row(i);
1386 }
1387 }
1388 let recovered_data = encoded.recover();
1389 assert_eq!(data, recovered_data);
1390 }
1391 }
1392
1393 #[derive(Debug)]
1394 pub enum Plan {
1395 NttEqNaive(PolynomialVector<F>),
1396 EvaluationThenInverse(PolynomialVector<F>),
1397 VanishingPolynomial(VanishingPoints),
1398 Recovery(RecoverySetup),
1399 }
1400
1401 impl<'a> Arbitrary<'a> for Plan {
1402 fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result<Self> {
1403 match u.int_in_range(0..=3)? {
1404 0 => Ok(Self::NttEqNaive(arb_polynomial_vector(u, 6, 4)?)),
1405 1 => Ok(Self::EvaluationThenInverse(arb_polynomial_vector(u, 6, 4)?)),
1406 2 => Ok(Self::VanishingPolynomial(arb_bit_vec_not_all_0(u, 8)?)),
1407 _ => Ok(Self::Recovery(arb_recovery_setup(u, 128, 128, 4)?)),
1408 }
1409 }
1410 }
1411
1412 impl Plan {
1413 pub fn run(self, _u: &mut Unstructured<'_>) -> arbitrary::Result<()> {
1414 match self {
1415 Self::NttEqNaive(p) => {
1416 let ntt = p.clone().evaluate();
1417 let ntt_naive = p.evaluate_naive();
1418 assert_eq!(ntt, ntt_naive);
1419 }
1420 Self::EvaluationThenInverse(p) => {
1421 assert_eq!(p.clone(), p.evaluate().interpolate());
1422 }
1423 Self::VanishingPolynomial(bv) => {
1424 let total = 1u64 << bv.lg_size();
1425 let expected_degree = total - bv.count_non_vanishing();
1426 let (at_zero, evals) = EvaluationColumn::<F>::vanishing(&bv);
1427 let v = evals.interpolate();
1428 assert_eq!(
1429 v.degree(),
1430 expected_degree as usize,
1431 "expected v to have degree {}",
1432 expected_degree
1433 );
1434 assert_eq!(
1435 at_zero, v.coefficients[0],
1436 "at_zero should be the 0th coefficient"
1437 );
1438 let w = F::root_of_unity(bv.lg_size() as u8).unwrap();
1439 let mut w_i = F::one();
1440 for b_i in bv.iter_bits_in_order() {
1441 let v_at_w_i = v.evaluate_one(w_i);
1442 if !b_i {
1443 assert_eq!(v_at_w_i, F::zero(), "v should evaluate to 0 at {:?}", w_i);
1444 } else {
1445 assert_ne!(v_at_w_i, F::zero());
1446 }
1447 w_i = w_i * w;
1448 }
1449 }
1450 Self::Recovery(setup) => {
1451 setup.test();
1452 }
1453 }
1454 Ok(())
1455 }
1456 }
1457
1458 #[test]
1459 fn test_fuzz() {
1460 use commonware_invariants::minifuzz;
1461 minifuzz::test(|u| u.arbitrary::<Plan>()?.run(u));
1462 }
1463}
1464
1465#[cfg(test)]
1466mod test {
1467 use super::*;
1468 use crate::{algebra::Ring, fields::goldilocks::F};
1469
1470 #[test]
1471 fn test_reverse_bits() {
1472 assert_eq!(reverse_bits(4, 0b1000), 0b0001);
1473 assert_eq!(reverse_bits(4, 0b0100), 0b0010);
1474 assert_eq!(reverse_bits(4, 0b0010), 0b0100);
1475 assert_eq!(reverse_bits(4, 0b0001), 0b1000);
1476 }
1477
1478 #[test]
1479 fn matrix_read_rejects_length_mismatch() {
1480 use bytes::BytesMut;
1481 use commonware_codec::{Read as _, Write as _};
1482
1483 let mut buf = BytesMut::new();
1484 (2usize).write(&mut buf);
1485 (2usize).write(&mut buf);
1486 vec![F::one(); 3].write(&mut buf);
1487
1488 let mut bytes = buf.freeze();
1489 let result = Matrix::<F>::read_cfg(&mut bytes, &(8, ()));
1490 assert!(matches!(
1491 result,
1492 Err(commonware_codec::Error::Invalid(
1493 "Matrix",
1494 "matrix element count does not match dimensions"
1495 ))
1496 ));
1497 }
1498
1499 fn assert_vanishing_points_correct(points: &VanishingPoints) {
1500 let expected_degree = (1 << points.lg_size()) - points.count_non_vanishing();
1501 let (at_zero, evaluations) = EvaluationColumn::<F>::vanishing(points);
1502 if points.count_non_vanishing() == 0 {
1503 // EvaluationColumn::vanishing assumes at least one non-vanishing point.
1504 // We still invoke it so callers can exercise internal branch coverage.
1505 return;
1506 }
1507 let polynomial = evaluations.interpolate();
1508 assert_eq!(
1509 polynomial.degree(),
1510 expected_degree as usize,
1511 "expected v to have degree {expected_degree}"
1512 );
1513 assert_eq!(
1514 at_zero, polynomial.coefficients[0],
1515 "at_zero should be the 0th coefficient"
1516 );
1517 let w = F::root_of_unity(points.lg_size() as u8).unwrap();
1518 let mut w_i = F::one();
1519 for (i, point_is_non_vanishing) in points.iter_bits_in_order().enumerate() {
1520 let value = polynomial.evaluate_one(w_i);
1521 if point_is_non_vanishing {
1522 assert_ne!(value, F::zero(), "expected non-zero at i={i}");
1523 } else {
1524 assert_eq!(value, F::zero(), "expected zero at i={i}");
1525 }
1526 w_i = w_i * w;
1527 }
1528 }
1529
1530 #[test]
1531 fn test_recovery_000() {
1532 let present = {
1533 let mut out = VanishingPoints::new(1);
1534 out.set_non_vanishing(1);
1535 out
1536 };
1537 fuzz::RecoverySetup::new(1, 1, 1, vec![F::one()], present).test()
1538 }
1539
1540 #[test]
1541 fn test_recovery_empty_vector() {
1542 let recovered = EvaluationVector::<F>::empty(4, 3).recover();
1543 let expected = EvaluationVector::<F>::empty(4, 3).interpolate();
1544 assert_eq!(recovered, expected);
1545 }
1546
1547 #[test]
1548 fn test_vanishing_polynomial_all_two_chunk_combinations() {
1549 fn fill_half(points: &mut VanishingPoints, half: usize, values: [bool; 2]) {
1550 let chunk_size = 1usize << LG_VANISHING_BASE;
1551 let start = half * chunk_size;
1552 let lg_size = points.lg_size();
1553 for i in 0..chunk_size {
1554 let value = values[i % 2];
1555 let raw_index = (start + i) as u64;
1556 points.set(reverse_bits(lg_size, raw_index), value);
1557 }
1558 }
1559
1560 let lg_size = LG_VANISHING_BASE + 1;
1561 // (0,0) => Everywhere, (0,1) => Somewhere, (1,1) => Nowhere.
1562 let states = [[false, false], [false, true], [true, true]];
1563 for left in states {
1564 for right in states {
1565 let mut points = VanishingPoints::new(lg_size);
1566 // VanishingPoints stores roots in reverse bit order. Writing raw halves
1567 // directly makes chunk 0/1 align exactly with the implementation's chunks.
1568 fill_half(&mut points, 0, left);
1569 fill_half(&mut points, 1, right);
1570 assert_vanishing_points_correct(&points);
1571 }
1572 }
1573 }
1574
1575 #[cfg(feature = "arbitrary")]
1576 mod conformance {
1577 use super::*;
1578 use commonware_codec::conformance::CodecConformance;
1579
1580 commonware_conformance::conformance_tests! {
1581 CodecConformance<Matrix<F>>,
1582 }
1583 }
1584}