Skip to main content

commonware_math/
ntt.rs

1use crate::algebra::{Additive, FieldNTT, Ring};
2#[cfg(not(feature = "std"))]
3use alloc::{vec, vec::Vec};
4use commonware_codec::{EncodeSize, RangeCfg, Read, Write};
5use commonware_utils::bitmap::BitMap;
6use core::{
7    num::NonZeroU32,
8    ops::{Index, IndexMut},
9};
10use rand_core::CryptoRngCore;
11#[cfg(feature = "std")]
12use std::vec::Vec;
13
14/// Determines the size of polynomials we compute naively in [`EvaluationColumn::vanishing`].
15///
16/// Benchmarked to be optimal, based on BLS12381 threshold recovery time.
17const LG_VANISHING_BASE: u32 = 8;
18
19/// Reverse the first `bit_width` bits of `i`.
20///
21/// Any bits beyond that width will be erased.
22fn reverse_bits(bit_width: u32, i: u64) -> u64 {
23    assert!(bit_width <= 64, "bit_width must be <= 64");
24    i.wrapping_shl(64 - bit_width).reverse_bits()
25}
26
27/// Turn a slice into reversed bit order in place.
28///
29/// `out` MUST have length `2^bit_width`.
30fn reverse_slice<T>(bit_width: u32, out: &mut [T]) {
31    assert_eq!(out.len(), 1 << bit_width);
32    for i in 0..out.len() {
33        let j = reverse_bits(bit_width, i as u64) as usize;
34        // Only swap once, and don't swap if the location is the same.
35        if i < j {
36            out.swap(i, j);
37        }
38    }
39}
40
41/// Calculate an NTT, or an inverse NTT (with FORWARD=false), in place.
42///
43/// We implement this generically over anything we can index into, which allows
44/// performing NTTs in place.
45fn ntt<const FORWARD: bool, F: FieldNTT, M: IndexMut<(usize, usize), Output = F>>(
46    rows: usize,
47    cols: usize,
48    matrix: &mut M,
49) {
50    let lg_rows = rows.ilog2() as usize;
51    assert_eq!(1 << lg_rows, rows, "rows should be a power of 2");
52    // A number w such that w^(2^lg_rows) = 1.
53    // (Or, in the inverse case, the inverse of that number, to undo the NTT).
54    let w = {
55        let w = F::root_of_unity(lg_rows as u8).expect("too many rows to perform NTT");
56        if FORWARD {
57            w
58        } else {
59            // since w^(2^lg_rows) = 1, w^(2^lg_rows - 1) * w = 1,
60            // making that left-hand term the inverse of w.
61            w.exp(&[(1 << lg_rows) - 1])
62        }
63    };
64    // The inverse algorithm consists of carefully undoing the work of the
65    // standard algorithm, so we describe that in detail.
66    //
67    // To understand the NTT algorithm, first consider the case of a single
68    // column. We have a polynomial f(X), and we want to turn that into:
69    //
70    // [f(w^0), f(w^1), ..., f(w^(2^lg_rows - 1))]
71    //
72    // Our polynomial can be written as:
73    //
74    // f+(X^2) + X f-(X^2)
75    //
76    // where f+ and f- are polynomials with half the degree.
77    // f+ is obtained by taking the coefficients at even indices,
78    // f- is obtained by taking the coefficients at odd indices.
79    //
80    // w^2 is also conveniently a 2^(lg_rows - 1) root of unity. Thus,
81    // we can recursively compute an NTT on f+, using w^2 as the root,
82    // and an NTT on f-, using w^2 as the root, each of which is a problem
83    // of half the size.
84    //
85    // We can then compute:
86    // f+((w^i)^2) + (w^i) f-((w^i)^2)
87    // f+((w^i)^2) - (w^i) f-((w^i)^2)
88    // for each i.
89    // (Note that (-w^i)^2 = ((-w)^2)^i = (w^i)^2))
90    //
91    // Our coefficients are conveniently laid out as [f+ f-], already
92    // in a neat order. When we recurse, the coefficients of f+ are, in
93    // turn, already laid out as [f++ f+-], and so on.
94    //
95    // We just need to transform this recursive algorithm, in top down form,
96    // into an iterative one, in bottom up form. For that, note that the NTT
97    // for the case of 1 row is trivial: do nothing.
98
99    // Will contain, in bottom up order, the power of w we need at that stage.
100    // At the last stage, we need w itself.
101    // At the stage before last, we need w^2.
102    // And so on.
103    // How many stages do we need? If we have 1 row, we need 0 stages.
104    // In general, with 2^n rows, we need n stages.
105    let stages = {
106        let mut out = vec![(0usize, F::zero()); lg_rows];
107        let mut w_i = w;
108        for i in (0..lg_rows).rev() {
109            out[i] = (i, w_i.clone());
110            w_i = w_i.clone() * &w_i;
111        }
112        // In the case of the reverse algorithm, we undo each stage of the
113        // forward algorithm, starting with the last stage.
114        if !FORWARD {
115            out.reverse();
116        }
117        out
118    };
119    for (stage, w) in stages.into_iter() {
120        // At stage i, we have polynomials with 2^i coefficients,
121        // which have already been evaluated to create 2^i entries.
122        // We need to combine these evaluations to create 2^(i + 1) entries,
123        // representing the evaluation of a polynomial with 2^(i + 1) coefficients.
124        // If we have two of these evaluations, laid out one after the other:
125        //
126        // [x_0, x_1, ...] [y_0, y_1, ...]
127        //
128        // Then the number of elements we need to skip to get the corresponding
129        // element in the other half is simply the number of elements in each half,
130        // i.e. 2^i.
131        let skip = 1 << stage;
132        let mut i = 0;
133        while i < rows {
134            // In the case of a backwards NTT, skew should be the inverse of the skew
135            // in the forwards direction.
136            let mut w_j = F::one();
137            for j in 0..skip {
138                let index_a = i + j;
139                let index_b = index_a + skip;
140                for k in 0..cols {
141                    let (a, b) = (matrix[(index_a, k)].clone(), matrix[(index_b, k)].clone());
142                    if FORWARD {
143                        let w_j_b = w_j.clone() * &b;
144                        matrix[(index_a, k)] = a.clone() + &w_j_b;
145                        matrix[(index_b, k)] = a - &w_j_b;
146                    } else {
147                        // To check the math, convince yourself that applying the forward
148                        // transformation, and then this transformation, with w_j being the
149                        // inverse of the value above, that you get (a, b).
150                        // (a + w_j * b) + (a - w_j * b) = 2 * a
151                        matrix[(index_a, k)] = (a.clone() + &b).div_2();
152                        // (a + w_j * b) - (a - w_j * b) = 2 * w_j * b.
153                        // w_j in this branch is the inverse of w_j in the other branch.
154                        matrix[(index_b, k)] = ((a - &b) * &w_j).div_2();
155                    }
156                }
157                w_j *= &w;
158            }
159            i += 2 * skip;
160        }
161    }
162}
163
164/// Columns of some larger piece of data.
165///
166/// This allows us to easily do NTTs over partial segments of some bigger matrix.
167struct Columns<'a, const N: usize, F> {
168    data: [&'a mut [F]; N],
169}
170
171impl<'a, const N: usize, F> Index<(usize, usize)> for Columns<'a, N, F> {
172    type Output = F;
173
174    fn index(&self, (i, j): (usize, usize)) -> &Self::Output {
175        &self.data[j][i]
176    }
177}
178
179impl<'a, const N: usize, F> IndexMut<(usize, usize)> for Columns<'a, N, F> {
180    fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
181        &mut self.data[j][i]
182    }
183}
184
185/// Used to keep track of the points at which a polynomial needs to vanish.
186///
187/// This takes care of subtle details like padding and bit ordering.
188///
189/// This struct is associated with a particular size, which is a power of two,
190/// and thus a particular root of unity.
191#[derive(Debug, PartialEq)]
192pub struct VanishingPoints {
193    lg_size: u32,
194    bits: BitMap,
195}
196
197impl VanishingPoints {
198    /// This will have size `2^lg_size`, and vanish everywhere.
199    ///
200    /// Be aware that this means all points are initially marked as vanishing.
201    pub fn new(lg_size: u32) -> Self {
202        Self {
203            lg_size,
204            bits: BitMap::zeroes(1 << lg_size),
205        }
206    }
207
208    /// This will have size `2^lg_size`, and vanish nowhere.
209    pub fn all_non_vanishing(lg_size: u32) -> Self {
210        Self {
211            lg_size,
212            bits: BitMap::ones(1 << lg_size),
213        }
214    }
215
216    pub const fn lg_size(&self) -> u32 {
217        self.lg_size
218    }
219
220    /// Set the root `w^index` to vanish, `value = false`, or not, `value = true`.
221    fn set(&mut self, index: u64, value: bool) {
222        self.bits.set(reverse_bits(self.lg_size, index), value);
223    }
224
225    /// Set the root `w^index` to not vanish.
226    ///
227    /// cf. `set`;
228    pub fn set_non_vanishing(&mut self, index: u64) {
229        self.set(index, true);
230    }
231
232    pub fn get(&self, index: u64) -> bool {
233        self.bits.get(reverse_bits(self.lg_size, index))
234    }
235
236    pub fn count_non_vanishing(&self) -> u64 {
237        self.bits.count_ones()
238    }
239
240    /// Check that a particular chunk of this set vanishes.
241    ///
242    /// `lg_chunk_size` determines the size of the chunk, which must be a power of two.
243    ///
244    /// `index` determines which chunk to use. After chunk 0, you have chunk 1, and so on.
245    ///
246    /// The chunk is taken from the set in reverse bit order. This is what methods
247    /// that create a vanishing polynomial recursively want. Take care when using
248    /// this naively.
249    fn chunk_vanishes_everywhere(&self, lg_chunk_size: u32, index: u64) -> bool {
250        assert!(lg_chunk_size <= self.lg_size);
251        let start = index << lg_chunk_size;
252        self.bits.is_unset(start..start + (1 << lg_chunk_size))
253    }
254
255    /// Yield the bits of a chunk, in reverse bit order.
256    ///
257    /// cf. `chunk_vanishes_everywhere`, which uses the same chunk indexing scheme.
258    fn get_chunk(&self, lg_chunk_size: u32, index: u64) -> impl Iterator<Item = bool> + '_ {
259        (index << lg_chunk_size..(index + 1) << lg_chunk_size).map(|i| self.bits.get(i))
260    }
261
262    #[cfg(any(test, feature = "fuzz"))]
263    fn iter_bits_in_order(&self) -> impl Iterator<Item = bool> + '_ {
264        (0..(1u64 << self.lg_size)).map(|i| self.get(i))
265    }
266}
267
268/// Represents the evaluation of a single polynomial over a full domain.
269#[derive(Debug)]
270struct EvaluationColumn<F> {
271    evaluations: Vec<F>,
272}
273
274impl<F: FieldNTT> EvaluationColumn<F> {
275    /// Evaluate the vanishing polynomial over `points` on the domain.
276    ///
277    /// This returns the evaluation of the polynomial at `0`, and then the evaluation
278    /// of the polynomial over the whole domain.
279    ///
280    /// This assumes that `points` has at least one non-vanishing point.
281    pub fn vanishing(points: &VanishingPoints) -> (F, Self) {
282        // The goal of this function is to produce a polynomial v such that
283        // v(w^j) = 0 for each index j where points.get(j) = false.
284        //
285        // The core idea is to split this up recursively. We split the possible
286        // roots into two groups, and figure out the vanishing polynomials
287        // v_L and v_R for the first and second groups, respectively. Then,
288        // multiplying v_L and v_R yields a polynomial with the appropriate roots.
289        //
290        // We can multiply the polynomials in O(N lg N) time, by performing an
291        // NTT on both of them, multiplying the evaluations point wise, and then
292        // using a reverse NTT to get a polynomial back.
293        //
294        // Naturally, we can extend this to construct each sub-polynomial recursively
295        // as well, giving an O(N lg^2 N) algorithm in total.
296        //
297        // This function doesn't return the polynomial directly, but rather an
298        // evaluation of the polynomial. This is because many consumers often
299        // need this anyways, and by providing them with this result, we avoid
300        // performing a reverse NTT that they then proceed to undo. However,
301        // they can also need the evaluation at 0, so we provide and calculate that
302        // as well. That can also be calculated recursively, and merged with the
303        // above calculation.
304        //
305        // One point we haven't clarified yet is how to split up the roots.
306        // Let's use an example. With size 8, the roots are:
307        //
308        // w^0 w^1 w^2 w^3 w^4 w^5 w^6 w^7
309        //
310        // or, writing down just the exponent
311        //
312        // 0 1 2 3 4 5 6 7
313        //
314        // We could build up our final polynomial by merging polynomials of size
315        // two, with roots chosen among the following possibilities:
316        //
317        // 0 1    2 3    4 5    6 7
318        //
319        // However, this requires using different roots for each polynomial.
320        //
321        // If we instead use reverse bit order, we can have things be:
322        //
323        // 0 4    2 6    1 5    3 7
324        //
325        // which is equal to:
326        //
327        // 0 4    2 + (0 4)    1 + (0 4    2 + (0 4))
328        //
329        // So, we can start by having polynomials with the same possible roots
330        // at the lowest level, and then merge by multiplying the roots by
331        // the right power, for the polynomial on the right.
332        //
333        // The roots of a polynomial can easily be multiplied by some factor
334        // by dividing its coefficients by powers of a factor.
335        // cf [`PolynomialColumn::divide_roots`].
336        //
337        // Another optimization we can do for the merges is to keep track
338        // of polynomials that vanish everywhere and nowhere. A polynomial
339        // vanishing nowhere has no effect when merging, so we can skip a multiplication.
340        // Similarly, a polynomial vanishing everywhere is of the form X^N - 1,
341        // with which multiplication is simple.
342
343        /// Used to keep track of special polynomial values.
344        #[derive(Clone, Copy)]
345        enum Where {
346            /// Vanishes at none of the roots; i.e. is f(X) = 1.
347            Nowhere,
348            /// Vanishes at at least one of the roots.
349            Somewhere,
350            /// Vanishes at every single one of the roots.
351            Everywhere,
352        }
353
354        use Where::*;
355
356        let lg_len = points.lg_size();
357        let len = 1usize << lg_len;
358        // This will store our in progress polynomials, and eventually,
359        // the final evaluations.
360        let mut out = vec![F::zero(); len];
361        // For small inputs, one chunk might more than cover it all, so we
362        // need to make the chunk size be too big.
363        let lg_chunk_size = LG_VANISHING_BASE.min(lg_len);
364        // We use this to keep track of the polynomial evaluated at 0.
365        let mut at_zero = F::one();
366
367        // Populate out with polynomials up to a low degree.
368        // We also get a vector with the status of each polymomial, letting
369        // us accelerate the merging step.
370        let mut vanishes = {
371            let chunk_size = 1usize << lg_chunk_size;
372            // The negation of each possible root vanishing polynomials can have.
373            // We have the roots in reverse bit order.
374            let minus_roots = {
375                // We can panic without worry here, because we require a smaller
376                // root of unity to exist elsewhere.
377                let w = u8::try_from(lg_chunk_size)
378                    .ok()
379                    .and_then(|s| F::root_of_unity(s))
380                    .expect("sub-root of unity should exist");
381                // The powers of w we'll use as roots, pre-negated.
382                let mut out: Vec<_> = (0..)
383                    .scan(F::one(), |state, _| {
384                        let out = -state.clone();
385                        *state *= &w;
386                        Some(out)
387                    })
388                    .take(chunk_size)
389                    .collect();
390                // Make sure the order is what the rest of this routine expects.
391                reverse_slice(lg_chunk_size, out.as_mut_slice());
392                out
393            };
394            // Instead of actually negating `at_zero` inside of the loop below,
395            // we instead keep track of whether or not it needs to be negated
396            // after the loop, to just perform that operation once.
397            let mut negate_at_zero = false;
398            // Populate each chunk with the initial polynomial,
399            let vanishing = out
400                .chunks_exact_mut(chunk_size)
401                .enumerate()
402                .map(|(i, poly)| {
403                    let i_u64 = i as u64;
404                    if points.chunk_vanishes_everywhere(lg_chunk_size, i_u64) {
405                        // Implicitly, there's a 1 past the end of the polynomial,
406                        // which we handle when merging.
407                        poly[0] = -F::one();
408                        negate_at_zero ^= true;
409                        return Where::Everywhere;
410                    }
411                    poly[0] = F::one();
412                    let mut coeffs = 1;
413                    for (b_j, minus_root) in points
414                        .get_chunk(lg_chunk_size, i_u64)
415                        .zip(minus_roots.iter())
416                    {
417                        if b_j {
418                            continue;
419                        }
420                        // Multiply the polynomial by (X - w^j).
421                        poly[coeffs] = F::one();
422                        for k in (1..coeffs).rev() {
423                            let (chunk_head, chunk_tail) = poly.split_at_mut(k);
424                            chunk_tail[0] *= minus_root;
425                            chunk_tail[0] += &chunk_head[k - 1];
426                        }
427                        poly[0] *= minus_root;
428                        coeffs += 1;
429                    }
430                    if coeffs > 1 {
431                        reverse_slice(lg_chunk_size, poly);
432                        at_zero *= &poly[0];
433                        Where::Somewhere
434                    } else {
435                        Where::Nowhere
436                    }
437                })
438                .collect::<Vec<_>>();
439            if negate_at_zero {
440                at_zero = -at_zero.clone();
441            }
442            vanishing
443        };
444        // Avoid doing any of the subsequent work if we've already covered this case.
445        if lg_chunk_size >= lg_len {
446            // We do, however, need to turn the coefficients into evaluations.
447            return (at_zero, PolynomialColumn { coefficients: out }.evaluate());
448        }
449        let w_invs = {
450            // since w^(2^lg_rows) = 1, w^(2^lg_rows - 1) * w = 1,
451            // making that left-hand term the inverse of w.
452            let mut w_inv = F::root_of_unity(lg_len as u8)
453                .expect("too many rows to create vanishing polynomial")
454                .exp(&[(1 << lg_len) - 1]);
455            let mut out = Vec::with_capacity((lg_len - lg_chunk_size) as usize);
456            for _ in lg_chunk_size..lg_len {
457                out.push(w_inv.clone());
458                w_inv = w_inv.clone() * &w_inv;
459            }
460            out.reverse();
461            out
462        };
463        let mut lg_chunk_size = lg_chunk_size;
464        let mut scratch = Vec::<F>::with_capacity(len);
465        let mut coeff_shifts = Vec::with_capacity(1 << lg_chunk_size);
466        for w_inv in w_invs.into_iter() {
467            let chunk_size = 1 << lg_chunk_size;
468            // Closure to shift coefficients by the current power.
469            // This lets us reuse the computation of the powers.
470            let mut shift_coeffs = |coeffs: &mut [F]| {
471                if coeff_shifts.len() != chunk_size {
472                    coeff_shifts.clear();
473                    let mut acc = F::one();
474                    for _ in 0..chunk_size {
475                        coeff_shifts.push(acc.clone());
476                        acc *= &w_inv;
477                    }
478                }
479                for (i, coeff_i) in coeffs.iter_mut().enumerate() {
480                    *coeff_i *= &coeff_shifts[reverse_bits(lg_chunk_size, i as u64) as usize];
481                }
482            };
483            let next_lg_chunk_size = lg_chunk_size + 1;
484            let next_chunk_size = 1 << next_lg_chunk_size;
485            for (i, chunk) in out.chunks_exact_mut(1 << next_lg_chunk_size).enumerate() {
486                let (left, right) = chunk.split_at_mut(1 << lg_chunk_size);
487                let (vanishes_l, vanishes_r) = (vanishes[2 * i], vanishes[2 * i + 1]);
488                // We keep track of whether or not the polynomial resulting from
489                // the merge is evaluated or not.
490                let mut evaluated = false;
491                vanishes[i] = match (vanishes_l, vanishes_r) {
492                    (Nowhere, Nowhere) => {
493                        // Both polynomials consist of 1 0 0 0 ..., and we
494                        // want the final result to be that, just with more zeroes,
495                        // so we need to clear the 1 value on the right side.
496                        right[0] = F::zero();
497                        Nowhere
498                    }
499                    (Nowhere, Somewhere) => {
500                        // Clear the one value on the left.
501                        left[0] = F::zero();
502                        // Adjust the roots on the right.
503                        shift_coeffs(right);
504                        // Make it take all of the left space.
505                        for i in 0..chunk_size {
506                            chunk.swap(chunk_size + i, 2 * i);
507                        }
508                        Somewhere
509                    }
510                    (Nowhere, Everywhere) => {
511                        // (X^(N/2) - 1) is on the right.
512                        // First, we multiply its roots by w_N, yielding:
513                        //
514                        // -X^(N/2) - 1
515                        //
516                        // in reverse bit order we get the following:
517                        left[0] = -F::one();
518                        left[1] = -F::one();
519                        // And we remove the -1 on the right side.
520                        right[0] = F::zero();
521                        Somewhere
522                    }
523                    // These two cases mirror the two above.
524                    (Somewhere, Nowhere) => {
525                        // Clear the one on the right side.
526                        right[0] = F::zero();
527                        // Make it take all of the right space.
528                        // We can skip moving index 0.
529                        for i in (1..chunk_size).rev() {
530                            chunk.swap(i, 2 * i);
531                        }
532                        Somewhere
533                    }
534                    (Everywhere, Nowhere) => {
535                        // Like above, but with the polynomial on the left,
536                        // there's no need to adjust the roots.
537                        left[0] = -F::one();
538                        left[1] = F::one();
539                        right[0] = F::zero();
540                        Somewhere
541                    }
542                    (Somewhere, Everywhere) => {
543                        // We need to make the left side occupy the whole space.
544                        // Shifting by one index has the effect of multiplying
545                        // the polynomial by X^(chunk_size), which is what we want.
546                        for i in (0..chunk_size).rev() {
547                            chunk.swap(i, 2 * i + 1);
548                            // We copy the value in i, negate it, and make it occupy
549                            // both 2 * i + 1 and 2 * i, thus multiplying by -(X^chunk_size + 1).
550                            chunk[2 * i + 1] = -chunk[2 * i + 1].clone();
551                            chunk[2 * i] = chunk[2 * i + 1].clone();
552                        }
553                        Somewhere
554                    }
555                    (Everywhere, Somewhere) => {
556                        // Adjust the roots on the right.
557                        shift_coeffs(right);
558                        // Like above, but moving the right side, and multiplying by
559                        // (X^chunk_size - 1).
560                        for i in 0..chunk_size {
561                            chunk.swap(chunk_size + i, 2 * i + 1);
562                            chunk[2 * i] = -chunk[2 * i + 1].clone();
563                        }
564                        Somewhere
565                    }
566                    (Everywhere, Everywhere) => {
567                        // Make sure to clear the -1 on the right side.
568                        right[0] = F::zero();
569                        // By choosing to do things this way, we effectively
570                        // negate the final polynomial, so we need to correct
571                        // for this with the zero value.
572                        at_zero = -at_zero.clone();
573                        Everywhere
574                    }
575                    // In this case, we can assume nothing, and have to do
576                    // the full logic for actually multiplying the polynomials.
577                    (Somewhere, Somewhere) => {
578                        // Adjust the roots on the right.
579                        shift_coeffs(right);
580                        // Populate the scratch buffer with the right side.
581                        scratch.clear();
582                        scratch.resize(next_chunk_size, F::zero());
583                        for i in 0..chunk_size {
584                            core::mem::swap(&mut right[i], &mut scratch[2 * i]);
585                        }
586                        // We can skip moving index 0.
587                        for i in (1..chunk_size).rev() {
588                            chunk.swap(i, 2 * i);
589                        }
590                        // Turn the polynomials into evaluations.
591                        ntt::<true, _, _>(
592                            next_chunk_size,
593                            2,
594                            &mut Columns {
595                                data: [chunk, scratch.as_mut_slice()],
596                            },
597                        );
598                        // Multiply them, into the chunk.
599                        for (l, r) in chunk.iter_mut().zip(scratch.iter_mut()) {
600                            *l *= r;
601                        }
602                        evaluated = true;
603                        Somewhere
604                    }
605                };
606                // If this isn't the last iteration, make sure to turn back into coefficients.
607                let should_be_evaluated = next_chunk_size >= len;
608                if should_be_evaluated != evaluated {
609                    if evaluated {
610                        ntt::<false, _, _>(next_chunk_size, 1, &mut Columns { data: [chunk] });
611                    } else {
612                        ntt::<true, _, _>(next_chunk_size, 1, &mut Columns { data: [chunk] });
613                    }
614                }
615            }
616            lg_chunk_size = next_lg_chunk_size;
617        }
618        // We do, however, need to turn the coefficients into evaluations.
619        (at_zero, Self { evaluations: out })
620    }
621
622    pub fn interpolate(self) -> PolynomialColumn<F> {
623        let mut data = self.evaluations;
624        ntt::<false, _, _>(
625            data.len(),
626            1,
627            &mut Columns {
628                data: [data.as_mut_slice()],
629            },
630        );
631        PolynomialColumn { coefficients: data }
632    }
633}
634
635/// A column containing a single polynomial.
636#[derive(Debug)]
637struct PolynomialColumn<F> {
638    coefficients: Vec<F>,
639}
640
641impl<F: FieldNTT> PolynomialColumn<F> {
642    /// Evaluate this polynomial over the domain, returning
643    pub fn evaluate(self) -> EvaluationColumn<F> {
644        let mut data = self.coefficients;
645        ntt::<true, _, _>(
646            data.len(),
647            1,
648            &mut Columns {
649                data: [data.as_mut_slice()],
650            },
651        );
652        EvaluationColumn { evaluations: data }
653    }
654
655    #[cfg(any(test, feature = "fuzz"))]
656    fn evaluate_one(&self, point: F) -> F {
657        let mut out = F::zero();
658        let rows = self.coefficients.len();
659        let lg_rows = rows.ilog2();
660        for i in (0..rows).rev() {
661            out = out * &point + &self.coefficients[reverse_bits(lg_rows, i as u64) as usize];
662        }
663        out
664    }
665
666    #[cfg(any(test, feature = "fuzz"))]
667    fn degree(&self) -> usize {
668        let rows = self.coefficients.len();
669        let lg_rows = rows.ilog2();
670        for i in (0..rows).rev() {
671            if self.coefficients[reverse_bits(lg_rows, i as u64) as usize] != F::zero() {
672                return i;
673            }
674        }
675        0
676    }
677
678    /// Divide the roots of each polynomial by some factor.
679    ///
680    /// If f(x) = 0, then after this transformation, f(x / z) = 0 instead.
681    ///
682    /// The number of roots does not change.
683    ///
684    /// c.f. [`EvaluationColumn::vanishing`] for how this is used.
685    fn divide_roots(&mut self, factor: F) {
686        let mut factor_i = F::one();
687        let lg_rows = self.coefficients.len().ilog2();
688        for i in 0..self.coefficients.len() {
689            let index = reverse_bits(lg_rows, i as u64) as usize;
690            self.coefficients[index] *= &factor_i;
691            factor_i *= &factor;
692        }
693    }
694}
695
696/// Represents a matrix of field elements, of arbitrary dimensions
697///
698/// This is in row major order, so consider processing elements in the same
699/// row first, for locality.
700#[derive(Clone, PartialEq)]
701pub struct Matrix<F> {
702    rows: usize,
703    cols: usize,
704    data: Vec<F>,
705}
706
707impl<F: EncodeSize> EncodeSize for Matrix<F> {
708    fn encode_size(&self) -> usize {
709        self.rows.encode_size() + self.cols.encode_size() + self.data.encode_size()
710    }
711}
712
713impl<F: Write> Write for Matrix<F> {
714    fn write(&self, buf: &mut impl bytes::BufMut) {
715        self.rows.write(buf);
716        self.cols.write(buf);
717        self.data.write(buf);
718    }
719}
720
721impl<F: Read> Read for Matrix<F> {
722    type Cfg = (usize, <F as Read>::Cfg);
723
724    fn read_cfg(
725        buf: &mut impl bytes::Buf,
726        (max_els, f_cfg): &Self::Cfg,
727    ) -> Result<Self, commonware_codec::Error> {
728        let cfg = RangeCfg::from(..=*max_els);
729        let rows = usize::read_cfg(buf, &cfg)?;
730        let cols = usize::read_cfg(buf, &cfg)?;
731        let data = Vec::<F>::read_cfg(buf, &(cfg, f_cfg.clone()))?;
732        let expected_len = rows
733            .checked_mul(cols)
734            .ok_or(commonware_codec::Error::Invalid(
735                "Matrix",
736                "matrix dimensions overflow",
737            ))?;
738        if data.len() != expected_len {
739            return Err(commonware_codec::Error::Invalid(
740                "Matrix",
741                "matrix element count does not match dimensions",
742            ));
743        }
744        Ok(Self { rows, cols, data })
745    }
746}
747
748impl<F: core::fmt::Debug> core::fmt::Debug for Matrix<F> {
749    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
750        for i in 0..self.rows {
751            let row_i = &self[i];
752            for row_i_j in row_i {
753                write!(f, "{row_i_j:?} ")?;
754            }
755            writeln!(f)?;
756        }
757        Ok(())
758    }
759}
760
761impl<F: Additive> Matrix<F> {
762    /// Create a zero matrix, with a certain number of rows and columns
763    fn zero(rows: usize, cols: usize) -> Self {
764        Self {
765            rows,
766            cols,
767            data: vec![F::zero(); rows * cols],
768        }
769    }
770
771    /// Initialize a matrix, with dimensions, and data to pull from.
772    ///
773    /// Any extra data is ignored, any data not supplied is treated as 0.
774    pub fn init(rows: usize, cols: usize, mut data: impl Iterator<Item = F>) -> Self {
775        let mut out = Self::zero(rows, cols);
776        'outer: for i in 0..rows {
777            for row_i in &mut out[i] {
778                let Some(x) = data.next() else {
779                    break 'outer;
780                };
781                *row_i = x;
782            }
783        }
784        out
785    }
786
787    /// Interpret the columns of this matrix as polynomials, with at least `min_coefficients`.
788    ///
789    /// This will, in fact, produce a matrix padded to the next power of 2 of that number.
790    ///
791    /// This will return `None` if `min_coefficients < self.rows`, which would mean
792    /// discarding data, instead of padding it.
793    pub fn as_polynomials(&self, min_coefficients: usize) -> Option<PolynomialVector<F>>
794    where
795        F: Clone,
796    {
797        if min_coefficients < self.rows {
798            return None;
799        }
800        Some(PolynomialVector::new(
801            min_coefficients,
802            self.cols,
803            (0..self.rows).flat_map(|i| self[i].iter().cloned()),
804        ))
805    }
806
807    /// Multiply this matrix by another.
808    ///
809    /// This assumes that the number of columns in this matrix match the number
810    /// of rows in the other matrix.
811    pub fn mul(&self, other: &Self) -> Self
812    where
813        F: Clone + Ring,
814    {
815        assert_eq!(self.cols, other.rows);
816        let mut out = Self::zero(self.rows, other.cols);
817        for i in 0..self.rows {
818            for j in 0..self.cols {
819                let c = self[(i, j)].clone();
820                let other_j = &other[j];
821                for k in 0..other.cols {
822                    out[(i, k)] += &(c.clone() * &other_j[k])
823                }
824            }
825        }
826        out
827    }
828}
829
830impl<F: FieldNTT> Matrix<F> {
831    fn ntt<const FORWARD: bool>(&mut self) {
832        ntt::<FORWARD, F, Self>(self.rows, self.cols, self)
833    }
834}
835
836impl<F> Matrix<F> {
837    pub const fn rows(&self) -> usize {
838        self.rows
839    }
840
841    pub const fn cols(&self) -> usize {
842        self.cols
843    }
844
845    /// Iterate over the rows of this matrix.
846    pub fn iter(&self) -> impl Iterator<Item = &[F]> {
847        (0..self.rows).map(|i| &self[i])
848    }
849}
850
851impl<F: crate::algebra::Random> Matrix<F> {
852    /// Create a random matrix with certain dimensions.
853    pub fn rand(mut rng: impl CryptoRngCore, rows: usize, cols: usize) -> Self
854    where
855        F: Additive,
856    {
857        Self::init(rows, cols, (0..rows * cols).map(|_| F::random(&mut rng)))
858    }
859}
860
861impl<F> Index<usize> for Matrix<F> {
862    type Output = [F];
863
864    fn index(&self, index: usize) -> &Self::Output {
865        &self.data[self.cols * index..self.cols * (index + 1)]
866    }
867}
868
869impl<F> IndexMut<usize> for Matrix<F> {
870    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
871        &mut self.data[self.cols * index..self.cols * (index + 1)]
872    }
873}
874
875impl<F> Index<(usize, usize)> for Matrix<F> {
876    type Output = F;
877
878    fn index(&self, (i, j): (usize, usize)) -> &Self::Output {
879        &self.data[self.cols * i + j]
880    }
881}
882
883impl<F> IndexMut<(usize, usize)> for Matrix<F> {
884    fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
885        &mut self.data[self.cols * i + j]
886    }
887}
888
889#[cfg(any(test, feature = "arbitrary"))]
890impl<'a, F: arbitrary::Arbitrary<'a>> arbitrary::Arbitrary<'a> for Matrix<F> {
891    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
892        let rows = u.int_in_range(1..=16)?;
893        let cols = u.int_in_range(1..=16)?;
894        let data = (0..rows * cols)
895            .map(|_| F::arbitrary(u))
896            .collect::<arbitrary::Result<Vec<F>>>()?;
897        Ok(Self { rows, cols, data })
898    }
899}
900
901#[derive(Clone, Debug, PartialEq)]
902pub struct PolynomialVector<F> {
903    // Each column of this matrix contains the coefficients of a polynomial,
904    // in reverse bit order. So, the ith coefficient appears at index i.reverse_bits().
905    //
906    // For example, a polynomial a0 + a1 X + a2 X^2 + a3 X^3 is stored as:
907    //
908    // a0 a2 a1 a3
909    //
910    // This is convenient because the even coefficients and the odd coefficients
911    // split nicely into halves. The first half of the rows have the property
912    // that the first bit of their coefficient index is 0, then in that subset
913    // the first half has the second bit set to 0, and the second half set to 1,
914    // and so on, recursively.
915    data: Matrix<F>,
916}
917
918impl<F: Additive> PolynomialVector<F> {
919    /// Construct a new vector of polynomials, from dimensions, and coefficients.
920    ///
921    /// The coefficients should be supplied in order of increasing index,
922    /// and then for each polynomial.
923    ///
924    /// In other words, if you have 3 polynomials:
925    ///
926    /// a0 + a1 X + ...
927    /// b0 + b1 X + ...
928    /// c0 + c1 X + ...
929    ///
930    /// The iterator should yield:
931    ///
932    /// a0 b0 c0
933    /// a1 b1 c1
934    /// ...
935    ///
936    /// Any coefficients not supplied are treated as being equal to 0.
937    fn new(rows: usize, cols: usize, mut coefficients: impl Iterator<Item = F>) -> Self {
938        assert!(rows > 0);
939        let rows = rows.next_power_of_two();
940        let lg_rows = rows.ilog2();
941        let mut data = Matrix::zero(rows, cols);
942        'outer: for i in 0..rows {
943            let row_i = &mut data[reverse_bits(lg_rows, i as u64) as usize];
944            for row_i_j in row_i {
945                let Some(c) = coefficients.next() else {
946                    break 'outer;
947                };
948                *row_i_j = c;
949            }
950        }
951        Self { data }
952    }
953}
954
955impl<F: FieldNTT> PolynomialVector<F> {
956    /// Evaluate each polynomial in this vector over all points in an interpolation domain.
957    pub fn evaluate(mut self) -> EvaluationVector<F> {
958        self.data.ntt::<true>();
959        let active_rows = VanishingPoints::all_non_vanishing(self.data.rows().ilog2());
960        EvaluationVector {
961            data: self.data,
962            active_rows,
963        }
964    }
965
966    /// Like [Self::evaluate], but with a simpler algorithm that's much less efficient.
967    ///
968    /// Exists as a useful tool for testing
969    #[cfg(any(test, feature = "fuzz"))]
970    fn evaluate_naive(self) -> EvaluationVector<F> {
971        let rows = self.data.rows;
972        let lg_rows = rows.ilog2();
973        let w = F::root_of_unity(lg_rows as u8).expect("too much data to calculate NTT");
974        // entry (i, j) of this matrix will contain w^ij. Thus, multiplying it
975        // with the coefficients of a polynomial, in column order, will evaluate it.
976        // We also need to re-arrange the columns of the matrix to match the same
977        // order we have for polynomial coefficients.
978        let mut vandermonde_matrix = Matrix::zero(rows, rows);
979        let mut w_i = F::one();
980        for i in 0..rows {
981            let row_i = &mut vandermonde_matrix[i];
982            let mut w_ij = F::one();
983            for j in 0..rows {
984                // Remember, the coeffients of the polynomial are in reverse bit order!
985                row_i[reverse_bits(lg_rows, j as u64) as usize] = w_ij.clone();
986                w_ij *= &w_i;
987            }
988            w_i *= &w;
989        }
990
991        EvaluationVector {
992            data: vandermonde_matrix.mul(&self.data),
993            active_rows: VanishingPoints::all_non_vanishing(lg_rows),
994        }
995    }
996
997    /// Divide the roots of each polynomial by some factor.
998    ///
999    /// c.f. [`PolynomialColumn::divide_roots`]. This performs the same operation on
1000    /// each polynomial in this vector.
1001    fn divide_roots(&mut self, factor: F) {
1002        let mut factor_i = F::one();
1003        let lg_rows = self.data.rows.ilog2();
1004        for i in 0..self.data.rows {
1005            for p_i in &mut self.data[reverse_bits(lg_rows, i as u64) as usize] {
1006                *p_i *= &factor_i;
1007            }
1008            factor_i *= &factor;
1009        }
1010    }
1011
1012    /// For each polynomial P_i in this vector compute the evaluation of P_i / Q.
1013    ///
1014    /// Naturally, you can call [EvaluationVector::interpolate]. The reason we don't
1015    /// do this is that the algorithm naturally yields an [EvaluationVector], and
1016    /// some use-cases may want access to that data as well.
1017    ///
1018    /// This assumes that the number of coefficients in the polynomials of this vector
1019    /// matches that of `q` (the coefficients can be 0, but need to be padded to the right size).
1020    ///
1021    /// This assumes that `q` has no zeroes over `coset_shift() * root_of_unity()^i`,
1022    /// for any i. This will be the case for a vanishing polynomial produced by
1023    /// [EvaluationColumn::vanishing] and then interpolated.
1024    /// If this isn't the case, the result may be junk.
1025    ///
1026    /// If `q` doesn't divide a partiular polynomial in this vector, the result
1027    /// for that polynomial is not guaranteed to be anything meaningful.
1028    fn divide(&mut self, mut q: PolynomialColumn<F>) {
1029        // The algorithm operates column wise.
1030        //
1031        // You can compute P(X) / Q(X) by evaluating each polynomial, then computing
1032        //
1033        //   P(w^i) / Q(w^i)
1034        //
1035        // for each evaluation point. Then, you can interpolate back.
1036        //
1037        // But wait! What if Q(w^i) = 0? In particular, for the case of recovering
1038        // a polynomial from data with missing rows, we *expect* P(w^i) = 0 = Q(w^i)
1039        // for the indicies we're missing, so this doesn't work.
1040        //
1041        // What we can do is to instead multiply each of the roots by some factor z,
1042        // such that z w^i != w^j, for any i, j. In other words, we change the roots
1043        // such that they're not in the evaluation domain anymore, allowing us to
1044        // divide. We can then interpolate the result back into a polynomial,
1045        // and divide back the roots to where they should be.
1046        //
1047        // c.f. [PolynomialVector::divide_roots]
1048        assert_eq!(
1049            self.data.rows,
1050            q.coefficients.len(),
1051            "cannot divide by polynomial of the wrong size"
1052        );
1053        let skew = F::coset_shift();
1054        let skew_inv = F::coset_shift_inv();
1055        self.divide_roots(skew.clone());
1056        q.divide_roots(skew);
1057        ntt::<true, F, _>(self.data.rows, self.data.cols, &mut self.data);
1058        ntt::<true, F, _>(
1059            q.coefficients.len(),
1060            1,
1061            &mut Columns {
1062                data: [&mut q.coefficients],
1063            },
1064        );
1065        // Do a point wise division.
1066        for i in 0..self.data.rows {
1067            let q_i = q.coefficients[i].clone();
1068            // If `q_i = 0`, then we will get 0 in the output.
1069            // We don't expect any of the q_i to be 0, but being 0 is only one
1070            // of the many possibilities for the coefficient to be incorrect,
1071            // so doing a runtime assertion here doesn't make sense.
1072            let q_i_inv = q_i.inv();
1073            for d_i_j in &mut self.data[i] {
1074                *d_i_j *= &q_i_inv;
1075            }
1076        }
1077        // Interpolate back, using the inverse skew
1078        ntt::<false, F, _>(self.data.rows, self.data.cols, &mut self.data);
1079        self.divide_roots(skew_inv);
1080    }
1081}
1082
1083impl<F> PolynomialVector<F> {
1084    /// Iterate over up to n rows of this vector.
1085    ///
1086    /// For example, given polynomials:
1087    ///
1088    ///   a0 + a1 X + a2 X^2 + ...
1089    ///   b0 + b1 X + b2 X^2 + ...
1090    ///
1091    /// This will return:
1092    ///
1093    ///   a0 b0
1094    ///   a1 b1
1095    ///   ...
1096    ///
1097    /// up to n times.
1098    pub fn coefficients_up_to(&self, n: usize) -> impl Iterator<Item = &[F]> {
1099        let n = n.min(self.data.rows);
1100        let lg_rows = self.data.rows().ilog2();
1101        (0..n).map(move |i| &self.data[reverse_bits(lg_rows, i as u64) as usize])
1102    }
1103}
1104
1105/// The result of evaluating a vector of polynomials over all points in an interpolation domain.
1106///
1107/// This struct also remembers which rows have ever been filled with [Self::fill_row].
1108/// This is used in [Self::recover], which can use the rows that are present to fill in the missing
1109/// rows.
1110#[derive(Debug, PartialEq)]
1111pub struct EvaluationVector<F> {
1112    data: Matrix<F>,
1113    active_rows: VanishingPoints,
1114}
1115
1116impl<F: FieldNTT> EvaluationVector<F> {
1117    /// Figure out the polynomial which evaluates to this vector.
1118    ///
1119    /// i.e. the inverse of [PolynomialVector::evaluate].
1120    ///
1121    /// (This makes all the rows count as filled).
1122    fn interpolate(mut self) -> PolynomialVector<F> {
1123        self.data.ntt::<false>();
1124        PolynomialVector { data: self.data }
1125    }
1126
1127    /// Erase a particular row.
1128    ///
1129    /// Useful for testing the recovery procedure.
1130    #[cfg(any(test, feature = "fuzz"))]
1131    fn remove_row(&mut self, row: usize) {
1132        self.data[row].fill(F::zero());
1133        self.active_rows.set(row as u64, false);
1134    }
1135
1136    fn multiply(&mut self, evaluation: &EvaluationColumn<F>) {
1137        for (i, e_i) in evaluation.evaluations.iter().enumerate() {
1138            for self_j in &mut self.data[i] {
1139                *self_j = self_j.clone() * e_i;
1140            }
1141        }
1142    }
1143
1144    /// Attempt to recover the missing rows in this data.
1145    pub fn recover(mut self) -> PolynomialVector<F> {
1146        let non_vanishing = self.active_rows.count_non_vanishing();
1147        if non_vanishing == 0 || non_vanishing == self.data.rows as u64 {
1148            return self.interpolate();
1149        }
1150
1151        // If we had all of the rows, we could simply call [Self::interpolate],
1152        // in order to recover the original polynomial. If we do this while missing some
1153        // rows, what we get is D(X) * V(X) where D is the original polynomial,
1154        // and V(X) is a polynomial which vanishes at all the rows we're missing.
1155        //
1156        // As long as the degree of D is low enough, compared to the number of evaluations
1157        // we *do* have, then we can recover it by performing:
1158        //
1159        //   (D(X) * V(X)) / V(X)
1160        //
1161        // If we have multiple columns, then this procedure can be done column by column,
1162        // with the same vanishing polynomial.
1163        let (_, vanishing) = EvaluationColumn::vanishing(&self.active_rows);
1164        self.multiply(&vanishing);
1165        let mut out = self.interpolate();
1166        out.divide(vanishing.interpolate());
1167        out
1168    }
1169}
1170
1171impl<F: Additive> EvaluationVector<F> {
1172    /// Create an empty element of this struct, with no filled rows.
1173    ///
1174    /// `2^lg_rows` must be a valid `usize`.
1175    pub fn empty(lg_rows: usize, cols: usize) -> Self {
1176        assert!(
1177            lg_rows < usize::BITS as usize,
1178            "2^lg_rows must be a valid usize"
1179        );
1180        let data = Matrix::zero(1 << lg_rows, cols);
1181        let active = VanishingPoints::new(lg_rows as u32);
1182        Self {
1183            data,
1184            active_rows: active,
1185        }
1186    }
1187
1188    /// Fill a specific row.
1189    pub fn fill_row(&mut self, row: usize, data: &[F])
1190    where
1191        F: Clone,
1192    {
1193        assert!(data.len() <= self.data.cols);
1194        self.data[row][..data.len()].clone_from_slice(data);
1195        self.active_rows.set(row as u64, true);
1196    }
1197}
1198
1199impl<F> EvaluationVector<F> {
1200    /// Get the underlying data, as a Matrix.
1201    pub fn data(self) -> Matrix<F> {
1202        self.data
1203    }
1204
1205    /// Return how many distinct rows have been filled.
1206    pub fn filled_rows(&self) -> usize {
1207        self.active_rows.count_non_vanishing() as usize
1208    }
1209}
1210
1211/// Compute Lagrange coefficients for interpolating a polynomial at 0 from evaluations
1212/// at roots of unity.
1213///
1214/// Given a subset S of indices where we have evaluations, this computes the Lagrange
1215/// coefficients needed to interpolate to 0. For each index `j` in S, the coefficient
1216/// is `L_j(0)` where `L_j` is the Lagrange basis polynomial.
1217///
1218/// The key formula is: `L_j(0) = P_Sbar(w^j) / (N * P_Sbar(0))`
1219///
1220/// where `P_Sbar` is the (possibly scaled) vanishing polynomial over the complement
1221/// (missing points), and N is the domain size. This follows from
1222/// `V_S(X) * V_Sbar(X) = X^N - 1`, which gives `V_S(0) = -1/V_Sbar(0)`.
1223/// The scaling factor of `P_Sbar` cancels in the ratio.
1224///
1225/// Building `P_Sbar` as the vanishing polynomial over missing points is cheaper than building `V_S`
1226/// when most points are present (the typical erasure-coding case), since `|Sbar| << |S|`.
1227///
1228/// # Arguments
1229/// * `total` - The total number of points in the domain (rounded up to power of 2)
1230/// * `iter` - Iterator of indices where we have evaluations (duplicates ignored, indices >= total ignored)
1231///
1232/// # Returns
1233/// A vector of `(index, coefficient)` pairs for each unique index in the input set.
1234pub fn lagrange_coefficients<F: FieldNTT>(
1235    total: NonZeroU32,
1236    iter: impl IntoIterator<Item = u32>,
1237) -> Vec<(u32, F)> {
1238    let total_u64 = u64::from(total.get());
1239    let size = total_u64.next_power_of_two();
1240    let lg_size = size.ilog2();
1241
1242    let mut present = VanishingPoints::new(lg_size);
1243    for i in iter {
1244        let i_u64 = u64::from(i);
1245        if i_u64 < total_u64 {
1246            present.set_non_vanishing(i_u64);
1247        }
1248    }
1249
1250    let num_present = present.count_non_vanishing();
1251
1252    if num_present == 0 {
1253        return Vec::new();
1254    }
1255
1256    let n_f = F::one().scale(&[size]);
1257    if num_present == size {
1258        let n_inv = n_f.inv();
1259        return (0..size).map(|i| (i as u32, n_inv.clone())).collect();
1260    }
1261
1262    // Build P_Sbar (vanishes at indices NOT in present) and evaluate at all
1263    // roots of unity via NTT. Note: vanishing() may produce a scaled polynomial
1264    // P_Sbar = c * V_Sbar, but the scaling cancels in the ratio below.
1265    let (p_sbar_at_zero, complement_evals) = EvaluationColumn::vanishing(&present);
1266
1267    // From V_S(0) * V_Sbar(0) = -1 (since V_S * V_Sbar = X^N - 1), we get:
1268    //   L_j(0) = -V_S(0) * V_Sbar(w^j) / N = V_Sbar(w^j) / (N * V_Sbar(0))
1269    // Since P_Sbar = c * V_Sbar, the scaling c cancels:
1270    //   L_j(0) = P_Sbar(w^j) / (N * P_Sbar(0))
1271    let factor = (n_f * &p_sbar_at_zero).inv();
1272
1273    let mut out = Vec::with_capacity(num_present as usize);
1274    for j in 0..size {
1275        if present.get(j) {
1276            let coeff = factor.clone() * &complement_evals.evaluations[j as usize];
1277            out.push((j as u32, coeff));
1278        }
1279    }
1280    out
1281}
1282
1283#[cfg(any(test, feature = "fuzz"))]
1284pub mod fuzz {
1285    use super::*;
1286    use crate::{algebra::Ring, fields::goldilocks::F};
1287    use arbitrary::{Arbitrary, Unstructured};
1288
1289    fn arb_polynomial_vector(
1290        u: &mut Unstructured<'_>,
1291        max_log_rows: u32,
1292        max_cols: usize,
1293    ) -> arbitrary::Result<PolynomialVector<F>> {
1294        let lg_rows = u.int_in_range(0..=max_log_rows)?;
1295        let cols = u.int_in_range(1..=max_cols)?;
1296        let rows = 1usize << lg_rows;
1297        let coefficients: Vec<F> = (0..rows * cols)
1298            .map(|_| Ok(F::from(u.arbitrary::<u64>()?)))
1299            .collect::<arbitrary::Result<_>>()?;
1300        Ok(PolynomialVector::new(rows, cols, coefficients.into_iter()))
1301    }
1302
1303    fn arb_bit_vec_not_all_0(
1304        u: &mut Unstructured<'_>,
1305        max_log_rows: u32,
1306    ) -> arbitrary::Result<VanishingPoints> {
1307        let lg_rows = u.int_in_range(0..=max_log_rows)?;
1308        let rows = 1usize << lg_rows;
1309        let set_row = u.int_in_range(0..=rows - 1)?;
1310        let mut bools: Vec<bool> = (0..rows)
1311            .map(|_| u.arbitrary())
1312            .collect::<arbitrary::Result<_>>()?;
1313        bools[set_row] = true;
1314        let mut out = VanishingPoints::new(lg_rows);
1315        for (i, b) in bools.into_iter().enumerate() {
1316            out.set(i as u64, b);
1317        }
1318        Ok(out)
1319    }
1320
1321    fn arb_recovery_setup(
1322        u: &mut Unstructured<'_>,
1323        max_n: usize,
1324        max_k: usize,
1325        max_cols: usize,
1326    ) -> arbitrary::Result<RecoverySetup> {
1327        let n = u.int_in_range(1..=max_n)?;
1328        let k = u.int_in_range(0..=max_k)?;
1329        let cols = u.int_in_range(1..=max_cols)?;
1330        let data: Vec<F> = (0..n * cols)
1331            .map(|_| Ok(F::from(u.arbitrary::<u64>()?)))
1332            .collect::<arbitrary::Result<_>>()?;
1333        let padded_rows = (n + k).next_power_of_two();
1334        let num_present = u.int_in_range(n..=padded_rows)?;
1335        let mut indices: Vec<usize> = (0..padded_rows).collect();
1336        for i in 0..num_present {
1337            let j = u.int_in_range(i..=padded_rows - 1)?;
1338            indices.swap(i, j);
1339        }
1340        let mut present = VanishingPoints::new(padded_rows.ilog2());
1341        for &i in &indices[..num_present] {
1342            present.set(i as u64, true);
1343        }
1344        Ok(RecoverySetup {
1345            n,
1346            k,
1347            cols,
1348            data,
1349            present,
1350        })
1351    }
1352
1353    #[derive(Debug)]
1354    pub struct RecoverySetup {
1355        n: usize,
1356        k: usize,
1357        cols: usize,
1358        data: Vec<F>,
1359        present: VanishingPoints,
1360    }
1361
1362    impl RecoverySetup {
1363        #[cfg(test)]
1364        pub(crate) const fn new(
1365            n: usize,
1366            k: usize,
1367            cols: usize,
1368            data: Vec<F>,
1369            present: VanishingPoints,
1370        ) -> Self {
1371            Self {
1372                n,
1373                k,
1374                cols,
1375                data,
1376                present,
1377            }
1378        }
1379
1380        pub fn test(self) {
1381            let data = PolynomialVector::new(self.n + self.k, self.cols, self.data.into_iter());
1382            let mut encoded = data.clone().evaluate();
1383            for (i, b_i) in self.present.iter_bits_in_order().enumerate() {
1384                if !b_i {
1385                    encoded.remove_row(i);
1386                }
1387            }
1388            let recovered_data = encoded.recover();
1389            assert_eq!(data, recovered_data);
1390        }
1391    }
1392
1393    #[derive(Debug)]
1394    pub enum Plan {
1395        NttEqNaive(PolynomialVector<F>),
1396        EvaluationThenInverse(PolynomialVector<F>),
1397        VanishingPolynomial(VanishingPoints),
1398        Recovery(RecoverySetup),
1399    }
1400
1401    impl<'a> Arbitrary<'a> for Plan {
1402        fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result<Self> {
1403            match u.int_in_range(0..=3)? {
1404                0 => Ok(Self::NttEqNaive(arb_polynomial_vector(u, 6, 4)?)),
1405                1 => Ok(Self::EvaluationThenInverse(arb_polynomial_vector(u, 6, 4)?)),
1406                2 => Ok(Self::VanishingPolynomial(arb_bit_vec_not_all_0(u, 8)?)),
1407                _ => Ok(Self::Recovery(arb_recovery_setup(u, 128, 128, 4)?)),
1408            }
1409        }
1410    }
1411
1412    impl Plan {
1413        pub fn run(self, _u: &mut Unstructured<'_>) -> arbitrary::Result<()> {
1414            match self {
1415                Self::NttEqNaive(p) => {
1416                    let ntt = p.clone().evaluate();
1417                    let ntt_naive = p.evaluate_naive();
1418                    assert_eq!(ntt, ntt_naive);
1419                }
1420                Self::EvaluationThenInverse(p) => {
1421                    assert_eq!(p.clone(), p.evaluate().interpolate());
1422                }
1423                Self::VanishingPolynomial(bv) => {
1424                    let total = 1u64 << bv.lg_size();
1425                    let expected_degree = total - bv.count_non_vanishing();
1426                    let (at_zero, evals) = EvaluationColumn::<F>::vanishing(&bv);
1427                    let v = evals.interpolate();
1428                    assert_eq!(
1429                        v.degree(),
1430                        expected_degree as usize,
1431                        "expected v to have degree {}",
1432                        expected_degree
1433                    );
1434                    assert_eq!(
1435                        at_zero, v.coefficients[0],
1436                        "at_zero should be the 0th coefficient"
1437                    );
1438                    let w = F::root_of_unity(bv.lg_size() as u8).unwrap();
1439                    let mut w_i = F::one();
1440                    for b_i in bv.iter_bits_in_order() {
1441                        let v_at_w_i = v.evaluate_one(w_i);
1442                        if !b_i {
1443                            assert_eq!(v_at_w_i, F::zero(), "v should evaluate to 0 at {:?}", w_i);
1444                        } else {
1445                            assert_ne!(v_at_w_i, F::zero());
1446                        }
1447                        w_i = w_i * w;
1448                    }
1449                }
1450                Self::Recovery(setup) => {
1451                    setup.test();
1452                }
1453            }
1454            Ok(())
1455        }
1456    }
1457
1458    #[test]
1459    fn test_fuzz() {
1460        use commonware_invariants::minifuzz;
1461        minifuzz::test(|u| u.arbitrary::<Plan>()?.run(u));
1462    }
1463}
1464
1465#[cfg(test)]
1466mod test {
1467    use super::*;
1468    use crate::{algebra::Ring, fields::goldilocks::F};
1469
1470    #[test]
1471    fn test_reverse_bits() {
1472        assert_eq!(reverse_bits(4, 0b1000), 0b0001);
1473        assert_eq!(reverse_bits(4, 0b0100), 0b0010);
1474        assert_eq!(reverse_bits(4, 0b0010), 0b0100);
1475        assert_eq!(reverse_bits(4, 0b0001), 0b1000);
1476    }
1477
1478    #[test]
1479    fn matrix_read_rejects_length_mismatch() {
1480        use bytes::BytesMut;
1481        use commonware_codec::{Read as _, Write as _};
1482
1483        let mut buf = BytesMut::new();
1484        (2usize).write(&mut buf);
1485        (2usize).write(&mut buf);
1486        vec![F::one(); 3].write(&mut buf);
1487
1488        let mut bytes = buf.freeze();
1489        let result = Matrix::<F>::read_cfg(&mut bytes, &(8, ()));
1490        assert!(matches!(
1491            result,
1492            Err(commonware_codec::Error::Invalid(
1493                "Matrix",
1494                "matrix element count does not match dimensions"
1495            ))
1496        ));
1497    }
1498
1499    fn assert_vanishing_points_correct(points: &VanishingPoints) {
1500        let expected_degree = (1 << points.lg_size()) - points.count_non_vanishing();
1501        let (at_zero, evaluations) = EvaluationColumn::<F>::vanishing(points);
1502        if points.count_non_vanishing() == 0 {
1503            // EvaluationColumn::vanishing assumes at least one non-vanishing point.
1504            // We still invoke it so callers can exercise internal branch coverage.
1505            return;
1506        }
1507        let polynomial = evaluations.interpolate();
1508        assert_eq!(
1509            polynomial.degree(),
1510            expected_degree as usize,
1511            "expected v to have degree {expected_degree}"
1512        );
1513        assert_eq!(
1514            at_zero, polynomial.coefficients[0],
1515            "at_zero should be the 0th coefficient"
1516        );
1517        let w = F::root_of_unity(points.lg_size() as u8).unwrap();
1518        let mut w_i = F::one();
1519        for (i, point_is_non_vanishing) in points.iter_bits_in_order().enumerate() {
1520            let value = polynomial.evaluate_one(w_i);
1521            if point_is_non_vanishing {
1522                assert_ne!(value, F::zero(), "expected non-zero at i={i}");
1523            } else {
1524                assert_eq!(value, F::zero(), "expected zero at i={i}");
1525            }
1526            w_i = w_i * w;
1527        }
1528    }
1529
1530    #[test]
1531    fn test_recovery_000() {
1532        let present = {
1533            let mut out = VanishingPoints::new(1);
1534            out.set_non_vanishing(1);
1535            out
1536        };
1537        fuzz::RecoverySetup::new(1, 1, 1, vec![F::one()], present).test()
1538    }
1539
1540    #[test]
1541    fn test_recovery_empty_vector() {
1542        let recovered = EvaluationVector::<F>::empty(4, 3).recover();
1543        let expected = EvaluationVector::<F>::empty(4, 3).interpolate();
1544        assert_eq!(recovered, expected);
1545    }
1546
1547    #[test]
1548    fn test_vanishing_polynomial_all_two_chunk_combinations() {
1549        fn fill_half(points: &mut VanishingPoints, half: usize, values: [bool; 2]) {
1550            let chunk_size = 1usize << LG_VANISHING_BASE;
1551            let start = half * chunk_size;
1552            let lg_size = points.lg_size();
1553            for i in 0..chunk_size {
1554                let value = values[i % 2];
1555                let raw_index = (start + i) as u64;
1556                points.set(reverse_bits(lg_size, raw_index), value);
1557            }
1558        }
1559
1560        let lg_size = LG_VANISHING_BASE + 1;
1561        // (0,0) => Everywhere, (0,1) => Somewhere, (1,1) => Nowhere.
1562        let states = [[false, false], [false, true], [true, true]];
1563        for left in states {
1564            for right in states {
1565                let mut points = VanishingPoints::new(lg_size);
1566                // VanishingPoints stores roots in reverse bit order. Writing raw halves
1567                // directly makes chunk 0/1 align exactly with the implementation's chunks.
1568                fill_half(&mut points, 0, left);
1569                fill_half(&mut points, 1, right);
1570                assert_vanishing_points_correct(&points);
1571            }
1572        }
1573    }
1574
1575    #[cfg(feature = "arbitrary")]
1576    mod conformance {
1577        use super::*;
1578        use commonware_codec::conformance::CodecConformance;
1579
1580        commonware_conformance::conformance_tests! {
1581            CodecConformance<Matrix<F>>,
1582        }
1583    }
1584}