commonware_cryptography/bls12381/primitives/
sharing.rs1use crate::bls12381::primitives::{group::Scalar, variant::Variant, Error};
2#[cfg(not(feature = "std"))]
3use alloc::sync::Arc;
4use cfg_if::cfg_if;
5use commonware_codec::{EncodeSize, FixedSize, RangeCfg, Read, ReadExt, Write};
6use commonware_math::poly::{Interpolator, Poly};
7use commonware_parallel::Sequential;
8use commonware_utils::{ordered::Set, Faults, Participant, NZU32};
9use core::{iter, num::NonZeroU32};
10#[cfg(feature = "std")]
11use std::sync::{Arc, OnceLock};
12
13#[derive(Copy, Clone, Default, PartialEq, Eq, Debug)]
18#[repr(u8)]
19pub enum Mode {
20 #[default]
22 NonZeroCounter = 0,
23}
24
25impl Mode {
26 pub(crate) fn scalar(self, total: NonZeroU32, i: Participant) -> Option<Scalar> {
30 if i.get() >= total.get() {
31 return None;
32 }
33 match self {
34 Self::NonZeroCounter => {
35 Some(Scalar::from_u64(i.get() as u64 + 1))
37 }
38 }
39 }
40
41 pub(crate) fn all_scalars(self, total: NonZeroU32) -> impl Iterator<Item = Scalar> {
43 (0..total.get()).map(move |i| self.scalar(total, Participant::new(i)).expect("i < total"))
44 }
45
46 fn interpolator<I: Clone + Ord>(
57 self,
58 total: NonZeroU32,
59 indices: &Set<I>,
60 to_index: impl Fn(&I) -> Option<Participant>,
61 ) -> Option<Interpolator<I, Scalar>> {
62 let mut count = 0;
63 let iter = indices
64 .iter()
65 .filter_map(|i| {
66 let scalar = self.scalar(total, to_index(i)?)?;
67 Some((i.clone(), scalar))
68 })
69 .inspect(|_| {
70 count += 1;
71 });
72 let out = Interpolator::new(iter);
73 if count != indices.len() {
75 return None;
76 }
77 Some(out)
78 }
79
80 pub(crate) fn subset_interpolator<I: Clone + Ord>(
88 self,
89 set: &Set<I>,
90 subset: &Set<I>,
91 ) -> Option<Interpolator<I, Scalar>> {
92 let Ok(total) = NonZeroU32::try_from(set.len() as u32) else {
93 return Some(Interpolator::new(iter::empty()));
94 };
95 self.interpolator(total, subset, |i| {
96 set.position(i).map(Participant::from_usize)
97 })
98 }
99}
100
101impl FixedSize for Mode {
102 const SIZE: usize = 1;
103}
104
105impl Write for Mode {
106 fn write(&self, buf: &mut impl bytes::BufMut) {
107 buf.put_u8(*self as u8);
108 }
109}
110
111impl Read for Mode {
112 type Cfg = ();
113
114 fn read_cfg(
115 buf: &mut impl bytes::Buf,
116 _cfg: &Self::Cfg,
117 ) -> Result<Self, commonware_codec::Error> {
118 let tag: u8 = ReadExt::read(buf)?;
119 match tag {
120 0 => Ok(Self::NonZeroCounter),
121 o => Err(commonware_codec::Error::InvalidEnum(o)),
122 }
123 }
124}
125
126#[derive(Clone, Debug)]
130pub struct Sharing<V: Variant> {
131 mode: Mode,
132 total: NonZeroU32,
133 poly: Arc<Poly<V::Public>>,
134 #[cfg(feature = "std")]
135 evals: Arc<Vec<OnceLock<V::Public>>>,
136}
137
138impl<V: Variant> PartialEq for Sharing<V> {
139 fn eq(&self, other: &Self) -> bool {
140 self.mode == other.mode && self.total == other.total && self.poly == other.poly
141 }
142}
143
144impl<V: Variant> Eq for Sharing<V> {}
145
146impl<V: Variant> Sharing<V> {
147 pub(crate) fn new(mode: Mode, total: NonZeroU32, poly: Poly<V::Public>) -> Self {
148 Self {
149 mode,
150 total,
151 poly: Arc::new(poly),
152 #[cfg(feature = "std")]
153 evals: Arc::new(vec![OnceLock::new(); total.get() as usize]),
154 }
155 }
156
157 pub(crate) const fn mode(&self) -> Mode {
159 self.mode
160 }
161
162 pub(crate) fn scalar(&self, i: Participant) -> Option<Scalar> {
163 self.mode.scalar(self.total, i)
164 }
165
166 fn all_scalars(&self) -> impl Iterator<Item = Scalar> {
167 self.mode.all_scalars(self.total)
168 }
169
170 pub fn required<M: Faults>(&self) -> u32 {
173 M::quorum(self.total.get())
174 }
175
176 pub const fn total(&self) -> NonZeroU32 {
178 self.total
179 }
180
181 pub(crate) fn interpolator(
185 &self,
186 indices: &Set<Participant>,
187 ) -> Result<Interpolator<Participant, Scalar>, Error> {
188 self.mode
189 .interpolator(self.total, indices, |&x| Some(x))
190 .ok_or(Error::InvalidIndex)
191 }
192
193 #[cfg(feature = "std")]
201 pub fn precompute_partial_publics(&self) {
202 self.evals
204 .iter()
205 .zip(self.all_scalars())
206 .for_each(|(e, s)| {
207 e.get_or_init(|| self.poly.eval_msm(&s, &Sequential));
208 })
209 }
210
211 pub fn partial_public(&self, i: Participant) -> Result<V::Public, Error> {
215 cfg_if! {
216 if #[cfg(feature = "std")] {
217 self.evals
218 .get(usize::from(i))
219 .map(|e| {
220 *e.get_or_init(|| self.poly.eval_msm(&self.scalar(i).expect("i < total"), &Sequential))
221 })
222 .ok_or(Error::InvalidIndex)
223 } else {
224 Ok(self.poly.eval_msm(&self.scalar(i).ok_or(Error::InvalidIndex)?, &Sequential))
225 }
226 }
227 }
228
229 pub fn public(&self) -> &V::Public {
233 self.poly.constant()
234 }
235}
236
237impl<V: Variant> EncodeSize for Sharing<V> {
238 fn encode_size(&self) -> usize {
239 self.mode.encode_size() + self.total.get().encode_size() + self.poly.encode_size()
240 }
241}
242
243impl<V: Variant> Write for Sharing<V> {
244 fn write(&self, buf: &mut impl bytes::BufMut) {
245 self.mode.write(buf);
246 self.total.get().write(buf);
247 self.poly.write(buf);
248 }
249}
250
251impl<V: Variant> Read for Sharing<V> {
252 type Cfg = NonZeroU32;
253
254 fn read_cfg(
255 buf: &mut impl bytes::Buf,
256 cfg: &Self::Cfg,
257 ) -> Result<Self, commonware_codec::Error> {
258 let mode = ReadExt::read(buf)?;
259 let total = {
262 let out: u32 = ReadExt::read(buf)?;
263 if out == 0 || out > cfg.get() {
264 return Err(commonware_codec::Error::Invalid(
265 "Sharing",
266 "total not in range",
267 ));
268 }
269 NZU32!(out)
271 };
272 let poly = Read::read_cfg(buf, &(RangeCfg::from(NZU32!(1)..=*cfg), ()))?;
273 Ok(Self::new(mode, total, poly))
274 }
275}
276
277#[cfg(feature = "arbitrary")]
278mod fuzz {
279 use super::*;
280 use arbitrary::Arbitrary;
281 use commonware_utils::{N3f1, NZU32};
282 use rand::{rngs::StdRng, SeedableRng};
283
284 impl<'a> Arbitrary<'a> for Mode {
285 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
286 match u.int_in_range(0u8..=0)? {
287 0 => Ok(Self::NonZeroCounter),
288 _ => Err(arbitrary::Error::IncorrectFormat),
289 }
290 }
291 }
292
293 impl<'a, V: Variant> Arbitrary<'a> for Sharing<V> {
294 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
295 let total: u32 = u.int_in_range(1..=100)?;
296 let mode: Mode = u.arbitrary()?;
297 let seed: u64 = u.arbitrary()?;
298 let poly = Poly::new(&mut StdRng::seed_from_u64(seed), N3f1::quorum(total) - 1);
299 Ok(Self::new(
300 mode,
301 NZU32!(total),
302 Poly::<V::Public>::commit(poly),
303 ))
304 }
305 }
306}