use super::variant::Variant;
use crate::bls12381::primitives::{
group::{self, Element, Scalar},
Error,
};
#[cfg(not(feature = "std"))]
use alloc::collections::BTreeMap;
#[cfg(not(feature = "std"))]
use alloc::{vec, vec::Vec};
use bytes::{Buf, BufMut};
use commonware_codec::{varint::UInt, EncodeSize, Error as CodecError, Read, ReadExt, Write};
use core::{hash::Hash, iter};
#[cfg(feature = "std")]
use rand::rngs::OsRng;
use rand_core::CryptoRngCore;
#[cfg(feature = "std")]
use std::collections::BTreeMap;
pub type Private = Poly<group::Private>;
pub type Public<V> = Poly<<V as Variant>::Public>;
pub type Signature<V> = Poly<<V as Variant>::Signature>;
pub type PartialSignature<V> = Eval<<V as Variant>::Signature>;
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Eval<C: Element> {
pub index: u32,
pub value: C,
}
impl<C: Element> Write for Eval<C> {
fn write(&self, buf: &mut impl BufMut) {
UInt(self.index).write(buf);
self.value.write(buf);
}
}
impl<C: Element> Read for Eval<C> {
type Cfg = ();
fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
let index = UInt::read(buf)?.into();
let value = C::read(buf)?;
Ok(Self { index, value })
}
}
impl<C: Element> EncodeSize for Eval<C> {
fn encode_size(&self) -> usize {
UInt(self.index).encode_size() + C::SIZE
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Poly<C>(Vec<C>);
#[cfg(feature = "std")]
pub fn new(degree: u32) -> Poly<Scalar> {
new_from(degree, &mut OsRng)
}
pub fn new_from<R: CryptoRngCore>(degree: u32, rng: &mut R) -> Poly<Scalar> {
let coeffs = (0..=degree).map(|_| Scalar::from_rand(rng));
Poly::from_iter(coeffs)
}
pub fn new_with_constant(
degree: u32,
mut rng: impl CryptoRngCore,
constant: Scalar,
) -> Poly<Scalar> {
Poly::from_iter(
iter::once(constant).chain((0..=degree).skip(1).map(|_| Scalar::from_rand(&mut rng))),
)
}
pub struct Weight(Scalar);
impl Weight {
pub fn as_scalar(&self) -> &Scalar {
&self.0
}
}
pub fn prepare_evaluations<'a, C, I>(threshold: u32, evals: I) -> Result<Vec<&'a Eval<C>>, Error>
where
C: 'a + Element,
I: IntoIterator<Item = &'a Eval<C>>,
{
let t = threshold as usize;
let mut evals = evals.into_iter().collect::<Vec<_>>();
if evals.len() < t {
return Err(Error::NotEnoughPartialSignatures(t, evals.len()));
}
evals.sort_by_key(|e| e.index);
evals.truncate(t);
Ok(evals)
}
pub fn compute_weights(indices: Vec<u32>) -> Result<BTreeMap<u32, Weight>, Error> {
let mut weights = BTreeMap::new();
for i in &indices {
let xi = Scalar::from_index(*i);
let (mut num, mut den) = (Scalar::one(), Scalar::one());
for j in &indices {
if i == j {
continue;
}
let xj = Scalar::from_index(*j);
num.mul(&xj);
let mut diff = xj;
diff.sub(&xi);
den.mul(&diff);
}
let inv = den.inverse().ok_or(Error::NoInverse)?;
num.mul(&inv);
weights.insert(*i, Weight(num));
}
Ok(weights)
}
impl<C> FromIterator<C> for Poly<C> {
fn from_iter<T: IntoIterator<Item = C>>(iter: T) -> Self {
Self(iter.into_iter().collect())
}
}
impl<C> Poly<C> {
pub fn from(c: Vec<C>) -> Self {
Self(c)
}
pub fn constant(&self) -> &C {
&self.0[0]
}
pub fn degree(&self) -> u32 {
(self.0.len() - 1) as u32 }
pub fn required(&self) -> u32 {
self.0.len() as u32 }
}
impl<C: Element> Poly<C> {
pub fn commit(commits: Poly<Scalar>) -> Self {
let commits = commits
.0
.iter()
.map(|c| {
let mut commitment = C::one();
commitment.mul(c);
commitment
})
.collect::<Vec<C>>();
Poly::<C>::from(commits)
}
pub fn zero() -> Self {
Self::from(vec![C::zero()])
}
pub fn get(&self, i: u32) -> C {
self.0[i as usize].clone()
}
pub fn set(&mut self, index: u32, value: C) {
self.0[index as usize] = value;
}
pub fn add(&mut self, other: &Self) {
if self.0.len() < other.0.len() {
self.0.resize(other.0.len(), C::zero())
}
self.0.iter_mut().zip(&other.0).for_each(|(a, b)| a.add(b))
}
pub fn evaluate(&self, index: u32) -> Eval<C> {
let xi = Scalar::from_index(index);
let value = self.0.iter().rev().fold(C::zero(), |mut sum, coeff| {
sum.mul(&xi);
sum.add(coeff);
sum
});
Eval { value, index }
}
pub fn recover_with_weights<'a, I>(
weights: &BTreeMap<u32, Weight>,
evals: I,
) -> Result<C, Error>
where
C: 'a,
I: IntoIterator<Item = &'a Eval<C>>,
{
let mut result = C::zero();
for eval in evals.into_iter() {
let Some(weight) = weights.get(&eval.index) else {
return Err(Error::InvalidIndex);
};
let mut scaled_value = eval.value.clone();
scaled_value.mul(&weight.0);
result.add(&scaled_value);
}
Ok(result)
}
pub fn recover<'a, I>(t: u32, evals: I) -> Result<C, Error>
where
C: 'a,
I: IntoIterator<Item = &'a Eval<C>>,
{
let evals = prepare_evaluations(t, evals)?;
let indices = evals.iter().map(|e| e.index).collect::<Vec<_>>();
let weights = compute_weights(indices)?;
Self::recover_with_weights(&weights, evals)
}
}
impl<C: Element> Write for Poly<C> {
fn write(&self, buf: &mut impl BufMut) {
for c in &self.0 {
c.write(buf);
}
}
}
impl<C: Element> Read for Poly<C> {
type Cfg = usize;
fn read_cfg(buf: &mut impl Buf, expected: &Self::Cfg) -> Result<Self, CodecError> {
let expected_size = C::SIZE * (*expected);
if buf.remaining() < expected_size {
return Err(CodecError::EndOfBuffer);
}
let mut coeffs = Vec::with_capacity(*expected);
for _ in 0..*expected {
coeffs.push(C::read(buf)?);
}
Ok(Self(coeffs))
}
}
impl<C: Element> EncodeSize for Poly<C> {
fn encode_size(&self) -> usize {
C::SIZE * self.0.len()
}
}
pub fn public<V: Variant>(public: &Public<V>) -> &V::Public {
public.constant()
}
#[cfg(test)]
pub mod tests {
use super::*;
use crate::bls12381::primitives::group::{Scalar, G2};
use commonware_codec::{Decode, Encode};
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
#[test]
fn poly_degree() {
let s = 5;
let p = new(s);
assert_eq!(p.degree(), s);
}
#[test]
fn add_zero() {
let p1 = new(3);
let p2 = Poly::<Scalar>::zero();
let mut res = p1.clone();
res.add(&p2);
assert_eq!(res, p1);
let p1 = Poly::<Scalar>::zero();
let p2 = new(3);
let mut res = p1;
res.add(&p2);
assert_eq!(res, p2);
}
#[test]
fn interpolation_insufficient_shares() {
let degree = 4;
let threshold = degree + 1;
let poly = new(degree);
let shares = (0..threshold - 1)
.map(|i| poly.evaluate(i))
.collect::<Vec<_>>();
Poly::recover(threshold, &shares).unwrap_err();
}
#[test]
fn evaluate_with_overflow() {
let degree = 4;
let poly = new(degree);
poly.evaluate(u32::MAX);
}
#[test]
fn commit() {
let secret = new(5);
let coeffs = secret.0.clone();
let commitment = coeffs
.iter()
.map(|coeff| {
let mut p = G2::one();
p.mul(coeff);
p
})
.collect::<Vec<_>>();
let commitment = Poly::from(commitment);
assert_eq!(commitment, Poly::commit(secret));
}
fn pow(base: Scalar, pow: usize) -> Scalar {
let mut res = Scalar::one();
for _ in 0..pow {
res.mul(&base)
}
res
}
#[test]
fn addition() {
for deg1 in 0..100u32 {
for deg2 in 0..100u32 {
let p1 = new(deg1);
let p2 = new(deg2);
let mut res = p1.clone();
res.add(&p2);
let (larger, smaller) = if p1.degree() > p2.degree() {
(&p1, &p2)
} else {
(&p2, &p1)
};
for i in 0..larger.degree() + 1 {
let i = i as usize;
if i < (smaller.degree() + 1) as usize {
let mut coeff_sum = p1.0[i].clone();
coeff_sum.add(&p2.0[i]);
assert_eq!(res.0[i], coeff_sum);
} else {
assert_eq!(res.0[i], larger.0[i]);
}
}
assert_eq!(res.degree(), larger.degree(), "deg1={deg1}, deg2={deg2}");
}
}
}
#[test]
fn interpolation() {
for degree in 0..100u32 {
for num_evals in 0..100u32 {
let poly = new(degree);
let expected = poly.0[0].clone();
let shares = (0..num_evals).map(|i| poly.evaluate(i)).collect::<Vec<_>>();
let recovered_constant = Poly::recover(num_evals, &shares).unwrap();
if num_evals > degree {
assert_eq!(
expected, recovered_constant,
"degree={degree}, num_evals={num_evals}"
);
} else {
assert_ne!(
expected, recovered_constant,
"degree={degree}, num_evals={num_evals}"
);
}
}
}
}
#[test]
fn evaluate() {
for d in 0..100u32 {
for idx in 0..100_u32 {
let x = Scalar::from_index(idx);
let p1 = new(d);
let evaluation = p1.evaluate(idx).value;
let coeffs = p1.0;
let mut sum = coeffs[0].clone();
for (i, coeff) in coeffs
.into_iter()
.enumerate()
.take((d + 1) as usize)
.skip(1)
{
let xi = pow(x.clone(), i);
let mut var = coeff;
var.mul(&xi);
sum.add(&var);
}
assert_eq!(sum, evaluation, "degree={d}, idx={idx}");
}
}
}
#[test]
fn test_codec() {
let original = new(5);
let encoded = original.encode();
let decoded = Poly::<Scalar>::decode_cfg(encoded, &(original.required() as usize)).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn test_new_with_constant() {
let mut rng = ChaCha8Rng::seed_from_u64(0);
let constant = Scalar::from_rand(&mut rng);
let poly = new_with_constant(5, &mut rng, constant.clone());
assert_eq!(poly.constant(), &constant);
}
}