commonware_cryptography/bls12381/primitives/
poly.rs

1//! Polynomial operations over the BLS12-381 scalar field.
2//!
3//! # Warning
4//!
5//! The security of the polynomial operations is critical for the overall
6//! security of the threshold schemes. Ensure that the scalar field operations
7//! are performed over the correct field and that all elements are valid.
8
9use crate::bls12381::primitives::{
10    group::{self, Element, Scalar},
11    Error,
12};
13use bytes::{Buf, BufMut};
14use commonware_codec::{EncodeSize, Error as CodecError, FixedSize, Read, ReadExt, Write};
15use rand::{rngs::OsRng, RngCore};
16use std::hash::Hash;
17
18/// Private polynomials are used to generate secret shares.
19pub type Private = Poly<group::Private>;
20
21/// Public polynomials represent commitments to secrets on a private polynomial.
22pub type Public = Poly<group::Public>;
23
24/// Signature polynomials are used in threshold signing (where a signature
25/// is interpolated using at least `threshold` evaluations).
26pub type Signature = Poly<group::Signature>;
27
28/// The default partial signature type (G2).
29pub type PartialSignature = Eval<group::Signature>;
30
31/// The default partial signature length (G2).
32pub const PARTIAL_SIGNATURE_LENGTH: usize = u32::SIZE + group::SIGNATURE_LENGTH;
33
34/// A polynomial evaluation at a specific index.
35#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
36pub struct Eval<C: Element> {
37    pub index: u32,
38    pub value: C,
39}
40
41impl<C: Element> Write for Eval<C> {
42    fn write(&self, buf: &mut impl BufMut) {
43        self.index.write(buf);
44        self.value.write(buf);
45    }
46}
47
48impl<C: Element> Read for Eval<C> {
49    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
50        let index = buf.get_u32();
51        let value = C::read(buf)?;
52        Ok(Self { index, value })
53    }
54}
55
56impl<C: Element> FixedSize for Eval<C> {
57    const SIZE: usize = u32::SIZE + C::SIZE;
58}
59
60/// A polynomial that is using a scalar for the variable x and a generic
61/// element for the coefficients.
62///
63/// The coefficients must be able to multiply the type of the variable,
64/// which is always a scalar.
65#[derive(Debug, Clone, PartialEq, Eq)]
66// Reference: https://github.com/celo-org/celo-threshold-bls-rs/blob/a714310be76620e10e8797d6637df64011926430/crates/threshold-bls/src/poly.rs#L24-L28
67pub struct Poly<C>(Vec<C>);
68
69/// Returns a new scalar polynomial of the given degree where each coefficients is
70/// sampled at random using kernel randomness.
71///
72/// In the context of secret sharing, the threshold is the degree + 1.
73pub fn new(degree: u32) -> Poly<Scalar> {
74    // Reference: https://github.com/celo-org/celo-threshold-bls-rs/blob/a714310be76620e10e8797d6637df64011926430/crates/threshold-bls/src/poly.rs#L46-L52
75    new_from(degree, &mut OsRng)
76}
77
78// Returns a new scalar polynomial of the given degree where each coefficient is
79// sampled at random from the provided RNG.
80///
81/// In the context of secret sharing, the threshold is the degree + 1.
82pub fn new_from<R: RngCore>(degree: u32, rng: &mut R) -> Poly<Scalar> {
83    // Reference: https://github.com/celo-org/celo-threshold-bls-rs/blob/a714310be76620e10e8797d6637df64011926430/crates/threshold-bls/src/poly.rs#L46-L52
84    let coeffs = (0..=degree).map(|_| Scalar::rand(rng)).collect::<Vec<_>>();
85    Poly::<Scalar>(coeffs)
86}
87
88impl<C> Poly<C> {
89    /// Creates a new polynomial from the given coefficients.
90    pub fn from(c: Vec<C>) -> Self {
91        Self(c)
92    }
93
94    /// Returns the constant term of the polynomial.
95    pub fn constant(&self) -> &C {
96        &self.0[0]
97    }
98
99    /// Returns the degree of the polynomial
100    pub fn degree(&self) -> u32 {
101        (self.0.len() - 1) as u32 // check size in deserialize, safe to cast
102    }
103
104    /// Returns the number of required shares to reconstruct the polynomial.
105    ///
106    /// This will be the threshold
107    pub fn required(&self) -> u32 {
108        self.0.len() as u32 // check size in deserialize, safe to cast
109    }
110}
111
112impl<C: Element> Poly<C> {
113    /// Commits the scalar polynomial to the group and returns a polynomial over
114    /// the group.
115    ///
116    /// This is done by multiplying each coefficient of the polynomial with the
117    /// group's generator.
118    pub fn commit(commits: Poly<Scalar>) -> Self {
119        // Reference: https://github.com/celo-org/celo-threshold-bls-rs/blob/a714310be76620e10e8797d6637df64011926430/crates/threshold-bls/src/poly.rs#L322-L340
120        let commits = commits
121            .0
122            .iter()
123            .map(|c| {
124                let mut commitment = C::one();
125                commitment.mul(c);
126                commitment
127            })
128            .collect::<Vec<C>>();
129
130        Poly::<C>::from(commits)
131    }
132
133    /// Returns a zero polynomial.
134    pub fn zero() -> Self {
135        Self::from(vec![C::zero()])
136    }
137
138    /// Returns the given coefficient at the requested index.
139    ///
140    /// It panics if the index is out of range.
141    pub fn get(&self, i: u32) -> C {
142        self.0[i as usize].clone()
143    }
144
145    /// Set the given element at the specified index.
146    ///
147    /// It panics if the index is out of range.
148    pub fn set(&mut self, index: u32, value: C) {
149        self.0[index as usize] = value;
150    }
151
152    /// Performs polynomial addition in place
153    pub fn add(&mut self, other: &Self) {
154        // Reference: https://github.com/celo-org/celo-threshold-bls-rs/blob/a714310be76620e10e8797d6637df64011926430/crates/threshold-bls/src/poly.rs#L87-L95
155
156        // if we have a smaller degree we should pad with zeros
157        if self.0.len() < other.0.len() {
158            self.0.resize(other.0.len(), C::zero())
159        }
160
161        self.0.iter_mut().zip(&other.0).for_each(|(a, b)| a.add(b))
162    }
163
164    /// Evaluates the polynomial at the specified value.
165    pub fn evaluate(&self, i: u32) -> Eval<C> {
166        // Reference: https://github.com/celo-org/celo-threshold-bls-rs/blob/a714310be76620e10e8797d6637df64011926430/crates/threshold-bls/src/poly.rs#L111-L129
167
168        // We add +1 because we must never evaluate the polynomial at its first point
169        // otherwise it reveals the "secret" value after a reshare (where the constant
170        // term is set to be the secret of the previous dealing).
171        let mut xi = Scalar::zero();
172        xi.set_int(i + 1);
173
174        // Use Horner's method to evaluate the polynomial
175        let res = self.0.iter().rev().fold(C::zero(), |mut sum, coeff| {
176            sum.mul(&xi);
177            sum.add(coeff);
178            sum
179        });
180        Eval {
181            value: res,
182            index: i,
183        }
184    }
185
186    /// Recovers the constant term of a polynomial of degree less than `t` using at least `t` evaluations of the polynomial.
187    ///
188    /// This function uses Lagrange interpolation to compute the constant term (i.e., the value of the polynomial at `x=0`)
189    /// given at least `t` distinct evaluations of the polynomial. Each evaluation is assumed to have a unique index,
190    /// which is mapped to a unique x-value as `x = index + 1`.
191    ///
192    /// # Warning
193    ///
194    /// This function assumes that each evaluation has a unique index. If there are duplicate indices, the function may
195    /// fail with an error when attempting to compute the inverse of zero.
196    pub fn recover<'a, I>(t: u32, evals: I) -> Result<C, Error>
197    where
198        C: 'a,
199        I: IntoIterator<Item = &'a Eval<C>>,
200    {
201        // Reference: https://github.com/celo-org/celo-threshold-bls-rs/blob/a714310be76620e10e8797d6637df64011926430/crates/threshold-bls/src/poly.rs#L131-L165
202
203        // Check if we have at least `t` evaluations; if not, return an error
204        let t = t as usize;
205        let mut evals = evals.into_iter().collect::<Vec<_>>();
206        if evals.len() < t {
207            return Err(Error::NotEnoughPartialSignatures(t, evals.len()));
208        }
209
210        // Convert the first `t` sorted shares into scalars
211        //
212        // We sort the evaluations by index to ensure that two invocations of
213        // `recover` select the same evals.
214        evals.sort_by_key(|e| e.index);
215
216        // Take the first `t` evaluations and prepare them for interpolation
217        //
218        // Each index `i` is mapped to `x = i + 1` to avoid `x=0` (the constant term we’re recovering).
219        let xs = evals
220            .into_iter()
221            .take(t)
222            .fold(Vec::with_capacity(t), |mut m, sh| {
223                let mut xi = Scalar::zero();
224                xi.set_int(sh.index + 1);
225                m.push((sh.index, (xi, &sh.value)));
226                m
227            });
228
229        // Use Lagrange interpolation to compute the constant term at `x=0`
230        //
231        // The constant term is `sum_{i=1 to t} yi * l_i(0)`, where `l_i(0) = product_{j != i} (xj / (xj - xi))`.
232        xs.iter().try_fold(C::zero(), |mut acc, (i, (xi, yi))| {
233            let (mut num, den) = xs.iter().fold(
234                (Scalar::one(), Scalar::one()),
235                |(mut num, mut den), (j, (xj, _))| {
236                    if i != j {
237                        // Include `xj` in the numerator product for `l_i(0)`
238                        num.mul(xj);
239
240                        // Compute `xj - xi` and include it in the denominator product
241                        let mut tmp = xj.clone();
242                        tmp.sub(xi);
243                        den.mul(&tmp);
244                    }
245                    (num, den)
246                },
247            );
248
249            // Compute the inverse of the denominator product; fails if den is zero (e.g., duplicate `xj`)
250            let inv = den.inverse().ok_or(Error::NoInverse)?;
251
252            // Compute `l_i(0) = num * inv`, the Lagrange basis coefficient at `x=0`
253            num.mul(&inv);
254
255            // Scale `yi` by `l_i(0)` to contribute to the constant term
256            let mut yi_scaled = (*yi).clone();
257            yi_scaled.mul(&num);
258
259            // Add `yi * l_i(0)` to the running sum
260            acc.add(&yi_scaled);
261            Ok(acc)
262        })
263    }
264}
265
266impl<C: Element> Write for Poly<C> {
267    fn write(&self, buf: &mut impl BufMut) {
268        for c in &self.0 {
269            c.write(buf);
270        }
271    }
272}
273
274impl<C: Element> Read<usize> for Poly<C> {
275    fn read_cfg(buf: &mut impl Buf, expected: &usize) -> Result<Self, CodecError> {
276        let expected_size = C::SIZE * (*expected);
277        if buf.remaining() < expected_size {
278            return Err(CodecError::EndOfBuffer);
279        }
280        let mut coeffs = Vec::with_capacity(*expected);
281        for _ in 0..*expected {
282            coeffs.push(C::read(buf)?);
283        }
284        Ok(Self(coeffs))
285    }
286}
287
288impl<C: Element> EncodeSize for Poly<C> {
289    fn encode_size(&self) -> usize {
290        C::SIZE * self.0.len()
291    }
292}
293
294/// Returns the public key of the polynomial (constant term).
295pub fn public(public: &Public) -> &group::Public {
296    public.constant()
297}
298
299#[cfg(test)]
300pub mod tests {
301    use commonware_codec::{Decode, Encode};
302
303    // Reference: https://github.com/celo-org/celo-threshold-bls-rs/blob/b0ef82ff79769d085a5a7d3f4fe690b1c8fe6dc9/crates/threshold-bls/src/poly.rs#L355-L604
304    use super::*;
305    use crate::bls12381::primitives::group::{Scalar, G2};
306
307    #[test]
308    fn poly_degree() {
309        let s = 5;
310        let p = new(s);
311        assert_eq!(p.degree(), s);
312    }
313
314    #[test]
315    fn add_zero() {
316        let p1 = new(3);
317        let p2 = Poly::<Scalar>::zero();
318        let mut res = p1.clone();
319        res.add(&p2);
320        assert_eq!(res, p1);
321
322        let p1 = Poly::<Scalar>::zero();
323        let p2 = new(3);
324        let mut res = p1;
325        res.add(&p2);
326        assert_eq!(res, p2);
327    }
328
329    #[test]
330    fn interpolation_insufficient_shares() {
331        let degree = 4;
332        let threshold = degree + 1;
333        let poly = new(degree);
334        let shares = (0..threshold - 1)
335            .map(|i| poly.evaluate(i))
336            .collect::<Vec<_>>();
337        Poly::recover(threshold, &shares).unwrap_err();
338    }
339
340    #[test]
341    fn commit() {
342        let secret = new(5);
343        let coeffs = secret.0.clone();
344        let commitment = coeffs
345            .iter()
346            .map(|coeff| {
347                let mut p = G2::one();
348                p.mul(coeff);
349                p
350            })
351            .collect::<Vec<_>>();
352        let commitment = Poly::from(commitment);
353        assert_eq!(commitment, Poly::commit(secret));
354    }
355
356    fn pow(base: Scalar, pow: usize) -> Scalar {
357        let mut res = Scalar::one();
358        for _ in 0..pow {
359            res.mul(&base)
360        }
361        res
362    }
363
364    #[test]
365    fn addition() {
366        for deg1 in 0..100u32 {
367            for deg2 in 0..100u32 {
368                let p1 = new(deg1);
369                let p2 = new(deg2);
370                let mut res = p1.clone();
371                res.add(&p2);
372
373                let (larger, smaller) = if p1.degree() > p2.degree() {
374                    (&p1, &p2)
375                } else {
376                    (&p2, &p1)
377                };
378
379                for i in 0..larger.degree() + 1 {
380                    let i = i as usize;
381                    if i < (smaller.degree() + 1) as usize {
382                        let mut coeff_sum = p1.0[i].clone();
383                        coeff_sum.add(&p2.0[i]);
384                        assert_eq!(res.0[i], coeff_sum);
385                    } else {
386                        assert_eq!(res.0[i], larger.0[i]);
387                    }
388                }
389                assert_eq!(
390                    res.degree(),
391                    larger.degree(),
392                    "deg1={}, deg2={}",
393                    deg1,
394                    deg2
395                );
396            }
397        }
398    }
399
400    #[test]
401    fn interpolation() {
402        for degree in 0..100u32 {
403            for num_evals in 0..100u32 {
404                let poly = new(degree);
405                let expected = poly.0[0].clone();
406
407                let shares = (0..num_evals).map(|i| poly.evaluate(i)).collect::<Vec<_>>();
408                let recovered_constant = Poly::recover(num_evals, &shares).unwrap();
409
410                if num_evals > degree {
411                    assert_eq!(
412                        expected, recovered_constant,
413                        "degree={}, num_evals={}",
414                        degree, num_evals
415                    );
416                } else {
417                    assert_ne!(
418                        expected, recovered_constant,
419                        "degree={}, num_evals={}",
420                        degree, num_evals
421                    );
422                }
423            }
424        }
425    }
426
427    #[test]
428    fn evaluate() {
429        for d in 0..100u32 {
430            for idx in 0..100_u32 {
431                let mut x = Scalar::zero();
432                x.set_int(idx + 1);
433
434                let p1 = new(d);
435                let evaluation = p1.evaluate(idx).value;
436
437                let coeffs = p1.0;
438                let mut sum = coeffs[0].clone();
439                for (i, coeff) in coeffs
440                    .into_iter()
441                    .enumerate()
442                    .take((d + 1) as usize)
443                    .skip(1)
444                {
445                    let xi = pow(x.clone(), i);
446                    let mut var = coeff;
447                    var.mul(&xi);
448                    sum.add(&var);
449                }
450
451                assert_eq!(sum, evaluation, "degree={}, idx={}", d, idx);
452            }
453        }
454    }
455
456    #[test]
457    fn test_codec() {
458        let original = new(5);
459        let encoded = original.encode();
460        let decoded = Poly::<Scalar>::decode_cfg(encoded, &(original.required() as usize)).unwrap();
461        assert_eq!(original, decoded);
462    }
463}