1use crate::algebra::{
2 msm_naive, powers, Additive, CryptoGroup, Field, Object, Random, Ring, Space,
3};
4#[cfg(not(feature = "std"))]
5use alloc::{borrow::Cow, vec, vec::Vec};
6use commonware_codec::{EncodeSize, RangeCfg, Read, Write};
7use commonware_parallel::Strategy;
8use commonware_utils::{non_empty_vec, ordered::Map, vec::NonEmptyVec, TryCollect};
9use core::{
10 fmt::Debug,
11 iter,
12 num::NonZeroU32,
13 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
14};
15use rand_core::CryptoRngCore;
16#[cfg(feature = "std")]
17use std::borrow::Cow;
18
19const MIN_POINTS_FOR_MSM: usize = 2;
21
22#[derive(Clone)]
24pub struct Poly<K> {
25 coeffs: NonEmptyVec<K>,
27}
28
29impl<K> Poly<K> {
30 fn len(&self) -> NonZeroU32 {
31 self.coeffs
32 .len()
33 .try_into()
34 .expect("Impossible: polynomial length not in 1..=u32::MAX")
35 }
36
37 const fn len_usize(&self) -> usize {
38 self.coeffs.len().get()
39 }
40
41 fn from_iter_unchecked(iter: impl IntoIterator<Item = K>) -> Self {
46 let coeffs = iter
47 .into_iter()
48 .try_collect::<NonEmptyVec<_>>()
49 .expect("polynomial must have a least 1 coefficient");
50 Self { coeffs }
51 }
52
53 pub fn degree(&self) -> u32 {
64 self.len().get() - 1
65 }
66
67 pub fn required(&self) -> NonZeroU32 {
71 self.len()
72 }
73
74 pub fn constant(&self) -> &K {
78 &self.coeffs[0]
79 }
80
81 pub fn translate<L>(&self, f: impl Fn(&K) -> L) -> Poly<L> {
86 Poly {
87 coeffs: self.coeffs.map(f),
88 }
89 }
90
91 pub fn eval<R>(&self, r: &R) -> K
107 where
108 K: Space<R>,
109 {
110 let mut iter = self.coeffs.iter().rev();
111 let mut acc = iter
118 .next()
119 .expect("Impossible: Polynomial has no coefficients")
120 .clone();
121 for coeff in iter {
122 acc *= r;
123 acc += coeff;
124 }
125 acc
126 }
127
128 pub fn eval_msm<R: Ring>(&self, r: &R, strategy: &impl Strategy) -> K
134 where
135 K: Space<R>,
136 {
137 let weights = powers(R::one(), r)
139 .take(self.len_usize())
140 .collect::<Vec<_>>();
141 K::msm(&self.coeffs, &weights, strategy)
142 }
143
144 pub fn lin_comb_eval<'a, R: Ring + 'a>(
151 &self,
152 into_iter: impl IntoIterator<Item = (R, Cow<'a, R>)>,
153 strategy: &impl Strategy,
154 ) -> K
155 where
156 K: Space<R>,
157 {
158 let weights = {
160 let mut iter = into_iter.into_iter();
161 let Some((a0, b0)) = iter.next() else {
162 return K::zero();
163 };
164
165 let len = self.len_usize();
166 let mut out: Vec<_> = powers(a0, b0.as_ref()).take(len).collect();
167 for (ai, bi) in iter {
168 powers(ai, bi.as_ref())
169 .take(len)
170 .zip(out.iter_mut())
171 .for_each(|(c_j, o_j)| *o_j += &c_j);
172 }
173 out
174 };
175 K::msm(&self.coeffs, &weights, strategy)
176 }
177}
178
179impl<K: Debug> Debug for Poly<K> {
180 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
181 write!(f, "Poly(")?;
182 for (i, c) in self.coeffs.iter().enumerate() {
183 if i > 0 {
184 write!(f, " + {c:?} X^{i}")?;
185 } else {
186 write!(f, "{c:?}")?;
187 }
188 }
189 write!(f, ")")?;
190 Ok(())
191 }
192}
193
194impl<K: EncodeSize> EncodeSize for Poly<K> {
195 fn encode_size(&self) -> usize {
196 self.coeffs.encode_size()
197 }
198}
199
200impl<K: Write> Write for Poly<K> {
201 fn write(&self, buf: &mut impl bytes::BufMut) {
202 self.coeffs.write(buf);
203 }
204}
205
206impl<K: Read> Read for Poly<K> {
207 type Cfg = (RangeCfg<NonZeroU32>, <K as Read>::Cfg);
208
209 fn read_cfg(
210 buf: &mut impl bytes::Buf,
211 cfg: &Self::Cfg,
212 ) -> Result<Self, commonware_codec::Error> {
213 Ok(Self {
214 coeffs: NonEmptyVec::<K>::read_cfg(buf, &(cfg.0.into(), cfg.1.clone()))?,
215 })
216 }
217}
218
219impl<K: Random> Poly<K> {
220 pub fn new(mut rng: impl CryptoRngCore, degree: u32) -> Self {
223 Self::from_iter_unchecked((0..=degree).map(|_| K::random(&mut rng)))
224 }
225
226 pub fn new_with_constant(mut rng: impl CryptoRngCore, degree: u32, constant: K) -> Self {
230 Self::from_iter_unchecked(
231 iter::once(constant).chain((0..=degree).skip(1).map(|_| K::random(&mut rng))),
232 )
233 }
234}
235
236impl<K: Additive> PartialEq for Poly<K> {
241 fn eq(&self, other: &Self) -> bool {
242 let zero = K::zero();
243 let max_len = self.len().max(other.len());
244 let self_then_zeros = self.coeffs.iter().chain(iter::repeat(&zero));
245 let other_then_zeros = other.coeffs.iter().chain(iter::repeat(&zero));
246 self_then_zeros
247 .zip(other_then_zeros)
248 .take(max_len.get() as usize)
249 .all(|(a, b)| a == b)
250 }
251}
252
253impl<K: Additive> Eq for Poly<K> {}
254
255impl<K: Additive> Poly<K> {
256 fn merge_with(&mut self, rhs: &Self, f: impl Fn(&mut K, &K)) {
257 self.coeffs
258 .resize(self.coeffs.len().max(rhs.coeffs.len()), K::zero());
259 self.coeffs
260 .iter_mut()
261 .zip(&rhs.coeffs)
262 .for_each(|(a, b)| f(a, b));
263 }
264
265 pub fn degree_exact(&self) -> u32 {
271 let zero = K::zero();
272 let leading_zeroes = self.coeffs.iter().rev().take_while(|&x| x == &zero).count();
273 let lz_u32 =
274 u32::try_from(leading_zeroes).expect("Impossible: Polynomial has >= 2^32 coefficients");
275 self.degree().saturating_sub(lz_u32)
278 }
279}
280
281impl<K: Additive> Object for Poly<K> {}
282
283impl<'a, K: Additive> AddAssign<&'a Self> for Poly<K> {
286 fn add_assign(&mut self, rhs: &'a Self) {
287 self.merge_with(rhs, |a, b| *a += b);
288 }
289}
290
291impl<'a, K: Additive> Add<&'a Self> for Poly<K> {
292 type Output = Self;
293
294 fn add(mut self, rhs: &'a Self) -> Self::Output {
295 self += rhs;
296 self
297 }
298}
299
300impl<'a, K: Additive> SubAssign<&'a Self> for Poly<K> {
301 fn sub_assign(&mut self, rhs: &'a Self) {
302 self.merge_with(rhs, |a, b| *a -= b);
303 }
304}
305
306impl<'a, K: Additive> Sub<&'a Self> for Poly<K> {
307 type Output = Self;
308
309 fn sub(mut self, rhs: &'a Self) -> Self::Output {
310 self -= rhs;
311 self
312 }
313}
314
315impl<K: Additive> Neg for Poly<K> {
316 type Output = Self;
317
318 fn neg(self) -> Self::Output {
319 Self {
320 coeffs: self.coeffs.map_into(Neg::neg),
321 }
322 }
323}
324
325impl<K: Additive> Additive for Poly<K> {
326 fn zero() -> Self {
327 Self {
328 coeffs: non_empty_vec![K::zero()],
329 }
330 }
331}
332
333impl<'a, R, K: Space<R>> MulAssign<&'a R> for Poly<K> {
336 fn mul_assign(&mut self, rhs: &'a R) {
337 self.coeffs.iter_mut().for_each(|c| *c *= rhs);
338 }
339}
340
341impl<'a, R, K: Space<R>> Mul<&'a R> for Poly<K> {
342 type Output = Self;
343
344 fn mul(mut self, rhs: &'a R) -> Self::Output {
345 self *= rhs;
346 self
347 }
348}
349
350impl<R: Sync, K: Space<R> + Send> Space<R> for Poly<K> {
351 fn msm(polys: &[Self], scalars: &[R], strategy: &impl Strategy) -> Self {
352 if polys.len() < MIN_POINTS_FOR_MSM {
353 return msm_naive(polys, scalars);
354 }
355
356 let cols = polys.len().min(scalars.len());
357 let polys = &polys[..cols];
358 let scalars = &scalars[..cols];
359
360 let rows = polys
361 .iter()
362 .map(|x| x.len_usize())
363 .max()
364 .expect("at least 1 point");
365
366 let coeffs = strategy.map_init_collect_vec(
367 0..rows,
368 || Vec::with_capacity(cols),
369 |row, i| {
370 row.clear();
371 for p in polys {
372 row.push(p.coeffs.get(i).cloned().unwrap_or_else(K::zero));
373 }
374 K::msm(row, scalars, strategy)
375 },
376 );
377 Self::from_iter_unchecked(coeffs)
378 }
379}
380
381impl<G: CryptoGroup> Poly<G> {
382 pub fn commit(p: Poly<G::Scalar>) -> Self {
384 p.translate(|c| G::generator() * c)
385 }
386}
387
388pub struct Interpolator<I, F> {
419 weights: Map<I, F>,
420}
421
422impl<I: PartialEq, F: Ring> Interpolator<I, F> {
423 pub fn interpolate<K: Space<F>>(
428 &self,
429 evals: &Map<I, K>,
430 strategy: &impl Strategy,
431 ) -> Option<K> {
432 if evals.keys() != self.weights.keys() {
433 return None;
434 }
435 Some(K::msm(evals.values(), self.weights.values(), strategy))
436 }
437}
438
439impl<I: Clone + Ord, F: Field> Interpolator<I, F> {
440 pub fn new(points: impl IntoIterator<Item = (I, F)>) -> Self {
447 let points = Map::from_iter_dedup(points);
448 let n = points.len();
449 if n == 0 {
450 return Self { weights: points };
451 }
452
453 let values = points.values();
456 let zero = F::zero();
457 let mut total_product = F::one();
458 let mut c = Vec::with_capacity(n);
459 for (i, w_i) in values.iter().enumerate() {
460 if w_i == &zero {
462 let mut out = points;
463 for (j, w) in out.values_mut().iter_mut().enumerate() {
464 *w = if j == i { F::one() } else { F::zero() };
465 }
466 return Self { weights: out };
467 }
468
469 total_product *= w_i;
471 let mut c_i = w_i.clone();
472 for w_j in values
473 .iter()
474 .enumerate()
475 .filter_map(|(j, v)| (j != i).then_some(v))
476 {
477 c_i *= &(w_j.clone() - w_i);
478 }
479 c.push(c_i);
480 }
481
482 let mut prefix = Vec::with_capacity(n + 1);
485 prefix.push(F::one());
486 let mut acc = F::one();
487 for c_i in &c {
488 acc *= c_i;
489 prefix.push(acc.clone());
490 }
491
492 let mut inv_acc = total_product * &prefix[n].inv();
494
495 let mut out = points;
497 let out_vals = out.values_mut();
498 for i in (0..n).rev() {
499 out_vals[i] = inv_acc.clone() * &prefix[i];
500 inv_acc *= &c[i];
501 }
502 Self { weights: out }
503 }
504}
505
506#[commonware_macros::stability(ALPHA)]
507impl<I: Clone + Ord, F: crate::algebra::FieldNTT> Interpolator<I, F> {
508 pub fn roots_of_unity(
517 total: NonZeroU32,
518 points: commonware_utils::ordered::BiMap<I, u32>,
519 ) -> Self {
520 let weights = <Map<I, F> as commonware_utils::TryFromIterator<(I, F)>>::try_from_iter(
521 crate::ntt::lagrange_coefficients(total, points.values().iter().copied())
522 .into_iter()
523 .filter_map(|(k, coeff)| Some((points.get_key(&k)?.clone(), coeff))),
524 )
525 .expect("points has already been deduped");
526 Self { weights }
527 }
528
529 #[cfg(any(test, feature = "fuzz"))]
536 fn roots_of_unity_naive(
537 total: NonZeroU32,
538 points: commonware_utils::ordered::BiMap<I, u32>,
539 ) -> Self {
540 use crate::algebra::powers;
541
542 let total_u32 = total.get();
543 let size = (total_u32 as u64).next_power_of_two();
544 let lg_size = size.ilog2() as u8;
545 let w = F::root_of_unity(lg_size).expect("domain too large for NTT");
546
547 let points: Vec<(I, u32)> = points.into_iter().filter(|(_, k)| *k < total_u32).collect();
548 let max_k = points.iter().map(|(_, k)| *k).max().unwrap_or(0) as usize;
549 let powers: Vec<_> = powers(F::one(), &w).take(max_k + 1).collect();
550
551 let eval_points = points
552 .into_iter()
553 .map(|(i, k)| (i, powers[k as usize].clone()));
554 Self::new(eval_points)
555 }
556}
557
558#[cfg(any(test, feature = "arbitrary"))]
559mod impl_arbitrary {
560 use super::*;
561 use arbitrary::Arbitrary;
562
563 impl<'a, F: Arbitrary<'a>> Arbitrary<'a> for Poly<F> {
564 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
565 let first = u.arbitrary()?;
566 let rest: Vec<F> = u.arbitrary()?;
567 let mut coeffs = NonEmptyVec::new(first);
568 coeffs.extend(rest);
569 Ok(Self { coeffs })
570 }
571 }
572}
573
574#[commonware_macros::stability(ALPHA)]
575#[cfg(any(test, feature = "fuzz"))]
576pub mod fuzz {
577 use super::*;
578 use crate::{
579 algebra::test_suites,
580 test::{F, G},
581 };
582 use arbitrary::{Arbitrary, Unstructured};
583 use commonware_codec::Encode as _;
584 use commonware_parallel::Sequential;
585 use commonware_utils::{
586 ordered::{BiMap, Map},
587 TryFromIterator,
588 };
589
590 #[derive(Debug, Arbitrary)]
591 pub enum Plan {
592 Codec(Poly<F>),
593 EvalAdd(Poly<F>, Poly<F>, F),
594 EvalScale(Poly<F>, F, F),
595 EvalZero(Poly<F>),
596 EvalMsm(Poly<F>, F),
597 LinCombEval(Poly<F>, Vec<(F, F)>),
598 Interpolate(Poly<F>),
599 InterpolateWithZeroPoint(Poly<F>),
600 InterpolateWithZeroPointMiddle(Poly<F>),
601 TranslateScale(Poly<F>, F),
602 CommitEval(Poly<F>, F),
603 RootsOfUnityEqNaive(u16),
604 FuzzAdditive,
605 FuzzSpaceRing,
606 }
607
608 impl Plan {
609 pub fn run(self, u: &mut Unstructured<'_>) -> arbitrary::Result<()> {
610 match self {
611 Self::Codec(f) => {
612 assert_eq!(
613 &f,
614 &Poly::<F>::read_cfg(&mut f.encode(), &(RangeCfg::exact(f.required()), ()))
615 .unwrap()
616 );
617 }
618 Self::EvalAdd(f, g, x) => {
619 assert_eq!(f.eval(&x) + &g.eval(&x), (f + &g).eval(&x));
620 }
621 Self::EvalScale(f, x, w) => {
622 assert_eq!(f.eval(&x) * &w, (f * &w).eval(&x));
623 }
624 Self::EvalZero(f) => {
625 assert_eq!(&f.eval(&F::zero()), f.constant());
626 }
627 Self::EvalMsm(f, x) => {
628 assert_eq!(f.eval(&x), f.eval_msm(&x, &Sequential));
629 }
630 Self::LinCombEval(f, pairs) => {
631 let naive_eval = pairs.iter().fold(F::zero(), |mut acc, (a, b)| {
632 acc += &(*a * &f.eval(b));
633 acc
634 });
635 let lin_comb = f.lin_comb_eval(
636 pairs.iter().map(|(a, b)| (*a, Cow::Borrowed(b))),
637 &Sequential,
638 );
639 assert_eq!(naive_eval, lin_comb);
640 }
641 Self::Interpolate(f) => {
642 if f == Poly::zero() || f.required().get() >= F::MAX as u32 {
643 return Ok(());
644 }
645 let mut points = (0..f.required().get())
646 .map(|i| F::from((i + 1) as u8))
647 .collect::<Vec<_>>();
648 let interpolator = Interpolator::new(points.iter().copied().enumerate());
649 let evals = Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate());
650 let recovered = interpolator.interpolate(&evals, &Sequential);
651 assert_eq!(recovered.as_ref(), Some(f.constant()));
652 points.pop();
653 assert_eq!(
654 interpolator.interpolate(
655 &Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate()),
656 &Sequential
657 ),
658 None
659 );
660 }
661 Self::InterpolateWithZeroPoint(f) => {
662 if f == Poly::zero() || f.required().get() >= F::MAX as u32 {
663 return Ok(());
664 }
665 let points: Vec<_> =
666 (0..f.required().get()).map(|i| F::from(i as u8)).collect();
667 let interpolator = Interpolator::new(points.iter().copied().enumerate());
668 let evals = Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate());
669 let recovered = interpolator.interpolate(&evals, &Sequential);
670 assert_eq!(recovered.as_ref(), Some(f.constant()));
671 }
672 Self::InterpolateWithZeroPointMiddle(f) => {
673 if f == Poly::zero()
674 || f.required().get() < 2
675 || f.required().get() >= F::MAX as u32
676 {
677 return Ok(());
678 }
679 let n = f.required().get();
680 let points: Vec<_> = (1..n)
681 .map(|i| F::from(i as u8))
682 .chain(core::iter::once(F::zero()))
683 .collect();
684 let interpolator = Interpolator::new(points.iter().copied().enumerate());
685 let evals = Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate());
686 let recovered = interpolator.interpolate(&evals, &Sequential);
687 assert_eq!(recovered.as_ref(), Some(f.constant()));
688 }
689 Self::TranslateScale(f, x) => {
690 assert_eq!(f.translate(|c| x * c), f * &x);
691 }
692 Self::CommitEval(f, x) => {
693 assert_eq!(G::generator() * &f.eval(&x), Poly::<G>::commit(f).eval(&x));
694 }
695 Self::RootsOfUnityEqNaive(n) => {
696 let n = (u32::from(n) % 256) + 1;
697 let total = NonZeroU32::new(n).expect("n is in 1..=256");
698 let points = BiMap::try_from_iter((0..n as usize).map(|i| (i, i as u32)))
699 .expect("interpolation points should be bijective");
700 let fast = Interpolator::<usize, crate::fields::goldilocks::F>::roots_of_unity(
701 total,
702 points.clone(),
703 );
704 let naive =
705 Interpolator::<usize, crate::fields::goldilocks::F>::roots_of_unity_naive(
706 total, points,
707 );
708 assert_eq!(fast.weights, naive.weights);
709 }
710 Self::FuzzAdditive => {
711 test_suites::fuzz_additive::<Poly<F>>(u)?;
712 }
713 Self::FuzzSpaceRing => {
714 test_suites::fuzz_space_ring::<F, Poly<F>>(u)?;
715 }
716 }
717 Ok(())
718 }
719 }
720
721 #[test]
722 fn test_fuzz() {
723 commonware_invariants::minifuzz::test(|u| u.arbitrary::<Plan>()?.run(u));
724 }
725}
726#[cfg(test)]
727mod test {
728 use super::{fuzz::Plan, *};
729 use crate::test::F;
730 use arbitrary::Unstructured;
731
732 #[test]
733 fn test_eq() {
734 fn eq(a: &[u8], b: &[u8]) -> bool {
735 Poly {
736 coeffs: a.iter().copied().map(F::from).try_collect().unwrap(),
737 } == Poly {
738 coeffs: b.iter().copied().map(F::from).try_collect().unwrap(),
739 }
740 }
741 assert!(eq(&[1, 2], &[1, 2]));
742 assert!(!eq(&[1, 2], &[2, 3]));
743 assert!(!eq(&[1, 2], &[1, 2, 3]));
744 assert!(!eq(&[1, 2, 3], &[1, 2]));
745 assert!(eq(&[1, 2], &[1, 2, 0, 0]));
746 assert!(eq(&[1, 2, 0, 0], &[1, 2]));
747 assert!(!eq(&[1, 2, 0], &[2, 3]));
748 assert!(!eq(&[2, 3], &[1, 2, 0]));
749 }
750
751 #[test]
752 fn lin_comb_eval_edge_cases() {
753 fn poly(coeffs: &[u8]) -> Poly<F> {
754 Poly {
755 coeffs: coeffs.iter().copied().map(F::from).try_collect().unwrap(),
756 }
757 }
758
759 fn pairs(values: &[(u8, u8)]) -> Vec<(F, F)> {
760 values
761 .iter()
762 .map(|(a, b)| (F::from(*a), F::from(*b)))
763 .collect()
764 }
765
766 let cases = [
767 Plan::LinCombEval(poly(&[3, 5, 7]), vec![]),
768 Plan::LinCombEval(poly(&[11]), pairs(&[(2, 0), (3, 1), (5, 8)])),
769 Plan::LinCombEval(poly(&[4, 6, 8]), pairs(&[(2, 5), (7, 5), (3, 5)])),
770 Plan::LinCombEval(poly(&[9, 2, 3, 4]), pairs(&[(6, 0), (1, 0), (5, 7)])),
771 Plan::LinCombEval(poly(&[1, 2, 4, 8]), pairs(&[(3, 1), (7, 1), (2, 6)])),
772 ];
773 let mut u = Unstructured::new(&[]);
774 for case in cases {
775 case.run(&mut u).unwrap();
776 }
777 }
778
779 #[cfg(feature = "arbitrary")]
780 mod conformance {
781 use super::*;
782 use commonware_codec::conformance::CodecConformance;
783
784 commonware_conformance::conformance_tests! {
785 CodecConformance<Poly<F>>
786 }
787 }
788}