1use crate::algebra::{msm_naive, Additive, CryptoGroup, Field, Object, Random, Ring, Space};
2#[cfg(not(feature = "std"))]
3use alloc::{vec, vec::Vec};
4use commonware_codec::{EncodeSize, RangeCfg, Read, Write};
5use commonware_parallel::Strategy;
6use commonware_utils::{non_empty_vec, ordered::Map, vec::NonEmptyVec, TryCollect};
7use core::{
8 fmt::Debug,
9 iter,
10 num::NonZeroU32,
11 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
12};
13use rand_core::CryptoRngCore;
14
15const MIN_POINTS_FOR_MSM: usize = 2;
17
18#[derive(Clone)]
20pub struct Poly<K> {
21 coeffs: NonEmptyVec<K>,
23}
24
25impl<K> Poly<K> {
26 fn len(&self) -> NonZeroU32 {
27 self.coeffs
28 .len()
29 .try_into()
30 .expect("Impossible: polynomial length not in 1..=u32::MAX")
31 }
32
33 const fn len_usize(&self) -> usize {
34 self.coeffs.len().get()
35 }
36
37 fn from_iter_unchecked(iter: impl IntoIterator<Item = K>) -> Self {
42 let coeffs = iter
43 .into_iter()
44 .try_collect::<NonEmptyVec<_>>()
45 .expect("polynomial must have a least 1 coefficient");
46 Self { coeffs }
47 }
48
49 pub fn degree(&self) -> u32 {
60 self.len().get() - 1
61 }
62
63 pub fn required(&self) -> NonZeroU32 {
67 self.len()
68 }
69
70 pub fn constant(&self) -> &K {
74 &self.coeffs[0]
75 }
76
77 pub fn translate<L>(&self, f: impl Fn(&K) -> L) -> Poly<L> {
82 Poly {
83 coeffs: self.coeffs.map(f),
84 }
85 }
86
87 pub fn eval<R>(&self, r: &R) -> K
103 where
104 K: Space<R>,
105 {
106 let mut iter = self.coeffs.iter().rev();
107 let mut acc = iter
114 .next()
115 .expect("Impossible: Polynomial has no coefficients")
116 .clone();
117 for coeff in iter {
118 acc *= r;
119 acc += coeff;
120 }
121 acc
122 }
123
124 pub fn eval_msm<R: Ring>(&self, r: &R, strategy: &impl Strategy) -> K
130 where
131 K: Space<R>,
132 {
133 let weights = {
135 let len = self.len_usize();
136 let mut out = Vec::with_capacity(len);
137 out.push(R::one());
138 let mut acc = R::one();
139 for _ in 1..len {
140 acc *= r;
141 out.push(acc.clone());
142 }
143 out
144 };
145 K::msm(&self.coeffs, &weights, strategy)
146 }
147}
148
149impl<K: Debug> Debug for Poly<K> {
150 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
151 write!(f, "Poly(")?;
152 for (i, c) in self.coeffs.iter().enumerate() {
153 if i > 0 {
154 write!(f, " + {c:?} X^{i}")?;
155 } else {
156 write!(f, "{c:?}")?;
157 }
158 }
159 write!(f, ")")?;
160 Ok(())
161 }
162}
163
164impl<K: EncodeSize> EncodeSize for Poly<K> {
165 fn encode_size(&self) -> usize {
166 self.coeffs.encode_size()
167 }
168}
169
170impl<K: Write> Write for Poly<K> {
171 fn write(&self, buf: &mut impl bytes::BufMut) {
172 self.coeffs.write(buf);
173 }
174}
175
176impl<K: Read> Read for Poly<K> {
177 type Cfg = (RangeCfg<NonZeroU32>, <K as Read>::Cfg);
178
179 fn read_cfg(
180 buf: &mut impl bytes::Buf,
181 cfg: &Self::Cfg,
182 ) -> Result<Self, commonware_codec::Error> {
183 Ok(Self {
184 coeffs: NonEmptyVec::<K>::read_cfg(buf, &(cfg.0.into(), cfg.1.clone()))?,
185 })
186 }
187}
188
189impl<K: Random> Poly<K> {
190 pub fn new(mut rng: impl CryptoRngCore, degree: u32) -> Self {
193 Self::from_iter_unchecked((0..=degree).map(|_| K::random(&mut rng)))
194 }
195
196 pub fn new_with_constant(mut rng: impl CryptoRngCore, degree: u32, constant: K) -> Self {
200 Self::from_iter_unchecked(
201 iter::once(constant).chain((0..=degree).skip(1).map(|_| K::random(&mut rng))),
202 )
203 }
204}
205
206impl<K: Additive> PartialEq for Poly<K> {
211 fn eq(&self, other: &Self) -> bool {
212 let zero = K::zero();
213 let max_len = self.len().max(other.len());
214 let self_then_zeros = self.coeffs.iter().chain(iter::repeat(&zero));
215 let other_then_zeros = other.coeffs.iter().chain(iter::repeat(&zero));
216 self_then_zeros
217 .zip(other_then_zeros)
218 .take(max_len.get() as usize)
219 .all(|(a, b)| a == b)
220 }
221}
222
223impl<K: Additive> Eq for Poly<K> {}
224
225impl<K: Additive> Poly<K> {
226 fn merge_with(&mut self, rhs: &Self, f: impl Fn(&mut K, &K)) {
227 self.coeffs
228 .resize(self.coeffs.len().max(rhs.coeffs.len()), K::zero());
229 self.coeffs
230 .iter_mut()
231 .zip(&rhs.coeffs)
232 .for_each(|(a, b)| f(a, b));
233 }
234
235 pub fn degree_exact(&self) -> u32 {
241 let zero = K::zero();
242 let leading_zeroes = self.coeffs.iter().rev().take_while(|&x| x == &zero).count();
243 let lz_u32 =
244 u32::try_from(leading_zeroes).expect("Impossible: Polynomial has >= 2^32 coefficients");
245 self.degree().saturating_sub(lz_u32)
248 }
249}
250
251impl<K: Additive> Object for Poly<K> {}
252
253impl<'a, K: Additive> AddAssign<&'a Self> for Poly<K> {
256 fn add_assign(&mut self, rhs: &'a Self) {
257 self.merge_with(rhs, |a, b| *a += b);
258 }
259}
260
261impl<'a, K: Additive> Add<&'a Self> for Poly<K> {
262 type Output = Self;
263
264 fn add(mut self, rhs: &'a Self) -> Self::Output {
265 self += rhs;
266 self
267 }
268}
269
270impl<'a, K: Additive> SubAssign<&'a Self> for Poly<K> {
271 fn sub_assign(&mut self, rhs: &'a Self) {
272 self.merge_with(rhs, |a, b| *a -= b);
273 }
274}
275
276impl<'a, K: Additive> Sub<&'a Self> for Poly<K> {
277 type Output = Self;
278
279 fn sub(mut self, rhs: &'a Self) -> Self::Output {
280 self -= rhs;
281 self
282 }
283}
284
285impl<K: Additive> Neg for Poly<K> {
286 type Output = Self;
287
288 fn neg(self) -> Self::Output {
289 Self {
290 coeffs: self.coeffs.map_into(Neg::neg),
291 }
292 }
293}
294
295impl<K: Additive> Additive for Poly<K> {
296 fn zero() -> Self {
297 Self {
298 coeffs: non_empty_vec![K::zero()],
299 }
300 }
301}
302
303impl<'a, R, K: Space<R>> MulAssign<&'a R> for Poly<K> {
306 fn mul_assign(&mut self, rhs: &'a R) {
307 self.coeffs.iter_mut().for_each(|c| *c *= rhs);
308 }
309}
310
311impl<'a, R, K: Space<R>> Mul<&'a R> for Poly<K> {
312 type Output = Self;
313
314 fn mul(mut self, rhs: &'a R) -> Self::Output {
315 self *= rhs;
316 self
317 }
318}
319
320impl<R: Sync, K: Space<R> + Send> Space<R> for Poly<K> {
321 fn msm(polys: &[Self], scalars: &[R], strategy: &impl Strategy) -> Self {
322 if polys.len() < MIN_POINTS_FOR_MSM {
323 return msm_naive(polys, scalars);
324 }
325
326 let cols = polys.len().min(scalars.len());
327 let polys = &polys[..cols];
328 let scalars = &scalars[..cols];
329
330 let rows = polys
331 .iter()
332 .map(|x| x.len_usize())
333 .max()
334 .expect("at least 1 point");
335
336 let coeffs = strategy.map_init_collect_vec(
337 0..rows,
338 || Vec::with_capacity(cols),
339 |row, i| {
340 row.clear();
341 for p in polys {
342 row.push(p.coeffs.get(i).cloned().unwrap_or_else(K::zero));
343 }
344 K::msm(row, scalars, strategy)
345 },
346 );
347 Self::from_iter_unchecked(coeffs)
348 }
349}
350
351impl<G: CryptoGroup> Poly<G> {
352 pub fn commit(p: Poly<G::Scalar>) -> Self {
354 p.translate(|c| G::generator() * c)
355 }
356}
357
358pub struct Interpolator<I, F> {
389 weights: Map<I, F>,
390}
391
392impl<I: PartialEq, F: Ring> Interpolator<I, F> {
393 pub fn interpolate<K: Space<F>>(
398 &self,
399 evals: &Map<I, K>,
400 strategy: &impl Strategy,
401 ) -> Option<K> {
402 if evals.keys() != self.weights.keys() {
403 return None;
404 }
405 Some(K::msm(evals.values(), self.weights.values(), strategy))
406 }
407}
408
409impl<I: Clone + Ord, F: Field> Interpolator<I, F> {
410 pub fn new(points: impl IntoIterator<Item = (I, F)>) -> Self {
417 let points = Map::from_iter_dedup(points);
418 let n = points.len();
419 if n == 0 {
420 return Self { weights: points };
421 }
422
423 let values = points.values();
426 let zero = F::zero();
427 let mut total_product = F::one();
428 let mut c = Vec::with_capacity(n);
429 for (i, w_i) in values.iter().enumerate() {
430 if w_i == &zero {
432 let mut out = points;
433 for (j, w) in out.values_mut().iter_mut().enumerate() {
434 *w = if j == i { F::one() } else { F::zero() };
435 }
436 return Self { weights: out };
437 }
438
439 total_product *= w_i;
441 let mut c_i = w_i.clone();
442 for w_j in values
443 .iter()
444 .enumerate()
445 .filter_map(|(j, v)| (j != i).then_some(v))
446 {
447 c_i *= &(w_j.clone() - w_i);
448 }
449 c.push(c_i);
450 }
451
452 let mut prefix = Vec::with_capacity(n + 1);
455 prefix.push(F::one());
456 let mut acc = F::one();
457 for c_i in &c {
458 acc *= c_i;
459 prefix.push(acc.clone());
460 }
461
462 let mut inv_acc = total_product * &prefix[n].inv();
464
465 let mut out = points;
467 let out_vals = out.values_mut();
468 for i in (0..n).rev() {
469 out_vals[i] = inv_acc.clone() * &prefix[i];
470 inv_acc *= &c[i];
471 }
472 Self { weights: out }
473 }
474}
475
476#[commonware_macros::stability(ALPHA)]
477impl<I: Clone + Ord, F: crate::algebra::FieldNTT> Interpolator<I, F> {
478 pub fn roots_of_unity(
487 total: NonZeroU32,
488 points: commonware_utils::ordered::BiMap<I, u32>,
489 ) -> Self {
490 let weights = <Map<I, F> as commonware_utils::TryFromIterator<(I, F)>>::try_from_iter(
491 crate::ntt::lagrange_coefficients(total, points.values().iter().copied())
492 .into_iter()
493 .filter_map(|(k, coeff)| Some((points.get_key(&k)?.clone(), coeff))),
494 )
495 .expect("points has already been deduped");
496 Self { weights }
497 }
498
499 #[cfg(any(test, feature = "fuzz"))]
506 fn roots_of_unity_naive(
507 total: NonZeroU32,
508 points: commonware_utils::ordered::BiMap<I, u32>,
509 ) -> Self {
510 use crate::algebra::powers;
511
512 let total_u32 = total.get();
513 let size = (total_u32 as u64).next_power_of_two();
514 let lg_size = size.ilog2() as u8;
515 let w = F::root_of_unity(lg_size).expect("domain too large for NTT");
516
517 let points: Vec<(I, u32)> = points.into_iter().filter(|(_, k)| *k < total_u32).collect();
518 let max_k = points.iter().map(|(_, k)| *k).max().unwrap_or(0) as usize;
519 let powers: Vec<_> = powers(&w, max_k + 1).collect();
520
521 let eval_points = points
522 .into_iter()
523 .map(|(i, k)| (i, powers[k as usize].clone()));
524 Self::new(eval_points)
525 }
526}
527
528#[cfg(any(test, feature = "arbitrary"))]
529mod impl_arbitrary {
530 use super::*;
531 use arbitrary::Arbitrary;
532
533 impl<'a, F: Arbitrary<'a>> Arbitrary<'a> for Poly<F> {
534 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
535 let first = u.arbitrary()?;
536 let rest: Vec<F> = u.arbitrary()?;
537 let mut coeffs = NonEmptyVec::new(first);
538 coeffs.extend(rest);
539 Ok(Self { coeffs })
540 }
541 }
542}
543
544#[commonware_macros::stability(ALPHA)]
545#[cfg(any(test, feature = "fuzz"))]
546pub mod fuzz {
547 use super::*;
548 use crate::{
549 algebra::test_suites,
550 test::{F, G},
551 };
552 use arbitrary::{Arbitrary, Unstructured};
553 use commonware_codec::Encode as _;
554 use commonware_parallel::Sequential;
555 use commonware_utils::{
556 ordered::{BiMap, Map},
557 TryFromIterator,
558 };
559
560 #[derive(Debug, Arbitrary)]
561 pub enum Plan {
562 Codec(Poly<F>),
563 EvalAdd(Poly<F>, Poly<F>, F),
564 EvalScale(Poly<F>, F, F),
565 EvalZero(Poly<F>),
566 EvalMsm(Poly<F>, F),
567 Interpolate(Poly<F>),
568 InterpolateWithZeroPoint(Poly<F>),
569 InterpolateWithZeroPointMiddle(Poly<F>),
570 TranslateScale(Poly<F>, F),
571 CommitEval(Poly<F>, F),
572 RootsOfUnityEqNaive(u16),
573 FuzzAdditive,
574 FuzzSpaceRing,
575 }
576
577 impl Plan {
578 pub fn run(self, u: &mut Unstructured<'_>) -> arbitrary::Result<()> {
579 match self {
580 Self::Codec(f) => {
581 assert_eq!(
582 &f,
583 &Poly::<F>::read_cfg(&mut f.encode(), &(RangeCfg::exact(f.required()), ()))
584 .unwrap()
585 );
586 }
587 Self::EvalAdd(f, g, x) => {
588 assert_eq!(f.eval(&x) + &g.eval(&x), (f + &g).eval(&x));
589 }
590 Self::EvalScale(f, x, w) => {
591 assert_eq!(f.eval(&x) * &w, (f * &w).eval(&x));
592 }
593 Self::EvalZero(f) => {
594 assert_eq!(&f.eval(&F::zero()), f.constant());
595 }
596 Self::EvalMsm(f, x) => {
597 assert_eq!(f.eval(&x), f.eval_msm(&x, &Sequential));
598 }
599 Self::Interpolate(f) => {
600 if f == Poly::zero() || f.required().get() >= F::MAX as u32 {
601 return Ok(());
602 }
603 let mut points = (0..f.required().get())
604 .map(|i| F::from((i + 1) as u8))
605 .collect::<Vec<_>>();
606 let interpolator = Interpolator::new(points.iter().copied().enumerate());
607 let evals = Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate());
608 let recovered = interpolator.interpolate(&evals, &Sequential);
609 assert_eq!(recovered.as_ref(), Some(f.constant()));
610 points.pop();
611 assert_eq!(
612 interpolator.interpolate(
613 &Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate()),
614 &Sequential
615 ),
616 None
617 );
618 }
619 Self::InterpolateWithZeroPoint(f) => {
620 if f == Poly::zero() || f.required().get() >= F::MAX as u32 {
621 return Ok(());
622 }
623 let points: Vec<_> =
624 (0..f.required().get()).map(|i| F::from(i as u8)).collect();
625 let interpolator = Interpolator::new(points.iter().copied().enumerate());
626 let evals = Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate());
627 let recovered = interpolator.interpolate(&evals, &Sequential);
628 assert_eq!(recovered.as_ref(), Some(f.constant()));
629 }
630 Self::InterpolateWithZeroPointMiddle(f) => {
631 if f == Poly::zero()
632 || f.required().get() < 2
633 || f.required().get() >= F::MAX as u32
634 {
635 return Ok(());
636 }
637 let n = f.required().get();
638 let points: Vec<_> = (1..n)
639 .map(|i| F::from(i as u8))
640 .chain(core::iter::once(F::zero()))
641 .collect();
642 let interpolator = Interpolator::new(points.iter().copied().enumerate());
643 let evals = Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate());
644 let recovered = interpolator.interpolate(&evals, &Sequential);
645 assert_eq!(recovered.as_ref(), Some(f.constant()));
646 }
647 Self::TranslateScale(f, x) => {
648 assert_eq!(f.translate(|c| x * c), f * &x);
649 }
650 Self::CommitEval(f, x) => {
651 assert_eq!(G::generator() * &f.eval(&x), Poly::<G>::commit(f).eval(&x));
652 }
653 Self::RootsOfUnityEqNaive(n) => {
654 let n = (u32::from(n) % 256) + 1;
655 let total = NonZeroU32::new(n).expect("n is in 1..=256");
656 let points = BiMap::try_from_iter((0..n as usize).map(|i| (i, i as u32)))
657 .expect("interpolation points should be bijective");
658 let fast = Interpolator::<usize, crate::fields::goldilocks::F>::roots_of_unity(
659 total,
660 points.clone(),
661 );
662 let naive =
663 Interpolator::<usize, crate::fields::goldilocks::F>::roots_of_unity_naive(
664 total, points,
665 );
666 assert_eq!(fast.weights, naive.weights);
667 }
668 Self::FuzzAdditive => {
669 test_suites::fuzz_additive::<Poly<F>>(u)?;
670 }
671 Self::FuzzSpaceRing => {
672 test_suites::fuzz_space_ring::<F, Poly<F>>(u)?;
673 }
674 }
675 Ok(())
676 }
677 }
678
679 #[test]
680 fn test_fuzz() {
681 commonware_invariants::minifuzz::test(|u| u.arbitrary::<Plan>()?.run(u));
682 }
683}
684#[cfg(test)]
685mod test {
686 use super::*;
687 use crate::test::F;
688
689 #[test]
690 fn test_eq() {
691 fn eq(a: &[u8], b: &[u8]) -> bool {
692 Poly {
693 coeffs: a.iter().copied().map(F::from).try_collect().unwrap(),
694 } == Poly {
695 coeffs: b.iter().copied().map(F::from).try_collect().unwrap(),
696 }
697 }
698 assert!(eq(&[1, 2], &[1, 2]));
699 assert!(!eq(&[1, 2], &[2, 3]));
700 assert!(!eq(&[1, 2], &[1, 2, 3]));
701 assert!(!eq(&[1, 2, 3], &[1, 2]));
702 assert!(eq(&[1, 2], &[1, 2, 0, 0]));
703 assert!(eq(&[1, 2, 0, 0], &[1, 2]));
704 assert!(!eq(&[1, 2, 0], &[2, 3]));
705 assert!(!eq(&[2, 3], &[1, 2, 0]));
706 }
707
708 #[cfg(feature = "arbitrary")]
709 mod conformance {
710 use super::*;
711 use commonware_codec::conformance::CodecConformance;
712
713 commonware_conformance::conformance_tests! {
714 CodecConformance<Poly<F>>
715 }
716 }
717}