commonware_cryptography/bls12381/primitives/
sharing.rs

1use crate::bls12381::primitives::{group::Scalar, variant::Variant, Error};
2#[cfg(not(feature = "std"))]
3use alloc::sync::Arc;
4use cfg_if::cfg_if;
5use commonware_codec::{EncodeSize, FixedSize, RangeCfg, Read, ReadExt, Write};
6use commonware_math::poly::{Interpolator, Poly};
7use commonware_utils::{ordered::Set, quorum, NZU32};
8use core::{iter, num::NonZeroU32};
9#[cfg(feature = "std")]
10use std::sync::{Arc, OnceLock};
11
12/// Configures how participants are assigned shares of a secret.
13///
14/// More specifically, this configures how evaluation points of a polynomial
15/// are assigned to participant identities.
16#[derive(Copy, Clone, Default, PartialEq, Eq, Debug)]
17#[repr(u8)]
18pub enum Mode {
19    // TODO (https://github.com/commonware-xyz/monorepo/issues/1836): Add a mode for sub O(N^2) interpolation
20    #[default]
21    NonZeroCounter = 0,
22}
23
24impl Mode {
25    /// Compute the scalar for one participant.
26    ///
27    /// This will return `None` only if `i >= total`.
28    pub(crate) fn scalar(self, total: NonZeroU32, i: u32) -> Option<Scalar> {
29        if i >= total.get() {
30            return None;
31        }
32        match self {
33            Self::NonZeroCounter => {
34                // Adding 1 is critical, because f(0) will contain the secret.
35                Some(Scalar::from_u64(i as u64 + 1))
36            }
37        }
38    }
39
40    /// Compute the scalars for all participants.
41    pub(crate) fn all_scalars(self, total: NonZeroU32) -> impl Iterator<Item = Scalar> {
42        (0..total.get()).map(move |i| self.scalar(total, i).expect("i < total"))
43    }
44
45    /// Create an interpolator for this mode, given a set of indices.
46    ///
47    /// This will return `None` if:
48    /// - any `to_index` call on the provided `indices` returns `None`,
49    /// - any index returned by `to_index` is >= `total`.
50    ///
51    /// To be generic over different use cases, we need:
52    /// - the total number of participants,
53    /// - a set of indices (of any type),
54    /// - a means to convert indices to u32 values.
55    fn interpolator<I: Clone + Ord>(
56        self,
57        total: NonZeroU32,
58        indices: &Set<I>,
59        to_index: impl Fn(&I) -> Option<u32>,
60    ) -> Option<Interpolator<I, Scalar>> {
61        let mut count = 0;
62        let iter = indices
63            .iter()
64            .filter_map(|i| {
65                let scalar = self.scalar(total, to_index(i)?)?;
66                Some((i.clone(), scalar))
67            })
68            .inspect(|_| {
69                count += 1;
70            });
71        let out = Interpolator::new(iter);
72        // If any indices failed to produce a scalar, reject.
73        if count != indices.len() {
74            return None;
75        }
76        Some(out)
77    }
78
79    /// Create an interpolator for this mode, given a set, and a subset.
80    ///
81    /// The set determines the total number of participants to use for interpolation,
82    /// and the indices that will get assigned to the subset.
83    ///
84    /// This function will return `None` only if `subset` contains elements
85    /// not in `set`.
86    pub(crate) fn subset_interpolator<I: Clone + Ord>(
87        self,
88        set: &Set<I>,
89        subset: &Set<I>,
90    ) -> Option<Interpolator<I, Scalar>> {
91        let Ok(total) = NonZeroU32::try_from(set.len() as u32) else {
92            return Some(Interpolator::new(iter::empty()));
93        };
94        self.interpolator(total, subset, |i| set.position(i).map(|x| x as u32))
95    }
96}
97
98impl FixedSize for Mode {
99    const SIZE: usize = 1;
100}
101
102impl Write for Mode {
103    fn write(&self, buf: &mut impl bytes::BufMut) {
104        buf.put_u8(*self as u8);
105    }
106}
107
108impl Read for Mode {
109    type Cfg = ();
110
111    fn read_cfg(
112        buf: &mut impl bytes::Buf,
113        _cfg: &Self::Cfg,
114    ) -> Result<Self, commonware_codec::Error> {
115        let tag: u8 = ReadExt::read(buf)?;
116        match tag {
117            0 => Ok(Self::NonZeroCounter),
118            o => Err(commonware_codec::Error::InvalidEnum(o)),
119        }
120    }
121}
122
123/// Represents the public output of a polynomial secret sharing.
124///
125/// This does not contain any secret information.
126#[derive(Clone, Debug)]
127pub struct Sharing<V: Variant> {
128    mode: Mode,
129    total: NonZeroU32,
130    poly: Arc<Poly<V::Public>>,
131    #[cfg(feature = "std")]
132    evals: Arc<Vec<OnceLock<V::Public>>>,
133}
134
135impl<V: Variant> PartialEq for Sharing<V> {
136    fn eq(&self, other: &Self) -> bool {
137        self.mode == other.mode && self.total == other.total && self.poly == other.poly
138    }
139}
140
141impl<V: Variant> Eq for Sharing<V> {}
142
143impl<V: Variant> Sharing<V> {
144    pub(crate) fn new(mode: Mode, total: NonZeroU32, poly: Poly<V::Public>) -> Self {
145        Self {
146            mode,
147            total,
148            poly: Arc::new(poly),
149            #[cfg(feature = "std")]
150            evals: Arc::new(vec![OnceLock::new(); total.get() as usize]),
151        }
152    }
153
154    /// Get the mode used for this sharing.
155    pub(crate) const fn mode(&self) -> Mode {
156        self.mode
157    }
158
159    pub(crate) fn scalar(&self, i: u32) -> Option<Scalar> {
160        self.mode.scalar(self.total, i)
161    }
162
163    fn all_scalars(&self) -> impl Iterator<Item = Scalar> {
164        self.mode.all_scalars(self.total)
165    }
166
167    /// Return the number of participants required to recover the secret.
168    pub fn required(&self) -> u32 {
169        quorum(self.total.get())
170    }
171
172    /// Return the total number of participants in this sharing.
173    pub const fn total(&self) -> NonZeroU32 {
174        self.total
175    }
176
177    /// Create an interpolator over some indices.
178    ///
179    /// This will return an error if any of the indices are >= [`Self::total`].
180    pub(crate) fn interpolator(
181        &self,
182        indices: &Set<u32>,
183    ) -> Result<Interpolator<u32, Scalar>, Error> {
184        self.mode
185            .interpolator(self.total, indices, |&x| Some(x))
186            .ok_or(Error::InvalidIndex)
187    }
188
189    /// Call this to pre-compute the results of [`Self::partial_public`].
190    ///
191    /// This should be used if you expect to access many of the partial public
192    /// keys, e.g. if verifying several public signatures.
193    ///
194    /// The first time this method is called can be expensive, but subsequent
195    /// calls are idempotent, and cheap.
196    #[cfg(feature = "std")]
197    pub fn precompute_partial_publics(&self) {
198        // NOTE: once we add more interpolation methods, this can be smarter.
199        self.evals
200            .iter()
201            .zip(self.all_scalars())
202            .for_each(|(e, s)| {
203                e.get_or_init(|| self.poly.eval_msm(&s));
204            })
205    }
206
207    /// Get the partial public key associated with a given participant.
208    ///
209    /// This will return `None` if the index is greater >= [`Self::total`].
210    pub fn partial_public(&self, i: u32) -> Result<V::Public, Error> {
211        cfg_if! {
212            if #[cfg(feature = "std")] {
213                self.evals
214                    .get(i as usize)
215                    .map(|e| {
216                        *e.get_or_init(|| self.poly.eval_msm(&self.scalar(i).expect("i < total")))
217                    })
218                    .ok_or(Error::InvalidIndex)
219            } else {
220                Ok(self.poly.eval_msm(&self.scalar(i).ok_or(Error::InvalidIndex)?))
221            }
222        }
223    }
224
225    /// Get the group public key of this sharing.
226    ///
227    /// In other words, the public key associated with the shared secret.
228    pub fn public(&self) -> &V::Public {
229        self.poly.constant()
230    }
231}
232
233impl<V: Variant> EncodeSize for Sharing<V> {
234    fn encode_size(&self) -> usize {
235        self.mode.encode_size() + self.total.get().encode_size() + self.poly.encode_size()
236    }
237}
238
239impl<V: Variant> Write for Sharing<V> {
240    fn write(&self, buf: &mut impl bytes::BufMut) {
241        self.mode.write(buf);
242        self.total.get().write(buf);
243        self.poly.write(buf);
244    }
245}
246
247impl<V: Variant> Read for Sharing<V> {
248    type Cfg = NonZeroU32;
249
250    fn read_cfg(
251        buf: &mut impl bytes::Buf,
252        cfg: &Self::Cfg,
253    ) -> Result<Self, commonware_codec::Error> {
254        let mode = ReadExt::read(buf)?;
255        // We bound total to the config, in order to prevent doing arbitrary
256        // computation if we precompute public keys.
257        let total = {
258            let out: u32 = ReadExt::read(buf)?;
259            if out == 0 || out > cfg.get() {
260                return Err(commonware_codec::Error::Invalid(
261                    "Sharing",
262                    "total not in range",
263                ));
264            }
265            // This will not panic, because we checked != 0 above.
266            NZU32!(out)
267        };
268        let poly = Read::read_cfg(buf, &(RangeCfg::from(NZU32!(1)..=*cfg), ()))?;
269        Ok(Self::new(mode, total, poly))
270    }
271}
272
273#[cfg(feature = "arbitrary")]
274mod fuzz {
275    use super::*;
276    use arbitrary::Arbitrary;
277    use commonware_utils::NZU32;
278    use rand::{rngs::StdRng, SeedableRng};
279
280    impl<'a> Arbitrary<'a> for Mode {
281        fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
282            match u.int_in_range(0u8..=0)? {
283                0 => Ok(Self::NonZeroCounter),
284                _ => Err(arbitrary::Error::IncorrectFormat),
285            }
286        }
287    }
288
289    impl<'a, V: Variant> Arbitrary<'a> for Sharing<V> {
290        fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
291            let total: u32 = u.int_in_range(1..=100)?;
292            let mode: Mode = u.arbitrary()?;
293            let seed: u64 = u.arbitrary()?;
294            let poly = Poly::new(&mut StdRng::seed_from_u64(seed), quorum(total) - 1);
295            Ok(Self::new(
296                mode,
297                NZU32!(total),
298                Poly::<V::Public>::commit(poly),
299            ))
300        }
301    }
302}