1use crate::algebra::{msm_naive, Additive, CryptoGroup, Field, Object, Random, Ring, Space};
2#[cfg(not(feature = "std"))]
3use alloc::{vec, vec::Vec};
4use commonware_codec::{EncodeSize, RangeCfg, Read, Write};
5use commonware_parallel::Strategy;
6use commonware_utils::{non_empty_vec, ordered::Map, vec::NonEmptyVec, TryCollect};
7use core::{
8 fmt::Debug,
9 iter,
10 num::NonZeroU32,
11 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
12};
13use rand_core::CryptoRngCore;
14
15const MIN_POINTS_FOR_MSM: usize = 2;
17
18#[derive(Clone)]
20pub struct Poly<K> {
21 coeffs: NonEmptyVec<K>,
23}
24
25impl<K> Poly<K> {
26 fn len(&self) -> NonZeroU32 {
27 self.coeffs
28 .len()
29 .try_into()
30 .expect("Impossible: polynomial length not in 1..=u32::MAX")
31 }
32
33 fn len_usize(&self) -> usize {
34 self.coeffs.len().get()
35 }
36
37 fn from_iter_unchecked(iter: impl IntoIterator<Item = K>) -> Self {
42 let coeffs = iter
43 .into_iter()
44 .try_collect::<NonEmptyVec<_>>()
45 .expect("polynomial must have a least 1 coefficient");
46 Self { coeffs }
47 }
48
49 pub fn degree(&self) -> u32 {
60 self.len().get() - 1
61 }
62
63 pub fn required(&self) -> NonZeroU32 {
67 self.len()
68 }
69
70 pub fn constant(&self) -> &K {
74 &self.coeffs[0]
75 }
76
77 pub fn translate<L>(&self, f: impl Fn(&K) -> L) -> Poly<L> {
82 Poly {
83 coeffs: self.coeffs.map(f),
84 }
85 }
86
87 pub fn eval<R>(&self, r: &R) -> K
103 where
104 K: Space<R>,
105 {
106 let mut iter = self.coeffs.iter().rev();
107 let mut acc = iter
114 .next()
115 .expect("Impossible: Polynomial has no coefficients")
116 .clone();
117 for coeff in iter {
118 acc *= r;
119 acc += coeff;
120 }
121 acc
122 }
123
124 pub fn eval_msm<R: Ring>(&self, r: &R, strategy: &impl Strategy) -> K
130 where
131 K: Space<R>,
132 {
133 let weights = {
135 let len = self.len_usize();
136 let mut out = Vec::with_capacity(len);
137 out.push(R::one());
138 let mut acc = R::one();
139 for _ in 1..len {
140 acc *= r;
141 out.push(acc.clone());
142 }
143 out
144 };
145 K::msm(&self.coeffs, &weights, strategy)
146 }
147}
148
149impl<K: Debug> Debug for Poly<K> {
150 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
151 write!(f, "Poly(")?;
152 for (i, c) in self.coeffs.iter().enumerate() {
153 if i > 0 {
154 write!(f, " + {c:?} X^{i}")?;
155 } else {
156 write!(f, "{c:?}")?;
157 }
158 }
159 write!(f, ")")?;
160 Ok(())
161 }
162}
163
164impl<K: EncodeSize> EncodeSize for Poly<K> {
165 fn encode_size(&self) -> usize {
166 self.coeffs.encode_size()
167 }
168}
169
170impl<K: Write> Write for Poly<K> {
171 fn write(&self, buf: &mut impl bytes::BufMut) {
172 self.coeffs.write(buf);
173 }
174}
175
176impl<K: Read> Read for Poly<K> {
177 type Cfg = (RangeCfg<NonZeroU32>, <K as Read>::Cfg);
178
179 fn read_cfg(
180 buf: &mut impl bytes::Buf,
181 cfg: &Self::Cfg,
182 ) -> Result<Self, commonware_codec::Error> {
183 Ok(Self {
184 coeffs: NonEmptyVec::<K>::read_cfg(buf, &(cfg.0.into(), cfg.1.clone()))?,
185 })
186 }
187}
188
189impl<K: Random> Poly<K> {
190 pub fn new(mut rng: impl CryptoRngCore, degree: u32) -> Self {
193 Self::from_iter_unchecked((0..=degree).map(|_| K::random(&mut rng)))
194 }
195
196 pub fn new_with_constant(mut rng: impl CryptoRngCore, degree: u32, constant: K) -> Self {
200 Self::from_iter_unchecked(
201 iter::once(constant).chain((0..=degree).skip(1).map(|_| K::random(&mut rng))),
202 )
203 }
204}
205
206impl<K: Additive> PartialEq for Poly<K> {
211 fn eq(&self, other: &Self) -> bool {
212 let zero = K::zero();
213 let max_len = self.len().max(other.len());
214 let self_then_zeros = self.coeffs.iter().chain(iter::repeat(&zero));
215 let other_then_zeros = other.coeffs.iter().chain(iter::repeat(&zero));
216 self_then_zeros
217 .zip(other_then_zeros)
218 .take(max_len.get() as usize)
219 .all(|(a, b)| a == b)
220 }
221}
222
223impl<K: Additive> Eq for Poly<K> {}
224
225impl<K: Additive> Poly<K> {
226 fn merge_with(&mut self, rhs: &Self, f: impl Fn(&mut K, &K)) {
227 self.coeffs
228 .resize(self.coeffs.len().max(rhs.coeffs.len()), K::zero());
229 self.coeffs
230 .iter_mut()
231 .zip(&rhs.coeffs)
232 .for_each(|(a, b)| f(a, b));
233 }
234
235 pub fn degree_exact(&self) -> u32 {
241 let zero = K::zero();
242 let leading_zeroes = self.coeffs.iter().rev().take_while(|&x| x == &zero).count();
243 let lz_u32 =
244 u32::try_from(leading_zeroes).expect("Impossible: Polynomial has >= 2^32 coefficients");
245 self.degree().saturating_sub(lz_u32)
248 }
249}
250
251impl<K: Additive> Object for Poly<K> {}
252
253impl<'a, K: Additive> AddAssign<&'a Poly<K>> for Poly<K> {
256 fn add_assign(&mut self, rhs: &'a Poly<K>) {
257 self.merge_with(rhs, |a, b| *a += b);
258 }
259}
260
261impl<'a, K: Additive> Add<&'a Poly<K>> for Poly<K> {
262 type Output = Self;
263
264 fn add(mut self, rhs: &'a Poly<K>) -> Self::Output {
265 self += rhs;
266 self
267 }
268}
269
270impl<'a, K: Additive> SubAssign<&'a Poly<K>> for Poly<K> {
271 fn sub_assign(&mut self, rhs: &'a Poly<K>) {
272 self.merge_with(rhs, |a, b| *a -= b);
273 }
274}
275
276impl<'a, K: Additive> Sub<&'a Poly<K>> for Poly<K> {
277 type Output = Self;
278
279 fn sub(mut self, rhs: &'a Poly<K>) -> Self::Output {
280 self -= rhs;
281 self
282 }
283}
284
285impl<K: Additive> Neg for Poly<K> {
286 type Output = Self;
287
288 fn neg(self) -> Self::Output {
289 Self {
290 coeffs: self.coeffs.map_into(Neg::neg),
291 }
292 }
293}
294
295impl<K: Additive> Additive for Poly<K> {
296 fn zero() -> Self {
297 Self {
298 coeffs: non_empty_vec![K::zero()],
299 }
300 }
301}
302
303impl<'a, R, K: Space<R>> MulAssign<&'a R> for Poly<K> {
306 fn mul_assign(&mut self, rhs: &'a R) {
307 self.coeffs.iter_mut().for_each(|c| *c *= rhs);
308 }
309}
310
311impl<'a, R, K: Space<R>> Mul<&'a R> for Poly<K> {
312 type Output = Self;
313
314 fn mul(mut self, rhs: &'a R) -> Self::Output {
315 self *= rhs;
316 self
317 }
318}
319
320impl<R: Sync, K: Space<R> + Send> Space<R> for Poly<K> {
321 fn msm(polys: &[Self], scalars: &[R], strategy: &impl Strategy) -> Self {
322 if polys.len() < MIN_POINTS_FOR_MSM {
323 return msm_naive(polys, scalars);
324 }
325
326 let cols = polys.len().min(scalars.len());
327 let polys = &polys[..cols];
328 let scalars = &scalars[..cols];
329
330 let rows = polys
331 .iter()
332 .map(|x| x.len_usize())
333 .max()
334 .expect("at least 1 point");
335
336 let coeffs = strategy.map_init_collect_vec(
337 0..rows,
338 || Vec::with_capacity(cols),
339 |row, i| {
340 row.clear();
341 for p in polys {
342 row.push(p.coeffs.get(i).cloned().unwrap_or_else(K::zero));
343 }
344 K::msm(row, scalars, strategy)
345 },
346 );
347 Poly::from_iter_unchecked(coeffs)
348 }
349}
350
351impl<G: CryptoGroup> Poly<G> {
352 pub fn commit(p: Poly<G::Scalar>) -> Self {
354 p.translate(|c| G::generator() * c)
355 }
356}
357
358pub struct Interpolator<I, F> {
389 weights: Map<I, F>,
390}
391
392impl<I: PartialEq, F: Ring> Interpolator<I, F> {
393 pub fn interpolate<K: Space<F>>(
398 &self,
399 evals: &Map<I, K>,
400 strategy: &impl Strategy,
401 ) -> Option<K> {
402 if evals.keys() != self.weights.keys() {
403 return None;
404 }
405 Some(K::msm(evals.values(), self.weights.values(), strategy))
406 }
407}
408
409impl<I: Clone + Ord, F: Field> Interpolator<I, F> {
410 pub fn new(points: impl IntoIterator<Item = (I, F)>) -> Self {
417 let points = Map::from_iter_dedup(points);
418 let weights = points
419 .iter_pairs()
420 .map(|(i, w_i)| {
421 let mut top_i = F::one();
422 let mut bot_i = F::one();
423 for (j, w_j) in points.iter_pairs() {
424 if i == j {
425 continue;
426 }
427 top_i *= w_j;
428 bot_i *= &(w_j.clone() - w_i);
429 }
430 top_i * &bot_i.inv()
431 })
432 .collect::<Vec<_>>();
433 let mut out = points;
435 for (out_i, weight_i) in out.values_mut().iter_mut().zip(weights.into_iter()) {
436 *out_i = weight_i;
437 }
438 Self { weights: out }
439 }
440}
441
442#[cfg(feature = "arbitrary")]
443mod fuzz {
444 use super::*;
445 use arbitrary::Arbitrary;
446
447 impl<'a, F: Arbitrary<'a>> Arbitrary<'a> for Poly<F> {
448 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
449 Ok(Self {
450 coeffs: u.arbitrary()?,
451 })
452 }
453 }
454}
455
456#[cfg(test)]
457mod test {
458 use super::*;
459 use crate::test::{F, G};
460 use commonware_codec::Encode;
461 use proptest::{
462 prelude::{Arbitrary, BoxedStrategy, Strategy as _},
463 prop_assume, proptest,
464 sample::SizeRange,
465 };
466
467 impl Arbitrary for Poly<F> {
468 type Parameters = SizeRange;
469 type Strategy = BoxedStrategy<Self>;
470
471 fn arbitrary_with(size: Self::Parameters) -> Self::Strategy {
472 let nonempty_size = if size.start() == 0 { size + 1 } else { size };
473 proptest::collection::vec(F::arbitrary(), nonempty_size)
474 .prop_map(Poly::from_iter_unchecked)
475 .boxed()
476 }
477 }
478
479 #[test]
480 fn test_additive() {
481 crate::algebra::test_suites::test_additive(file!(), &Poly::<F>::arbitrary());
482 }
483
484 #[test]
485 fn test_space() {
486 crate::algebra::test_suites::test_space_ring(
487 file!(),
488 &F::arbitrary(),
489 &Poly::<F>::arbitrary(),
490 );
491 }
492
493 #[test]
494 fn test_eq() {
495 fn eq(a: &[u8], b: &[u8]) -> bool {
496 Poly {
497 coeffs: a.iter().copied().map(F::from).try_collect().unwrap(),
498 } == Poly {
499 coeffs: b.iter().copied().map(F::from).try_collect().unwrap(),
500 }
501 }
502 assert!(eq(&[1, 2], &[1, 2]));
503 assert!(!eq(&[1, 2], &[2, 3]));
504 assert!(!eq(&[1, 2], &[1, 2, 3]));
505 assert!(!eq(&[1, 2, 3], &[1, 2]));
506 assert!(eq(&[1, 2], &[1, 2, 0, 0]));
507 assert!(eq(&[1, 2, 0, 0], &[1, 2]));
508 assert!(!eq(&[1, 2, 0], &[2, 3]));
509 assert!(!eq(&[2, 3], &[1, 2, 0]));
510 }
511
512 proptest! {
513 #[test]
514 fn test_codec(f: Poly<F>) {
515 assert_eq!(&f, &Poly::<F>::read_cfg(&mut f.encode(), &(RangeCfg::exact(f.required()), ())).unwrap())
516 }
517
518 #[test]
519 fn test_eval_add(f: Poly<F>, g: Poly<F>, x: F) {
520 assert_eq!(f.eval(&x) + &g.eval(&x), (f + &g).eval(&x));
521 }
522
523 #[test]
524 fn test_eval_scale(f: Poly<F>, x: F, w: F) {
525 assert_eq!(f.eval(&x) * &w, (f * &w).eval(&x));
526 }
527
528 #[test]
529 fn test_eval_zero(f: Poly<F>) {
530 assert_eq!(&f.eval(&F::zero()), f.constant());
531 }
532
533 #[test]
534 fn test_eval_msm(f: Poly<F>, x: F) {
535 use commonware_parallel::Sequential;
536 assert_eq!(f.eval(&x), f.eval_msm(&x, &Sequential));
537 }
538
539 #[test]
540 fn test_interpolate(f: Poly<F>) {
541 use commonware_parallel::Sequential;
542 prop_assume!(f != Poly::zero());
544 prop_assume!(f.required().get() < F::MAX as u32);
545 let mut points = (0..f.required().get()).map(|i| F::from((i + 1) as u8)).collect::<Vec<_>>();
546 let interpolator = Interpolator::new(points.iter().copied().enumerate());
547 let evals = Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate());
548 let recovered = interpolator.interpolate(&evals, &Sequential);
549 assert_eq!(recovered.as_ref(), Some(f.constant()));
550 points.pop();
551 assert!(interpolator.interpolate(&Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate()), &Sequential).is_none());
552 }
553
554 #[test]
555 fn test_translate_scale(f: Poly<F>, x: F) {
556 assert_eq!(f.translate(|c| x * c), f * &x);
557 }
558
559 #[test]
560 fn test_commit_eval(f: Poly<F>, x: F) {
561 assert_eq!(G::generator() * &f.eval(&x), Poly::<G>::commit(f).eval(&x));
562 }
563 }
564
565 #[cfg(feature = "arbitrary")]
566 mod conformance {
567 use super::*;
568 use commonware_codec::conformance::CodecConformance;
569
570 commonware_conformance::conformance_tests! {
571 CodecConformance<Poly<F>>
572 }
573 }
574}