commonware_math/
poly.rs

1use crate::algebra::{msm_naive, Additive, CryptoGroup, Field, Object, Random, Ring, Space};
2#[cfg(not(feature = "std"))]
3use alloc::{vec, vec::Vec};
4use commonware_codec::{EncodeSize, RangeCfg, Read, Write};
5use commonware_utils::{non_empty_vec, ordered::Map, vec::NonEmptyVec, TryCollect};
6use core::{
7    fmt::Debug,
8    iter,
9    num::NonZeroU32,
10    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
11};
12use rand_core::CryptoRngCore;
13#[cfg(feature = "std")]
14use rayon::{prelude::*, ThreadPoolBuilder};
15
16// SECTION: Performance knobs.
17const MIN_POINTS_FOR_MSM: usize = 2;
18
19/// A polynomial, with coefficients in `K`.
20#[derive(Clone)]
21pub struct Poly<K> {
22    // Invariant: (1..=u32::MAX).contains(coeffs.len())
23    coeffs: NonEmptyVec<K>,
24}
25
26impl<K> Poly<K> {
27    fn len(&self) -> NonZeroU32 {
28        self.coeffs
29            .len()
30            .try_into()
31            .expect("Impossible: polynomial length not in 1..=u32::MAX")
32    }
33
34    fn len_usize(&self) -> usize {
35        self.coeffs.len().get()
36    }
37
38    /// Internal method to construct a polynomial from an iterator.
39    ///
40    /// This will panic if the iterator does not return any coefficients,
41    /// so make sure that the iterator you pass to this function does that.
42    fn from_iter_unchecked(iter: impl IntoIterator<Item = K>) -> Self {
43        let coeffs = iter
44            .into_iter()
45            .try_collect::<NonEmptyVec<_>>()
46            .expect("polynomial must have a least 1 coefficient");
47        Self { coeffs }
48    }
49
50    /// The degree of this polynomial.
51    ///
52    /// Technically, this is only an *upper bound* on the degree, because
53    /// this method does not inspect the coefficients of a polynomial to check
54    /// if they're non-zero.
55    ///
56    /// Because of this, it's possible that two polynomials
57    /// considered equal have different degrees.
58    ///
59    /// For that behavior, see [`Self::degree_exact`].
60    pub fn degree(&self) -> u32 {
61        self.len().get() - 1
62    }
63
64    /// Return the number of evaluation points required to recover this polynomial.
65    ///
66    /// In other words, [`Self::degree`] + 1.
67    pub fn required(&self) -> NonZeroU32 {
68        self.len()
69    }
70
71    /// Return the constant value of this polynomial.
72    ///
73    /// I.e. the first coefficient.
74    pub fn constant(&self) -> &K {
75        &self.coeffs[0]
76    }
77
78    /// Translate the coefficients of this polynomial.
79    ///
80    /// This applies some kind of map to each coefficient, creating a new
81    /// polynomial.
82    pub fn translate<L>(&self, f: impl Fn(&K) -> L) -> Poly<L> {
83        Poly {
84            coeffs: self.coeffs.map(f),
85        }
86    }
87
88    /// Evaluate a polynomial at a particular point.
89    ///
90    /// For
91    ///
92    ///   `p(X) := a_0 + a_1 X + a_2 X^2 + ...`
93    ///
94    /// this returns:
95    ///
96    ///   `a_0 + a_1 r + a_2 r^2 + ...`
97    ///
98    /// This can work for any type which can multiply the coefficients of
99    /// this polynomial.
100    ///
101    /// For example, if you have a polynomial consistent of elements of a group,
102    /// you can evaluate it using a scalar over that group.
103    pub fn eval<R>(&self, r: &R) -> K
104    where
105        K: Space<R>,
106    {
107        let mut iter = self.coeffs.iter().rev();
108        // Evaluation using Horner's method.
109        //
110        // p(r)
111        // = a_0 + a_1 r + ... + a_n r^N =
112        // = a_n r^n + ...
113        // = ((a_n) r + a_(n - 1))r + ...)
114        let mut acc = iter
115            .next()
116            .expect("Impossible: Polynomial has no coefficients")
117            .clone();
118        for coeff in iter {
119            acc *= r;
120            acc += coeff;
121        }
122        acc
123    }
124
125    /// Like [`Self::eval`], but using [`Space::msm`].
126    ///
127    /// This method uses more scratch space, and requires cloning values of
128    /// type `R` more, but should be better if [`Space::msm`] has a better algorithm
129    /// for `K`.
130    pub fn eval_msm<R: Ring>(&self, r: &R) -> K
131    where
132        K: Space<R>,
133    {
134        // Contains 1, r, r^2, ...
135        let weights = {
136            let len = self.len_usize();
137            let mut out = Vec::with_capacity(len);
138            out.push(R::one());
139            let mut acc = R::one();
140            for _ in 1..len {
141                acc *= r;
142                out.push(acc.clone());
143            }
144            out
145        };
146        K::msm(&self.coeffs, &weights, 1)
147    }
148}
149
150impl<K: Debug> Debug for Poly<K> {
151    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
152        write!(f, "Poly(")?;
153        for (i, c) in self.coeffs.iter().enumerate() {
154            if i > 0 {
155                write!(f, " + {c:?} X^{i}")?;
156            } else {
157                write!(f, "{c:?}")?;
158            }
159        }
160        write!(f, ")")?;
161        Ok(())
162    }
163}
164
165impl<K: EncodeSize> EncodeSize for Poly<K> {
166    fn encode_size(&self) -> usize {
167        self.coeffs.encode_size()
168    }
169}
170
171impl<K: Write> Write for Poly<K> {
172    fn write(&self, buf: &mut impl bytes::BufMut) {
173        self.coeffs.write(buf);
174    }
175}
176
177impl<K: Read> Read for Poly<K> {
178    type Cfg = (RangeCfg<NonZeroU32>, <K as Read>::Cfg);
179
180    fn read_cfg(
181        buf: &mut impl bytes::Buf,
182        cfg: &Self::Cfg,
183    ) -> Result<Self, commonware_codec::Error> {
184        Ok(Self {
185            coeffs: NonEmptyVec::<K>::read_cfg(buf, &(cfg.0.into(), cfg.1.clone()))?,
186        })
187    }
188}
189
190impl<K: Random> Poly<K> {
191    // Returns a new polynomial of the given degree where each coefficient is
192    // sampled at random from the provided RNG.
193    pub fn new(mut rng: impl CryptoRngCore, degree: u32) -> Self {
194        Self::from_iter_unchecked((0..=degree).map(|_| K::random(&mut rng)))
195    }
196
197    /// Returns a new scalar polynomial with a particular value for the constant coefficient.
198    ///
199    /// This does the same thing as [`Poly::new`] otherwise.
200    pub fn new_with_constant(mut rng: impl CryptoRngCore, degree: u32, constant: K) -> Self {
201        Self::from_iter_unchecked(
202            iter::once(constant).chain((0..=degree).skip(1).map(|_| K::random(&mut rng))),
203        )
204    }
205}
206
207/// An equality test taking into account high 0 coefficients.
208///
209/// Without this behavior, the additive test suite does not past, because
210/// `x - x` may result in a polynomial with extra 0 coefficients.
211impl<K: Additive> PartialEq for Poly<K> {
212    fn eq(&self, other: &Self) -> bool {
213        let zero = K::zero();
214        let max_len = self.len().max(other.len());
215        let self_then_zeros = self.coeffs.iter().chain(iter::repeat(&zero));
216        let other_then_zeros = other.coeffs.iter().chain(iter::repeat(&zero));
217        self_then_zeros
218            .zip(other_then_zeros)
219            .take(max_len.get() as usize)
220            .all(|(a, b)| a == b)
221    }
222}
223
224impl<K: Additive> Eq for Poly<K> {}
225
226impl<K: Additive> Poly<K> {
227    fn merge_with(&mut self, rhs: &Self, f: impl Fn(&mut K, &K)) {
228        self.coeffs
229            .resize(self.coeffs.len().max(rhs.coeffs.len()), K::zero());
230        self.coeffs
231            .iter_mut()
232            .zip(&rhs.coeffs)
233            .for_each(|(a, b)| f(a, b));
234    }
235
236    /// Like [`Self::degree`], but checking for zero coefficients.
237    ///
238    /// This method is slower, but reports exact results in case there are zeros.
239    ///
240    /// This will return 0 for a polynomial with no coefficients.
241    pub fn degree_exact(&self) -> u32 {
242        let zero = K::zero();
243        let leading_zeroes = self.coeffs.iter().rev().take_while(|&x| x == &zero).count();
244        let lz_u32 =
245            u32::try_from(leading_zeroes).expect("Impossible: Polynomial has >= 2^32 coefficients");
246        // The saturation is critical, otherwise you get a negative number for
247        // the zero polynomial.
248        self.degree().saturating_sub(lz_u32)
249    }
250}
251
252impl<K: Additive> Object for Poly<K> {}
253
254// SECTION: implementing Additive
255
256impl<'a, K: Additive> AddAssign<&'a Poly<K>> for Poly<K> {
257    fn add_assign(&mut self, rhs: &'a Poly<K>) {
258        self.merge_with(rhs, |a, b| *a += b);
259    }
260}
261
262impl<'a, K: Additive> Add<&'a Poly<K>> for Poly<K> {
263    type Output = Self;
264
265    fn add(mut self, rhs: &'a Poly<K>) -> Self::Output {
266        self += rhs;
267        self
268    }
269}
270
271impl<'a, K: Additive> SubAssign<&'a Poly<K>> for Poly<K> {
272    fn sub_assign(&mut self, rhs: &'a Poly<K>) {
273        self.merge_with(rhs, |a, b| *a -= b);
274    }
275}
276
277impl<'a, K: Additive> Sub<&'a Poly<K>> for Poly<K> {
278    type Output = Self;
279
280    fn sub(mut self, rhs: &'a Poly<K>) -> Self::Output {
281        self -= rhs;
282        self
283    }
284}
285
286impl<K: Additive> Neg for Poly<K> {
287    type Output = Self;
288
289    fn neg(self) -> Self::Output {
290        Self {
291            coeffs: self.coeffs.map_into(Neg::neg),
292        }
293    }
294}
295
296impl<K: Additive> Additive for Poly<K> {
297    fn zero() -> Self {
298        Self {
299            coeffs: non_empty_vec![K::zero()],
300        }
301    }
302}
303
304// SECTION: implementing Space<K>.
305
306impl<'a, R, K: Space<R>> MulAssign<&'a R> for Poly<K> {
307    fn mul_assign(&mut self, rhs: &'a R) {
308        self.coeffs.iter_mut().for_each(|c| *c *= rhs);
309    }
310}
311
312impl<'a, R, K: Space<R>> Mul<&'a R> for Poly<K> {
313    type Output = Self;
314
315    fn mul(mut self, rhs: &'a R) -> Self::Output {
316        self *= rhs;
317        self
318    }
319}
320
321#[cfg(feature = "std")]
322impl<R: Sync, K: Space<R>> Space<R> for Poly<K> {
323    fn msm(polys: &[Self], scalars: &[R], concurrency: usize) -> Self {
324        if polys.len() < MIN_POINTS_FOR_MSM {
325            return msm_naive(polys, scalars);
326        }
327
328        let cols = polys.len().min(scalars.len());
329        let polys = &polys[..cols];
330        let scalars = &scalars[..cols];
331
332        let rows = polys
333            .iter()
334            .map(|x| x.len_usize())
335            .max()
336            .expect("at least 1 point");
337
338        if concurrency > 1 {
339            let pool = ThreadPoolBuilder::new()
340                .num_threads(concurrency)
341                .build()
342                .expect("Unable to build thread pool");
343
344            let coeffs = pool.install(|| {
345                (0..rows)
346                    .into_par_iter()
347                    .map(|i| {
348                        let row: Vec<_> = polys
349                            .iter()
350                            .map(|p| p.coeffs.get(i).cloned().unwrap_or_else(K::zero))
351                            .collect();
352                        K::msm(&row, scalars, 1)
353                    })
354                    .collect::<Vec<_>>()
355            });
356            return Poly::from_iter_unchecked(coeffs);
357        }
358
359        let mut row = Vec::with_capacity(cols);
360        let coeffs = (0..rows).map(|i| {
361            row.clear();
362            for p in polys {
363                row.push(p.coeffs.get(i).cloned().unwrap_or_else(K::zero));
364            }
365            K::msm(&row, scalars, concurrency)
366        });
367
368        Poly::from_iter_unchecked(coeffs)
369    }
370}
371
372#[cfg(not(feature = "std"))]
373impl<R, K: Space<R>> Space<R> for Poly<K> {
374    fn msm(polys: &[Self], scalars: &[R], concurrency: usize) -> Self {
375        if polys.len() < MIN_POINTS_FOR_MSM {
376            return msm_naive(polys, scalars);
377        }
378
379        let cols = polys.len().min(scalars.len());
380        let polys = &polys[..cols];
381        let scalars = &scalars[..cols];
382
383        let rows = polys
384            .iter()
385            .map(|x| x.len_usize())
386            .max()
387            .expect("at least 1 point");
388
389        let mut row = Vec::with_capacity(cols);
390        let coeffs = (0..rows).map(|i| {
391            row.clear();
392            for p in polys {
393                row.push(p.coeffs.get(i).cloned().unwrap_or_else(K::zero));
394            }
395            K::msm(&row, scalars, concurrency)
396        });
397        Poly::from_iter_unchecked(coeffs)
398    }
399}
400
401impl<G: CryptoGroup> Poly<G> {
402    /// Commit to a polynomial of scalars, producing a polynomial of group elements.
403    pub fn commit(p: Poly<G::Scalar>) -> Self {
404        p.translate(|c| G::generator() * c)
405    }
406}
407
408/// An interpolator allows recovering a polynomial's constant from values.
409///
410/// This is useful for polynomial secret sharing. There, a secret is stored
411/// in the constant of a polynomial. Shares of the secret are created by
412/// evaluating the polynomial at various points. Given enough values for
413/// these points, the secret can be recovered.
414///
415/// Using an [`Interpolator`] can be more efficient, because work can be
416/// done in advance based only on the points that will be used for recovery,
417/// before the value of the polynomial at these points is known. The interpolator
418/// can use these values to recover the secret at a later time.
419///
420/// ### Usage
421///
422/// ```
423/// # use commonware_math::{fields::goldilocks::F, poly::{Poly, Interpolator}};
424/// # use commonware_utils::TryCollect;
425/// # fn example(f: Poly<F>, g: Poly<F>, p0: F, p1: F) {
426///     let interpolator = Interpolator::new([(0, p0), (1, p1)]);
427///     assert_eq!(
428///         Some(*f.constant()),
429///         interpolator.interpolate(&[(0, f.eval(&p0)), (1, f.eval(&p1))].into_iter().try_collect().unwrap(), 1)
430///     );
431///     assert_eq!(
432///         Some(*g.constant()),
433///         interpolator.interpolate(&[(1, g.eval(&p1)), (0, g.eval(&p0))].into_iter().try_collect().unwrap(), 1)
434///     );
435/// # }
436/// ```
437pub struct Interpolator<I, F> {
438    weights: Map<I, F>,
439}
440
441impl<I: PartialEq, F: Ring> Interpolator<I, F> {
442    /// Interpolate a polynomial's evaluations to recover its constant.
443    ///
444    /// The indices provided here MUST match those provided to [`Self::new`] exactly,
445    /// otherwise `None` will be returned.
446    pub fn interpolate<K: Space<F>>(&self, evals: &Map<I, K>, concurrency: usize) -> Option<K> {
447        if evals.keys() != self.weights.keys() {
448            return None;
449        }
450        Some(K::msm(evals.values(), self.weights.values(), concurrency))
451    }
452}
453
454impl<I: Clone + Ord, F: Field> Interpolator<I, F> {
455    /// Create a new interpolator, given an association from indices to evaluation points.
456    ///
457    /// If an index appears multiple times, the implementation is free to use
458    /// any one of the evaluation points associated with that index. In other words,
459    /// don't do that, or ensure that if, for some reason, an index appears more
460    /// than once, then it has the same evaluation point.
461    pub fn new(points: impl IntoIterator<Item = (I, F)>) -> Self {
462        let points = Map::from_iter_dedup(points);
463        let weights = points
464            .iter_pairs()
465            .map(|(i, w_i)| {
466                let mut top_i = F::one();
467                let mut bot_i = F::one();
468                for (j, w_j) in points.iter_pairs() {
469                    if i == j {
470                        continue;
471                    }
472                    top_i *= w_j;
473                    bot_i *= &(w_j.clone() - w_i);
474                }
475                top_i * &bot_i.inv()
476            })
477            .collect::<Vec<_>>();
478        // Avoid re-sorting by using the memory of points.
479        let mut out = points;
480        for (out_i, weight_i) in out.values_mut().iter_mut().zip(weights.into_iter()) {
481            *out_i = weight_i;
482        }
483        Self { weights: out }
484    }
485}
486
487#[cfg(feature = "arbitrary")]
488mod fuzz {
489    use super::*;
490    use arbitrary::Arbitrary;
491
492    impl<'a, F: Arbitrary<'a>> Arbitrary<'a> for Poly<F> {
493        fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
494            Ok(Self {
495                coeffs: u.arbitrary()?,
496            })
497        }
498    }
499}
500
501#[cfg(test)]
502mod test {
503    use super::*;
504    use crate::test::{F, G};
505    use commonware_codec::Encode;
506    use proptest::{
507        prelude::{Arbitrary, BoxedStrategy, Strategy},
508        prop_assume, proptest,
509        sample::SizeRange,
510    };
511
512    impl Arbitrary for Poly<F> {
513        type Parameters = SizeRange;
514        type Strategy = BoxedStrategy<Self>;
515
516        fn arbitrary_with(size: Self::Parameters) -> Self::Strategy {
517            let nonempty_size = if size.start() == 0 { size + 1 } else { size };
518            proptest::collection::vec(F::arbitrary(), nonempty_size)
519                .prop_map(Poly::from_iter_unchecked)
520                .boxed()
521        }
522    }
523
524    #[test]
525    fn test_additive() {
526        crate::algebra::test_suites::test_additive(file!(), &Poly::<F>::arbitrary());
527    }
528
529    #[test]
530    fn test_space() {
531        crate::algebra::test_suites::test_space_ring(
532            file!(),
533            &F::arbitrary(),
534            &Poly::<F>::arbitrary(),
535        );
536    }
537
538    #[test]
539    fn test_eq() {
540        fn eq(a: &[u8], b: &[u8]) -> bool {
541            Poly {
542                coeffs: a.iter().copied().map(F::from).try_collect().unwrap(),
543            } == Poly {
544                coeffs: b.iter().copied().map(F::from).try_collect().unwrap(),
545            }
546        }
547        assert!(eq(&[1, 2], &[1, 2]));
548        assert!(!eq(&[1, 2], &[2, 3]));
549        assert!(!eq(&[1, 2], &[1, 2, 3]));
550        assert!(!eq(&[1, 2, 3], &[1, 2]));
551        assert!(eq(&[1, 2], &[1, 2, 0, 0]));
552        assert!(eq(&[1, 2, 0, 0], &[1, 2]));
553        assert!(!eq(&[1, 2, 0], &[2, 3]));
554        assert!(!eq(&[2, 3], &[1, 2, 0]));
555    }
556
557    proptest! {
558        #[test]
559        fn test_codec(f: Poly<F>) {
560            assert_eq!(&f, &Poly::<F>::read_cfg(&mut f.encode(), &(RangeCfg::exact(f.required()), ())).unwrap())
561        }
562
563        #[test]
564        fn test_eval_add(f: Poly<F>, g: Poly<F>, x: F) {
565            assert_eq!(f.eval(&x) + &g.eval(&x), (f + &g).eval(&x));
566        }
567
568        #[test]
569        fn test_eval_scale(f: Poly<F>, x: F, w: F) {
570            assert_eq!(f.eval(&x) * &w, (f * &w).eval(&x));
571        }
572
573        #[test]
574        fn test_eval_zero(f: Poly<F>) {
575            assert_eq!(&f.eval(&F::zero()), f.constant());
576        }
577
578        #[test]
579        fn test_eval_msm(f: Poly<F>, x: F) {
580            assert_eq!(f.eval(&x), f.eval_msm(&x));
581        }
582
583        #[test]
584        fn test_interpolate(f: Poly<F>) {
585            // Make sure this isn't the zero polynomial.
586            prop_assume!(f != Poly::zero());
587            prop_assume!(f.required().get() < F::MAX as u32);
588            let mut points = (0..f.required().get()).map(|i| F::from((i + 1) as u8)).collect::<Vec<_>>();
589            let interpolator = Interpolator::new(points.iter().copied().enumerate());
590            let evals = Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate());
591            let recovered = interpolator.interpolate(&evals, 1);
592            assert_eq!(recovered.as_ref(), Some(f.constant()));
593            points.pop();
594            assert!(interpolator.interpolate(&Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate()), 1).is_none());
595        }
596
597        #[test]
598        fn test_translate_scale(f: Poly<F>, x: F) {
599            assert_eq!(f.translate(|c| x * c), f * &x);
600        }
601
602        #[test]
603        fn test_commit_eval(f: Poly<F>, x: F) {
604            assert_eq!(G::generator() * &f.eval(&x), Poly::<G>::commit(f).eval(&x));
605        }
606    }
607
608    #[cfg(feature = "arbitrary")]
609    mod conformance {
610        use super::*;
611        use commonware_codec::conformance::CodecConformance;
612
613        commonware_conformance::conformance_tests! {
614            CodecConformance<Poly<F>>
615        }
616    }
617}