Skip to main content

commonware_cryptography/bls12381/primitives/
sharing.rs

1use crate::bls12381::primitives::{group::Scalar, variant::Variant, Error};
2#[cfg(not(feature = "std"))]
3use alloc::sync::Arc;
4use cfg_if::cfg_if;
5use commonware_codec::{EncodeSize, FixedSize, RangeCfg, Read, ReadExt, Write};
6use commonware_math::poly::{Interpolator, Poly};
7use commonware_parallel::Sequential;
8use commonware_utils::{ordered::Set, Faults, Participant, NZU32};
9#[cfg(feature = "std")]
10use core::iter;
11use core::num::NonZeroU32;
12#[cfg(feature = "std")]
13use std::sync::{Arc, OnceLock};
14
15/// Configures how participants are assigned shares of a secret.
16///
17/// More specifically, this configures how evaluation points of a polynomial
18/// are assigned to participant identities.
19#[derive(Copy, Clone, Default, PartialEq, Eq, Debug)]
20#[repr(u8)]
21pub enum Mode {
22    // TODO (https://github.com/commonware-xyz/monorepo/issues/1836): Add a mode for sub O(N^2) interpolation
23    #[default]
24    NonZeroCounter = 0,
25}
26
27impl Mode {
28    /// Compute the scalar for one participant.
29    ///
30    /// This will return `None` only if `i >= total`.
31    pub(crate) fn scalar(self, total: NonZeroU32, i: Participant) -> Option<Scalar> {
32        if i.get() >= total.get() {
33            return None;
34        }
35        match self {
36            Self::NonZeroCounter => {
37                // Adding 1 is critical, because f(0) will contain the secret.
38                Some(Scalar::from_u64(i.get() as u64 + 1))
39            }
40        }
41    }
42
43    /// Compute the scalars for all participants.
44    #[cfg(feature = "std")]
45    pub(crate) fn all_scalars(self, total: NonZeroU32) -> impl Iterator<Item = Scalar> {
46        (0..total.get()).map(move |i| self.scalar(total, Participant::new(i)).expect("i < total"))
47    }
48
49    /// Create an interpolator for this mode, given a set of indices.
50    ///
51    /// This will return `None` if:
52    /// - any `to_index` call on the provided `indices` returns `None`,
53    /// - any index returned by `to_index` is >= `total`.
54    ///
55    /// To be generic over different use cases, we need:
56    /// - the total number of participants,
57    /// - a set of indices (of any type),
58    /// - a means to convert indices to Participant values.
59    fn interpolator<I: Clone + Ord>(
60        self,
61        total: NonZeroU32,
62        indices: &Set<I>,
63        to_index: impl Fn(&I) -> Option<Participant>,
64    ) -> Option<Interpolator<I, Scalar>> {
65        let mut count = 0;
66        let iter = indices
67            .iter()
68            .filter_map(|i| {
69                let scalar = self.scalar(total, to_index(i)?)?;
70                Some((i.clone(), scalar))
71            })
72            .inspect(|_| {
73                count += 1;
74            });
75        let out = Interpolator::new(iter);
76        // If any indices failed to produce a scalar, reject.
77        if count != indices.len() {
78            return None;
79        }
80        Some(out)
81    }
82
83    /// Create an interpolator for this mode, given a set, and a subset.
84    ///
85    /// The set determines the total number of participants to use for interpolation,
86    /// and the indices that will get assigned to the subset.
87    ///
88    /// This function will return `None` only if `subset` contains elements
89    /// not in `set`.
90    #[cfg(feature = "std")]
91    pub(crate) fn subset_interpolator<I: Clone + Ord>(
92        self,
93        set: &Set<I>,
94        subset: &Set<I>,
95    ) -> Option<Interpolator<I, Scalar>> {
96        let Ok(total) = NonZeroU32::try_from(set.len() as u32) else {
97            return Some(Interpolator::new(iter::empty()));
98        };
99        self.interpolator(total, subset, |i| {
100            set.position(i).map(Participant::from_usize)
101        })
102    }
103}
104
105impl FixedSize for Mode {
106    const SIZE: usize = 1;
107}
108
109impl Write for Mode {
110    fn write(&self, buf: &mut impl bytes::BufMut) {
111        buf.put_u8(*self as u8);
112    }
113}
114
115impl Read for Mode {
116    type Cfg = ();
117
118    fn read_cfg(
119        buf: &mut impl bytes::Buf,
120        _cfg: &Self::Cfg,
121    ) -> Result<Self, commonware_codec::Error> {
122        let tag: u8 = ReadExt::read(buf)?;
123        match tag {
124            0 => Ok(Self::NonZeroCounter),
125            o => Err(commonware_codec::Error::InvalidEnum(o)),
126        }
127    }
128}
129
130/// Represents the public output of a polynomial secret sharing.
131///
132/// This does not contain any secret information.
133#[derive(Clone, Debug)]
134pub struct Sharing<V: Variant> {
135    mode: Mode,
136    total: NonZeroU32,
137    poly: Arc<Poly<V::Public>>,
138    #[cfg(feature = "std")]
139    evals: Arc<Vec<OnceLock<V::Public>>>,
140}
141
142impl<V: Variant> PartialEq for Sharing<V> {
143    fn eq(&self, other: &Self) -> bool {
144        self.mode == other.mode && self.total == other.total && self.poly == other.poly
145    }
146}
147
148impl<V: Variant> Eq for Sharing<V> {}
149
150impl<V: Variant> Sharing<V> {
151    pub(crate) fn new(mode: Mode, total: NonZeroU32, poly: Poly<V::Public>) -> Self {
152        Self {
153            mode,
154            total,
155            poly: Arc::new(poly),
156            #[cfg(feature = "std")]
157            evals: Arc::new(vec![OnceLock::new(); total.get() as usize]),
158        }
159    }
160
161    /// Get the mode used for this sharing.
162    #[cfg(feature = "std")]
163    pub(crate) const fn mode(&self) -> Mode {
164        self.mode
165    }
166
167    pub(crate) fn scalar(&self, i: Participant) -> Option<Scalar> {
168        self.mode.scalar(self.total, i)
169    }
170
171    #[cfg(feature = "std")]
172    fn all_scalars(&self) -> impl Iterator<Item = Scalar> {
173        self.mode.all_scalars(self.total)
174    }
175
176    /// Return the number of participants required to recover the secret
177    /// using the given fault model.
178    pub fn required<M: Faults>(&self) -> u32 {
179        M::quorum(self.total.get())
180    }
181
182    /// Return the total number of participants in this sharing.
183    pub const fn total(&self) -> NonZeroU32 {
184        self.total
185    }
186
187    /// Create an interpolator over some indices.
188    ///
189    /// This will return an error if any of the indices are >= [`Self::total`].
190    pub(crate) fn interpolator(
191        &self,
192        indices: &Set<Participant>,
193    ) -> Result<Interpolator<Participant, Scalar>, Error> {
194        self.mode
195            .interpolator(self.total, indices, |&x| Some(x))
196            .ok_or(Error::InvalidIndex)
197    }
198
199    /// Call this to pre-compute the results of [`Self::partial_public`].
200    ///
201    /// This should be used if you expect to access many of the partial public
202    /// keys, e.g. if verifying several public signatures.
203    ///
204    /// The first time this method is called can be expensive, but subsequent
205    /// calls are idempotent, and cheap.
206    #[cfg(feature = "std")]
207    pub fn precompute_partial_publics(&self) {
208        // NOTE: once we add more interpolation methods, this can be smarter.
209        self.evals
210            .iter()
211            .zip(self.all_scalars())
212            .for_each(|(e, s)| {
213                e.get_or_init(|| self.poly.eval_msm(&s, &Sequential));
214            })
215    }
216
217    /// Get the partial public key associated with a given participant.
218    ///
219    /// This will return `None` if the index is greater >= [`Self::total`].
220    pub fn partial_public(&self, i: Participant) -> Result<V::Public, Error> {
221        cfg_if! {
222            if #[cfg(feature = "std")] {
223                self.evals
224                    .get(usize::from(i))
225                    .map(|e| {
226                        *e.get_or_init(|| {
227                            self.poly
228                                .eval_msm(&self.scalar(i).expect("i < total"), &Sequential)
229                        })
230                    })
231                    .ok_or(Error::InvalidIndex)
232            } else {
233                Ok(self
234                    .poly
235                    .eval_msm(&self.scalar(i).ok_or(Error::InvalidIndex)?, &Sequential))
236            }
237        }
238    }
239
240    /// Get the group public key of this sharing.
241    ///
242    /// In other words, the public key associated with the shared secret.
243    pub fn public(&self) -> &V::Public {
244        self.poly.constant()
245    }
246}
247
248impl<V: Variant> EncodeSize for Sharing<V> {
249    fn encode_size(&self) -> usize {
250        self.mode.encode_size() + self.total.get().encode_size() + self.poly.encode_size()
251    }
252}
253
254impl<V: Variant> Write for Sharing<V> {
255    fn write(&self, buf: &mut impl bytes::BufMut) {
256        self.mode.write(buf);
257        self.total.get().write(buf);
258        self.poly.write(buf);
259    }
260}
261
262impl<V: Variant> Read for Sharing<V> {
263    type Cfg = NonZeroU32;
264
265    fn read_cfg(
266        buf: &mut impl bytes::Buf,
267        cfg: &Self::Cfg,
268    ) -> Result<Self, commonware_codec::Error> {
269        let mode = ReadExt::read(buf)?;
270        // We bound total to the config, in order to prevent doing arbitrary
271        // computation if we precompute public keys.
272        let total = {
273            let out: u32 = ReadExt::read(buf)?;
274            if out == 0 || out > cfg.get() {
275                return Err(commonware_codec::Error::Invalid(
276                    "Sharing",
277                    "total not in range",
278                ));
279            }
280            // This will not panic, because we checked != 0 above.
281            NZU32!(out)
282        };
283        let poly = Read::read_cfg(buf, &(RangeCfg::from(NZU32!(1)..=*cfg), ()))?;
284        Ok(Self::new(mode, total, poly))
285    }
286}
287
288#[cfg(feature = "arbitrary")]
289mod fuzz {
290    use super::*;
291    use arbitrary::Arbitrary;
292    use commonware_utils::{N3f1, NZU32};
293    use rand::{rngs::StdRng, SeedableRng};
294
295    impl<'a> Arbitrary<'a> for Mode {
296        fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
297            match u.int_in_range(0u8..=0)? {
298                0 => Ok(Self::NonZeroCounter),
299                _ => Err(arbitrary::Error::IncorrectFormat),
300            }
301        }
302    }
303
304    impl<'a, V: Variant> Arbitrary<'a> for Sharing<V> {
305        fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
306            let total: u32 = u.int_in_range(1..=100)?;
307            let mode: Mode = u.arbitrary()?;
308            let seed: u64 = u.arbitrary()?;
309            let poly = Poly::new(&mut StdRng::seed_from_u64(seed), N3f1::quorum(total) - 1);
310            Ok(Self::new(
311                mode,
312                NZU32!(total),
313                Poly::<V::Public>::commit(poly),
314            ))
315        }
316    }
317}