1use crate::algebra::{msm_naive, Additive, CryptoGroup, Field, Object, Random, Ring, Space};
2#[cfg(not(feature = "std"))]
3use alloc::{vec, vec::Vec};
4use commonware_codec::{EncodeSize, RangeCfg, Read, Write};
5use commonware_parallel::Strategy;
6use commonware_utils::{non_empty_vec, ordered::Map, vec::NonEmptyVec, TryCollect};
7use core::{
8 fmt::Debug,
9 iter,
10 num::NonZeroU32,
11 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
12};
13use rand_core::CryptoRngCore;
14
15const MIN_POINTS_FOR_MSM: usize = 2;
17
18#[derive(Clone)]
20pub struct Poly<K> {
21 coeffs: NonEmptyVec<K>,
23}
24
25impl<K> Poly<K> {
26 fn len(&self) -> NonZeroU32 {
27 self.coeffs
28 .len()
29 .try_into()
30 .expect("Impossible: polynomial length not in 1..=u32::MAX")
31 }
32
33 const fn len_usize(&self) -> usize {
34 self.coeffs.len().get()
35 }
36
37 fn from_iter_unchecked(iter: impl IntoIterator<Item = K>) -> Self {
42 let coeffs = iter
43 .into_iter()
44 .try_collect::<NonEmptyVec<_>>()
45 .expect("polynomial must have a least 1 coefficient");
46 Self { coeffs }
47 }
48
49 pub fn degree(&self) -> u32 {
60 self.len().get() - 1
61 }
62
63 pub fn required(&self) -> NonZeroU32 {
67 self.len()
68 }
69
70 pub fn constant(&self) -> &K {
74 &self.coeffs[0]
75 }
76
77 pub fn translate<L>(&self, f: impl Fn(&K) -> L) -> Poly<L> {
82 Poly {
83 coeffs: self.coeffs.map(f),
84 }
85 }
86
87 pub fn eval<R>(&self, r: &R) -> K
103 where
104 K: Space<R>,
105 {
106 let mut iter = self.coeffs.iter().rev();
107 let mut acc = iter
114 .next()
115 .expect("Impossible: Polynomial has no coefficients")
116 .clone();
117 for coeff in iter {
118 acc *= r;
119 acc += coeff;
120 }
121 acc
122 }
123
124 pub fn eval_msm<R: Ring>(&self, r: &R, strategy: &impl Strategy) -> K
130 where
131 K: Space<R>,
132 {
133 let weights = {
135 let len = self.len_usize();
136 let mut out = Vec::with_capacity(len);
137 out.push(R::one());
138 let mut acc = R::one();
139 for _ in 1..len {
140 acc *= r;
141 out.push(acc.clone());
142 }
143 out
144 };
145 K::msm(&self.coeffs, &weights, strategy)
146 }
147}
148
149impl<K: Debug> Debug for Poly<K> {
150 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
151 write!(f, "Poly(")?;
152 for (i, c) in self.coeffs.iter().enumerate() {
153 if i > 0 {
154 write!(f, " + {c:?} X^{i}")?;
155 } else {
156 write!(f, "{c:?}")?;
157 }
158 }
159 write!(f, ")")?;
160 Ok(())
161 }
162}
163
164impl<K: EncodeSize> EncodeSize for Poly<K> {
165 fn encode_size(&self) -> usize {
166 self.coeffs.encode_size()
167 }
168}
169
170impl<K: Write> Write for Poly<K> {
171 fn write(&self, buf: &mut impl bytes::BufMut) {
172 self.coeffs.write(buf);
173 }
174}
175
176impl<K: Read> Read for Poly<K> {
177 type Cfg = (RangeCfg<NonZeroU32>, <K as Read>::Cfg);
178
179 fn read_cfg(
180 buf: &mut impl bytes::Buf,
181 cfg: &Self::Cfg,
182 ) -> Result<Self, commonware_codec::Error> {
183 Ok(Self {
184 coeffs: NonEmptyVec::<K>::read_cfg(buf, &(cfg.0.into(), cfg.1.clone()))?,
185 })
186 }
187}
188
189impl<K: Random> Poly<K> {
190 pub fn new(mut rng: impl CryptoRngCore, degree: u32) -> Self {
193 Self::from_iter_unchecked((0..=degree).map(|_| K::random(&mut rng)))
194 }
195
196 pub fn new_with_constant(mut rng: impl CryptoRngCore, degree: u32, constant: K) -> Self {
200 Self::from_iter_unchecked(
201 iter::once(constant).chain((0..=degree).skip(1).map(|_| K::random(&mut rng))),
202 )
203 }
204}
205
206impl<K: Additive> PartialEq for Poly<K> {
211 fn eq(&self, other: &Self) -> bool {
212 let zero = K::zero();
213 let max_len = self.len().max(other.len());
214 let self_then_zeros = self.coeffs.iter().chain(iter::repeat(&zero));
215 let other_then_zeros = other.coeffs.iter().chain(iter::repeat(&zero));
216 self_then_zeros
217 .zip(other_then_zeros)
218 .take(max_len.get() as usize)
219 .all(|(a, b)| a == b)
220 }
221}
222
223impl<K: Additive> Eq for Poly<K> {}
224
225impl<K: Additive> Poly<K> {
226 fn merge_with(&mut self, rhs: &Self, f: impl Fn(&mut K, &K)) {
227 self.coeffs
228 .resize(self.coeffs.len().max(rhs.coeffs.len()), K::zero());
229 self.coeffs
230 .iter_mut()
231 .zip(&rhs.coeffs)
232 .for_each(|(a, b)| f(a, b));
233 }
234
235 pub fn degree_exact(&self) -> u32 {
241 let zero = K::zero();
242 let leading_zeroes = self.coeffs.iter().rev().take_while(|&x| x == &zero).count();
243 let lz_u32 =
244 u32::try_from(leading_zeroes).expect("Impossible: Polynomial has >= 2^32 coefficients");
245 self.degree().saturating_sub(lz_u32)
248 }
249}
250
251impl<K: Additive> Object for Poly<K> {}
252
253impl<'a, K: Additive> AddAssign<&'a Self> for Poly<K> {
256 fn add_assign(&mut self, rhs: &'a Self) {
257 self.merge_with(rhs, |a, b| *a += b);
258 }
259}
260
261impl<'a, K: Additive> Add<&'a Self> for Poly<K> {
262 type Output = Self;
263
264 fn add(mut self, rhs: &'a Self) -> Self::Output {
265 self += rhs;
266 self
267 }
268}
269
270impl<'a, K: Additive> SubAssign<&'a Self> for Poly<K> {
271 fn sub_assign(&mut self, rhs: &'a Self) {
272 self.merge_with(rhs, |a, b| *a -= b);
273 }
274}
275
276impl<'a, K: Additive> Sub<&'a Self> for Poly<K> {
277 type Output = Self;
278
279 fn sub(mut self, rhs: &'a Self) -> Self::Output {
280 self -= rhs;
281 self
282 }
283}
284
285impl<K: Additive> Neg for Poly<K> {
286 type Output = Self;
287
288 fn neg(self) -> Self::Output {
289 Self {
290 coeffs: self.coeffs.map_into(Neg::neg),
291 }
292 }
293}
294
295impl<K: Additive> Additive for Poly<K> {
296 fn zero() -> Self {
297 Self {
298 coeffs: non_empty_vec![K::zero()],
299 }
300 }
301}
302
303impl<'a, R, K: Space<R>> MulAssign<&'a R> for Poly<K> {
306 fn mul_assign(&mut self, rhs: &'a R) {
307 self.coeffs.iter_mut().for_each(|c| *c *= rhs);
308 }
309}
310
311impl<'a, R, K: Space<R>> Mul<&'a R> for Poly<K> {
312 type Output = Self;
313
314 fn mul(mut self, rhs: &'a R) -> Self::Output {
315 self *= rhs;
316 self
317 }
318}
319
320impl<R: Sync, K: Space<R> + Send> Space<R> for Poly<K> {
321 fn msm(polys: &[Self], scalars: &[R], strategy: &impl Strategy) -> Self {
322 if polys.len() < MIN_POINTS_FOR_MSM {
323 return msm_naive(polys, scalars);
324 }
325
326 let cols = polys.len().min(scalars.len());
327 let polys = &polys[..cols];
328 let scalars = &scalars[..cols];
329
330 let rows = polys
331 .iter()
332 .map(|x| x.len_usize())
333 .max()
334 .expect("at least 1 point");
335
336 let coeffs = strategy.map_init_collect_vec(
337 0..rows,
338 || Vec::with_capacity(cols),
339 |row, i| {
340 row.clear();
341 for p in polys {
342 row.push(p.coeffs.get(i).cloned().unwrap_or_else(K::zero));
343 }
344 K::msm(row, scalars, strategy)
345 },
346 );
347 Self::from_iter_unchecked(coeffs)
348 }
349}
350
351impl<G: CryptoGroup> Poly<G> {
352 pub fn commit(p: Poly<G::Scalar>) -> Self {
354 p.translate(|c| G::generator() * c)
355 }
356}
357
358pub struct Interpolator<I, F> {
389 weights: Map<I, F>,
390}
391
392impl<I: PartialEq, F: Ring> Interpolator<I, F> {
393 pub fn interpolate<K: Space<F>>(
398 &self,
399 evals: &Map<I, K>,
400 strategy: &impl Strategy,
401 ) -> Option<K> {
402 if evals.keys() != self.weights.keys() {
403 return None;
404 }
405 Some(K::msm(evals.values(), self.weights.values(), strategy))
406 }
407}
408
409impl<I: Clone + Ord, F: Field> Interpolator<I, F> {
410 pub fn new(points: impl IntoIterator<Item = (I, F)>) -> Self {
417 let points = Map::from_iter_dedup(points);
418 let n = points.len();
419 if n == 0 {
420 return Self { weights: points };
421 }
422
423 let values = points.values();
426 let zero = F::zero();
427 let mut total_product = F::one();
428 let mut c = Vec::with_capacity(n);
429 for (i, w_i) in values.iter().enumerate() {
430 if w_i == &zero {
432 let mut out = points;
433 for (j, w) in out.values_mut().iter_mut().enumerate() {
434 *w = if j == i { F::one() } else { F::zero() };
435 }
436 return Self { weights: out };
437 }
438
439 total_product *= w_i;
441 let mut c_i = w_i.clone();
442 for w_j in values
443 .iter()
444 .enumerate()
445 .filter_map(|(j, v)| (j != i).then_some(v))
446 {
447 c_i *= &(w_j.clone() - w_i);
448 }
449 c.push(c_i);
450 }
451
452 let mut prefix = Vec::with_capacity(n + 1);
455 prefix.push(F::one());
456 let mut acc = F::one();
457 for c_i in &c {
458 acc *= c_i;
459 prefix.push(acc.clone());
460 }
461
462 let mut inv_acc = total_product * &prefix[n].inv();
464
465 let mut out = points;
467 let out_vals = out.values_mut();
468 for i in (0..n).rev() {
469 out_vals[i] = inv_acc.clone() * &prefix[i];
470 inv_acc *= &c[i];
471 }
472 Self { weights: out }
473 }
474}
475
476#[cfg(feature = "arbitrary")]
477mod fuzz {
478 use super::*;
479 use arbitrary::Arbitrary;
480
481 impl<'a, F: Arbitrary<'a>> Arbitrary<'a> for Poly<F> {
482 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
483 Ok(Self {
484 coeffs: u.arbitrary()?,
485 })
486 }
487 }
488}
489
490#[cfg(test)]
491mod test {
492 use super::*;
493 use crate::test::{F, G};
494 use commonware_codec::Encode;
495 use commonware_parallel::Sequential;
496 use proptest::{
497 prelude::{Arbitrary, BoxedStrategy, Strategy as _},
498 prop_assume, proptest,
499 sample::SizeRange,
500 };
501
502 impl Arbitrary for Poly<F> {
503 type Parameters = SizeRange;
504 type Strategy = BoxedStrategy<Self>;
505
506 fn arbitrary_with(size: Self::Parameters) -> Self::Strategy {
507 let nonempty_size = if size.start() == 0 { size + 1 } else { size };
508 proptest::collection::vec(F::arbitrary(), nonempty_size)
509 .prop_map(Self::from_iter_unchecked)
510 .boxed()
511 }
512 }
513
514 #[test]
515 fn test_additive() {
516 crate::algebra::test_suites::test_additive(file!(), &Poly::<F>::arbitrary());
517 }
518
519 #[test]
520 fn test_space() {
521 crate::algebra::test_suites::test_space_ring(
522 file!(),
523 &F::arbitrary(),
524 &Poly::<F>::arbitrary(),
525 );
526 }
527
528 #[test]
529 fn test_eq() {
530 fn eq(a: &[u8], b: &[u8]) -> bool {
531 Poly {
532 coeffs: a.iter().copied().map(F::from).try_collect().unwrap(),
533 } == Poly {
534 coeffs: b.iter().copied().map(F::from).try_collect().unwrap(),
535 }
536 }
537 assert!(eq(&[1, 2], &[1, 2]));
538 assert!(!eq(&[1, 2], &[2, 3]));
539 assert!(!eq(&[1, 2], &[1, 2, 3]));
540 assert!(!eq(&[1, 2, 3], &[1, 2]));
541 assert!(eq(&[1, 2], &[1, 2, 0, 0]));
542 assert!(eq(&[1, 2, 0, 0], &[1, 2]));
543 assert!(!eq(&[1, 2, 0], &[2, 3]));
544 assert!(!eq(&[2, 3], &[1, 2, 0]));
545 }
546
547 proptest! {
548 #[test]
549 fn test_codec(f: Poly<F>) {
550 assert_eq!(&f, &Poly::<F>::read_cfg(&mut f.encode(), &(RangeCfg::exact(f.required()), ())).unwrap())
551 }
552
553 #[test]
554 fn test_eval_add(f: Poly<F>, g: Poly<F>, x: F) {
555 assert_eq!(f.eval(&x) + &g.eval(&x), (f + &g).eval(&x));
556 }
557
558 #[test]
559 fn test_eval_scale(f: Poly<F>, x: F, w: F) {
560 assert_eq!(f.eval(&x) * &w, (f * &w).eval(&x));
561 }
562
563 #[test]
564 fn test_eval_zero(f: Poly<F>) {
565 assert_eq!(&f.eval(&F::zero()), f.constant());
566 }
567
568 #[test]
569 fn test_eval_msm(f: Poly<F>, x: F) {
570 assert_eq!(f.eval(&x), f.eval_msm(&x, &Sequential));
571 }
572
573 #[test]
574 fn test_interpolate(f: Poly<F>) {
575 prop_assume!(f != Poly::zero());
577 prop_assume!(f.required().get() < F::MAX as u32);
578 let mut points = (0..f.required().get()).map(|i| F::from((i + 1) as u8)).collect::<Vec<_>>();
579 let interpolator = Interpolator::new(points.iter().copied().enumerate());
580 let evals = Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate());
581 let recovered = interpolator.interpolate(&evals, &Sequential);
582 assert_eq!(recovered.as_ref(), Some(f.constant()));
583 points.pop();
584 assert!(interpolator.interpolate(&Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate()), &Sequential).is_none());
585 }
586
587 #[test]
588 fn test_interpolate_with_zero_point(f: Poly<F>) {
589 prop_assume!(f != Poly::zero());
591 prop_assume!(f.required().get() < F::MAX as u32);
592 let points: Vec<_> = (0..f.required().get()).map(|i| F::from(i as u8)).collect();
593 let interpolator = Interpolator::new(points.iter().copied().enumerate());
594 let evals = Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate());
595 let recovered = interpolator.interpolate(&evals, &Sequential);
596 assert_eq!(recovered.as_ref(), Some(f.constant()));
597 }
598
599 #[test]
600 fn test_interpolate_with_zero_point_middle(f: Poly<F>) {
601 prop_assume!(f != Poly::zero());
603 prop_assume!(f.required().get() >= 2);
604 prop_assume!(f.required().get() < F::MAX as u32);
605 let n = f.required().get();
606 let points: Vec<_> = (1..n).map(|i| F::from(i as u8)).chain(core::iter::once(F::zero())).collect();
607 let interpolator = Interpolator::new(points.iter().copied().enumerate());
608 let evals = Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate());
609 let recovered = interpolator.interpolate(&evals, &Sequential);
610 assert_eq!(recovered.as_ref(), Some(f.constant()));
611 }
612
613 #[test]
614 fn test_translate_scale(f: Poly<F>, x: F) {
615 assert_eq!(f.translate(|c| x * c), f * &x);
616 }
617
618 #[test]
619 fn test_commit_eval(f: Poly<F>, x: F) {
620 assert_eq!(G::generator() * &f.eval(&x), Poly::<G>::commit(f).eval(&x));
621 }
622 }
623
624 #[cfg(feature = "arbitrary")]
625 mod conformance {
626 use super::*;
627 use commonware_codec::conformance::CodecConformance;
628
629 commonware_conformance::conformance_tests! {
630 CodecConformance<Poly<F>>
631 }
632 }
633}