Skip to main content

commonware_math/
poly.rs

1use crate::algebra::{
2    msm_naive, powers, Additive, CryptoGroup, Field, Object, Random, Ring, Space,
3};
4#[cfg(not(feature = "std"))]
5use alloc::{borrow::Cow, vec, vec::Vec};
6use commonware_codec::{EncodeSize, RangeCfg, Read, Write};
7use commonware_parallel::Strategy;
8use commonware_utils::{non_empty_vec, ordered::Map, vec::NonEmptyVec, TryCollect};
9use core::{
10    fmt::Debug,
11    iter,
12    num::NonZeroU32,
13    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
14};
15use rand_core::CryptoRngCore;
16#[cfg(feature = "std")]
17use std::borrow::Cow;
18
19// SECTION: Performance knobs.
20const MIN_POINTS_FOR_MSM: usize = 2;
21
22/// A polynomial, with coefficients in `K`.
23#[derive(Clone)]
24pub struct Poly<K> {
25    // Invariant: (1..=u32::MAX).contains(coeffs.len())
26    coeffs: NonEmptyVec<K>,
27}
28
29impl<K> Poly<K> {
30    fn len(&self) -> NonZeroU32 {
31        self.coeffs
32            .len()
33            .try_into()
34            .expect("Impossible: polynomial length not in 1..=u32::MAX")
35    }
36
37    const fn len_usize(&self) -> usize {
38        self.coeffs.len().get()
39    }
40
41    /// Internal method to construct a polynomial from an iterator.
42    ///
43    /// This will panic if the iterator does not return any coefficients,
44    /// so make sure that the iterator you pass to this function does that.
45    fn from_iter_unchecked(iter: impl IntoIterator<Item = K>) -> Self {
46        let coeffs = iter
47            .into_iter()
48            .try_collect::<NonEmptyVec<_>>()
49            .expect("polynomial must have a least 1 coefficient");
50        Self { coeffs }
51    }
52
53    /// The degree of this polynomial.
54    ///
55    /// Technically, this is only an *upper bound* on the degree, because
56    /// this method does not inspect the coefficients of a polynomial to check
57    /// if they're non-zero.
58    ///
59    /// Because of this, it's possible that two polynomials
60    /// considered equal have different degrees.
61    ///
62    /// For that behavior, see [`Self::degree_exact`].
63    pub fn degree(&self) -> u32 {
64        self.len().get() - 1
65    }
66
67    /// Return the number of evaluation points required to recover this polynomial.
68    ///
69    /// In other words, [`Self::degree`] + 1.
70    pub fn required(&self) -> NonZeroU32 {
71        self.len()
72    }
73
74    /// Return the constant value of this polynomial.
75    ///
76    /// I.e. the first coefficient.
77    pub fn constant(&self) -> &K {
78        &self.coeffs[0]
79    }
80
81    /// Translate the coefficients of this polynomial.
82    ///
83    /// This applies some kind of map to each coefficient, creating a new
84    /// polynomial.
85    pub fn translate<L>(&self, f: impl Fn(&K) -> L) -> Poly<L> {
86        Poly {
87            coeffs: self.coeffs.map(f),
88        }
89    }
90
91    /// Evaluate a polynomial at a particular point.
92    ///
93    /// For
94    ///
95    ///   `p(X) := a_0 + a_1 X + a_2 X^2 + ...`
96    ///
97    /// this returns:
98    ///
99    ///   `a_0 + a_1 r + a_2 r^2 + ...`
100    ///
101    /// This can work for any type which can multiply the coefficients of
102    /// this polynomial.
103    ///
104    /// For example, if you have a polynomial consistent of elements of a group,
105    /// you can evaluate it using a scalar over that group.
106    pub fn eval<R>(&self, r: &R) -> K
107    where
108        K: Space<R>,
109    {
110        let mut iter = self.coeffs.iter().rev();
111        // Evaluation using Horner's method.
112        //
113        // p(r)
114        // = a_0 + a_1 r + ... + a_n r^N =
115        // = a_n r^n + ...
116        // = ((a_n) r + a_(n - 1))r + ...)
117        let mut acc = iter
118            .next()
119            .expect("Impossible: Polynomial has no coefficients")
120            .clone();
121        for coeff in iter {
122            acc *= r;
123            acc += coeff;
124        }
125        acc
126    }
127
128    /// Like [`Self::eval`], but using [`Space::msm`].
129    ///
130    /// This method uses more scratch space, and requires cloning values of
131    /// type `R` more, but should be better if [`Space::msm`] has a better algorithm
132    /// for `K`.
133    pub fn eval_msm<R: Ring>(&self, r: &R, strategy: &impl Strategy) -> K
134    where
135        K: Space<R>,
136    {
137        // Contains 1, r, r^2, ...
138        let weights = powers(R::one(), r)
139            .take(self.len_usize())
140            .collect::<Vec<_>>();
141        K::msm(&self.coeffs, &weights, strategy)
142    }
143
144    /// Compute `sum_i a_i * self(b_i)`.
145    ///
146    /// This is more efficient than several calls to `eval_msm`, but produces the
147    /// same result.
148    ///
149    /// This returns `0` if the iterator is empty.
150    pub fn lin_comb_eval<'a, R: Ring + 'a>(
151        &self,
152        into_iter: impl IntoIterator<Item = (R, Cow<'a, R>)>,
153        strategy: &impl Strategy,
154    ) -> K
155    where
156        K: Space<R>,
157    {
158        // Contains a0 + a1 + ..., a0 b0 + a1 b1 + ..., a0 b0^2 + a1 b1^2 + ...
159        let weights = {
160            let mut iter = into_iter.into_iter();
161            let Some((a0, b0)) = iter.next() else {
162                return K::zero();
163            };
164
165            let len = self.len_usize();
166            let mut out: Vec<_> = powers(a0, b0.as_ref()).take(len).collect();
167            for (ai, bi) in iter {
168                powers(ai, bi.as_ref())
169                    .take(len)
170                    .zip(out.iter_mut())
171                    .for_each(|(c_j, o_j)| *o_j += &c_j);
172            }
173            out
174        };
175        K::msm(&self.coeffs, &weights, strategy)
176    }
177}
178
179impl<K: Debug> Debug for Poly<K> {
180    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
181        write!(f, "Poly(")?;
182        for (i, c) in self.coeffs.iter().enumerate() {
183            if i > 0 {
184                write!(f, " + {c:?} X^{i}")?;
185            } else {
186                write!(f, "{c:?}")?;
187            }
188        }
189        write!(f, ")")?;
190        Ok(())
191    }
192}
193
194impl<K: EncodeSize> EncodeSize for Poly<K> {
195    fn encode_size(&self) -> usize {
196        self.coeffs.encode_size()
197    }
198}
199
200impl<K: Write> Write for Poly<K> {
201    fn write(&self, buf: &mut impl bytes::BufMut) {
202        self.coeffs.write(buf);
203    }
204}
205
206impl<K: Read> Read for Poly<K> {
207    type Cfg = (RangeCfg<NonZeroU32>, <K as Read>::Cfg);
208
209    fn read_cfg(
210        buf: &mut impl bytes::Buf,
211        cfg: &Self::Cfg,
212    ) -> Result<Self, commonware_codec::Error> {
213        Ok(Self {
214            coeffs: NonEmptyVec::<K>::read_cfg(buf, &(cfg.0.into(), cfg.1.clone()))?,
215        })
216    }
217}
218
219impl<K: Random> Poly<K> {
220    // Returns a new polynomial of the given degree where each coefficient is
221    // sampled at random from the provided RNG.
222    pub fn new(mut rng: impl CryptoRngCore, degree: u32) -> Self {
223        Self::from_iter_unchecked((0..=degree).map(|_| K::random(&mut rng)))
224    }
225
226    /// Returns a new scalar polynomial with a particular value for the constant coefficient.
227    ///
228    /// This does the same thing as [`Poly::new`] otherwise.
229    pub fn new_with_constant(mut rng: impl CryptoRngCore, degree: u32, constant: K) -> Self {
230        Self::from_iter_unchecked(
231            iter::once(constant).chain((0..=degree).skip(1).map(|_| K::random(&mut rng))),
232        )
233    }
234}
235
236/// An equality test taking into account high 0 coefficients.
237///
238/// Without this behavior, the additive test suite does not past, because
239/// `x - x` may result in a polynomial with extra 0 coefficients.
240impl<K: Additive> PartialEq for Poly<K> {
241    fn eq(&self, other: &Self) -> bool {
242        let zero = K::zero();
243        let max_len = self.len().max(other.len());
244        let self_then_zeros = self.coeffs.iter().chain(iter::repeat(&zero));
245        let other_then_zeros = other.coeffs.iter().chain(iter::repeat(&zero));
246        self_then_zeros
247            .zip(other_then_zeros)
248            .take(max_len.get() as usize)
249            .all(|(a, b)| a == b)
250    }
251}
252
253impl<K: Additive> Eq for Poly<K> {}
254
255impl<K: Additive> Poly<K> {
256    fn merge_with(&mut self, rhs: &Self, f: impl Fn(&mut K, &K)) {
257        self.coeffs
258            .resize(self.coeffs.len().max(rhs.coeffs.len()), K::zero());
259        self.coeffs
260            .iter_mut()
261            .zip(&rhs.coeffs)
262            .for_each(|(a, b)| f(a, b));
263    }
264
265    /// Like [`Self::degree`], but checking for zero coefficients.
266    ///
267    /// This method is slower, but reports exact results in case there are zeros.
268    ///
269    /// This will return 0 for a polynomial with no coefficients.
270    pub fn degree_exact(&self) -> u32 {
271        let zero = K::zero();
272        let leading_zeroes = self.coeffs.iter().rev().take_while(|&x| x == &zero).count();
273        let lz_u32 =
274            u32::try_from(leading_zeroes).expect("Impossible: Polynomial has >= 2^32 coefficients");
275        // The saturation is critical, otherwise you get a negative number for
276        // the zero polynomial.
277        self.degree().saturating_sub(lz_u32)
278    }
279}
280
281impl<K: Additive> Object for Poly<K> {}
282
283// SECTION: implementing Additive
284
285impl<'a, K: Additive> AddAssign<&'a Self> for Poly<K> {
286    fn add_assign(&mut self, rhs: &'a Self) {
287        self.merge_with(rhs, |a, b| *a += b);
288    }
289}
290
291impl<'a, K: Additive> Add<&'a Self> for Poly<K> {
292    type Output = Self;
293
294    fn add(mut self, rhs: &'a Self) -> Self::Output {
295        self += rhs;
296        self
297    }
298}
299
300impl<'a, K: Additive> SubAssign<&'a Self> for Poly<K> {
301    fn sub_assign(&mut self, rhs: &'a Self) {
302        self.merge_with(rhs, |a, b| *a -= b);
303    }
304}
305
306impl<'a, K: Additive> Sub<&'a Self> for Poly<K> {
307    type Output = Self;
308
309    fn sub(mut self, rhs: &'a Self) -> Self::Output {
310        self -= rhs;
311        self
312    }
313}
314
315impl<K: Additive> Neg for Poly<K> {
316    type Output = Self;
317
318    fn neg(self) -> Self::Output {
319        Self {
320            coeffs: self.coeffs.map_into(Neg::neg),
321        }
322    }
323}
324
325impl<K: Additive> Additive for Poly<K> {
326    fn zero() -> Self {
327        Self {
328            coeffs: non_empty_vec![K::zero()],
329        }
330    }
331}
332
333// SECTION: implementing Space<K>.
334
335impl<'a, R, K: Space<R>> MulAssign<&'a R> for Poly<K> {
336    fn mul_assign(&mut self, rhs: &'a R) {
337        self.coeffs.iter_mut().for_each(|c| *c *= rhs);
338    }
339}
340
341impl<'a, R, K: Space<R>> Mul<&'a R> for Poly<K> {
342    type Output = Self;
343
344    fn mul(mut self, rhs: &'a R) -> Self::Output {
345        self *= rhs;
346        self
347    }
348}
349
350impl<R: Sync, K: Space<R> + Send> Space<R> for Poly<K> {
351    fn msm(polys: &[Self], scalars: &[R], strategy: &impl Strategy) -> Self {
352        if polys.len() < MIN_POINTS_FOR_MSM {
353            return msm_naive(polys, scalars);
354        }
355
356        let cols = polys.len().min(scalars.len());
357        let polys = &polys[..cols];
358        let scalars = &scalars[..cols];
359
360        let rows = polys
361            .iter()
362            .map(|x| x.len_usize())
363            .max()
364            .expect("at least 1 point");
365
366        let coeffs = strategy.map_init_collect_vec(
367            0..rows,
368            || Vec::with_capacity(cols),
369            |row, i| {
370                row.clear();
371                for p in polys {
372                    row.push(p.coeffs.get(i).cloned().unwrap_or_else(K::zero));
373                }
374                K::msm(row, scalars, strategy)
375            },
376        );
377        Self::from_iter_unchecked(coeffs)
378    }
379}
380
381impl<G: CryptoGroup> Poly<G> {
382    /// Commit to a polynomial of scalars, producing a polynomial of group elements.
383    pub fn commit(p: Poly<G::Scalar>) -> Self {
384        p.translate(|c| G::generator() * c)
385    }
386}
387
388/// An interpolator allows recovering a polynomial's constant from values.
389///
390/// This is useful for polynomial secret sharing. There, a secret is stored
391/// in the constant of a polynomial. Shares of the secret are created by
392/// evaluating the polynomial at various points. Given enough values for
393/// these points, the secret can be recovered.
394///
395/// Using an [`Interpolator`] can be more efficient, because work can be
396/// done in advance based only on the points that will be used for recovery,
397/// before the value of the polynomial at these points is known. The interpolator
398/// can use these values to recover the secret at a later time.
399///
400/// ### Usage
401///
402/// ```
403/// # use commonware_math::{fields::goldilocks::F, poly::{Poly, Interpolator}};
404/// # use commonware_parallel::Sequential;
405/// # use commonware_utils::TryCollect;
406/// # fn example(f: Poly<F>, g: Poly<F>, p0: F, p1: F) {
407///     let interpolator = Interpolator::new([(0, p0), (1, p1)]);
408///     assert_eq!(
409///         Some(*f.constant()),
410///         interpolator.interpolate(&[(0, f.eval(&p0)), (1, f.eval(&p1))].into_iter().try_collect().unwrap(), &Sequential)
411///     );
412///     assert_eq!(
413///         Some(*g.constant()),
414///         interpolator.interpolate(&[(1, g.eval(&p1)), (0, g.eval(&p0))].into_iter().try_collect().unwrap(), &Sequential)
415///     );
416/// # }
417/// ```
418pub struct Interpolator<I, F> {
419    weights: Map<I, F>,
420}
421
422impl<I: PartialEq, F: Ring> Interpolator<I, F> {
423    /// Interpolate a polynomial's evaluations to recover its constant.
424    ///
425    /// The indices provided here MUST match those provided to [`Self::new`] exactly,
426    /// otherwise `None` will be returned.
427    pub fn interpolate<K: Space<F>>(
428        &self,
429        evals: &Map<I, K>,
430        strategy: &impl Strategy,
431    ) -> Option<K> {
432        if evals.keys() != self.weights.keys() {
433            return None;
434        }
435        Some(K::msm(evals.values(), self.weights.values(), strategy))
436    }
437}
438
439impl<I: Clone + Ord, F: Field> Interpolator<I, F> {
440    /// Create a new interpolator, given an association from indices to evaluation points.
441    ///
442    /// If an index appears multiple times, the implementation is free to use
443    /// any one of the evaluation points associated with that index. In other words,
444    /// don't do that, or ensure that if, for some reason, an index appears more
445    /// than once, then it has the same evaluation point.
446    pub fn new(points: impl IntoIterator<Item = (I, F)>) -> Self {
447        let points = Map::from_iter_dedup(points);
448        let n = points.len();
449        if n == 0 {
450            return Self { weights: points };
451        }
452
453        // Compute W = product of all w_i
454        // Compute c_i = w_i * product((w_j - w_i) for j != i)
455        let values = points.values();
456        let zero = F::zero();
457        let mut total_product = F::one();
458        let mut c = Vec::with_capacity(n);
459        for (i, w_i) in values.iter().enumerate() {
460            // If evaluation point is zero, L_i(0) = 1 for this point and 0 for all others.
461            if w_i == &zero {
462                let mut out = points;
463                for (j, w) in out.values_mut().iter_mut().enumerate() {
464                    *w = if j == i { F::one() } else { F::zero() };
465                }
466                return Self { weights: out };
467            }
468
469            // Accumulate c_i = w_i * product((w_j - w_i) for j != i) for batch inversion.
470            total_product *= w_i;
471            let mut c_i = w_i.clone();
472            for w_j in values
473                .iter()
474                .enumerate()
475                .filter_map(|(j, v)| (j != i).then_some(v))
476            {
477                c_i *= &(w_j.clone() - w_i);
478            }
479            c.push(c_i);
480        }
481
482        // Batch inversion using Montgomery's trick to compute W/c_i for all i
483        // Step 1: Compute prefix products (prefix[i] = c[0] * ... * c[i-1])
484        let mut prefix = Vec::with_capacity(n + 1);
485        prefix.push(F::one());
486        let mut acc = F::one();
487        for c_i in &c {
488            acc *= c_i;
489            prefix.push(acc.clone());
490        }
491
492        // Step 2: Single inversion, multiplied by W
493        let mut inv_acc = total_product * &prefix[n].inv();
494
495        // Step 3: Compute weights directly into output
496        let mut out = points;
497        let out_vals = out.values_mut();
498        for i in (0..n).rev() {
499            out_vals[i] = inv_acc.clone() * &prefix[i];
500            inv_acc *= &c[i];
501        }
502        Self { weights: out }
503    }
504}
505
506#[commonware_macros::stability(ALPHA)]
507impl<I: Clone + Ord, F: crate::algebra::FieldNTT> Interpolator<I, F> {
508    /// Create an interpolator for evaluation points at roots of unity.
509    ///
510    /// This uses the fast O(n log n) algorithm from [`crate::ntt::lagrange_coefficients`].
511    ///
512    /// Each `(I, u32)` pair maps an index `I` to an evaluation point `w^k` where `w` is
513    /// a primitive root of unity of order `next_power_of_two(total)`.
514    ///
515    /// Indices `k >= total` are ignored.
516    pub fn roots_of_unity(
517        total: NonZeroU32,
518        points: commonware_utils::ordered::BiMap<I, u32>,
519    ) -> Self {
520        let weights = <Map<I, F> as commonware_utils::TryFromIterator<(I, F)>>::try_from_iter(
521            crate::ntt::lagrange_coefficients(total, points.values().iter().copied())
522                .into_iter()
523                .filter_map(|(k, coeff)| Some((points.get_key(&k)?.clone(), coeff))),
524        )
525        .expect("points has already been deduped");
526        Self { weights }
527    }
528
529    /// Create an interpolator for evaluation points at roots of unity using naive O(n^2) algorithm.
530    ///
531    /// This computes the actual root of unity values and delegates to [`Interpolator::new`].
532    /// Useful for testing against [`Self::roots_of_unity`].
533    ///
534    /// Indices `k >= total` are ignored.
535    #[cfg(any(test, feature = "fuzz"))]
536    fn roots_of_unity_naive(
537        total: NonZeroU32,
538        points: commonware_utils::ordered::BiMap<I, u32>,
539    ) -> Self {
540        use crate::algebra::powers;
541
542        let total_u32 = total.get();
543        let size = (total_u32 as u64).next_power_of_two();
544        let lg_size = size.ilog2() as u8;
545        let w = F::root_of_unity(lg_size).expect("domain too large for NTT");
546
547        let points: Vec<(I, u32)> = points.into_iter().filter(|(_, k)| *k < total_u32).collect();
548        let max_k = points.iter().map(|(_, k)| *k).max().unwrap_or(0) as usize;
549        let powers: Vec<_> = powers(F::one(), &w).take(max_k + 1).collect();
550
551        let eval_points = points
552            .into_iter()
553            .map(|(i, k)| (i, powers[k as usize].clone()));
554        Self::new(eval_points)
555    }
556}
557
558#[cfg(any(test, feature = "arbitrary"))]
559mod impl_arbitrary {
560    use super::*;
561    use arbitrary::Arbitrary;
562
563    impl<'a, F: Arbitrary<'a>> Arbitrary<'a> for Poly<F> {
564        fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
565            let first = u.arbitrary()?;
566            let rest: Vec<F> = u.arbitrary()?;
567            let mut coeffs = NonEmptyVec::new(first);
568            coeffs.extend(rest);
569            Ok(Self { coeffs })
570        }
571    }
572}
573
574#[commonware_macros::stability(ALPHA)]
575#[cfg(any(test, feature = "fuzz"))]
576pub mod fuzz {
577    use super::*;
578    use crate::{
579        algebra::test_suites,
580        test::{F, G},
581    };
582    use arbitrary::{Arbitrary, Unstructured};
583    use commonware_codec::Encode as _;
584    use commonware_parallel::Sequential;
585    use commonware_utils::{
586        ordered::{BiMap, Map},
587        TryFromIterator,
588    };
589
590    #[derive(Debug, Arbitrary)]
591    pub enum Plan {
592        Codec(Poly<F>),
593        EvalAdd(Poly<F>, Poly<F>, F),
594        EvalScale(Poly<F>, F, F),
595        EvalZero(Poly<F>),
596        EvalMsm(Poly<F>, F),
597        LinCombEval(Poly<F>, Vec<(F, F)>),
598        Interpolate(Poly<F>),
599        InterpolateWithZeroPoint(Poly<F>),
600        InterpolateWithZeroPointMiddle(Poly<F>),
601        TranslateScale(Poly<F>, F),
602        CommitEval(Poly<F>, F),
603        RootsOfUnityEqNaive(u16),
604        FuzzAdditive,
605        FuzzSpaceRing,
606    }
607
608    impl Plan {
609        pub fn run(self, u: &mut Unstructured<'_>) -> arbitrary::Result<()> {
610            match self {
611                Self::Codec(f) => {
612                    assert_eq!(
613                        &f,
614                        &Poly::<F>::read_cfg(&mut f.encode(), &(RangeCfg::exact(f.required()), ()))
615                            .unwrap()
616                    );
617                }
618                Self::EvalAdd(f, g, x) => {
619                    assert_eq!(f.eval(&x) + &g.eval(&x), (f + &g).eval(&x));
620                }
621                Self::EvalScale(f, x, w) => {
622                    assert_eq!(f.eval(&x) * &w, (f * &w).eval(&x));
623                }
624                Self::EvalZero(f) => {
625                    assert_eq!(&f.eval(&F::zero()), f.constant());
626                }
627                Self::EvalMsm(f, x) => {
628                    assert_eq!(f.eval(&x), f.eval_msm(&x, &Sequential));
629                }
630                Self::LinCombEval(f, pairs) => {
631                    let naive_eval = pairs.iter().fold(F::zero(), |mut acc, (a, b)| {
632                        acc += &(*a * &f.eval(b));
633                        acc
634                    });
635                    let lin_comb = f.lin_comb_eval(
636                        pairs.iter().map(|(a, b)| (*a, Cow::Borrowed(b))),
637                        &Sequential,
638                    );
639                    assert_eq!(naive_eval, lin_comb);
640                }
641                Self::Interpolate(f) => {
642                    if f == Poly::zero() || f.required().get() >= F::MAX as u32 {
643                        return Ok(());
644                    }
645                    let mut points = (0..f.required().get())
646                        .map(|i| F::from((i + 1) as u8))
647                        .collect::<Vec<_>>();
648                    let interpolator = Interpolator::new(points.iter().copied().enumerate());
649                    let evals = Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate());
650                    let recovered = interpolator.interpolate(&evals, &Sequential);
651                    assert_eq!(recovered.as_ref(), Some(f.constant()));
652                    points.pop();
653                    assert_eq!(
654                        interpolator.interpolate(
655                            &Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate()),
656                            &Sequential
657                        ),
658                        None
659                    );
660                }
661                Self::InterpolateWithZeroPoint(f) => {
662                    if f == Poly::zero() || f.required().get() >= F::MAX as u32 {
663                        return Ok(());
664                    }
665                    let points: Vec<_> =
666                        (0..f.required().get()).map(|i| F::from(i as u8)).collect();
667                    let interpolator = Interpolator::new(points.iter().copied().enumerate());
668                    let evals = Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate());
669                    let recovered = interpolator.interpolate(&evals, &Sequential);
670                    assert_eq!(recovered.as_ref(), Some(f.constant()));
671                }
672                Self::InterpolateWithZeroPointMiddle(f) => {
673                    if f == Poly::zero()
674                        || f.required().get() < 2
675                        || f.required().get() >= F::MAX as u32
676                    {
677                        return Ok(());
678                    }
679                    let n = f.required().get();
680                    let points: Vec<_> = (1..n)
681                        .map(|i| F::from(i as u8))
682                        .chain(core::iter::once(F::zero()))
683                        .collect();
684                    let interpolator = Interpolator::new(points.iter().copied().enumerate());
685                    let evals = Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate());
686                    let recovered = interpolator.interpolate(&evals, &Sequential);
687                    assert_eq!(recovered.as_ref(), Some(f.constant()));
688                }
689                Self::TranslateScale(f, x) => {
690                    assert_eq!(f.translate(|c| x * c), f * &x);
691                }
692                Self::CommitEval(f, x) => {
693                    assert_eq!(G::generator() * &f.eval(&x), Poly::<G>::commit(f).eval(&x));
694                }
695                Self::RootsOfUnityEqNaive(n) => {
696                    let n = (u32::from(n) % 256) + 1;
697                    let total = NonZeroU32::new(n).expect("n is in 1..=256");
698                    let points = BiMap::try_from_iter((0..n as usize).map(|i| (i, i as u32)))
699                        .expect("interpolation points should be bijective");
700                    let fast = Interpolator::<usize, crate::fields::goldilocks::F>::roots_of_unity(
701                        total,
702                        points.clone(),
703                    );
704                    let naive =
705                        Interpolator::<usize, crate::fields::goldilocks::F>::roots_of_unity_naive(
706                            total, points,
707                        );
708                    assert_eq!(fast.weights, naive.weights);
709                }
710                Self::FuzzAdditive => {
711                    test_suites::fuzz_additive::<Poly<F>>(u)?;
712                }
713                Self::FuzzSpaceRing => {
714                    test_suites::fuzz_space_ring::<F, Poly<F>>(u)?;
715                }
716            }
717            Ok(())
718        }
719    }
720
721    #[test]
722    fn test_fuzz() {
723        commonware_invariants::minifuzz::test(|u| u.arbitrary::<Plan>()?.run(u));
724    }
725}
726#[cfg(test)]
727mod test {
728    use super::{fuzz::Plan, *};
729    use crate::test::F;
730    use arbitrary::Unstructured;
731
732    #[test]
733    fn test_eq() {
734        fn eq(a: &[u8], b: &[u8]) -> bool {
735            Poly {
736                coeffs: a.iter().copied().map(F::from).try_collect().unwrap(),
737            } == Poly {
738                coeffs: b.iter().copied().map(F::from).try_collect().unwrap(),
739            }
740        }
741        assert!(eq(&[1, 2], &[1, 2]));
742        assert!(!eq(&[1, 2], &[2, 3]));
743        assert!(!eq(&[1, 2], &[1, 2, 3]));
744        assert!(!eq(&[1, 2, 3], &[1, 2]));
745        assert!(eq(&[1, 2], &[1, 2, 0, 0]));
746        assert!(eq(&[1, 2, 0, 0], &[1, 2]));
747        assert!(!eq(&[1, 2, 0], &[2, 3]));
748        assert!(!eq(&[2, 3], &[1, 2, 0]));
749    }
750
751    #[test]
752    fn lin_comb_eval_edge_cases() {
753        fn poly(coeffs: &[u8]) -> Poly<F> {
754            Poly {
755                coeffs: coeffs.iter().copied().map(F::from).try_collect().unwrap(),
756            }
757        }
758
759        fn pairs(values: &[(u8, u8)]) -> Vec<(F, F)> {
760            values
761                .iter()
762                .map(|(a, b)| (F::from(*a), F::from(*b)))
763                .collect()
764        }
765
766        let cases = [
767            Plan::LinCombEval(poly(&[3, 5, 7]), vec![]),
768            Plan::LinCombEval(poly(&[11]), pairs(&[(2, 0), (3, 1), (5, 8)])),
769            Plan::LinCombEval(poly(&[4, 6, 8]), pairs(&[(2, 5), (7, 5), (3, 5)])),
770            Plan::LinCombEval(poly(&[9, 2, 3, 4]), pairs(&[(6, 0), (1, 0), (5, 7)])),
771            Plan::LinCombEval(poly(&[1, 2, 4, 8]), pairs(&[(3, 1), (7, 1), (2, 6)])),
772        ];
773        let mut u = Unstructured::new(&[]);
774        for case in cases {
775            case.run(&mut u).unwrap();
776        }
777    }
778
779    #[cfg(feature = "arbitrary")]
780    mod conformance {
781        use super::*;
782        use commonware_codec::conformance::CodecConformance;
783
784        commonware_conformance::conformance_tests! {
785            CodecConformance<Poly<F>>
786        }
787    }
788}