commonware_math/ntt.rs
1use crate::{
2 algebra::{Additive as _, Ring},
3 fields::goldilocks::F,
4};
5#[cfg(not(feature = "std"))]
6use alloc::{vec, vec::Vec};
7use commonware_codec::{EncodeSize, RangeCfg, Read, Write};
8use commonware_utils::bitmap::{BitMap, DEFAULT_CHUNK_SIZE};
9use core::ops::{Index, IndexMut};
10use rand_core::CryptoRngCore;
11
12/// Reverse the first `bit_width` bits of `i`.
13///
14/// Any bits beyond that width will be erased.
15fn reverse_bits(bit_width: u32, i: u64) -> u64 {
16 assert!(bit_width <= 64, "bit_width must be <= 64");
17 i.wrapping_shl(64 - bit_width).reverse_bits()
18}
19
20/// Calculate an NTT, or an inverse NTT (with FORWARD=false), in place.
21///
22/// We implement this generically over anything we can index into, which allows
23/// performing NTTs in place
24fn ntt<const FORWARD: bool, M: IndexMut<(usize, usize), Output = F>>(
25 rows: usize,
26 cols: usize,
27 matrix: &mut M,
28) {
29 let lg_rows = rows.ilog2() as usize;
30 assert_eq!(1 << lg_rows, rows, "rows should be a power of 2");
31 // A number w such that w^(2^lg_rows) = 1.
32 // (Or, in the inverse case, the inverse of that number, to undo the NTT).
33 let w = {
34 let w = F::root_of_unity(lg_rows as u8).expect("too many rows to perform NTT");
35 if FORWARD {
36 w
37 } else {
38 // since w^(2^lg_rows) = 1, w^(2^lg_rows - 1) * w = 1,
39 // making that left-hand term the inverse of w.
40 w.exp(&[(1 << lg_rows) - 1])
41 }
42 };
43 // The inverse algorithm consists of carefully undoing the work of the
44 // standard algorithm, so we describe that in detail.
45 //
46 // To understand the NTT algorithm, first consider the case of a single
47 // column. We have a polynomial f(X), and we want to turn that into:
48 //
49 // [f(w^0), f(w^1), ..., f(w^(2^lg_rows - 1))]
50 //
51 // Our polynomial can be written as:
52 //
53 // f+(X^2) + X f-(X^2)
54 //
55 // where f+ and f- are polynomials with half the degree.
56 // f+ is obtained by taking the coefficients at even indices,
57 // f- is obtained by taking the coefficients at odd indices.
58 //
59 // w^2 is also conveniently a 2^(lg_rows - 1) root of unity. Thus,
60 // we can recursively compute an NTT on f+, using w^2 as the root,
61 // and an NTT on f-, using w^2 as the root, each of which is a problem
62 // of half the size.
63 //
64 // We can then compute:
65 // f+((w^i)^2) + (w^i) f-((w^i)^2)
66 // f+((w^i)^2) - (w^i) f-((w^i)^2)
67 // for each i.
68 // (Note that (-w^i)^2 = ((-w)^2)^i = (w^i)^2))
69 //
70 // Our coefficients are conveniently laid out as [f+ f-], already
71 // in a neat order. When we recurse, the coefficients of f+ are, in
72 // turn, already laid out as [f++ f+-], and so on.
73 //
74 // We just need to transform this recursive algorithm, in top down form,
75 // into an iterative one, in bottom up form. For that, note that the NTT
76 // for the case of 1 row is trivial: do nothing.
77
78 // Will contain, in bottom up order, the power of w we need at that stage.
79 // At the last stage, we need w itself.
80 // At the stage before last, we need w^2.
81 // And so on.
82 // How many stages do we need? If we have 1 row, we need 0 stages.
83 // In general, with 2^n rows, we need n stages.
84 let stages = {
85 let mut out = vec![(0usize, F::zero()); lg_rows];
86 let mut w_i = w;
87 for i in (0..lg_rows).rev() {
88 out[i] = (i, w_i);
89 w_i = w_i * w_i;
90 }
91 // In the case of the reverse algorithm, we undo each stage of the
92 // forward algorithm, starting with the last stage.
93 if !FORWARD {
94 out.reverse();
95 }
96 out
97 };
98 for (stage, w) in stages.into_iter() {
99 // At stage i, we have polynomials with 2^i coefficients,
100 // which have already been evaluated to create 2^i entries.
101 // We need to combine these evaluations to create 2^(i + 1) entries,
102 // representing the evaluation of a polynomial with 2^(i + 1) coefficients.
103 // If we have two of these evaluations, laid out one after the other:
104 //
105 // [x_0, x_1, ...] [y_0, y_1, ...]
106 //
107 // Then the number of elements we need to skip to get the corresponding
108 // element in the other half is simply the number of elements in each half,
109 // i.e. 2^i.
110 let skip = 1 << stage;
111 let mut i = 0;
112 while i < rows {
113 // In the case of a backwards NTT, skew should be the inverse of the skew
114 // in the forwards direction.
115 let mut w_j = F::one();
116 for j in 0..skip {
117 let index_a = i + j;
118 let index_b = index_a + skip;
119 for k in 0..cols {
120 let (a, b) = (matrix[(index_a, k)], matrix[(index_b, k)]);
121 if FORWARD {
122 matrix[(index_a, k)] = a + w_j * b;
123 matrix[(index_b, k)] = a - w_j * b;
124 } else {
125 // To check the math, convince yourself that applying the forward
126 // transformation, and then this transformation, with w_j being the
127 // inverse of the value above, that you get (a, b).
128 // (a + w_j * b) + (a - w_j * b) = 2 * a
129 matrix[(index_a, k)] = (a + b).div_2();
130 // (a + w_j * b) - (a - w_j * b) = 2 * w_j * b.
131 // w_j in this branch is the inverse of w_j in the other branch.
132 matrix[(index_b, k)] = ((a - b) * w_j).div_2();
133 }
134 }
135 w_j = w_j * w;
136 }
137 i += 2 * skip;
138 }
139 }
140}
141
142/// A single column of some larger data.
143///
144/// This allows us to easily do NTTs over partial segments of some bigger matrix.
145struct Column<'a> {
146 data: &'a mut [F],
147}
148
149impl<'a> Index<(usize, usize)> for Column<'a> {
150 type Output = F;
151
152 fn index(&self, (i, _): (usize, usize)) -> &Self::Output {
153 &self.data[i]
154 }
155}
156impl<'a> IndexMut<(usize, usize)> for Column<'a> {
157 fn index_mut(&mut self, (i, _): (usize, usize)) -> &mut Self::Output {
158 &mut self.data[i]
159 }
160}
161
162/// Represents a matrix of field elements, of arbitrary dimensions
163///
164/// This is in row major order, so consider processing elements in the same
165/// row first, for locality.
166#[derive(Clone, PartialEq)]
167pub struct Matrix {
168 rows: usize,
169 cols: usize,
170 data: Vec<F>,
171}
172
173impl EncodeSize for Matrix {
174 fn encode_size(&self) -> usize {
175 self.rows.encode_size() + self.cols.encode_size() + self.data.encode_size()
176 }
177}
178
179impl Write for Matrix {
180 fn write(&self, buf: &mut impl bytes::BufMut) {
181 self.rows.write(buf);
182 self.cols.write(buf);
183 self.data.write(buf);
184 }
185}
186
187impl Read for Matrix {
188 type Cfg = usize;
189
190 fn read_cfg(
191 buf: &mut impl bytes::Buf,
192 &max_els: &Self::Cfg,
193 ) -> Result<Self, commonware_codec::Error> {
194 let cfg = RangeCfg::from(..=max_els);
195 let rows = usize::read_cfg(buf, &cfg)?;
196 let cols = usize::read_cfg(buf, &cfg)?;
197 let data = Vec::<F>::read_cfg(buf, &(cfg, ()))?;
198 let expected_len = rows
199 .checked_mul(cols)
200 .ok_or(commonware_codec::Error::Invalid(
201 "Matrix",
202 "matrix dimensions overflow",
203 ))?;
204 if data.len() != expected_len {
205 return Err(commonware_codec::Error::Invalid(
206 "Matrix",
207 "matrix element count does not match dimensions",
208 ));
209 }
210 Ok(Self { rows, cols, data })
211 }
212}
213
214impl core::fmt::Debug for Matrix {
215 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
216 for i in 0..self.rows {
217 let row_i = &self[i];
218 for &row_i_j in row_i {
219 write!(f, "{row_i_j:?} ")?;
220 }
221 writeln!(f)?;
222 }
223 Ok(())
224 }
225}
226
227impl Matrix {
228 /// Create a zero matrix, with a certain number of rows and columns
229 fn zero(rows: usize, cols: usize) -> Self {
230 Self {
231 rows,
232 cols,
233 data: vec![F::zero(); rows * cols],
234 }
235 }
236
237 /// Initialize a matrix, with dimensions, and data to pull from.
238 ///
239 /// Any extra data is ignored, any data not supplied is treated as 0.
240 pub fn init(rows: usize, cols: usize, mut data: impl Iterator<Item = F>) -> Self {
241 let mut out = Self::zero(rows, cols);
242 'outer: for i in 0..rows {
243 for row_i in &mut out[i] {
244 let Some(x) = data.next() else {
245 break 'outer;
246 };
247 *row_i = x;
248 }
249 }
250 out
251 }
252
253 /// Interpret the columns of this matrix as polynomials, with at least `min_coefficients`.
254 ///
255 /// This will, in fact, produce a matrix padded to the next power of 2 of that number.
256 ///
257 /// This will return `None` if `min_coefficients < self.rows`, which would mean
258 /// discarding data, instead of padding it.
259 pub fn as_polynomials(&self, min_coefficients: usize) -> Option<PolynomialVector> {
260 if min_coefficients < self.rows {
261 return None;
262 }
263 Some(PolynomialVector::new(
264 min_coefficients,
265 self.cols,
266 (0..self.rows).flat_map(|i| self[i].iter().copied()),
267 ))
268 }
269
270 /// Multiply this matrix by another.
271 ///
272 /// This assumes that the number of columns in this matrix match the number
273 /// of rows in the other matrix.
274 pub fn mul(&self, other: &Self) -> Self {
275 assert_eq!(self.cols, other.rows);
276 let mut out = Self::zero(self.rows, other.cols);
277 for i in 0..self.rows {
278 for j in 0..self.cols {
279 let c = self[(i, j)];
280 let other_j = &other[j];
281 for k in 0..other.cols {
282 out[(i, k)] = out[(i, k)] + c * other_j[k]
283 }
284 }
285 }
286 out
287 }
288
289 fn ntt<const FORWARD: bool>(&mut self) {
290 ntt::<FORWARD, Self>(self.rows, self.cols, self)
291 }
292
293 pub const fn rows(&self) -> usize {
294 self.rows
295 }
296
297 pub const fn cols(&self) -> usize {
298 self.cols
299 }
300
301 // Iterate over the rows of this matrix.
302 pub fn iter(&self) -> impl Iterator<Item = &[F]> {
303 (0..self.rows).map(|i| &self[i])
304 }
305
306 /// Create a random matrix with certain dimensions.
307 pub fn rand(mut rng: impl CryptoRngCore, rows: usize, cols: usize) -> Self {
308 Self::init(rows, cols, (0..rows * cols).map(|_| F::rand(&mut rng)))
309 }
310}
311
312impl Index<usize> for Matrix {
313 type Output = [F];
314
315 fn index(&self, index: usize) -> &Self::Output {
316 &self.data[self.cols * index..self.cols * (index + 1)]
317 }
318}
319
320impl IndexMut<usize> for Matrix {
321 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
322 &mut self.data[self.cols * index..self.cols * (index + 1)]
323 }
324}
325
326impl Index<(usize, usize)> for Matrix {
327 type Output = F;
328
329 fn index(&self, (i, j): (usize, usize)) -> &Self::Output {
330 &self.data[self.cols * i + j]
331 }
332}
333
334impl IndexMut<(usize, usize)> for Matrix {
335 fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
336 &mut self.data[self.cols * i + j]
337 }
338}
339
340#[cfg(feature = "arbitrary")]
341impl arbitrary::Arbitrary<'_> for Matrix {
342 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
343 let rows = u.int_in_range(1..=16)?;
344 let cols = u.int_in_range(1..=16)?;
345 let data = (0..rows * cols)
346 .map(|_| F::arbitrary(u))
347 .collect::<arbitrary::Result<Vec<F>>>()?;
348 Ok(Self { rows, cols, data })
349 }
350}
351
352#[derive(Clone, Debug, PartialEq)]
353struct NTTPolynomial {
354 coefficients: Vec<F>,
355}
356
357impl NTTPolynomial {
358 /// Create a polynomial which vanishes (evaluates to 0) except at a few points.
359 ///
360 /// It's assumed that `except` is a bit vector with length a power of 2.
361 ///
362 /// For each index i NOT IN `except`, the resulting polynomial will evaluate
363 /// to w^i, where w is a `except.len()` root of unity.
364 ///
365 /// e.g. with `except` = 1001, then the resulting polynomial will
366 /// evaluate to 0 at w^1 and w^2, where w is a 4th root of unity.
367 fn vanishing(except: &BitMap) -> Self {
368 // Algorithm taken from: https://ethresear.ch/t/reed-solomon-erasure-code-recovery-in-n-log-2-n-time-with-ffts/3039.
369 // The basic idea of the algorithm is that given a set of indices S,
370 // we can split it in two: the even indices (first bit = 0) and the odd indices.
371 // We compute two vanishing polynomials over
372 //
373 // S_L := {i / 2 | i in S}
374 // S_R := {(i - 1) / 2 | i in S}
375 //
376 // Using a domain of half the size. i.e. instead of w, they use w^2 as the root.
377 //
378 // V_L vanishes at (w^2)^(i / 2) for each i in S, i.e. w^i, for each even i in S.
379 // Similarly, V_R vanishes at (w^2)^((i - 1) / 2) = w^(i - 1), for each odd i in S.
380 //
381 // To combine these into one polynomial, we multiply the roots of V_R by w, so that it
382 // vanishes at the w^i (for odd i) instead of w^(i - 1).
383 //
384 // To multiply the roots of a polynomial
385 //
386 // P(X) := a0 + a1 X + a2 X^2 + ...
387 //
388 // by some factor z, it suffices to divide the ith coefficient by z^i:
389 //
390 // Q(X) := a0 + (a1 / z) X + (a2 / z^2) X^2 + ...
391 //
392 // Notice that Q(z X) = P(X), so if P(x) = 0, then Q(z x) = 0, so we've multiplied
393 // the roots by a factor of z.
394 //
395 // After multiplying the roots of V_R by w, we can then multiply the resulting polynomial
396 // with V_L, producing a polynomial which vanishes at the right indices.
397 //
398 // To multiply efficiently, we can do multiplication over the evaluation domain:
399 // we perform an NTT over each polynomial, multiplie the evaluations pointwise,
400 // and then perform an inverse NTT to get the result. We just need to make sure that
401 // when we perform the NTT, we've added enough extra 0 coefficients in each polynomial
402 // to accommodate the extra degree. e.g. if we have two polynomials of degree 1, then
403 // we need to make sure to pad them to have enough coefficients for a polynomial of degree 2,
404 // so that we can correctly interpolate the result back.
405 //
406 // The tricky part is transforming this algorithm into an iterative one, and respecting
407 // the reverse bit order of the coefficients we need
408 let rows = except.len() as usize;
409 let padded_rows = rows.next_power_of_two();
410 let zeroes = except.count_zeros() as usize + padded_rows - rows;
411 assert!(zeroes < padded_rows, "too many points to vanish over");
412 let lg_rows = padded_rows.ilog2();
413 // At each iteration, we split `except` into sections.
414 // Each section has a polynomial associated with it, which should
415 // be the polynomial that vanishes over all the 0 bits of that section,
416 // or the 0 polynomial if that section has no 0 bits.
417 //
418 // The sections are organized into a tree:
419 //
420 // 0xx 1xx
421 // 00x 01x 10x 11x
422 // 000 001 010 011 100 101 110 111
423 //
424 // The first half of the sections are even, the second half are odd.
425 // The first half of the first half have their first two bits set to 00,
426 // the second half of the first half have their first two bits set to 01,
427 // and so on.
428 //
429 // In other words, the ith index in except becomes the i.reverse_bits()
430 // section.
431 //
432 // How many polynomials do we have? (Potentially 0 ones).
433 let mut polynomial_count = padded_rows;
434 // How many coefficients does each polynomial have?
435 let mut polynomial_size: usize = 2;
436 // For the first iteration, each
437 let mut polynomials = vec![F::zero(); 2 * padded_rows];
438 let mut active = BitMap::<DEFAULT_CHUNK_SIZE>::with_capacity(polynomial_count as u64);
439 for i in 0..polynomial_count {
440 let rev_i = reverse_bits(lg_rows, i as u64) as usize;
441 if !except.get(rev_i as u64) {
442 polynomials[2 * i] = -F::one();
443 polynomials[2 * i + 1] = F::one();
444 active.push(true);
445 } else {
446 active.push(false);
447 }
448 }
449 // Rather than store w at each iteration, and divide by it, just store its inverse,
450 // allowing us to multiply by it.
451 let w_invs = {
452 // since w^(2^lg_rows) = 1, w^(2^lg_rows - 1) * w = 1,
453 // making that left-hand term the inverse of w.
454 let mut w_inv = F::root_of_unity(lg_rows as u8)
455 .expect("too many rows to create vanishing polynomial")
456 .exp(&[(1 << lg_rows) - 1]);
457 let lg_rows = lg_rows as usize;
458 let mut out = Vec::with_capacity(lg_rows);
459 for _ in 0..lg_rows {
460 out.push(w_inv);
461 w_inv = w_inv * w_inv;
462 }
463 out.reverse();
464 out
465 };
466 // When we multiply
467 let mut scratch: Vec<F> = Vec::with_capacity(padded_rows);
468 for w_inv in w_invs.into_iter() {
469 // After this iteration, we're going to end up with half the polynomials
470 polynomial_count >>= 1;
471 // and each of them will be twice as large.
472 let new_polynomial_size = polynomial_size << 1;
473 // Our goal is to construct the ith polynomial.
474 for i in 0..polynomial_count {
475 let start = new_polynomial_size * i;
476 let has_left = if ((2 * i) as u64) < active.len() {
477 active.get((2 * i) as u64)
478 } else {
479 false
480 };
481 let has_right = if ((2 * i + 1) as u64) < active.len() {
482 active.get((2 * i + 1) as u64)
483 } else {
484 false
485 };
486 match (has_left, has_right) {
487 // No polynomials to combine.
488 (false, false) => {}
489 // We need to multiply the roots of the right side,
490 // but then it can just expand to fill the entire polynomial.
491 (false, true) => {
492 let slice = &mut polynomials[start..start + new_polynomial_size];
493 // Scale the roots of the right side by w.
494 let lg_p_size = polynomial_size.ilog2();
495 let mut w_j = F::one();
496 for j in 0..polynomial_size {
497 let index =
498 polynomial_size + reverse_bits(lg_p_size, j as u64) as usize;
499 slice[index] = slice[index] * w_j;
500 w_j = w_j * w_inv;
501 }
502 // Expand the right side to occupy the entire space.
503 // The left side must be 0s.
504 for j in 0..polynomial_size {
505 slice.swap(polynomial_size + j, 2 * j);
506 }
507 }
508 // No need to multiply roots, but we do need to expand the left side.
509 (true, false) => {
510 let slice = &mut polynomials[start..start + new_polynomial_size];
511 // Expand the left side to occupy the entire space.
512 // The right side must be 0s.
513 for j in (0..polynomial_size).rev() {
514 slice.swap(j, 2 * j);
515 }
516 }
517 // We need to combine the two doing an actual multiplication.
518 (true, true) => {
519 debug_assert_eq!(scratch.len(), 0);
520 scratch.resize(new_polynomial_size, F::zero());
521 let slice = &mut polynomials[start..start + new_polynomial_size];
522
523 let lg_p_size = polynomial_size.ilog2();
524 let mut w_j = F::one();
525 for j in 0..polynomial_size {
526 let index =
527 polynomial_size + reverse_bits(lg_p_size, j as u64) as usize;
528 slice[index] = slice[index] * w_j;
529 w_j = w_j * w_inv;
530 }
531
532 // Expand the right side to occupy all of scratch.
533 // Clear the right side.
534 for j in 0..polynomial_size {
535 scratch[2 * j] = slice[polynomial_size + j];
536 slice[polynomial_size + j] = F::zero();
537 }
538
539 // Expand the left side to occupy the entire space.
540 // The right side has been cleared above.
541 for j in (0..polynomial_size).rev() {
542 slice.swap(j, 2 * j);
543 }
544
545 // Multiply the polynomials together, by first evaluating each of them,
546 // then multiplying their evaluations, producing (f * g) evaluated over
547 // the domain, which we can then interpolate back.
548 ntt::<true, _>(new_polynomial_size, 1, &mut Column { data: &mut scratch });
549 ntt::<true, _>(new_polynomial_size, 1, &mut Column { data: slice });
550 for (s_i, p_i) in scratch.drain(..).zip(slice.iter_mut()) {
551 *p_i = *p_i * s_i
552 }
553 ntt::<false, _>(new_polynomial_size, 1, &mut Column { data: slice })
554 }
555 }
556 // If there was a polynomial on the left or the right, then on the next iteration
557 // the combined section will have data to process, so we need to set it to true
558 // Resize active if needed and set the bit
559 active.set(i as u64, has_left | has_right);
560 }
561 polynomial_size = new_polynomial_size;
562 }
563 // If the final polynomial is inactive, there are no points to vanish over,
564 // so we want to return the polynomial f(X) = 1.
565 if !active.get(0) {
566 let mut coefficients = vec![F::zero(); padded_rows];
567 coefficients[0] = F::one();
568 return Self { coefficients };
569 }
570 // We have a polynomial that's twice the size we need, so we need to truncate it.
571 // This is the opposite of the sub-routine we had for expanding the left side to fit
572 // the entire polynomial.
573 for i in 0..padded_rows {
574 polynomials.swap(i, 2 * i);
575 }
576 polynomials.truncate(padded_rows);
577 Self {
578 coefficients: polynomials,
579 }
580 }
581
582 #[cfg(test)]
583 fn evaluate(&self, point: F) -> F {
584 let mut out = F::zero();
585 let rows = self.coefficients.len();
586 let lg_rows = rows.ilog2();
587 for i in (0..rows).rev() {
588 out = out * point + self.coefficients[reverse_bits(lg_rows, i as u64) as usize];
589 }
590 out
591 }
592
593 #[cfg(test)]
594 fn degree(&self) -> usize {
595 let rows = self.coefficients.len();
596 let lg_rows = rows.ilog2();
597 for i in (0..rows).rev() {
598 if self.coefficients[reverse_bits(lg_rows, i as u64) as usize] != F::zero() {
599 return i;
600 }
601 }
602 0
603 }
604
605 /// Divide the roots of each polynomial by some factor.
606 ///
607 /// If f(x) = 0, then after this transformation, f(x / z) = 0 instead.
608 ///
609 /// The number of roots does not change.
610 ///
611 /// c.f. [Self::vanishing] for an explanation of how this works.
612 fn divide_roots(&mut self, factor: F) {
613 let mut factor_i = F::one();
614 let lg_rows = self.coefficients.len().ilog2();
615 for i in 0..self.coefficients.len() {
616 let index = reverse_bits(lg_rows, i as u64) as usize;
617 self.coefficients[index] = self.coefficients[index] * factor_i;
618 factor_i = factor_i * factor;
619 }
620 }
621}
622
623#[derive(Clone, Debug, PartialEq)]
624pub struct PolynomialVector {
625 // Each column of this matrix contains the coefficients of a polynomial,
626 // in reverse bit order. So, the ith coefficient appears at index i.reverse_bits().
627 //
628 // For example, a polynomial a0 + a1 X + a2 X^2 + a3 X^3 is stored as:
629 //
630 // a0 a2 a1 a3
631 //
632 // This is convenient because the even coefficients and the odd coefficients
633 // split nicely into halves. The first half of the rows have the property
634 // that the first bit of their coefficient index is 0, then in that subset
635 // the first half has the second bit set to 0, and the second half set to 1,
636 // and so on, recursively.
637 data: Matrix,
638}
639
640impl PolynomialVector {
641 /// Construct a new vector of polynomials, from dimensions, and coefficients.
642 ///
643 /// The coefficients should be supplied in order of increasing index,
644 /// and then for each polynomial.
645 ///
646 /// In other words, if you have 3 polynomials:
647 ///
648 /// a0 + a1 X + ...
649 /// b0 + b1 X + ...
650 /// c0 + c1 X + ...
651 ///
652 /// The iterator should yield:
653 ///
654 /// a0 b0 c0
655 /// a1 b1 c1
656 /// ...
657 ///
658 /// Any coefficients not supplied are treated as being equal to 0.
659 fn new(rows: usize, cols: usize, mut coefficients: impl Iterator<Item = F>) -> Self {
660 assert!(rows > 0);
661 let rows = rows.next_power_of_two();
662 let lg_rows = rows.ilog2();
663 let mut data = Matrix::zero(rows, cols);
664 'outer: for i in 0..rows {
665 let row_i = &mut data[reverse_bits(lg_rows, i as u64) as usize];
666 for row_i_j in row_i {
667 let Some(c) = coefficients.next() else {
668 break 'outer;
669 };
670 *row_i_j = c;
671 }
672 }
673 Self { data }
674 }
675
676 /// Evaluate each polynomial in this vector over all points in an interpolation domain.
677 pub fn evaluate(mut self) -> EvaluationVector {
678 self.data.ntt::<true>();
679 let active_rows = BitMap::ones(self.data.rows as u64);
680 EvaluationVector {
681 data: self.data,
682 active_rows,
683 }
684 }
685
686 /// Like [Self::evaluate], but with a simpler algorithm that's much less efficient.
687 ///
688 /// Exists as a useful tool for testing
689 #[cfg(test)]
690 fn evaluate_naive(self) -> EvaluationVector {
691 let rows = self.data.rows;
692 let lg_rows = rows.ilog2();
693 let w = F::root_of_unity(lg_rows as u8).expect("too much data to calculate NTT");
694 // entry (i, j) of this matrix will contain w^ij. Thus, multiplying it
695 // with the coefficients of a polynomial, in column order, will evaluate it.
696 // We also need to re-arrange the columns of the matrix to match the same
697 // order we have for polynomial coefficients.
698 let mut vandermonde_matrix = Matrix::zero(rows, rows);
699 let mut w_i = F::one();
700 for i in 0..rows {
701 let row_i = &mut vandermonde_matrix[i];
702 let mut w_ij = F::one();
703 for j in 0..rows {
704 // Remember, the coeffients of the polynomial are in reverse bit order!
705 row_i[reverse_bits(lg_rows, j as u64) as usize] = w_ij;
706 w_ij = w_ij * w_i;
707 }
708 w_i = w_i * w;
709 }
710
711 EvaluationVector {
712 data: vandermonde_matrix.mul(&self.data),
713 active_rows: BitMap::ones(rows as u64),
714 }
715 }
716
717 /// Divide the roots of each polynomial by some factor.
718 ///
719 /// c.f. [NTTPolynomial::divide_roots]. This performs the same operation on
720 /// each polynomial in this vector.
721 fn divide_roots(&mut self, factor: F) {
722 let mut factor_i = F::one();
723 let lg_rows = self.data.rows.ilog2();
724 for i in 0..self.data.rows {
725 for p_i in &mut self.data[reverse_bits(lg_rows, i as u64) as usize] {
726 *p_i = *p_i * factor_i;
727 }
728 factor_i = factor_i * factor;
729 }
730 }
731
732 /// For each polynomial P_i in this vector compute the evaluation of P_i / Q.
733 ///
734 /// Naturally, you can call [EvaluationVector::interpolate]. The reason we don't
735 /// do this is that the algorithm naturally yields an [EvaluationVector], and
736 /// some use-cases may want access to that data as well.
737 ///
738 /// This assumes that the number of coefficients in the polynomials of this vector
739 /// matches that of `q` (the coefficients can be 0, but need to be padded to the right size).
740 ///
741 /// This assumes that `q` has no zeroes over [F::NOT_ROOT_OF_UNITY] * [F::ROOT_OF_UNITY]^i,
742 /// for any i. This will be the case for [NTTPolynomial::vanishing].
743 /// If this isn't the case, the result may be junk.
744 ///
745 /// If `q` doesn't divide a partiular polynomial in this vector, the result
746 /// for that polynomial is not guaranteed to be anything meaningful.
747 fn divide(&mut self, mut q: NTTPolynomial) {
748 // The algorithm operates column wise.
749 //
750 // You can compute P(X) / Q(X) by evaluating each polynomial, then computing
751 //
752 // P(w^i) / Q(w^i)
753 //
754 // for each evaluation point. Then, you can interpolate back.
755 //
756 // But wait! What if Q(w^i) = 0? In particular, for the case of recovering
757 // a polynomial from data with missing rows, we *expect* P(w^i) = 0 = Q(w^i)
758 // for the indicies we're missing, so this doesn't work.
759 //
760 // What we can do is to instead multiply each of the roots by some factor z,
761 // such that z w^i != w^j, for any i, j. In other words, we change the roots
762 // such that they're not in the evaluation domain anymore, allowing us to
763 // divide. We can then interpolate the result back into a polynomial,
764 // and divide back the roots to where they should be.
765 //
766 // c.f. [PolynomialVector::divide_roots]
767 assert_eq!(
768 self.data.rows,
769 q.coefficients.len(),
770 "cannot divide by polynomial of the wrong size"
771 );
772 let skew = F::NOT_ROOT_OF_UNITY;
773 let skew_inv = F::NOT_ROOT_OF_UNITY_INV;
774 self.divide_roots(skew);
775 q.divide_roots(skew);
776 ntt::<true, _>(self.data.rows, self.data.cols, &mut self.data);
777 ntt::<true, _>(
778 q.coefficients.len(),
779 1,
780 &mut Column {
781 data: &mut q.coefficients,
782 },
783 );
784 // Do a point wise division.
785 for i in 0..self.data.rows {
786 let q_i = q.coefficients[i];
787 // If `q_i = 0`, then we will get 0 in the output.
788 // We don't expect any of the q_i to be 0, but being 0 is only one
789 // of the many possibilities for the coefficient to be incorrect,
790 // so doing a runtime assertion here doesn't make sense.
791 let q_i_inv = q_i.inv();
792 for d_i_j in &mut self.data[i] {
793 *d_i_j = *d_i_j * q_i_inv;
794 }
795 }
796 // Interpolate back, using the inverse skew
797 ntt::<false, _>(self.data.rows, self.data.cols, &mut self.data);
798 self.divide_roots(skew_inv);
799 }
800
801 /// Iterate over up to n rows of this vector.
802 ///
803 /// For example, given polynomials:
804 ///
805 /// a0 + a1 X + a2 X^2 + ...
806 /// b0 + b1 X + b2 X^2 + ...
807 ///
808 /// This will return:
809 ///
810 /// a0 b0
811 /// a1 b1
812 /// ...
813 ///
814 /// up to n times.
815 pub fn coefficients_up_to(&self, n: usize) -> impl Iterator<Item = &[F]> {
816 let n = n.min(self.data.rows);
817 let lg_rows = self.data.rows().ilog2();
818 (0..n).map(move |i| &self.data[reverse_bits(lg_rows, i as u64) as usize])
819 }
820}
821
822/// The result of evaluating a vector of polynomials over all points in an interpolation domain.
823///
824/// This struct also remembers which rows have ever been filled with [Self::fill_row].
825/// This is used in [Self::recover], which can use the rows that are present to fill in the missing
826/// rows.
827#[derive(Debug, PartialEq)]
828pub struct EvaluationVector {
829 data: Matrix,
830 active_rows: BitMap,
831}
832
833impl EvaluationVector {
834 /// Figure out the polynomial which evaluates to this vector.
835 ///
836 /// i.e. the inverse of [PolynomialVector::evaluate].
837 ///
838 /// (This makes all the rows count as filled).
839 fn interpolate(mut self) -> PolynomialVector {
840 self.data.ntt::<false>();
841 PolynomialVector { data: self.data }
842 }
843
844 /// Create an empty element of this struct, with no filled rows.
845 pub fn empty(lg_rows: usize, cols: usize) -> Self {
846 let data = Matrix::zero(1 << lg_rows, cols);
847 let active = BitMap::zeroes(data.rows as u64);
848 Self {
849 data,
850 active_rows: active,
851 }
852 }
853
854 /// Fill a specific row.
855 pub fn fill_row(&mut self, row: usize, data: &[F]) {
856 assert!(data.len() <= self.data.cols);
857 self.data[row][..data.len()].copy_from_slice(data);
858 self.active_rows.set(row as u64, true);
859 }
860
861 /// Erase a particular row.
862 ///
863 /// Useful for testing the recovery procedure.
864 #[cfg(test)]
865 fn remove_row(&mut self, row: usize) {
866 self.data[row].fill(F::zero());
867 self.active_rows.set(row as u64, false);
868 }
869
870 fn multiply(&mut self, polynomial: NTTPolynomial) {
871 let NTTPolynomial { mut coefficients } = polynomial;
872 ntt::<true, _>(
873 coefficients.len(),
874 1,
875 &mut Column {
876 data: &mut coefficients,
877 },
878 );
879 for (i, &c_i) in coefficients.iter().enumerate() {
880 for self_j in &mut self.data[i] {
881 *self_j = *self_j * c_i;
882 }
883 }
884 }
885
886 /// Attempt to recover the missing rows in this data.
887 pub fn recover(mut self) -> PolynomialVector {
888 // If we had all of the rows, we could simply call [Self::interpolate],
889 // in order to recover the original polynomial. If we do this while missing some
890 // rows, what we get is D(X) * V(X) where D is the original polynomial,
891 // and V(X) is a polynomial which vanishes at all the rows we're missing.
892 //
893 // As long as the degree of D is low enough, compared to the number of evaluations
894 // we *do* have, then we can recover it by performing:
895 //
896 // (D(X) * V(X)) / V(X)
897 //
898 // If we have multiple columns, then this procedure can be done column by column,
899 // with the same vanishing polynomial.
900 let vanishing = NTTPolynomial::vanishing(&self.active_rows);
901 self.multiply(vanishing.clone());
902 let mut out = self.interpolate();
903 out.divide(vanishing);
904 out
905 }
906
907 /// Get the underlying data, as a Matrix.
908 pub fn data(self) -> Matrix {
909 self.data
910 }
911
912 /// Return how many distinct rows have been filled.
913 pub fn filled_rows(&self) -> usize {
914 self.active_rows.count_ones() as usize
915 }
916}
917
918#[cfg(test)]
919mod test {
920 use super::*;
921 use proptest::prelude::*;
922
923 fn any_f() -> impl Strategy<Value = F> {
924 any::<u64>().prop_map(F::from)
925 }
926
927 #[test]
928 fn test_reverse_bits() {
929 assert_eq!(reverse_bits(4, 0b1000), 0b0001);
930 assert_eq!(reverse_bits(4, 0b0100), 0b0010);
931 assert_eq!(reverse_bits(4, 0b0010), 0b0100);
932 assert_eq!(reverse_bits(4, 0b0001), 0b1000);
933 }
934
935 #[test]
936 fn matrix_read_rejects_length_mismatch() {
937 use bytes::BytesMut;
938 use commonware_codec::{Read as _, Write as _};
939
940 let mut buf = BytesMut::new();
941 (2usize).write(&mut buf);
942 (2usize).write(&mut buf);
943 vec![F::one(); 3].write(&mut buf);
944
945 let mut bytes = buf.freeze();
946 let result = Matrix::read_cfg(&mut bytes, &8);
947 assert!(matches!(
948 result,
949 Err(commonware_codec::Error::Invalid(
950 "Matrix",
951 "matrix element count does not match dimensions"
952 ))
953 ));
954 }
955
956 fn any_polynomial_vector(
957 max_log_rows: usize,
958 max_cols: usize,
959 ) -> impl Strategy<Value = PolynomialVector> {
960 (0..=max_log_rows).prop_flat_map(move |lg_rows| {
961 (1..=max_cols).prop_flat_map(move |cols| {
962 let rows = 1 << lg_rows;
963 proptest::collection::vec(any_f(), rows * cols).prop_map(move |coefficients| {
964 PolynomialVector::new(rows, cols, coefficients.into_iter())
965 })
966 })
967 })
968 }
969
970 fn any_bit_vec_not_all_0(max_log_rows: usize) -> impl Strategy<Value = BitMap> {
971 (0..=max_log_rows).prop_flat_map(move |lg_rows| {
972 let rows = (1 << lg_rows) as usize;
973 (0..rows).prop_flat_map(move |set_row| {
974 proptest::collection::vec(any::<bool>(), 1 << lg_rows).prop_map(move |mut bools| {
975 bools[set_row] = true;
976 BitMap::from(bools.as_slice())
977 })
978 })
979 })
980 }
981
982 #[derive(Debug)]
983 struct RecoverySetup {
984 n: usize,
985 k: usize,
986 cols: usize,
987 data: Vec<F>,
988 present: BitMap,
989 }
990
991 impl RecoverySetup {
992 fn any(max_n: usize, max_k: usize, max_cols: usize) -> impl Strategy<Value = Self> {
993 (1..=max_n).prop_flat_map(move |n| {
994 (0..=max_k).prop_flat_map(move |k| {
995 (1..=max_cols).prop_flat_map(move |cols| {
996 proptest::collection::vec(any_f(), n * cols).prop_flat_map(move |data| {
997 let padded_rows = (n + k).next_power_of_two();
998 proptest::sample::subsequence(
999 (0..padded_rows).collect::<Vec<_>>(),
1000 n..=padded_rows,
1001 )
1002 .prop_map(move |indices| {
1003 let mut present = BitMap::zeroes(padded_rows as u64);
1004 for i in indices {
1005 present.set(i as u64, true);
1006 }
1007 Self {
1008 n,
1009 k,
1010 cols,
1011 // idk why this is necessary, but who cares
1012 data: data.clone(),
1013 present,
1014 }
1015 })
1016 })
1017 })
1018 })
1019 })
1020 }
1021
1022 fn test(self) {
1023 let data = PolynomialVector::new(self.n + self.k, self.cols, self.data.into_iter());
1024 let mut encoded = data.clone().evaluate();
1025 for (i, b_i) in self.present.iter().enumerate() {
1026 if !b_i {
1027 encoded.remove_row(i);
1028 }
1029 }
1030 let recovered_data = encoded.recover();
1031 assert_eq!(data, recovered_data);
1032 }
1033 }
1034
1035 #[test]
1036 fn test_recovery_000() {
1037 RecoverySetup {
1038 n: 1,
1039 k: 1,
1040 cols: 1,
1041 data: vec![F::one()],
1042 present: vec![false, true].into(),
1043 }
1044 .test()
1045 }
1046
1047 proptest! {
1048 #[test]
1049 fn test_ntt_eq_naive(p in any_polynomial_vector(6, 4)) {
1050 let ntt = p.clone().evaluate();
1051 let ntt_naive = p.evaluate_naive();
1052 assert_eq!(ntt, ntt_naive);
1053 }
1054
1055 #[test]
1056 fn test_evaluation_then_inverse(p in any_polynomial_vector(6, 4)) {
1057 assert_eq!(p.clone(), p.evaluate().interpolate());
1058 }
1059
1060 #[test]
1061 fn test_vanishing_polynomial(bv in any_bit_vec_not_all_0(8)) {
1062 let v = NTTPolynomial::vanishing(&bv);
1063 let expected_degree = bv.count_zeros();
1064 assert_eq!(v.degree(), expected_degree as usize, "expected v to have degree {expected_degree}");
1065 let w = F::root_of_unity(bv.len().ilog2() as u8).unwrap();
1066 let mut w_i = F::one();
1067 for b_i in bv.iter() {
1068 let v_at_w_i = v.evaluate(w_i);
1069 if !b_i {
1070 assert_eq!(v_at_w_i, F::zero(), "v should evaluate to 0 at {w_i:?}");
1071 } else {
1072 assert_ne!(v_at_w_i, F::zero());
1073 }
1074 w_i = w_i * w;
1075 }
1076 }
1077
1078 #[test]
1079 fn test_recovery(setup in RecoverySetup::any(128, 128, 4)) {
1080 setup.test();
1081 }
1082 }
1083
1084 #[cfg(feature = "arbitrary")]
1085 mod conformance {
1086 use super::*;
1087 use commonware_codec::conformance::CodecConformance;
1088
1089 commonware_conformance::conformance_tests! {
1090 CodecConformance<Matrix>,
1091 }
1092 }
1093}