Skip to main content

commonware_cryptography/bls12381/primitives/
sharing.rs

1use crate::bls12381::primitives::{group::Scalar, variant::Variant, Error};
2#[cfg(not(feature = "std"))]
3use alloc::sync::Arc;
4#[cfg(not(feature = "std"))]
5use alloc::vec::Vec;
6use cfg_if::cfg_if;
7use commonware_codec::{EncodeSize, FixedSize, RangeCfg, Read, ReadExt, Write};
8use commonware_macros::stability;
9#[stability(ALPHA)]
10use commonware_math::algebra::{FieldNTT, Ring};
11use commonware_math::poly::{Interpolator, Poly};
12use commonware_parallel::Sequential;
13#[stability(ALPHA)]
14use commonware_utils::{ordered::BiMap, TryFromIterator};
15use commonware_utils::{ordered::Set, Faults, Participant, NZU32};
16#[cfg(feature = "std")]
17use core::iter;
18use core::num::NonZeroU32;
19#[cfg(feature = "std")]
20use std::sync::{Arc, OnceLock};
21#[cfg(feature = "std")]
22use std::vec::Vec;
23
24/// Configures how participants are assigned shares of a secret.
25///
26/// More specifically, this configures how evaluation points of a polynomial
27/// are assigned to participant identities.
28#[derive(Copy, Clone, Default, PartialEq, Eq, Debug)]
29#[repr(u8)]
30pub enum Mode {
31    #[default]
32    NonZeroCounter = 0,
33
34    /// Assigns participants to powers of a root of unity.
35    ///
36    /// This mode enables sub-quadratic interpolation using NTT-based algorithms.
37    #[cfg(not(any(
38        commonware_stability_BETA,
39        commonware_stability_GAMMA,
40        commonware_stability_DELTA,
41        commonware_stability_EPSILON,
42        commonware_stability_RESERVED
43    )))]
44    RootsOfUnity = 1,
45}
46
47impl Mode {
48    /// Compute the scalar for one participant.
49    ///
50    /// This will return `None` only if `i >= total`.
51    pub(crate) fn scalar(self, total: NonZeroU32, i: Participant) -> Option<Scalar> {
52        if i.get() >= total.get() {
53            return None;
54        }
55        match self {
56            Self::NonZeroCounter => {
57                // Adding 1 is critical, because f(0) will contain the secret.
58                Some(Scalar::from_u64(i.get() as u64 + 1))
59            }
60            #[cfg(not(any(
61                commonware_stability_BETA,
62                commonware_stability_GAMMA,
63                commonware_stability_DELTA,
64                commonware_stability_EPSILON,
65                commonware_stability_RESERVED
66            )))]
67            Self::RootsOfUnity => {
68                // Participant i gets w^i. Since w^i != 0 for any i, this never
69                // collides with the secret at f(0).
70                let size = (total.get() as u64).next_power_of_two();
71                let lg_size = size.ilog2() as u8;
72                let w = Scalar::root_of_unity(lg_size).expect("domain too large for NTT");
73                Some(w.exp(&[i.get() as u64]))
74            }
75        }
76    }
77
78    /// Compute the scalars for all participants.
79    #[cfg(feature = "std")]
80    pub(crate) fn all_scalars(self, total: NonZeroU32) -> Vec<Scalar> {
81        match self {
82            Self::NonZeroCounter => (0..total.get())
83                .map(|i| Scalar::from_u64(i as u64 + 1))
84                .collect(),
85            #[cfg(not(any(
86                commonware_stability_BETA,
87                commonware_stability_GAMMA,
88                commonware_stability_DELTA,
89                commonware_stability_EPSILON,
90                commonware_stability_RESERVED
91            )))]
92            Self::RootsOfUnity => {
93                let size = (total.get() as u64).next_power_of_two();
94                let lg_size = size.ilog2() as u8;
95                let w = Scalar::root_of_unity(lg_size).expect("domain too large for NTT");
96                (0..total.get())
97                    .scan(Scalar::one(), |state, _| {
98                        let val = state.clone();
99                        *state *= &w;
100                        Some(val)
101                    })
102                    .collect()
103            }
104        }
105    }
106
107    /// Create an interpolator for this mode, given a set of indices.
108    ///
109    /// This will return `None` if:
110    /// - any `to_index` call on the provided `indices` returns `None`,
111    /// - any index returned by `to_index` is >= `total`.
112    ///
113    /// To be generic over different use cases, we need:
114    /// - the total number of participants,
115    /// - a set of indices (of any type),
116    /// - a means to convert indices to Participant values.
117    fn interpolator<I: Clone + Ord>(
118        self,
119        total: NonZeroU32,
120        indices: &Set<I>,
121        to_index: impl Fn(&I) -> Option<Participant>,
122    ) -> Option<Interpolator<I, Scalar>> {
123        match self {
124            Self::NonZeroCounter => {
125                let mut count = 0;
126                let iter = indices
127                    .iter()
128                    .filter_map(|i| {
129                        let scalar = self.scalar(total, to_index(i)?)?;
130                        Some((i.clone(), scalar))
131                    })
132                    .inspect(|_| {
133                        count += 1;
134                    });
135                let out = Interpolator::new(iter);
136                // If any indices fail to produce a scalar, reject.
137                if count != indices.len() {
138                    return None;
139                }
140                Some(out)
141            }
142            #[cfg(not(any(
143                commonware_stability_BETA,
144                commonware_stability_GAMMA,
145                commonware_stability_DELTA,
146                commonware_stability_EPSILON,
147                commonware_stability_RESERVED
148            )))]
149            Self::RootsOfUnity => {
150                // For roots of unity mode, we use the fast O(n log n) interpolation.
151                // Participant i maps to exponent i, so the evaluation point is w^i.
152                let size = (total.get() as u64).next_power_of_two();
153                let ntt_total = NonZeroU32::new(u32::try_from(size).ok()?)?;
154
155                let mut count = 0;
156                let points: Vec<(I, u32)> = indices
157                    .iter()
158                    .filter_map(|i| {
159                        let participant = to_index(i)?;
160                        if participant.get() >= total.get() {
161                            return None;
162                        }
163                        count += 1;
164                        Some((i.clone(), participant.get()))
165                    })
166                    .collect();
167
168                // If any indices fail to produce a scalar, reject.
169                if count != indices.len() {
170                    return None;
171                }
172
173                let points = BiMap::try_from_iter(points).ok()?;
174                Some(Interpolator::roots_of_unity(ntt_total, points))
175            }
176        }
177    }
178
179    /// Create an interpolator for this mode, given a set, and a subset.
180    ///
181    /// The set determines the total number of participants to use for interpolation,
182    /// and the indices that will get assigned to the subset.
183    ///
184    /// This function will return `None` only if `subset` contains elements
185    /// not in `set`.
186    #[cfg(feature = "std")]
187    pub(crate) fn subset_interpolator<I: Clone + Ord>(
188        self,
189        set: &Set<I>,
190        subset: &Set<I>,
191    ) -> Option<Interpolator<I, Scalar>> {
192        let Ok(total) = NonZeroU32::try_from(set.len() as u32) else {
193            return Some(Interpolator::new(iter::empty()));
194        };
195        self.interpolator(total, subset, |i| {
196            set.position(i).map(Participant::from_usize)
197        })
198    }
199}
200
201impl FixedSize for Mode {
202    const SIZE: usize = 1;
203}
204
205impl Write for Mode {
206    fn write(&self, buf: &mut impl bytes::BufMut) {
207        buf.put_u8(*self as u8);
208    }
209}
210
211/// Determines which modes can be parsed.
212///
213/// As modes have been added over time, this versioning mechanism helps with
214/// supporting compatibility.
215///
216/// This allows upgrading to a new version of the library, including more modes,
217/// while using this version to determine which modes are supported at runtime.
218#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
219pub struct ModeVersion(u8);
220
221impl ModeVersion {
222    /// Version 0, supporting:
223    ///
224    /// - [`Mode::NonZeroCounter`]
225    pub const fn v0() -> Self {
226        Self(0)
227    }
228
229    /// Version 1, supporting v0, and:
230    ///
231    /// - [`Mode::RootsOfUnity`]
232    #[stability(ALPHA)]
233    pub const fn v1() -> Self {
234        Self(1)
235    }
236
237    const fn supports(&self, mode: &Mode) -> bool {
238        match mode {
239            Mode::NonZeroCounter => true,
240            #[cfg(not(any(
241                commonware_stability_BETA,
242                commonware_stability_GAMMA,
243                commonware_stability_DELTA,
244                commonware_stability_EPSILON,
245                commonware_stability_RESERVED
246            )))]
247            Mode::RootsOfUnity => self.0 >= 1,
248        }
249    }
250}
251
252impl Read for Mode {
253    type Cfg = ModeVersion;
254
255    fn read_cfg(
256        buf: &mut impl bytes::Buf,
257        version: &Self::Cfg,
258    ) -> Result<Self, commonware_codec::Error> {
259        let tag: u8 = ReadExt::read(buf)?;
260        let mode = match tag {
261            0 => Self::NonZeroCounter,
262            #[cfg(not(any(
263                commonware_stability_BETA,
264                commonware_stability_GAMMA,
265                commonware_stability_DELTA,
266                commonware_stability_EPSILON,
267                commonware_stability_RESERVED
268            )))]
269            1 => Self::RootsOfUnity,
270            o => return Err(commonware_codec::Error::InvalidEnum(o)),
271        };
272        if !version.supports(&mode) {
273            return Err(commonware_codec::Error::Invalid(
274                "Mode",
275                "unsupported mode for version",
276            ));
277        }
278        Ok(mode)
279    }
280}
281
282/// Represents the public output of a polynomial secret sharing.
283///
284/// This does not contain any secret information.
285#[derive(Clone, Debug)]
286pub struct Sharing<V: Variant> {
287    mode: Mode,
288    total: NonZeroU32,
289    poly: Arc<Poly<V::Public>>,
290    #[cfg(feature = "std")]
291    evals: Arc<Vec<OnceLock<V::Public>>>,
292}
293
294impl<V: Variant> PartialEq for Sharing<V> {
295    fn eq(&self, other: &Self) -> bool {
296        self.mode == other.mode && self.total == other.total && self.poly == other.poly
297    }
298}
299
300impl<V: Variant> Eq for Sharing<V> {}
301
302impl<V: Variant> Sharing<V> {
303    pub(crate) fn new(mode: Mode, total: NonZeroU32, poly: Poly<V::Public>) -> Self {
304        Self {
305            mode,
306            total,
307            poly: Arc::new(poly),
308            #[cfg(feature = "std")]
309            evals: Arc::new(vec![OnceLock::new(); total.get() as usize]),
310        }
311    }
312
313    /// Get the mode used for this sharing.
314    #[cfg(feature = "std")]
315    pub(crate) const fn mode(&self) -> Mode {
316        self.mode
317    }
318
319    pub(crate) fn scalar(&self, i: Participant) -> Option<Scalar> {
320        self.mode.scalar(self.total, i)
321    }
322
323    #[cfg(feature = "std")]
324    fn all_scalars(&self) -> Vec<Scalar> {
325        self.mode.all_scalars(self.total)
326    }
327
328    /// Return the number of participants required to recover the secret
329    /// using the given fault model.
330    pub fn required<M: Faults>(&self) -> u32 {
331        M::quorum(self.total.get())
332    }
333
334    /// Return the total number of participants in this sharing.
335    pub const fn total(&self) -> NonZeroU32 {
336        self.total
337    }
338
339    /// Create an interpolator over some indices.
340    ///
341    /// This will return an error if any of the indices are >= [`Self::total`].
342    pub(crate) fn interpolator(
343        &self,
344        indices: &Set<Participant>,
345    ) -> Result<Interpolator<Participant, Scalar>, Error> {
346        self.mode
347            .interpolator(self.total, indices, |&x| Some(x))
348            .ok_or(Error::InvalidIndex)
349    }
350
351    /// Call this to pre-compute the results of [`Self::partial_public`].
352    ///
353    /// This should be used if you expect to access many of the partial public
354    /// keys, e.g. if verifying several public signatures.
355    ///
356    /// The first time this method is called can be expensive, but subsequent
357    /// calls are idempotent, and cheap.
358    #[cfg(feature = "std")]
359    pub fn precompute_partial_publics(&self) {
360        // NOTE: once we add more interpolation methods, this can be smarter.
361        self.evals
362            .iter()
363            .zip(self.all_scalars())
364            .for_each(|(e, s)| {
365                e.get_or_init(|| self.poly.eval_msm(&s, &Sequential));
366            })
367    }
368
369    /// Get the partial public key associated with a given participant.
370    ///
371    /// This will return `None` if the index is greater >= [`Self::total`].
372    pub fn partial_public(&self, i: Participant) -> Result<V::Public, Error> {
373        cfg_if! {
374            if #[cfg(feature = "std")] {
375                self.evals
376                    .get(usize::from(i))
377                    .map(|e| {
378                        *e.get_or_init(|| {
379                            self.poly
380                                .eval_msm(&self.scalar(i).expect("i < total"), &Sequential)
381                        })
382                    })
383                    .ok_or(Error::InvalidIndex)
384            } else {
385                Ok(self
386                    .poly
387                    .eval_msm(&self.scalar(i).ok_or(Error::InvalidIndex)?, &Sequential))
388            }
389        }
390    }
391
392    /// Get the group public key of this sharing.
393    ///
394    /// In other words, the public key associated with the shared secret.
395    pub fn public(&self) -> &V::Public {
396        self.poly.constant()
397    }
398}
399
400impl<V: Variant> EncodeSize for Sharing<V> {
401    fn encode_size(&self) -> usize {
402        self.mode.encode_size() + self.total.get().encode_size() + self.poly.encode_size()
403    }
404}
405
406impl<V: Variant> Write for Sharing<V> {
407    fn write(&self, buf: &mut impl bytes::BufMut) {
408        self.mode.write(buf);
409        self.total.get().write(buf);
410        self.poly.write(buf);
411    }
412}
413
414impl<V: Variant> Read for Sharing<V> {
415    type Cfg = (NonZeroU32, ModeVersion);
416
417    fn read_cfg(
418        buf: &mut impl bytes::Buf,
419        (max_participants, max_supported_mode): &Self::Cfg,
420    ) -> Result<Self, commonware_codec::Error> {
421        let mode = Read::read_cfg(buf, max_supported_mode)?;
422        // We bound total to the config, in order to prevent doing arbitrary
423        // computation if we precompute public keys.
424        let total = {
425            let out: u32 = ReadExt::read(buf)?;
426            if out == 0 || out > max_participants.get() {
427                return Err(commonware_codec::Error::Invalid(
428                    "Sharing",
429                    "total not in range",
430                ));
431            }
432            // This will not panic, because we checked != 0 above.
433            NZU32!(out)
434        };
435        let poly = Read::read_cfg(buf, &(RangeCfg::from(NZU32!(1)..=*max_participants), ()))?;
436        Ok(Self::new(mode, total, poly))
437    }
438}
439
440#[cfg(all(test, feature = "std"))]
441mod tests {
442    use super::*;
443    use commonware_invariants::minifuzz;
444    use commonware_utils::ordered::Map;
445    use rand::{rngs::StdRng, SeedableRng};
446
447    #[test]
448    fn test_roots_of_unity_interpolator_large_total_returns_none() {
449        let total = NonZeroU32::new(u32::MAX).expect("u32::MAX is non-zero");
450        let indices = Set::from_iter_dedup([Participant::new(0)]);
451        let interpolator =
452            Mode::RootsOfUnity.interpolator(total, &indices, |participant| Some(*participant));
453        assert!(
454            interpolator.is_none(),
455            "domain > u32::MAX should be rejected instead of panicking"
456        );
457    }
458
459    #[test]
460    fn test_mode_read_rejects_mode_above_max_supported_mode() {
461        let encoded = [Mode::RootsOfUnity as u8];
462        Mode::read_cfg(&mut &encoded[..], &ModeVersion::v0())
463            .expect_err("roots mode must be rejected when max mode is counter");
464    }
465
466    #[test]
467    fn test_all_scalars_matches_scalar() {
468        minifuzz::test(|u| {
469            let mode = match u.int_in_range(0u8..=1)? {
470                0 => Mode::NonZeroCounter,
471                1 => Mode::RootsOfUnity,
472                _ => unreachable!("range is 0..=1"),
473            };
474            let total = NonZeroU32::new(u.int_in_range(1u32..=512u32)?).expect("range is non-zero");
475            let index = u.int_in_range(0u32..=total.get() - 1)?;
476            let participant = Participant::new(index);
477
478            let scalars = mode.all_scalars(total);
479            assert_eq!(
480                scalars[usize::from(participant)].clone(),
481                mode.scalar(total, participant).expect("index is in range")
482            );
483            Ok(())
484        });
485    }
486
487    #[test]
488    fn test_subset_interpolation_recovers_constant() {
489        minifuzz::test(|u| {
490            let mode = match u.int_in_range(0u8..=1)? {
491                0 => Mode::NonZeroCounter,
492                1 => Mode::RootsOfUnity,
493                _ => unreachable!("range is 0..=1"),
494            };
495            let total = NonZeroU32::new(u.int_in_range(1u32..=64u32)?).expect("range is non-zero");
496
497            let mut subset_vec = Vec::new();
498            for i in 0..total.get() {
499                if u.arbitrary::<bool>()? {
500                    subset_vec.push(Participant::new(i));
501                }
502            }
503            if subset_vec.is_empty() {
504                let i = u.int_in_range(0u32..=total.get() - 1)?;
505                subset_vec.push(Participant::new(i));
506            }
507            let subset = Set::from_iter_dedup(subset_vec);
508
509            let max_degree = u32::try_from(subset.len() - 1).expect("subset len fits in u32");
510            let degree = u.int_in_range(0u32..=max_degree)?;
511            let seed: u64 = u.arbitrary()?;
512            let poly: Poly<Scalar> = Poly::new(&mut StdRng::seed_from_u64(seed), degree);
513
514            let all_shares = Map::from_iter_dedup((0..total.get()).map(|i| {
515                let participant = Participant::new(i);
516                let scalar = mode.scalar(total, participant).expect("in range");
517                let share = poly.eval(&scalar);
518                (participant, share)
519            }));
520
521            let subset_evals = Map::from_iter_dedup(subset.iter().map(|participant| {
522                (
523                    *participant,
524                    all_shares
525                        .get_value(participant)
526                        .expect("participant exists")
527                        .clone(),
528                )
529            }));
530
531            let interpolator = mode
532                .interpolator(total, &subset, |participant| Some(*participant))
533                .expect("subset indices are valid");
534            let recovered = interpolator
535                .interpolate(&subset_evals, &Sequential)
536                .expect("subset should match interpolator domain");
537
538            assert_eq!(recovered, poly.constant().clone());
539            Ok(())
540        });
541    }
542}
543
544#[cfg(feature = "arbitrary")]
545mod fuzz {
546    use super::*;
547    use arbitrary::Arbitrary;
548    use commonware_utils::{N3f1, NZU32};
549    use rand::{rngs::StdRng, SeedableRng};
550
551    impl<'a> Arbitrary<'a> for Mode {
552        fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
553            match u.int_in_range(0u8..=1)? {
554                0 => Ok(Self::NonZeroCounter),
555                1 => Ok(Self::RootsOfUnity),
556                _ => Err(arbitrary::Error::IncorrectFormat),
557            }
558        }
559    }
560
561    impl<'a, V: Variant> Arbitrary<'a> for Sharing<V> {
562        fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
563            let total: u32 = u.int_in_range(1..=100)?;
564            let mode: Mode = u.arbitrary()?;
565            let seed: u64 = u.arbitrary()?;
566            let poly = Poly::new(&mut StdRng::seed_from_u64(seed), N3f1::quorum(total) - 1);
567            Ok(Self::new(
568                mode,
569                NZU32!(total),
570                Poly::<V::Public>::commit(poly),
571            ))
572        }
573    }
574}