commonware_cryptography/bls12381/primitives/
sharing.rs

1use crate::bls12381::primitives::{group::Scalar, variant::Variant, Error};
2#[cfg(not(feature = "std"))]
3use alloc::sync::Arc;
4use cfg_if::cfg_if;
5use commonware_codec::{EncodeSize, FixedSize, RangeCfg, Read, ReadExt, Write};
6use commonware_math::poly::{Interpolator, Poly};
7use commonware_parallel::Sequential;
8use commonware_utils::{ordered::Set, Faults, Participant, NZU32};
9use core::{iter, num::NonZeroU32};
10#[cfg(feature = "std")]
11use std::sync::{Arc, OnceLock};
12
13/// Configures how participants are assigned shares of a secret.
14///
15/// More specifically, this configures how evaluation points of a polynomial
16/// are assigned to participant identities.
17#[derive(Copy, Clone, Default, PartialEq, Eq, Debug)]
18#[repr(u8)]
19pub enum Mode {
20    // TODO (https://github.com/commonware-xyz/monorepo/issues/1836): Add a mode for sub O(N^2) interpolation
21    #[default]
22    NonZeroCounter = 0,
23}
24
25impl Mode {
26    /// Compute the scalar for one participant.
27    ///
28    /// This will return `None` only if `i >= total`.
29    pub(crate) fn scalar(self, total: NonZeroU32, i: Participant) -> Option<Scalar> {
30        if i.get() >= total.get() {
31            return None;
32        }
33        match self {
34            Self::NonZeroCounter => {
35                // Adding 1 is critical, because f(0) will contain the secret.
36                Some(Scalar::from_u64(i.get() as u64 + 1))
37            }
38        }
39    }
40
41    /// Compute the scalars for all participants.
42    pub(crate) fn all_scalars(self, total: NonZeroU32) -> impl Iterator<Item = Scalar> {
43        (0..total.get()).map(move |i| self.scalar(total, Participant::new(i)).expect("i < total"))
44    }
45
46    /// Create an interpolator for this mode, given a set of indices.
47    ///
48    /// This will return `None` if:
49    /// - any `to_index` call on the provided `indices` returns `None`,
50    /// - any index returned by `to_index` is >= `total`.
51    ///
52    /// To be generic over different use cases, we need:
53    /// - the total number of participants,
54    /// - a set of indices (of any type),
55    /// - a means to convert indices to Participant values.
56    fn interpolator<I: Clone + Ord>(
57        self,
58        total: NonZeroU32,
59        indices: &Set<I>,
60        to_index: impl Fn(&I) -> Option<Participant>,
61    ) -> Option<Interpolator<I, Scalar>> {
62        let mut count = 0;
63        let iter = indices
64            .iter()
65            .filter_map(|i| {
66                let scalar = self.scalar(total, to_index(i)?)?;
67                Some((i.clone(), scalar))
68            })
69            .inspect(|_| {
70                count += 1;
71            });
72        let out = Interpolator::new(iter);
73        // If any indices failed to produce a scalar, reject.
74        if count != indices.len() {
75            return None;
76        }
77        Some(out)
78    }
79
80    /// Create an interpolator for this mode, given a set, and a subset.
81    ///
82    /// The set determines the total number of participants to use for interpolation,
83    /// and the indices that will get assigned to the subset.
84    ///
85    /// This function will return `None` only if `subset` contains elements
86    /// not in `set`.
87    pub(crate) fn subset_interpolator<I: Clone + Ord>(
88        self,
89        set: &Set<I>,
90        subset: &Set<I>,
91    ) -> Option<Interpolator<I, Scalar>> {
92        let Ok(total) = NonZeroU32::try_from(set.len() as u32) else {
93            return Some(Interpolator::new(iter::empty()));
94        };
95        self.interpolator(total, subset, |i| {
96            set.position(i).map(Participant::from_usize)
97        })
98    }
99}
100
101impl FixedSize for Mode {
102    const SIZE: usize = 1;
103}
104
105impl Write for Mode {
106    fn write(&self, buf: &mut impl bytes::BufMut) {
107        buf.put_u8(*self as u8);
108    }
109}
110
111impl Read for Mode {
112    type Cfg = ();
113
114    fn read_cfg(
115        buf: &mut impl bytes::Buf,
116        _cfg: &Self::Cfg,
117    ) -> Result<Self, commonware_codec::Error> {
118        let tag: u8 = ReadExt::read(buf)?;
119        match tag {
120            0 => Ok(Self::NonZeroCounter),
121            o => Err(commonware_codec::Error::InvalidEnum(o)),
122        }
123    }
124}
125
126/// Represents the public output of a polynomial secret sharing.
127///
128/// This does not contain any secret information.
129#[derive(Clone, Debug)]
130pub struct Sharing<V: Variant> {
131    mode: Mode,
132    total: NonZeroU32,
133    poly: Arc<Poly<V::Public>>,
134    #[cfg(feature = "std")]
135    evals: Arc<Vec<OnceLock<V::Public>>>,
136}
137
138impl<V: Variant> PartialEq for Sharing<V> {
139    fn eq(&self, other: &Self) -> bool {
140        self.mode == other.mode && self.total == other.total && self.poly == other.poly
141    }
142}
143
144impl<V: Variant> Eq for Sharing<V> {}
145
146impl<V: Variant> Sharing<V> {
147    pub(crate) fn new(mode: Mode, total: NonZeroU32, poly: Poly<V::Public>) -> Self {
148        Self {
149            mode,
150            total,
151            poly: Arc::new(poly),
152            #[cfg(feature = "std")]
153            evals: Arc::new(vec![OnceLock::new(); total.get() as usize]),
154        }
155    }
156
157    /// Get the mode used for this sharing.
158    pub(crate) const fn mode(&self) -> Mode {
159        self.mode
160    }
161
162    pub(crate) fn scalar(&self, i: Participant) -> Option<Scalar> {
163        self.mode.scalar(self.total, i)
164    }
165
166    fn all_scalars(&self) -> impl Iterator<Item = Scalar> {
167        self.mode.all_scalars(self.total)
168    }
169
170    /// Return the number of participants required to recover the secret
171    /// using the given fault model.
172    pub fn required<M: Faults>(&self) -> u32 {
173        M::quorum(self.total.get())
174    }
175
176    /// Return the total number of participants in this sharing.
177    pub const fn total(&self) -> NonZeroU32 {
178        self.total
179    }
180
181    /// Create an interpolator over some indices.
182    ///
183    /// This will return an error if any of the indices are >= [`Self::total`].
184    pub(crate) fn interpolator(
185        &self,
186        indices: &Set<Participant>,
187    ) -> Result<Interpolator<Participant, Scalar>, Error> {
188        self.mode
189            .interpolator(self.total, indices, |&x| Some(x))
190            .ok_or(Error::InvalidIndex)
191    }
192
193    /// Call this to pre-compute the results of [`Self::partial_public`].
194    ///
195    /// This should be used if you expect to access many of the partial public
196    /// keys, e.g. if verifying several public signatures.
197    ///
198    /// The first time this method is called can be expensive, but subsequent
199    /// calls are idempotent, and cheap.
200    #[cfg(feature = "std")]
201    pub fn precompute_partial_publics(&self) {
202        // NOTE: once we add more interpolation methods, this can be smarter.
203        self.evals
204            .iter()
205            .zip(self.all_scalars())
206            .for_each(|(e, s)| {
207                e.get_or_init(|| self.poly.eval_msm(&s, &Sequential));
208            })
209    }
210
211    /// Get the partial public key associated with a given participant.
212    ///
213    /// This will return `None` if the index is greater >= [`Self::total`].
214    pub fn partial_public(&self, i: Participant) -> Result<V::Public, Error> {
215        cfg_if! {
216            if #[cfg(feature = "std")] {
217                self.evals
218                    .get(usize::from(i))
219                    .map(|e| {
220                        *e.get_or_init(|| self.poly.eval_msm(&self.scalar(i).expect("i < total"), &Sequential))
221                    })
222                    .ok_or(Error::InvalidIndex)
223            } else {
224                Ok(self.poly.eval_msm(&self.scalar(i).ok_or(Error::InvalidIndex)?, &Sequential))
225            }
226        }
227    }
228
229    /// Get the group public key of this sharing.
230    ///
231    /// In other words, the public key associated with the shared secret.
232    pub fn public(&self) -> &V::Public {
233        self.poly.constant()
234    }
235}
236
237impl<V: Variant> EncodeSize for Sharing<V> {
238    fn encode_size(&self) -> usize {
239        self.mode.encode_size() + self.total.get().encode_size() + self.poly.encode_size()
240    }
241}
242
243impl<V: Variant> Write for Sharing<V> {
244    fn write(&self, buf: &mut impl bytes::BufMut) {
245        self.mode.write(buf);
246        self.total.get().write(buf);
247        self.poly.write(buf);
248    }
249}
250
251impl<V: Variant> Read for Sharing<V> {
252    type Cfg = NonZeroU32;
253
254    fn read_cfg(
255        buf: &mut impl bytes::Buf,
256        cfg: &Self::Cfg,
257    ) -> Result<Self, commonware_codec::Error> {
258        let mode = ReadExt::read(buf)?;
259        // We bound total to the config, in order to prevent doing arbitrary
260        // computation if we precompute public keys.
261        let total = {
262            let out: u32 = ReadExt::read(buf)?;
263            if out == 0 || out > cfg.get() {
264                return Err(commonware_codec::Error::Invalid(
265                    "Sharing",
266                    "total not in range",
267                ));
268            }
269            // This will not panic, because we checked != 0 above.
270            NZU32!(out)
271        };
272        let poly = Read::read_cfg(buf, &(RangeCfg::from(NZU32!(1)..=*cfg), ()))?;
273        Ok(Self::new(mode, total, poly))
274    }
275}
276
277#[cfg(feature = "arbitrary")]
278mod fuzz {
279    use super::*;
280    use arbitrary::Arbitrary;
281    use commonware_utils::{N3f1, NZU32};
282    use rand::{rngs::StdRng, SeedableRng};
283
284    impl<'a> Arbitrary<'a> for Mode {
285        fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
286            match u.int_in_range(0u8..=0)? {
287                0 => Ok(Self::NonZeroCounter),
288                _ => Err(arbitrary::Error::IncorrectFormat),
289            }
290        }
291    }
292
293    impl<'a, V: Variant> Arbitrary<'a> for Sharing<V> {
294        fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
295            let total: u32 = u.int_in_range(1..=100)?;
296            let mode: Mode = u.arbitrary()?;
297            let seed: u64 = u.arbitrary()?;
298            let poly = Poly::new(&mut StdRng::seed_from_u64(seed), N3f1::quorum(total) - 1);
299            Ok(Self::new(
300                mode,
301                NZU32!(total),
302                Poly::<V::Public>::commit(poly),
303            ))
304        }
305    }
306}