use crate::algebra::{
msm_naive, powers, Additive, CryptoGroup, Field, Object, Random, Ring, Space,
};
#[cfg(not(feature = "std"))]
use alloc::{borrow::Cow, vec, vec::Vec};
use commonware_codec::{EncodeSize, RangeCfg, Read, Write};
use commonware_parallel::Strategy;
use commonware_utils::{non_empty_vec, ordered::Map, vec::NonEmptyVec, TryCollect};
use core::{
fmt::Debug,
iter,
num::NonZeroU32,
ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
};
use rand_core::CryptoRngCore;
#[cfg(feature = "std")]
use std::borrow::Cow;
const MIN_POINTS_FOR_MSM: usize = 2;
#[derive(Clone)]
pub struct Poly<K> {
coeffs: NonEmptyVec<K>,
}
impl<K> Poly<K> {
fn len(&self) -> NonZeroU32 {
self.coeffs
.len()
.try_into()
.expect("Impossible: polynomial length not in 1..=u32::MAX")
}
const fn len_usize(&self) -> usize {
self.coeffs.len().get()
}
fn from_iter_unchecked(iter: impl IntoIterator<Item = K>) -> Self {
let coeffs = iter
.into_iter()
.try_collect::<NonEmptyVec<_>>()
.expect("polynomial must have a least 1 coefficient");
Self { coeffs }
}
pub fn degree(&self) -> u32 {
self.len().get() - 1
}
pub fn required(&self) -> NonZeroU32 {
self.len()
}
pub fn constant(&self) -> &K {
&self.coeffs[0]
}
pub fn translate<L>(&self, f: impl Fn(&K) -> L) -> Poly<L> {
Poly {
coeffs: self.coeffs.map(f),
}
}
pub fn eval<R>(&self, r: &R) -> K
where
K: Space<R>,
{
let mut iter = self.coeffs.iter().rev();
let mut acc = iter
.next()
.expect("Impossible: Polynomial has no coefficients")
.clone();
for coeff in iter {
acc *= r;
acc += coeff;
}
acc
}
pub fn eval_msm<R: Ring>(&self, r: &R, strategy: &impl Strategy) -> K
where
K: Space<R>,
{
let weights = powers(R::one(), r)
.take(self.len_usize())
.collect::<Vec<_>>();
K::msm(&self.coeffs, &weights, strategy)
}
pub fn lin_comb_eval<'a, R: Ring + 'a>(
&self,
into_iter: impl IntoIterator<Item = (R, Cow<'a, R>)>,
strategy: &impl Strategy,
) -> K
where
K: Space<R>,
{
let weights = {
let mut iter = into_iter.into_iter();
let Some((a0, b0)) = iter.next() else {
return K::zero();
};
let len = self.len_usize();
let mut out: Vec<_> = powers(a0, b0.as_ref()).take(len).collect();
for (ai, bi) in iter {
powers(ai, bi.as_ref())
.take(len)
.zip(out.iter_mut())
.for_each(|(c_j, o_j)| *o_j += &c_j);
}
out
};
K::msm(&self.coeffs, &weights, strategy)
}
}
impl<K: Debug> Debug for Poly<K> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Poly(")?;
for (i, c) in self.coeffs.iter().enumerate() {
if i > 0 {
write!(f, " + {c:?} X^{i}")?;
} else {
write!(f, "{c:?}")?;
}
}
write!(f, ")")?;
Ok(())
}
}
impl<K: EncodeSize> EncodeSize for Poly<K> {
fn encode_size(&self) -> usize {
self.coeffs.encode_size()
}
}
impl<K: Write> Write for Poly<K> {
fn write(&self, buf: &mut impl bytes::BufMut) {
self.coeffs.write(buf);
}
}
impl<K: Read> Read for Poly<K> {
type Cfg = (RangeCfg<NonZeroU32>, <K as Read>::Cfg);
fn read_cfg(
buf: &mut impl bytes::Buf,
cfg: &Self::Cfg,
) -> Result<Self, commonware_codec::Error> {
Ok(Self {
coeffs: NonEmptyVec::<K>::read_cfg(buf, &(cfg.0.into(), cfg.1.clone()))?,
})
}
}
impl<K: Random> Poly<K> {
pub fn new(mut rng: impl CryptoRngCore, degree: u32) -> Self {
Self::from_iter_unchecked((0..=degree).map(|_| K::random(&mut rng)))
}
pub fn new_with_constant(mut rng: impl CryptoRngCore, degree: u32, constant: K) -> Self {
Self::from_iter_unchecked(
iter::once(constant).chain((0..=degree).skip(1).map(|_| K::random(&mut rng))),
)
}
}
impl<K: Additive> PartialEq for Poly<K> {
fn eq(&self, other: &Self) -> bool {
let zero = K::zero();
let max_len = self.len().max(other.len());
let self_then_zeros = self.coeffs.iter().chain(iter::repeat(&zero));
let other_then_zeros = other.coeffs.iter().chain(iter::repeat(&zero));
self_then_zeros
.zip(other_then_zeros)
.take(max_len.get() as usize)
.all(|(a, b)| a == b)
}
}
impl<K: Additive> Eq for Poly<K> {}
impl<K: Additive> Poly<K> {
fn merge_with(&mut self, rhs: &Self, f: impl Fn(&mut K, &K)) {
self.coeffs
.resize(self.coeffs.len().max(rhs.coeffs.len()), K::zero());
self.coeffs
.iter_mut()
.zip(&rhs.coeffs)
.for_each(|(a, b)| f(a, b));
}
pub fn degree_exact(&self) -> u32 {
let zero = K::zero();
let leading_zeroes = self.coeffs.iter().rev().take_while(|&x| x == &zero).count();
let lz_u32 =
u32::try_from(leading_zeroes).expect("Impossible: Polynomial has >= 2^32 coefficients");
self.degree().saturating_sub(lz_u32)
}
}
impl<K: Additive> Object for Poly<K> {}
impl<'a, K: Additive> AddAssign<&'a Self> for Poly<K> {
fn add_assign(&mut self, rhs: &'a Self) {
self.merge_with(rhs, |a, b| *a += b);
}
}
impl<'a, K: Additive> Add<&'a Self> for Poly<K> {
type Output = Self;
fn add(mut self, rhs: &'a Self) -> Self::Output {
self += rhs;
self
}
}
impl<'a, K: Additive> SubAssign<&'a Self> for Poly<K> {
fn sub_assign(&mut self, rhs: &'a Self) {
self.merge_with(rhs, |a, b| *a -= b);
}
}
impl<'a, K: Additive> Sub<&'a Self> for Poly<K> {
type Output = Self;
fn sub(mut self, rhs: &'a Self) -> Self::Output {
self -= rhs;
self
}
}
impl<K: Additive> Neg for Poly<K> {
type Output = Self;
fn neg(self) -> Self::Output {
Self {
coeffs: self.coeffs.map_into(Neg::neg),
}
}
}
impl<K: Additive> Additive for Poly<K> {
fn zero() -> Self {
Self {
coeffs: non_empty_vec![K::zero()],
}
}
}
impl<'a, R, K: Space<R>> MulAssign<&'a R> for Poly<K> {
fn mul_assign(&mut self, rhs: &'a R) {
self.coeffs.iter_mut().for_each(|c| *c *= rhs);
}
}
impl<'a, R, K: Space<R>> Mul<&'a R> for Poly<K> {
type Output = Self;
fn mul(mut self, rhs: &'a R) -> Self::Output {
self *= rhs;
self
}
}
impl<R: Sync, K: Space<R> + Send> Space<R> for Poly<K> {
fn msm(polys: &[Self], scalars: &[R], strategy: &impl Strategy) -> Self {
if polys.len() < MIN_POINTS_FOR_MSM {
return msm_naive(polys, scalars);
}
let cols = polys.len().min(scalars.len());
let polys = &polys[..cols];
let scalars = &scalars[..cols];
let rows = polys
.iter()
.map(|x| x.len_usize())
.max()
.expect("at least 1 point");
let coeffs = strategy.map_init_collect_vec(
0..rows,
|| Vec::with_capacity(cols),
|row, i| {
row.clear();
for p in polys {
row.push(p.coeffs.get(i).cloned().unwrap_or_else(K::zero));
}
K::msm(row, scalars, strategy)
},
);
Self::from_iter_unchecked(coeffs)
}
}
impl<G: CryptoGroup> Poly<G> {
pub fn commit(p: Poly<G::Scalar>) -> Self {
p.translate(|c| G::generator() * c)
}
}
pub struct Interpolator<I, F> {
weights: Map<I, F>,
}
impl<I: PartialEq, F: Ring> Interpolator<I, F> {
pub fn interpolate<K: Space<F>>(
&self,
evals: &Map<I, K>,
strategy: &impl Strategy,
) -> Option<K> {
if evals.keys() != self.weights.keys() {
return None;
}
Some(K::msm(evals.values(), self.weights.values(), strategy))
}
}
impl<I: Clone + Ord, F: Field> Interpolator<I, F> {
pub fn new(points: impl IntoIterator<Item = (I, F)>) -> Self {
let points = Map::from_iter_dedup(points);
let n = points.len();
if n == 0 {
return Self { weights: points };
}
let values = points.values();
let zero = F::zero();
let mut total_product = F::one();
let mut c = Vec::with_capacity(n);
for (i, w_i) in values.iter().enumerate() {
if w_i == &zero {
let mut out = points;
for (j, w) in out.values_mut().iter_mut().enumerate() {
*w = if j == i { F::one() } else { F::zero() };
}
return Self { weights: out };
}
total_product *= w_i;
let mut c_i = w_i.clone();
for w_j in values
.iter()
.enumerate()
.filter_map(|(j, v)| (j != i).then_some(v))
{
c_i *= &(w_j.clone() - w_i);
}
c.push(c_i);
}
let mut prefix = Vec::with_capacity(n + 1);
prefix.push(F::one());
let mut acc = F::one();
for c_i in &c {
acc *= c_i;
prefix.push(acc.clone());
}
let mut inv_acc = total_product * &prefix[n].inv();
let mut out = points;
let out_vals = out.values_mut();
for i in (0..n).rev() {
out_vals[i] = inv_acc.clone() * &prefix[i];
inv_acc *= &c[i];
}
Self { weights: out }
}
}
#[commonware_macros::stability(ALPHA)]
impl<I: Clone + Ord, F: crate::algebra::FieldNTT> Interpolator<I, F> {
pub fn roots_of_unity(
total: NonZeroU32,
points: commonware_utils::ordered::BiMap<I, u32>,
) -> Self {
let weights = <Map<I, F> as commonware_utils::TryFromIterator<(I, F)>>::try_from_iter(
crate::ntt::lagrange_coefficients(total, points.values().iter().copied())
.into_iter()
.filter_map(|(k, coeff)| Some((points.get_key(&k)?.clone(), coeff))),
)
.expect("points has already been deduped");
Self { weights }
}
#[cfg(any(test, feature = "fuzz"))]
fn roots_of_unity_naive(
total: NonZeroU32,
points: commonware_utils::ordered::BiMap<I, u32>,
) -> Self {
use crate::algebra::powers;
let total_u32 = total.get();
let size = (total_u32 as u64).next_power_of_two();
let lg_size = size.ilog2() as u8;
let w = F::root_of_unity(lg_size).expect("domain too large for NTT");
let points: Vec<(I, u32)> = points.into_iter().filter(|(_, k)| *k < total_u32).collect();
let max_k = points.iter().map(|(_, k)| *k).max().unwrap_or(0) as usize;
let powers: Vec<_> = powers(F::one(), &w).take(max_k + 1).collect();
let eval_points = points
.into_iter()
.map(|(i, k)| (i, powers[k as usize].clone()));
Self::new(eval_points)
}
}
#[cfg(any(test, feature = "arbitrary"))]
mod impl_arbitrary {
use super::*;
use arbitrary::Arbitrary;
impl<'a, F: Arbitrary<'a>> Arbitrary<'a> for Poly<F> {
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let first = u.arbitrary()?;
let rest: Vec<F> = u.arbitrary()?;
let mut coeffs = NonEmptyVec::new(first);
coeffs.extend(rest);
Ok(Self { coeffs })
}
}
}
#[commonware_macros::stability(ALPHA)]
#[cfg(any(test, feature = "fuzz"))]
pub mod fuzz {
use super::*;
use crate::{
algebra::test_suites,
test::{F, G},
};
use arbitrary::{Arbitrary, Unstructured};
use commonware_codec::Encode as _;
use commonware_parallel::Sequential;
use commonware_utils::{
ordered::{BiMap, Map},
TryFromIterator,
};
#[derive(Debug, Arbitrary)]
pub enum Plan {
Codec(Poly<F>),
EvalAdd(Poly<F>, Poly<F>, F),
EvalScale(Poly<F>, F, F),
EvalZero(Poly<F>),
EvalMsm(Poly<F>, F),
LinCombEval(Poly<F>, Vec<(F, F)>),
Interpolate(Poly<F>),
InterpolateWithZeroPoint(Poly<F>),
InterpolateWithZeroPointMiddle(Poly<F>),
TranslateScale(Poly<F>, F),
CommitEval(Poly<F>, F),
RootsOfUnityEqNaive(u16),
FuzzAdditive,
FuzzSpaceRing,
}
impl Plan {
pub fn run(self, u: &mut Unstructured<'_>) -> arbitrary::Result<()> {
match self {
Self::Codec(f) => {
assert_eq!(
&f,
&Poly::<F>::read_cfg(&mut f.encode(), &(RangeCfg::exact(f.required()), ()))
.unwrap()
);
}
Self::EvalAdd(f, g, x) => {
assert_eq!(f.eval(&x) + &g.eval(&x), (f + &g).eval(&x));
}
Self::EvalScale(f, x, w) => {
assert_eq!(f.eval(&x) * &w, (f * &w).eval(&x));
}
Self::EvalZero(f) => {
assert_eq!(&f.eval(&F::zero()), f.constant());
}
Self::EvalMsm(f, x) => {
assert_eq!(f.eval(&x), f.eval_msm(&x, &Sequential));
}
Self::LinCombEval(f, pairs) => {
let naive_eval = pairs.iter().fold(F::zero(), |mut acc, (a, b)| {
acc += &(*a * &f.eval(b));
acc
});
let lin_comb = f.lin_comb_eval(
pairs.iter().map(|(a, b)| (*a, Cow::Borrowed(b))),
&Sequential,
);
assert_eq!(naive_eval, lin_comb);
}
Self::Interpolate(f) => {
if f == Poly::zero() || f.required().get() >= F::MAX as u32 {
return Ok(());
}
let mut points = (0..f.required().get())
.map(|i| F::from((i + 1) as u8))
.collect::<Vec<_>>();
let interpolator = Interpolator::new(points.iter().copied().enumerate());
let evals = Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate());
let recovered = interpolator.interpolate(&evals, &Sequential);
assert_eq!(recovered.as_ref(), Some(f.constant()));
points.pop();
assert_eq!(
interpolator.interpolate(
&Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate()),
&Sequential
),
None
);
}
Self::InterpolateWithZeroPoint(f) => {
if f == Poly::zero() || f.required().get() >= F::MAX as u32 {
return Ok(());
}
let points: Vec<_> =
(0..f.required().get()).map(|i| F::from(i as u8)).collect();
let interpolator = Interpolator::new(points.iter().copied().enumerate());
let evals = Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate());
let recovered = interpolator.interpolate(&evals, &Sequential);
assert_eq!(recovered.as_ref(), Some(f.constant()));
}
Self::InterpolateWithZeroPointMiddle(f) => {
if f == Poly::zero()
|| f.required().get() < 2
|| f.required().get() >= F::MAX as u32
{
return Ok(());
}
let n = f.required().get();
let points: Vec<_> = (1..n)
.map(|i| F::from(i as u8))
.chain(core::iter::once(F::zero()))
.collect();
let interpolator = Interpolator::new(points.iter().copied().enumerate());
let evals = Map::from_iter_dedup(points.iter().map(|p| f.eval(p)).enumerate());
let recovered = interpolator.interpolate(&evals, &Sequential);
assert_eq!(recovered.as_ref(), Some(f.constant()));
}
Self::TranslateScale(f, x) => {
assert_eq!(f.translate(|c| x * c), f * &x);
}
Self::CommitEval(f, x) => {
assert_eq!(G::generator() * &f.eval(&x), Poly::<G>::commit(f).eval(&x));
}
Self::RootsOfUnityEqNaive(n) => {
let n = (u32::from(n) % 256) + 1;
let total = NonZeroU32::new(n).expect("n is in 1..=256");
let points = BiMap::try_from_iter((0..n as usize).map(|i| (i, i as u32)))
.expect("interpolation points should be bijective");
let fast = Interpolator::<usize, crate::fields::goldilocks::F>::roots_of_unity(
total,
points.clone(),
);
let naive =
Interpolator::<usize, crate::fields::goldilocks::F>::roots_of_unity_naive(
total, points,
);
assert_eq!(fast.weights, naive.weights);
}
Self::FuzzAdditive => {
test_suites::fuzz_additive::<Poly<F>>(u)?;
}
Self::FuzzSpaceRing => {
test_suites::fuzz_space_ring::<F, Poly<F>>(u)?;
}
}
Ok(())
}
}
#[test]
fn test_fuzz() {
commonware_invariants::minifuzz::test(|u| u.arbitrary::<Plan>()?.run(u));
}
}
#[cfg(test)]
mod test {
use super::{fuzz::Plan, *};
use crate::test::F;
use arbitrary::Unstructured;
#[test]
fn test_eq() {
fn eq(a: &[u8], b: &[u8]) -> bool {
Poly {
coeffs: a.iter().copied().map(F::from).try_collect().unwrap(),
} == Poly {
coeffs: b.iter().copied().map(F::from).try_collect().unwrap(),
}
}
assert!(eq(&[1, 2], &[1, 2]));
assert!(!eq(&[1, 2], &[2, 3]));
assert!(!eq(&[1, 2], &[1, 2, 3]));
assert!(!eq(&[1, 2, 3], &[1, 2]));
assert!(eq(&[1, 2], &[1, 2, 0, 0]));
assert!(eq(&[1, 2, 0, 0], &[1, 2]));
assert!(!eq(&[1, 2, 0], &[2, 3]));
assert!(!eq(&[2, 3], &[1, 2, 0]));
}
#[test]
fn lin_comb_eval_edge_cases() {
fn poly(coeffs: &[u8]) -> Poly<F> {
Poly {
coeffs: coeffs.iter().copied().map(F::from).try_collect().unwrap(),
}
}
fn pairs(values: &[(u8, u8)]) -> Vec<(F, F)> {
values
.iter()
.map(|(a, b)| (F::from(*a), F::from(*b)))
.collect()
}
let cases = [
Plan::LinCombEval(poly(&[3, 5, 7]), vec![]),
Plan::LinCombEval(poly(&[11]), pairs(&[(2, 0), (3, 1), (5, 8)])),
Plan::LinCombEval(poly(&[4, 6, 8]), pairs(&[(2, 5), (7, 5), (3, 5)])),
Plan::LinCombEval(poly(&[9, 2, 3, 4]), pairs(&[(6, 0), (1, 0), (5, 7)])),
Plan::LinCombEval(poly(&[1, 2, 4, 8]), pairs(&[(3, 1), (7, 1), (2, 6)])),
];
let mut u = Unstructured::new(&[]);
for case in cases {
case.run(&mut u).unwrap();
}
}
#[cfg(feature = "arbitrary")]
mod conformance {
use super::*;
use commonware_codec::conformance::CodecConformance;
commonware_conformance::conformance_tests! {
CodecConformance<Poly<F>>
}
}
}