commonware_math/
ntt.rs

1use crate::{
2    algebra::{Additive as _, Ring},
3    fields::goldilocks::F,
4};
5#[cfg(not(feature = "std"))]
6use alloc::{vec, vec::Vec};
7use commonware_codec::{EncodeSize, RangeCfg, Read, Write};
8use commonware_utils::bitmap::{BitMap, DEFAULT_CHUNK_SIZE};
9use core::ops::{Index, IndexMut};
10use rand_core::CryptoRngCore;
11
12/// Reverse the first `bit_width` bits of `i`.
13///
14/// Any bits beyond that width will be erased.
15fn reverse_bits(bit_width: u32, i: u64) -> u64 {
16    assert!(bit_width <= 64, "bit_width must be <= 64");
17    i.wrapping_shl(64 - bit_width).reverse_bits()
18}
19
20/// Calculate an NTT, or an inverse NTT (with FORWARD=false), in place.
21///
22/// We implement this generically over anything we can index into, which allows
23/// performing NTTs in place
24fn ntt<const FORWARD: bool, M: IndexMut<(usize, usize), Output = F>>(
25    rows: usize,
26    cols: usize,
27    matrix: &mut M,
28) {
29    let lg_rows = rows.ilog2() as usize;
30    assert_eq!(1 << lg_rows, rows, "rows should be a power of 2");
31    // A number w such that w^(2^lg_rows) = 1.
32    // (Or, in the inverse case, the inverse of that number, to undo the NTT).
33    let w = {
34        let w = F::root_of_unity(lg_rows as u8).expect("too many rows to perform NTT");
35        if FORWARD {
36            w
37        } else {
38            // since w^(2^lg_rows) = 1, w^(2^lg_rows - 1) * w = 1,
39            // making that left-hand term the inverse of w.
40            w.exp(&[(1 << lg_rows) - 1])
41        }
42    };
43    // The inverse algorithm consists of carefully undoing the work of the
44    // standard algorithm, so we describe that in detail.
45    //
46    // To understand the NTT algorithm, first consider the case of a single
47    // column. We have a polynomial f(X), and we want to turn that into:
48    //
49    // [f(w^0), f(w^1), ..., f(w^(2^lg_rows - 1))]
50    //
51    // Our polynomial can be written as:
52    //
53    // f+(X^2) + X f-(X^2)
54    //
55    // where f+ and f- are polynomials with half the degree.
56    // f+ is obtained by taking the coefficients at even indices,
57    // f- is obtained by taking the coefficients at odd indices.
58    //
59    // w^2 is also conveniently a 2^(lg_rows - 1) root of unity. Thus,
60    // we can recursively compute an NTT on f+, using w^2 as the root,
61    // and an NTT on f-, using w^2 as the root, each of which is a problem
62    // of half the size.
63    //
64    // We can then compute:
65    // f+((w^i)^2) + (w^i) f-((w^i)^2)
66    // f+((w^i)^2) - (w^i) f-((w^i)^2)
67    // for each i.
68    // (Note that (-w^i)^2 = ((-w)^2)^i = (w^i)^2))
69    //
70    // Our coefficients are conveniently laid out as [f+ f-], already
71    // in a neat order. When we recurse, the coefficients of f+ are, in
72    // turn, already laid out as [f++ f+-], and so on.
73    //
74    // We just need to transform this recursive algorithm, in top down form,
75    // into an iterative one, in bottom up form. For that, note that the NTT
76    // for the case of 1 row is trivial: do nothing.
77
78    // Will contain, in bottom up order, the power of w we need at that stage.
79    // At the last stage, we need w itself.
80    // At the stage before last, we need w^2.
81    // And so on.
82    // How many stages do we need? If we have 1 row, we need 0 stages.
83    // In general, with 2^n rows, we need n stages.
84    let stages = {
85        let mut out = vec![(0usize, F::zero()); lg_rows];
86        let mut w_i = w;
87        for i in (0..lg_rows).rev() {
88            out[i] = (i, w_i);
89            w_i = w_i * w_i;
90        }
91        // In the case of the reverse algorithm, we undo each stage of the
92        // forward algorithm, starting with the last stage.
93        if !FORWARD {
94            out.reverse();
95        }
96        out
97    };
98    for (stage, w) in stages.into_iter() {
99        // At stage i, we have polynomials with 2^i coefficients,
100        // which have already been evaluated to create 2^i entries.
101        // We need to combine these evaluations to create 2^(i + 1) entries,
102        // representing the evaluation of a polynomial with 2^(i + 1) coefficients.
103        // If we have two of these evaluations, laid out one after the other:
104        //
105        // [x_0, x_1, ...] [y_0, y_1, ...]
106        //
107        // Then the number of elements we need to skip to get the corresponding
108        // element in the other half is simply the number of elements in each half,
109        // i.e. 2^i.
110        let skip = 1 << stage;
111        let mut i = 0;
112        while i < rows {
113            // In the case of a backwards NTT, skew should be the inverse of the skew
114            // in the forwards direction.
115            let mut w_j = F::one();
116            for j in 0..skip {
117                let index_a = i + j;
118                let index_b = index_a + skip;
119                for k in 0..cols {
120                    let (a, b) = (matrix[(index_a, k)], matrix[(index_b, k)]);
121                    if FORWARD {
122                        matrix[(index_a, k)] = a + w_j * b;
123                        matrix[(index_b, k)] = a - w_j * b;
124                    } else {
125                        // To check the math, convince yourself that applying the forward
126                        // transformation, and then this transformation, with w_j being the
127                        // inverse of the value above, that you get (a, b).
128                        // (a + w_j * b) + (a - w_j * b) = 2 * a
129                        matrix[(index_a, k)] = (a + b).div_2();
130                        // (a + w_j * b) - (a - w_j * b) = 2 * w_j * b.
131                        // w_j in this branch is the inverse of w_j in the other branch.
132                        matrix[(index_b, k)] = ((a - b) * w_j).div_2();
133                    }
134                }
135                w_j = w_j * w;
136            }
137            i += 2 * skip;
138        }
139    }
140}
141
142/// A single column of some larger data.
143///
144/// This allows us to easily do NTTs over partial segments of some bigger matrix.
145struct Column<'a> {
146    data: &'a mut [F],
147}
148
149impl<'a> Index<(usize, usize)> for Column<'a> {
150    type Output = F;
151
152    fn index(&self, (i, _): (usize, usize)) -> &Self::Output {
153        &self.data[i]
154    }
155}
156impl<'a> IndexMut<(usize, usize)> for Column<'a> {
157    fn index_mut(&mut self, (i, _): (usize, usize)) -> &mut Self::Output {
158        &mut self.data[i]
159    }
160}
161
162/// Represents a matrix of field elements, of arbitrary dimensions
163///
164/// This is in row major order, so consider processing elements in the same
165/// row first, for locality.
166#[derive(Clone, PartialEq)]
167pub struct Matrix {
168    rows: usize,
169    cols: usize,
170    data: Vec<F>,
171}
172
173impl EncodeSize for Matrix {
174    fn encode_size(&self) -> usize {
175        self.rows.encode_size() + self.cols.encode_size() + self.data.encode_size()
176    }
177}
178
179impl Write for Matrix {
180    fn write(&self, buf: &mut impl bytes::BufMut) {
181        self.rows.write(buf);
182        self.cols.write(buf);
183        self.data.write(buf);
184    }
185}
186
187impl Read for Matrix {
188    type Cfg = usize;
189
190    fn read_cfg(
191        buf: &mut impl bytes::Buf,
192        &max_els: &Self::Cfg,
193    ) -> Result<Self, commonware_codec::Error> {
194        let cfg = RangeCfg::from(..=max_els);
195        let rows = usize::read_cfg(buf, &cfg)?;
196        let cols = usize::read_cfg(buf, &cfg)?;
197        let data = Vec::<F>::read_cfg(buf, &(cfg, ()))?;
198        let expected_len = rows
199            .checked_mul(cols)
200            .ok_or(commonware_codec::Error::Invalid(
201                "Matrix",
202                "matrix dimensions overflow",
203            ))?;
204        if data.len() != expected_len {
205            return Err(commonware_codec::Error::Invalid(
206                "Matrix",
207                "matrix element count does not match dimensions",
208            ));
209        }
210        Ok(Self { rows, cols, data })
211    }
212}
213
214impl core::fmt::Debug for Matrix {
215    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
216        for i in 0..self.rows {
217            let row_i = &self[i];
218            for &row_i_j in row_i {
219                write!(f, "{row_i_j:?} ")?;
220            }
221            writeln!(f)?;
222        }
223        Ok(())
224    }
225}
226
227impl Matrix {
228    /// Create a zero matrix, with a certain number of rows and columns
229    fn zero(rows: usize, cols: usize) -> Self {
230        Self {
231            rows,
232            cols,
233            data: vec![F::zero(); rows * cols],
234        }
235    }
236
237    /// Initialize a matrix, with dimensions, and data to pull from.
238    ///
239    /// Any extra data is ignored, any data not supplied is treated as 0.
240    pub fn init(rows: usize, cols: usize, mut data: impl Iterator<Item = F>) -> Self {
241        let mut out = Self::zero(rows, cols);
242        'outer: for i in 0..rows {
243            for row_i in &mut out[i] {
244                let Some(x) = data.next() else {
245                    break 'outer;
246                };
247                *row_i = x;
248            }
249        }
250        out
251    }
252
253    /// Interpret the columns of this matrix as polynomials, with at least `min_coefficients`.
254    ///
255    /// This will, in fact, produce a matrix padded to the next power of 2 of that number.
256    ///
257    /// This will return `None` if `min_coefficients < self.rows`, which would mean
258    /// discarding data, instead of padding it.
259    pub fn as_polynomials(&self, min_coefficients: usize) -> Option<PolynomialVector> {
260        if min_coefficients < self.rows {
261            return None;
262        }
263        Some(PolynomialVector::new(
264            min_coefficients,
265            self.cols,
266            (0..self.rows).flat_map(|i| self[i].iter().copied()),
267        ))
268    }
269
270    /// Multiply this matrix by another.
271    ///
272    /// This assumes that the number of columns in this matrix match the number
273    /// of rows in the other matrix.
274    pub fn mul(&self, other: &Self) -> Self {
275        assert_eq!(self.cols, other.rows);
276        let mut out = Self::zero(self.rows, other.cols);
277        for i in 0..self.rows {
278            for j in 0..self.cols {
279                let c = self[(i, j)];
280                let other_j = &other[j];
281                for k in 0..other.cols {
282                    out[(i, k)] = out[(i, k)] + c * other_j[k]
283                }
284            }
285        }
286        out
287    }
288
289    fn ntt<const FORWARD: bool>(&mut self) {
290        ntt::<FORWARD, Self>(self.rows, self.cols, self)
291    }
292
293    pub const fn rows(&self) -> usize {
294        self.rows
295    }
296
297    pub const fn cols(&self) -> usize {
298        self.cols
299    }
300
301    // Iterate over the rows of this matrix.
302    pub fn iter(&self) -> impl Iterator<Item = &[F]> {
303        (0..self.rows).map(|i| &self[i])
304    }
305
306    /// Create a random matrix with certain dimensions.
307    pub fn rand(mut rng: impl CryptoRngCore, rows: usize, cols: usize) -> Self {
308        Self::init(rows, cols, (0..rows * cols).map(|_| F::rand(&mut rng)))
309    }
310}
311
312impl Index<usize> for Matrix {
313    type Output = [F];
314
315    fn index(&self, index: usize) -> &Self::Output {
316        &self.data[self.cols * index..self.cols * (index + 1)]
317    }
318}
319
320impl IndexMut<usize> for Matrix {
321    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
322        &mut self.data[self.cols * index..self.cols * (index + 1)]
323    }
324}
325
326impl Index<(usize, usize)> for Matrix {
327    type Output = F;
328
329    fn index(&self, (i, j): (usize, usize)) -> &Self::Output {
330        &self.data[self.cols * i + j]
331    }
332}
333
334impl IndexMut<(usize, usize)> for Matrix {
335    fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
336        &mut self.data[self.cols * i + j]
337    }
338}
339
340#[cfg(feature = "arbitrary")]
341impl arbitrary::Arbitrary<'_> for Matrix {
342    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
343        let rows = u.int_in_range(1..=16)?;
344        let cols = u.int_in_range(1..=16)?;
345        let data = (0..rows * cols)
346            .map(|_| F::arbitrary(u))
347            .collect::<arbitrary::Result<Vec<F>>>()?;
348        Ok(Self { rows, cols, data })
349    }
350}
351
352#[derive(Clone, Debug, PartialEq)]
353struct NTTPolynomial {
354    coefficients: Vec<F>,
355}
356
357impl NTTPolynomial {
358    /// Create a polynomial which vanishes (evaluates to 0) except at a few points.
359    ///
360    /// It's assumed that `except` is a bit vector with length a power of 2.
361    ///
362    /// For each index i NOT IN `except`, the resulting polynomial will evaluate
363    /// to w^i, where w is a `except.len()` root of unity.
364    ///
365    /// e.g. with `except` = 1001, then the resulting polynomial will
366    /// evaluate to 0 at w^1 and w^2, where w is a 4th root of unity.
367    fn vanishing(except: &BitMap) -> Self {
368        // Algorithm taken from: https://ethresear.ch/t/reed-solomon-erasure-code-recovery-in-n-log-2-n-time-with-ffts/3039.
369        // The basic idea of the algorithm is that given a set of indices S,
370        // we can split it in two: the even indices (first bit = 0) and the odd indices.
371        // We compute two vanishing polynomials over
372        //
373        //   S_L := {i / 2 | i in S}
374        //   S_R := {(i - 1) / 2 | i in S}
375        //
376        // Using a domain of half the size. i.e. instead of w, they use w^2 as the root.
377        //
378        // V_L vanishes at (w^2)^(i / 2) for each i in S, i.e. w^i, for each even i in S.
379        // Similarly, V_R vanishes at (w^2)^((i - 1) / 2) = w^(i - 1), for each odd i in S.
380        //
381        // To combine these into one polynomial, we multiply the roots of V_R by w, so that it
382        // vanishes at the w^i (for odd i) instead of w^(i - 1).
383        //
384        // To multiply the roots of a polynomial
385        //
386        //   P(X) := a0 + a1 X + a2 X^2 + ...
387        //
388        // by some factor z, it suffices to divide the ith coefficient by z^i:
389        //
390        //   Q(X) := a0 + (a1 / z) X + (a2 / z^2) X^2 + ...
391        //
392        // Notice that Q(z X) = P(X), so if P(x) = 0, then Q(z x) = 0, so we've multiplied
393        // the roots by a factor of z.
394        //
395        // After multiplying the roots of V_R by w, we can then multiply the resulting polynomial
396        // with V_L, producing a polynomial which vanishes at the right indices.
397        //
398        // To multiply efficiently, we can do multiplication over the evaluation domain:
399        // we perform an NTT over each polynomial, multiplie the evaluations pointwise,
400        // and then perform an inverse NTT to get the result. We just need to make sure that
401        // when we perform the NTT, we've added enough extra 0 coefficients in each polynomial
402        // to accommodate the extra degree. e.g. if we have two polynomials of degree 1, then
403        // we need to make sure to pad them to have enough coefficients for a polynomial of degree 2,
404        // so that we can correctly interpolate the result back.
405        //
406        // The tricky part is transforming this algorithm into an iterative one, and respecting
407        // the reverse bit order of the coefficients we need
408        let rows = except.len() as usize;
409        let padded_rows = rows.next_power_of_two();
410        let zeroes = except.count_zeros() as usize + padded_rows - rows;
411        assert!(zeroes < padded_rows, "too many points to vanish over");
412        let lg_rows = padded_rows.ilog2();
413        // At each iteration, we split `except` into sections.
414        // Each section has a polynomial associated with it, which should
415        // be the polynomial that vanishes over all the 0 bits of that section,
416        // or the 0 polynomial if that section has no 0 bits.
417        //
418        // The sections are organized into a tree:
419        //
420        // 0xx             1xx
421        // 00x     01x     10x         11x
422        // 000 001 010 011 100 101 110 111
423        //
424        // The first half of the sections are even, the second half are odd.
425        // The first half of the first half have their first two bits set to 00,
426        // the second half of the first half have their first two bits set to 01,
427        // and so on.
428        //
429        // In other words, the ith index in except becomes the i.reverse_bits()
430        // section.
431        //
432        // How many polynomials do we have? (Potentially 0 ones).
433        let mut polynomial_count = padded_rows;
434        // How many coefficients does each polynomial have?
435        let mut polynomial_size: usize = 2;
436        // For the first iteration, each
437        let mut polynomials = vec![F::zero(); 2 * padded_rows];
438        let mut active = BitMap::<DEFAULT_CHUNK_SIZE>::with_capacity(polynomial_count as u64);
439        for i in 0..polynomial_count {
440            let rev_i = reverse_bits(lg_rows, i as u64) as usize;
441            if !except.get(rev_i as u64) {
442                polynomials[2 * i] = -F::one();
443                polynomials[2 * i + 1] = F::one();
444                active.push(true);
445            } else {
446                active.push(false);
447            }
448        }
449        // Rather than store w at each iteration, and divide by it, just store its inverse,
450        // allowing us to multiply by it.
451        let w_invs = {
452            // since w^(2^lg_rows) = 1, w^(2^lg_rows - 1) * w = 1,
453            // making that left-hand term the inverse of w.
454            let mut w_inv = F::root_of_unity(lg_rows as u8)
455                .expect("too many rows to create vanishing polynomial")
456                .exp(&[(1 << lg_rows) - 1]);
457            let lg_rows = lg_rows as usize;
458            let mut out = Vec::with_capacity(lg_rows);
459            for _ in 0..lg_rows {
460                out.push(w_inv);
461                w_inv = w_inv * w_inv;
462            }
463            out.reverse();
464            out
465        };
466        // When we multiply
467        let mut scratch: Vec<F> = Vec::with_capacity(padded_rows);
468        for w_inv in w_invs.into_iter() {
469            // After this iteration, we're going to end up with half the polynomials
470            polynomial_count >>= 1;
471            // and each of them will be twice as large.
472            let new_polynomial_size = polynomial_size << 1;
473            // Our goal is to construct the ith polynomial.
474            for i in 0..polynomial_count {
475                let start = new_polynomial_size * i;
476                let has_left = if ((2 * i) as u64) < active.len() {
477                    active.get((2 * i) as u64)
478                } else {
479                    false
480                };
481                let has_right = if ((2 * i + 1) as u64) < active.len() {
482                    active.get((2 * i + 1) as u64)
483                } else {
484                    false
485                };
486                match (has_left, has_right) {
487                    // No polynomials to combine.
488                    (false, false) => {}
489                    // We need to multiply the roots of the right side,
490                    // but then it can just expand to fill the entire polynomial.
491                    (false, true) => {
492                        let slice = &mut polynomials[start..start + new_polynomial_size];
493                        // Scale the roots of the right side by w.
494                        let lg_p_size = polynomial_size.ilog2();
495                        let mut w_j = F::one();
496                        for j in 0..polynomial_size {
497                            let index =
498                                polynomial_size + reverse_bits(lg_p_size, j as u64) as usize;
499                            slice[index] = slice[index] * w_j;
500                            w_j = w_j * w_inv;
501                        }
502                        // Expand the right side to occupy the entire space.
503                        // The left side must be 0s.
504                        for j in 0..polynomial_size {
505                            slice.swap(polynomial_size + j, 2 * j);
506                        }
507                    }
508                    // No need to multiply roots, but we do need to expand the left side.
509                    (true, false) => {
510                        let slice = &mut polynomials[start..start + new_polynomial_size];
511                        // Expand the left side to occupy the entire space.
512                        // The right side must be 0s.
513                        for j in (0..polynomial_size).rev() {
514                            slice.swap(j, 2 * j);
515                        }
516                    }
517                    // We need to combine the two doing an actual multiplication.
518                    (true, true) => {
519                        debug_assert_eq!(scratch.len(), 0);
520                        scratch.resize(new_polynomial_size, F::zero());
521                        let slice = &mut polynomials[start..start + new_polynomial_size];
522
523                        let lg_p_size = polynomial_size.ilog2();
524                        let mut w_j = F::one();
525                        for j in 0..polynomial_size {
526                            let index =
527                                polynomial_size + reverse_bits(lg_p_size, j as u64) as usize;
528                            slice[index] = slice[index] * w_j;
529                            w_j = w_j * w_inv;
530                        }
531
532                        // Expand the right side to occupy all of scratch.
533                        // Clear the right side.
534                        for j in 0..polynomial_size {
535                            scratch[2 * j] = slice[polynomial_size + j];
536                            slice[polynomial_size + j] = F::zero();
537                        }
538
539                        // Expand the left side to occupy the entire space.
540                        // The right side has been cleared above.
541                        for j in (0..polynomial_size).rev() {
542                            slice.swap(j, 2 * j);
543                        }
544
545                        // Multiply the polynomials together, by first evaluating each of them,
546                        // then multiplying their evaluations, producing (f * g) evaluated over
547                        // the domain, which we can then interpolate back.
548                        ntt::<true, _>(new_polynomial_size, 1, &mut Column { data: &mut scratch });
549                        ntt::<true, _>(new_polynomial_size, 1, &mut Column { data: slice });
550                        for (s_i, p_i) in scratch.drain(..).zip(slice.iter_mut()) {
551                            *p_i = *p_i * s_i
552                        }
553                        ntt::<false, _>(new_polynomial_size, 1, &mut Column { data: slice })
554                    }
555                }
556                // If there was a polynomial on the left or the right, then on the next iteration
557                // the combined section will have data to process, so we need to set it to true
558                // Resize active if needed and set the bit
559                active.set(i as u64, has_left | has_right);
560            }
561            polynomial_size = new_polynomial_size;
562        }
563        // If the final polynomial is inactive, there are no points to vanish over,
564        // so we want to return the polynomial f(X) = 1.
565        if !active.get(0) {
566            let mut coefficients = vec![F::zero(); padded_rows];
567            coefficients[0] = F::one();
568            return Self { coefficients };
569        }
570        // We have a polynomial that's twice the size we need, so we need to truncate it.
571        // This is the opposite of the sub-routine we had for expanding the left side to fit
572        // the entire polynomial.
573        for i in 0..padded_rows {
574            polynomials.swap(i, 2 * i);
575        }
576        polynomials.truncate(padded_rows);
577        Self {
578            coefficients: polynomials,
579        }
580    }
581
582    #[cfg(test)]
583    fn evaluate(&self, point: F) -> F {
584        let mut out = F::zero();
585        let rows = self.coefficients.len();
586        let lg_rows = rows.ilog2();
587        for i in (0..rows).rev() {
588            out = out * point + self.coefficients[reverse_bits(lg_rows, i as u64) as usize];
589        }
590        out
591    }
592
593    #[cfg(test)]
594    fn degree(&self) -> usize {
595        let rows = self.coefficients.len();
596        let lg_rows = rows.ilog2();
597        for i in (0..rows).rev() {
598            if self.coefficients[reverse_bits(lg_rows, i as u64) as usize] != F::zero() {
599                return i;
600            }
601        }
602        0
603    }
604
605    /// Divide the roots of each polynomial by some factor.
606    ///
607    /// If f(x) = 0, then after this transformation, f(x / z) = 0 instead.
608    ///
609    /// The number of roots does not change.
610    ///
611    /// c.f. [Self::vanishing] for an explanation of how this works.
612    fn divide_roots(&mut self, factor: F) {
613        let mut factor_i = F::one();
614        let lg_rows = self.coefficients.len().ilog2();
615        for i in 0..self.coefficients.len() {
616            let index = reverse_bits(lg_rows, i as u64) as usize;
617            self.coefficients[index] = self.coefficients[index] * factor_i;
618            factor_i = factor_i * factor;
619        }
620    }
621}
622
623#[derive(Clone, Debug, PartialEq)]
624pub struct PolynomialVector {
625    // Each column of this matrix contains the coefficients of a polynomial,
626    // in reverse bit order. So, the ith coefficient appears at index i.reverse_bits().
627    //
628    // For example, a polynomial a0 + a1 X + a2 X^2 + a3 X^3 is stored as:
629    //
630    // a0 a2 a1 a3
631    //
632    // This is convenient because the even coefficients and the odd coefficients
633    // split nicely into halves. The first half of the rows have the property
634    // that the first bit of their coefficient index is 0, then in that subset
635    // the first half has the second bit set to 0, and the second half set to 1,
636    // and so on, recursively.
637    data: Matrix,
638}
639
640impl PolynomialVector {
641    /// Construct a new vector of polynomials, from dimensions, and coefficients.
642    ///
643    /// The coefficients should be supplied in order of increasing index,
644    /// and then for each polynomial.
645    ///
646    /// In other words, if you have 3 polynomials:
647    ///
648    /// a0 + a1 X + ...
649    /// b0 + b1 X + ...
650    /// c0 + c1 X + ...
651    ///
652    /// The iterator should yield:
653    ///
654    /// a0 b0 c0
655    /// a1 b1 c1
656    /// ...
657    ///
658    /// Any coefficients not supplied are treated as being equal to 0.
659    fn new(rows: usize, cols: usize, mut coefficients: impl Iterator<Item = F>) -> Self {
660        assert!(rows > 0);
661        let rows = rows.next_power_of_two();
662        let lg_rows = rows.ilog2();
663        let mut data = Matrix::zero(rows, cols);
664        'outer: for i in 0..rows {
665            let row_i = &mut data[reverse_bits(lg_rows, i as u64) as usize];
666            for row_i_j in row_i {
667                let Some(c) = coefficients.next() else {
668                    break 'outer;
669                };
670                *row_i_j = c;
671            }
672        }
673        Self { data }
674    }
675
676    /// Evaluate each polynomial in this vector over all points in an interpolation domain.
677    pub fn evaluate(mut self) -> EvaluationVector {
678        self.data.ntt::<true>();
679        let active_rows = BitMap::ones(self.data.rows as u64);
680        EvaluationVector {
681            data: self.data,
682            active_rows,
683        }
684    }
685
686    /// Like [Self::evaluate], but with a simpler algorithm that's much less efficient.
687    ///
688    /// Exists as a useful tool for testing
689    #[cfg(test)]
690    fn evaluate_naive(self) -> EvaluationVector {
691        let rows = self.data.rows;
692        let lg_rows = rows.ilog2();
693        let w = F::root_of_unity(lg_rows as u8).expect("too much data to calculate NTT");
694        // entry (i, j) of this matrix will contain w^ij. Thus, multiplying it
695        // with the coefficients of a polynomial, in column order, will evaluate it.
696        // We also need to re-arrange the columns of the matrix to match the same
697        // order we have for polynomial coefficients.
698        let mut vandermonde_matrix = Matrix::zero(rows, rows);
699        let mut w_i = F::one();
700        for i in 0..rows {
701            let row_i = &mut vandermonde_matrix[i];
702            let mut w_ij = F::one();
703            for j in 0..rows {
704                // Remember, the coeffients of the polynomial are in reverse bit order!
705                row_i[reverse_bits(lg_rows, j as u64) as usize] = w_ij;
706                w_ij = w_ij * w_i;
707            }
708            w_i = w_i * w;
709        }
710
711        EvaluationVector {
712            data: vandermonde_matrix.mul(&self.data),
713            active_rows: BitMap::ones(rows as u64),
714        }
715    }
716
717    /// Divide the roots of each polynomial by some factor.
718    ///
719    /// c.f. [NTTPolynomial::divide_roots]. This performs the same operation on
720    /// each polynomial in this vector.
721    fn divide_roots(&mut self, factor: F) {
722        let mut factor_i = F::one();
723        let lg_rows = self.data.rows.ilog2();
724        for i in 0..self.data.rows {
725            for p_i in &mut self.data[reverse_bits(lg_rows, i as u64) as usize] {
726                *p_i = *p_i * factor_i;
727            }
728            factor_i = factor_i * factor;
729        }
730    }
731
732    /// For each polynomial P_i in this vector compute the evaluation of P_i / Q.
733    ///
734    /// Naturally, you can call [EvaluationVector::interpolate]. The reason we don't
735    /// do this is that the algorithm naturally yields an [EvaluationVector], and
736    /// some use-cases may want access to that data as well.
737    ///
738    /// This assumes that the number of coefficients in the polynomials of this vector
739    /// matches that of `q` (the coefficients can be 0, but need to be padded to the right size).
740    ///
741    /// This assumes that `q` has no zeroes over [F::NOT_ROOT_OF_UNITY] * [F::ROOT_OF_UNITY]^i,
742    /// for any i. This will be the case for [NTTPolynomial::vanishing].
743    /// If this isn't the case, the result may be junk.
744    ///
745    /// If `q` doesn't divide a partiular polynomial in this vector, the result
746    /// for that polynomial is not guaranteed to be anything meaningful.
747    fn divide(&mut self, mut q: NTTPolynomial) {
748        // The algorithm operates column wise.
749        //
750        // You can compute P(X) / Q(X) by evaluating each polynomial, then computing
751        //
752        //   P(w^i) / Q(w^i)
753        //
754        // for each evaluation point. Then, you can interpolate back.
755        //
756        // But wait! What if Q(w^i) = 0? In particular, for the case of recovering
757        // a polynomial from data with missing rows, we *expect* P(w^i) = 0 = Q(w^i)
758        // for the indicies we're missing, so this doesn't work.
759        //
760        // What we can do is to instead multiply each of the roots by some factor z,
761        // such that z w^i != w^j, for any i, j. In other words, we change the roots
762        // such that they're not in the evaluation domain anymore, allowing us to
763        // divide. We can then interpolate the result back into a polynomial,
764        // and divide back the roots to where they should be.
765        //
766        // c.f. [PolynomialVector::divide_roots]
767        assert_eq!(
768            self.data.rows,
769            q.coefficients.len(),
770            "cannot divide by polynomial of the wrong size"
771        );
772        let skew = F::NOT_ROOT_OF_UNITY;
773        let skew_inv = F::NOT_ROOT_OF_UNITY_INV;
774        self.divide_roots(skew);
775        q.divide_roots(skew);
776        ntt::<true, _>(self.data.rows, self.data.cols, &mut self.data);
777        ntt::<true, _>(
778            q.coefficients.len(),
779            1,
780            &mut Column {
781                data: &mut q.coefficients,
782            },
783        );
784        // Do a point wise division.
785        for i in 0..self.data.rows {
786            let q_i = q.coefficients[i];
787            // If `q_i = 0`, then we will get 0 in the output.
788            // We don't expect any of the q_i to be 0, but being 0 is only one
789            // of the many possibilities for the coefficient to be incorrect,
790            // so doing a runtime assertion here doesn't make sense.
791            let q_i_inv = q_i.inv();
792            for d_i_j in &mut self.data[i] {
793                *d_i_j = *d_i_j * q_i_inv;
794            }
795        }
796        // Interpolate back, using the inverse skew
797        ntt::<false, _>(self.data.rows, self.data.cols, &mut self.data);
798        self.divide_roots(skew_inv);
799    }
800
801    /// Iterate over up to n rows of this vector.
802    ///
803    /// For example, given polynomials:
804    ///
805    ///   a0 + a1 X + a2 X^2 + ...
806    ///   b0 + b1 X + b2 X^2 + ...
807    ///
808    /// This will return:
809    ///
810    ///   a0 b0
811    ///   a1 b1
812    ///   ...
813    ///
814    /// up to n times.
815    pub fn coefficients_up_to(&self, n: usize) -> impl Iterator<Item = &[F]> {
816        let n = n.min(self.data.rows);
817        let lg_rows = self.data.rows().ilog2();
818        (0..n).map(move |i| &self.data[reverse_bits(lg_rows, i as u64) as usize])
819    }
820}
821
822/// The result of evaluating a vector of polynomials over all points in an interpolation domain.
823///
824/// This struct also remembers which rows have ever been filled with [Self::fill_row].
825/// This is used in [Self::recover], which can use the rows that are present to fill in the missing
826/// rows.
827#[derive(Debug, PartialEq)]
828pub struct EvaluationVector {
829    data: Matrix,
830    active_rows: BitMap,
831}
832
833impl EvaluationVector {
834    /// Figure out the polynomial which evaluates to this vector.
835    ///
836    /// i.e. the inverse of [PolynomialVector::evaluate].
837    ///
838    /// (This makes all the rows count as filled).
839    fn interpolate(mut self) -> PolynomialVector {
840        self.data.ntt::<false>();
841        PolynomialVector { data: self.data }
842    }
843
844    /// Create an empty element of this struct, with no filled rows.
845    pub fn empty(lg_rows: usize, cols: usize) -> Self {
846        let data = Matrix::zero(1 << lg_rows, cols);
847        let active = BitMap::zeroes(data.rows as u64);
848        Self {
849            data,
850            active_rows: active,
851        }
852    }
853
854    /// Fill a specific row.
855    pub fn fill_row(&mut self, row: usize, data: &[F]) {
856        assert!(data.len() <= self.data.cols);
857        self.data[row][..data.len()].copy_from_slice(data);
858        self.active_rows.set(row as u64, true);
859    }
860
861    /// Erase a particular row.
862    ///
863    /// Useful for testing the recovery procedure.
864    #[cfg(test)]
865    fn remove_row(&mut self, row: usize) {
866        self.data[row].fill(F::zero());
867        self.active_rows.set(row as u64, false);
868    }
869
870    fn multiply(&mut self, polynomial: NTTPolynomial) {
871        let NTTPolynomial { mut coefficients } = polynomial;
872        ntt::<true, _>(
873            coefficients.len(),
874            1,
875            &mut Column {
876                data: &mut coefficients,
877            },
878        );
879        for (i, &c_i) in coefficients.iter().enumerate() {
880            for self_j in &mut self.data[i] {
881                *self_j = *self_j * c_i;
882            }
883        }
884    }
885
886    /// Attempt to recover the missing rows in this data.
887    pub fn recover(mut self) -> PolynomialVector {
888        // If we had all of the rows, we could simply call [Self::interpolate],
889        // in order to recover the original polynomial. If we do this while missing some
890        // rows, what we get is D(X) * V(X) where D is the original polynomial,
891        // and V(X) is a polynomial which vanishes at all the rows we're missing.
892        //
893        // As long as the degree of D is low enough, compared to the number of evaluations
894        // we *do* have, then we can recover it by performing:
895        //
896        //   (D(X) * V(X)) / V(X)
897        //
898        // If we have multiple columns, then this procedure can be done column by column,
899        // with the same vanishing polynomial.
900        let vanishing = NTTPolynomial::vanishing(&self.active_rows);
901        self.multiply(vanishing.clone());
902        let mut out = self.interpolate();
903        out.divide(vanishing);
904        out
905    }
906
907    /// Get the underlying data, as a Matrix.
908    pub fn data(self) -> Matrix {
909        self.data
910    }
911
912    /// Return how many distinct rows have been filled.
913    pub fn filled_rows(&self) -> usize {
914        self.active_rows.count_ones() as usize
915    }
916}
917
918#[cfg(test)]
919mod test {
920    use super::*;
921    use proptest::prelude::*;
922
923    fn any_f() -> impl Strategy<Value = F> {
924        any::<u64>().prop_map(F::from)
925    }
926
927    #[test]
928    fn test_reverse_bits() {
929        assert_eq!(reverse_bits(4, 0b1000), 0b0001);
930        assert_eq!(reverse_bits(4, 0b0100), 0b0010);
931        assert_eq!(reverse_bits(4, 0b0010), 0b0100);
932        assert_eq!(reverse_bits(4, 0b0001), 0b1000);
933    }
934
935    #[test]
936    fn matrix_read_rejects_length_mismatch() {
937        use bytes::BytesMut;
938        use commonware_codec::{Read as _, Write as _};
939
940        let mut buf = BytesMut::new();
941        (2usize).write(&mut buf);
942        (2usize).write(&mut buf);
943        vec![F::one(); 3].write(&mut buf);
944
945        let mut bytes = buf.freeze();
946        let result = Matrix::read_cfg(&mut bytes, &8);
947        assert!(matches!(
948            result,
949            Err(commonware_codec::Error::Invalid(
950                "Matrix",
951                "matrix element count does not match dimensions"
952            ))
953        ));
954    }
955
956    fn any_polynomial_vector(
957        max_log_rows: usize,
958        max_cols: usize,
959    ) -> impl Strategy<Value = PolynomialVector> {
960        (0..=max_log_rows).prop_flat_map(move |lg_rows| {
961            (1..=max_cols).prop_flat_map(move |cols| {
962                let rows = 1 << lg_rows;
963                proptest::collection::vec(any_f(), rows * cols).prop_map(move |coefficients| {
964                    PolynomialVector::new(rows, cols, coefficients.into_iter())
965                })
966            })
967        })
968    }
969
970    fn any_bit_vec_not_all_0(max_log_rows: usize) -> impl Strategy<Value = BitMap> {
971        (0..=max_log_rows).prop_flat_map(move |lg_rows| {
972            let rows = (1 << lg_rows) as usize;
973            (0..rows).prop_flat_map(move |set_row| {
974                proptest::collection::vec(any::<bool>(), 1 << lg_rows).prop_map(move |mut bools| {
975                    bools[set_row] = true;
976                    BitMap::from(bools.as_slice())
977                })
978            })
979        })
980    }
981
982    #[derive(Debug)]
983    struct RecoverySetup {
984        n: usize,
985        k: usize,
986        cols: usize,
987        data: Vec<F>,
988        present: BitMap,
989    }
990
991    impl RecoverySetup {
992        fn any(max_n: usize, max_k: usize, max_cols: usize) -> impl Strategy<Value = Self> {
993            (1..=max_n).prop_flat_map(move |n| {
994                (0..=max_k).prop_flat_map(move |k| {
995                    (1..=max_cols).prop_flat_map(move |cols| {
996                        proptest::collection::vec(any_f(), n * cols).prop_flat_map(move |data| {
997                            let padded_rows = (n + k).next_power_of_two();
998                            proptest::sample::subsequence(
999                                (0..padded_rows).collect::<Vec<_>>(),
1000                                n..=padded_rows,
1001                            )
1002                            .prop_map(move |indices| {
1003                                let mut present = BitMap::zeroes(padded_rows as u64);
1004                                for i in indices {
1005                                    present.set(i as u64, true);
1006                                }
1007                                Self {
1008                                    n,
1009                                    k,
1010                                    cols,
1011                                    // idk why this is necessary, but who cares
1012                                    data: data.clone(),
1013                                    present,
1014                                }
1015                            })
1016                        })
1017                    })
1018                })
1019            })
1020        }
1021
1022        fn test(self) {
1023            let data = PolynomialVector::new(self.n + self.k, self.cols, self.data.into_iter());
1024            let mut encoded = data.clone().evaluate();
1025            for (i, b_i) in self.present.iter().enumerate() {
1026                if !b_i {
1027                    encoded.remove_row(i);
1028                }
1029            }
1030            let recovered_data = encoded.recover();
1031            assert_eq!(data, recovered_data);
1032        }
1033    }
1034
1035    #[test]
1036    fn test_recovery_000() {
1037        RecoverySetup {
1038            n: 1,
1039            k: 1,
1040            cols: 1,
1041            data: vec![F::one()],
1042            present: vec![false, true].into(),
1043        }
1044        .test()
1045    }
1046
1047    proptest! {
1048        #[test]
1049        fn test_ntt_eq_naive(p in any_polynomial_vector(6, 4)) {
1050            let ntt = p.clone().evaluate();
1051            let ntt_naive = p.evaluate_naive();
1052            assert_eq!(ntt, ntt_naive);
1053        }
1054
1055        #[test]
1056        fn test_evaluation_then_inverse(p in any_polynomial_vector(6, 4)) {
1057            assert_eq!(p.clone(), p.evaluate().interpolate());
1058        }
1059
1060        #[test]
1061        fn test_vanishing_polynomial(bv in any_bit_vec_not_all_0(8)) {
1062            let v = NTTPolynomial::vanishing(&bv);
1063            let expected_degree = bv.count_zeros();
1064            assert_eq!(v.degree(), expected_degree as usize, "expected v to have degree {expected_degree}");
1065            let w = F::root_of_unity(bv.len().ilog2() as u8).unwrap();
1066            let mut w_i = F::one();
1067            for b_i in bv.iter() {
1068                let v_at_w_i = v.evaluate(w_i);
1069                if !b_i {
1070                    assert_eq!(v_at_w_i, F::zero(), "v should evaluate to 0 at {w_i:?}");
1071                } else {
1072                    assert_ne!(v_at_w_i, F::zero());
1073                }
1074                w_i = w_i * w;
1075            }
1076        }
1077
1078        #[test]
1079        fn test_recovery(setup in RecoverySetup::any(128, 128, 4)) {
1080            setup.test();
1081        }
1082    }
1083
1084    #[cfg(feature = "arbitrary")]
1085    mod conformance {
1086        use super::*;
1087        use commonware_codec::conformance::CodecConformance;
1088
1089        commonware_conformance::conformance_tests! {
1090            CodecConformance<Matrix>,
1091        }
1092    }
1093}