1use crate::algebra::{msm_naive, Additive, CryptoGroup, Field, Object, Random, Ring, Space};
2#[cfg(not(feature = "std"))]
3use alloc::{vec, vec::Vec};
4use commonware_codec::{EncodeSize, RangeCfg, Read, Write};
5use commonware_utils::{non_empty_vec, ordered::Map, vec::NonEmptyVec, TryCollect};
6use core::{
7 fmt::Debug,
8 iter,
9 num::NonZeroU32,
10 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
11};
12use rand_core::CryptoRngCore;
13#[cfg(feature = "std")]
14use rayon::{prelude::*, ThreadPoolBuilder};
15
16const MIN_POINTS_FOR_MSM: usize = 2;
18
19#[derive(Clone)]
21pub struct Poly<K> {
22 coeffs: NonEmptyVec<K>,
24}
25
26impl<K> Poly<K> {
27 fn len(&self) -> NonZeroU32 {
28 self.coeffs
29 .len()
30 .try_into()
31 .expect("Impossible: polynomial length not in 1..=u32::MAX")
32 }
33
34 fn len_usize(&self) -> usize {
35 self.coeffs.len().get()
36 }
37
38 fn from_iter_unchecked(iter: impl IntoIterator<Item = K>) -> Self {
43 let coeffs = iter
44 .into_iter()
45 .try_collect::<NonEmptyVec<_>>()
46 .expect("polynomial must have a least 1 coefficient");
47 Self { coeffs }
48 }
49
50 pub fn degree(&self) -> u32 {
61 self.len().get() - 1
62 }
63
64 pub fn required(&self) -> NonZeroU32 {
68 self.len()
69 }
70
71 pub fn constant(&self) -> &K {
75 &self.coeffs[0]
76 }
77
78 pub fn translate<L>(&self, f: impl Fn(&K) -> L) -> Poly<L> {
83 Poly {
84 coeffs: self.coeffs.map(f),
85 }
86 }
87
88 pub fn eval<R>(&self, r: &R) -> K
104 where
105 K: Space<R>,
106 {
107 let mut iter = self.coeffs.iter().rev();
108 let mut acc = iter
115 .next()
116 .expect("Impossible: Polynomial has no coefficients")
117 .clone();
118 for coeff in iter {
119 acc *= r;
120 acc += coeff;
121 }
122 acc
123 }
124
125 pub fn eval_msm<R: Ring>(&self, r: &R) -> K
131 where
132 K: Space<R>,
133 {
134 let weights = {
136 let len = self.len_usize();
137 let mut out = Vec::with_capacity(len);
138 out.push(R::one());
139 let mut acc = R::one();
140 for _ in 1..len {
141 acc *= r;
142 out.push(acc.clone());
143 }
144 out
145 };
146 K::msm(&self.coeffs, &weights, 1)
147 }
148}
149
150impl<K: Debug> Debug for Poly<K> {
151 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
152 write!(f, "Poly(")?;
153 for (i, c) in self.coeffs.iter().enumerate() {
154 if i > 0 {
155 write!(f, " + {c:?} X^{i}")?;
156 } else {
157 write!(f, "{c:?}")?;
158 }
159 }
160 write!(f, ")")?;
161 Ok(())
162 }
163}
164
165impl<K: EncodeSize> EncodeSize for Poly<K> {
166 fn encode_size(&self) -> usize {
167 self.coeffs.encode_size()
168 }
169}
170
171impl<K: Write> Write for Poly<K> {
172 fn write(&self, buf: &mut impl bytes::BufMut) {
173 self.coeffs.write(buf);
174 }
175}
176
177impl<K: Read> Read for Poly<K> {
178 type Cfg = (RangeCfg<NonZeroU32>, <K as Read>::Cfg);
179
180 fn read_cfg(
181 buf: &mut impl bytes::Buf,
182 cfg: &Self::Cfg,
183 ) -> Result<Self, commonware_codec::Error> {
184 Ok(Self {
185 coeffs: NonEmptyVec::<K>::read_cfg(buf, &(cfg.0.into(), cfg.1.clone()))?,
186 })
187 }
188}
189
190impl<K: Random> Poly<K> {
191 pub fn new(mut rng: impl CryptoRngCore, degree: u32) -> Self {
194 Self::from_iter_unchecked((0..=degree).map(|_| K::random(&mut rng)))
195 }
196
197 pub fn new_with_constant(mut rng: impl CryptoRngCore, degree: u32, constant: K) -> Self {
201 Self::from_iter_unchecked(
202 iter::once(constant).chain((0..=degree).skip(1).map(|_| K::random(&mut rng))),
203 )
204 }
205}
206
207impl<K: Additive> PartialEq for Poly<K> {
212 fn eq(&self, other: &Self) -> bool {
213 let zero = K::zero();
214 let max_len = self.len().max(other.len());
215 let self_then_zeros = self.coeffs.iter().chain(iter::repeat(&zero));
216 let other_then_zeros = other.coeffs.iter().chain(iter::repeat(&zero));
217 self_then_zeros
218 .zip(other_then_zeros)
219 .take(max_len.get() as usize)
220 .all(|(a, b)| a == b)
221 }
222}
223
224impl<K: Additive> Eq for Poly<K> {}
225
226impl<K: Additive> Poly<K> {
227 fn merge_with(&mut self, rhs: &Self, f: impl Fn(&mut K, &K)) {
228 self.coeffs
229 .resize(self.coeffs.len().max(rhs.coeffs.len()), K::zero());
230 self.coeffs
231 .iter_mut()
232 .zip(&rhs.coeffs)
233 .for_each(|(a, b)| f(a, b));
234 }
235
236 pub fn degree_exact(&self) -> u32 {
242 let zero = K::zero();
243 let leading_zeroes = self.coeffs.iter().rev().take_while(|&x| x == &zero).count();
244 let lz_u32 =
245 u32::try_from(leading_zeroes).expect("Impossible: Polynomial has >= 2^32 coefficients");
246 self.degree().saturating_sub(lz_u32)
249 }
250}
251
252impl<K: Additive> Object for Poly<K> {}
253
254impl<'a, K: Additive> AddAssign<&'a Poly<K>> for Poly<K> {
257 fn add_assign(&mut self, rhs: &'a Poly<K>) {
258 self.merge_with(rhs, |a, b| *a += b);
259 }
260}
261
262impl<'a, K: Additive> Add<&'a Poly<K>> for Poly<K> {
263 type Output = Self;
264
265 fn add(mut self, rhs: &'a Poly<K>) -> Self::Output {
266 self += rhs;
267 self
268 }
269}
270
271impl<'a, K: Additive> SubAssign<&'a Poly<K>> for Poly<K> {
272 fn sub_assign(&mut self, rhs: &'a Poly<K>) {
273 self.merge_with(rhs, |a, b| *a -= b);
274 }
275}
276
277impl<'a, K: Additive> Sub<&'a Poly<K>> for Poly<K> {
278 type Output = Self;
279
280 fn sub(mut self, rhs: &'a Poly<K>) -> Self::Output {
281 self -= rhs;
282 self
283 }
284}
285
286impl<K: Additive> Neg for Poly<K> {
287 type Output = Self;
288
289 fn neg(self) -> Self::Output {
290 Self {
291 coeffs: self.coeffs.map_into(Neg::neg),
292 }
293 }
294}
295
296impl<K: Additive> Additive for Poly<K> {
297 fn zero() -> Self {
298 Self {
299 coeffs: non_empty_vec![K::zero()],
300 }
301 }
302}
303
304impl<'a, R, K: Space<R>> MulAssign<&'a R> for Poly<K> {
307 fn mul_assign(&mut self, rhs: &'a R) {
308 self.coeffs.iter_mut().for_each(|c| *c *= rhs);
309 }
310}
311
312impl<'a, R, K: Space<R>> Mul<&'a R> for Poly<K> {
313 type Output = Self;
314
315 fn mul(mut self, rhs: &'a R) -> Self::Output {
316 self *= rhs;
317 self
318 }
319}
320
321#[cfg(feature = "std")]
322impl<R: Sync, K: Space<R>> Space<R> for Poly<K> {
323 fn msm(polys: &[Self], scalars: &[R], concurrency: usize) -> Self {
324 if polys.len() < MIN_POINTS_FOR_MSM {
325 return msm_naive(polys, scalars);
326 }
327
328 let cols = polys.len().min(scalars.len());
329 let polys = &polys[..cols];
330 let scalars = &scalars[..cols];
331
332 let rows = polys
333 .iter()
334 .map(|x| x.len_usize())
335 .max()
336 .expect("at least 1 point");
337
338 if concurrency > 1 {
339 let pool = ThreadPoolBuilder::new()
340 .num_threads(concurrency)
341 .build()
342 .expect("Unable to build thread pool");
343
344 let coeffs = pool.install(|| {
345 (0..rows)
346 .into_par_iter()
347 .map(|i| {
348 let row: Vec<_> = polys
349 .iter()
350 .map(|p| p.coeffs.get(i).cloned().unwrap_or_else(K::zero))
351 .collect();
352 K::msm(&row, scalars, 1)
353 })
354 .collect::<Vec<_>>()
355 });
356 return Poly::from_iter_unchecked(coeffs);
357 }
358
359 let mut row = Vec::with_capacity(cols);
360 let coeffs = (0..rows).map(|i| {
361 row.clear();
362 for p in polys {
363 row.push(p.coeffs.get(i).cloned().unwrap_or_else(K::zero));
364 }
365 K::msm(&row, scalars, concurrency)
366 });
367
368 Poly::from_iter_unchecked(coeffs)
369 }
370}
371
372#[cfg(not(feature = "std"))]
373impl<R, K: Space<R>> Space<R> for Poly<K> {
374 fn msm(polys: &[Self], scalars: &[R], concurrency: usize) -> Self {
375 if polys.len() < MIN_POINTS_FOR_MSM {
376 return msm_naive(polys, scalars);
377 }
378
379 let cols = polys.len().min(scalars.len());
380 let polys = &polys[..cols];
381 let scalars = &scalars[..cols];
382
383 let rows = polys
384 .iter()
385 .map(|x| x.len_usize())
386 .max()
387 .expect("at least 1 point");
388
389 let mut row = Vec::with_capacity(cols);
390 let coeffs = (0..rows).map(|i| {
391 row.clear();
392 for p in polys {
393 row.push(p.coeffs.get(i).cloned().unwrap_or_else(K::zero));
394 }
395 K::msm(&row, scalars, concurrency)
396 });
397 Poly::from_iter_unchecked(coeffs)
398 }
399}
400
401impl<G: CryptoGroup> Poly<G> {
402 pub fn commit(p: Poly<G::Scalar>) -> Self {
404 p.translate(|c| G::generator() * c)
405 }
406}
407
408pub struct Interpolator<I, F> {
438 weights: Map<I, F>,
439}
440
441impl<I: PartialEq, F: Ring> Interpolator<I, F> {
442 pub fn interpolate<K: Space<F>>(&self, evals: &Map<I, K>, concurrency: usize) -> Option<K> {
447 if evals.keys() != self.weights.keys() {
448 return None;
449 }
450 Some(K::msm(evals.values(), self.weights.values(), concurrency))
451 }
452}
453
454impl<I: Clone + Ord, F: Field> Interpolator<I, F> {
455 pub fn new(points: impl IntoIterator<Item = (I, F)>) -> Self {
462 let points = Map::from_iter_dedup(points);
463 let weights = points
464 .iter_pairs()
465 .map(|(i, w_i)| {
466 let mut top_i = F::one();
467 let mut bot_i = F::one();
468 for (j, w_j) in points.iter_pairs() {
469 if i == j {
470 continue;
471 }
472 top_i *= w_j;
473 bot_i *= &(w_j.clone() - w_i);
474 }
475 top_i * &bot_i.inv()
476 })
477 .collect::<Vec<_>>();
478 let mut out = points;
480 for (out_i, weight_i) in out.values_mut().iter_mut().zip(weights.into_iter()) {
481 *out_i = weight_i;
482 }
483 Self { weights: out }
484 }
485}
486
487#[cfg(feature = "arbitrary")]
488mod fuzz {
489 use super::*;
490 use arbitrary::Arbitrary;
491
492 impl<'a, F: Arbitrary<'a>> Arbitrary<'a> for Poly<F> {
493 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
494 Ok(Self {
495 coeffs: u.arbitrary()?,
496 })
497 }
498 }
499}
500
501#[cfg(test)]
502mod test {
503 use super::*;
504 use crate::test::{F, G};
505 use commonware_codec::Encode;
506 use proptest::{
507 prelude::{Arbitrary, BoxedStrategy, Strategy},
508 prop_assume, proptest,
509 sample::SizeRange,
510 };
511
512 impl Arbitrary for Poly<F> {
513 type Parameters = SizeRange;
514 type Strategy = BoxedStrategy<Self>;
515
516 fn arbitrary_with(size: Self::Parameters) -> Self::Strategy {
517 let nonempty_size = if size.start() == 0 { size + 1 } else { size };
518 proptest::collection::vec(F::arbitrary(), nonempty_size)
519 .prop_map(Poly::from_iter_unchecked)
520 .boxed()
521 }
522 }
523
524 #[test]
525 fn test_additive() {
526 crate::algebra::test_suites::test_additive(file!(), &Poly::<F>::arbitrary());
527 }
528
529 #[test]
530 fn test_space() {
531 crate::algebra::test_suites::test_space_ring(
532 file!(),
533 &F::arbitrary(),
534 &Poly::<F>::arbitrary(),
535 );
536 }
537
538 #[test]
539 fn test_eq() {
540 fn eq(a: &[u8], b: &[u8]) -> bool {
541 Poly {
542 coeffs: a.iter().copied().map(F::from).try_collect().unwrap(),
543 } == Poly {
544 coeffs: b.iter().copied().map(F::from).try_collect().unwrap(),
545 }
546 }
547 assert!(eq(&[1, 2], &[1, 2]));
548 assert!(!eq(&[1, 2], &[2, 3]));
549 assert!(!eq(&[1, 2], &[1, 2, 3]));
550 assert!(!eq(&[1, 2, 3], &[1, 2]));
551 assert!(eq(&[1, 2], &[1, 2, 0, 0]));
552 assert!(eq(&[1, 2, 0, 0], &[1, 2]));
553 assert!(!eq(&[1, 2, 0], &[2, 3]));
554 assert!(!eq(&[2, 3], &[1, 2, 0]));
555 }
556
557 proptest! {
558 #[test]
559 fn test_codec(f: Poly<F>) {
560 assert_eq!(&f, &Poly::<F>::read_cfg(&mut f.encode(), &(RangeCfg::exact(f.required()), ())).unwrap())
561 }
562
563 #[test]
564 fn test_eval_add(f: Poly<F>, g: Poly<F>, x: F) {
565 assert_eq!(f.eval(&x) + &g.eval(&x), (f + &g).eval(&x));
566 }
567
568 #[test]
569 fn test_eval_scale(f: Poly<F>, x: F, w: F) {
570 assert_eq!(f.eval(&x) * &w, (f * &w).eval(&x));
571 }
572
573 #[test]
574 fn test_eval_zero(f: Poly<F>) {
575 assert_eq!(&f.eval(&F::zero()), f.constant());
576 }
577
578 #[test]
579 fn test_eval_msm(f: Poly<F>, x: F) {
580 assert_eq!(f.eval(&x), f.eval_msm(&x));
581 }
582
583 #[test]
584 fn test_interpolate(f: Poly<F>) {
585 prop_assume!(f != Poly::zero());
587 prop_assume!(f.required().get() < F::MAX as u32);
588 let mut points = (0..f.required().get()).map(|i| F::from((i + 1) as u8)).collect::<Vec<_>>();
589 let interpolator = Interpolator::new(points.iter().copied().enumerate());
590 let evals = Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate());
591 let recovered = interpolator.interpolate(&evals, 1);
592 assert_eq!(recovered.as_ref(), Some(f.constant()));
593 points.pop();
594 assert!(interpolator.interpolate(&Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate()), 1).is_none());
595 }
596
597 #[test]
598 fn test_translate_scale(f: Poly<F>, x: F) {
599 assert_eq!(f.translate(|c| x * c), f * &x);
600 }
601
602 #[test]
603 fn test_commit_eval(f: Poly<F>, x: F) {
604 assert_eq!(G::generator() * &f.eval(&x), Poly::<G>::commit(f).eval(&x));
605 }
606 }
607
608 #[cfg(feature = "arbitrary")]
609 mod conformance {
610 use super::*;
611 use commonware_codec::conformance::CodecConformance;
612
613 commonware_conformance::conformance_tests! {
614 CodecConformance<Poly<F>>
615 }
616 }
617}