commonware_cryptography/bls12381/primitives/
sharing.rs1use crate::bls12381::primitives::{group::Scalar, variant::Variant, Error};
2#[cfg(not(feature = "std"))]
3use alloc::sync::Arc;
4use cfg_if::cfg_if;
5use commonware_codec::{EncodeSize, FixedSize, RangeCfg, Read, ReadExt, Write};
6use commonware_math::poly::{Interpolator, Poly};
7use commonware_parallel::Sequential;
8use commonware_utils::{ordered::Set, Faults, Participant, NZU32};
9#[cfg(feature = "std")]
10use core::iter;
11use core::num::NonZeroU32;
12#[cfg(feature = "std")]
13use std::sync::{Arc, OnceLock};
14
15#[derive(Copy, Clone, Default, PartialEq, Eq, Debug)]
20#[repr(u8)]
21pub enum Mode {
22 #[default]
24 NonZeroCounter = 0,
25}
26
27impl Mode {
28 pub(crate) fn scalar(self, total: NonZeroU32, i: Participant) -> Option<Scalar> {
32 if i.get() >= total.get() {
33 return None;
34 }
35 match self {
36 Self::NonZeroCounter => {
37 Some(Scalar::from_u64(i.get() as u64 + 1))
39 }
40 }
41 }
42
43 #[cfg(feature = "std")]
45 pub(crate) fn all_scalars(self, total: NonZeroU32) -> impl Iterator<Item = Scalar> {
46 (0..total.get()).map(move |i| self.scalar(total, Participant::new(i)).expect("i < total"))
47 }
48
49 fn interpolator<I: Clone + Ord>(
60 self,
61 total: NonZeroU32,
62 indices: &Set<I>,
63 to_index: impl Fn(&I) -> Option<Participant>,
64 ) -> Option<Interpolator<I, Scalar>> {
65 let mut count = 0;
66 let iter = indices
67 .iter()
68 .filter_map(|i| {
69 let scalar = self.scalar(total, to_index(i)?)?;
70 Some((i.clone(), scalar))
71 })
72 .inspect(|_| {
73 count += 1;
74 });
75 let out = Interpolator::new(iter);
76 if count != indices.len() {
78 return None;
79 }
80 Some(out)
81 }
82
83 #[cfg(feature = "std")]
91 pub(crate) fn subset_interpolator<I: Clone + Ord>(
92 self,
93 set: &Set<I>,
94 subset: &Set<I>,
95 ) -> Option<Interpolator<I, Scalar>> {
96 let Ok(total) = NonZeroU32::try_from(set.len() as u32) else {
97 return Some(Interpolator::new(iter::empty()));
98 };
99 self.interpolator(total, subset, |i| {
100 set.position(i).map(Participant::from_usize)
101 })
102 }
103}
104
105impl FixedSize for Mode {
106 const SIZE: usize = 1;
107}
108
109impl Write for Mode {
110 fn write(&self, buf: &mut impl bytes::BufMut) {
111 buf.put_u8(*self as u8);
112 }
113}
114
115impl Read for Mode {
116 type Cfg = ();
117
118 fn read_cfg(
119 buf: &mut impl bytes::Buf,
120 _cfg: &Self::Cfg,
121 ) -> Result<Self, commonware_codec::Error> {
122 let tag: u8 = ReadExt::read(buf)?;
123 match tag {
124 0 => Ok(Self::NonZeroCounter),
125 o => Err(commonware_codec::Error::InvalidEnum(o)),
126 }
127 }
128}
129
130#[derive(Clone, Debug)]
134pub struct Sharing<V: Variant> {
135 mode: Mode,
136 total: NonZeroU32,
137 poly: Arc<Poly<V::Public>>,
138 #[cfg(feature = "std")]
139 evals: Arc<Vec<OnceLock<V::Public>>>,
140}
141
142impl<V: Variant> PartialEq for Sharing<V> {
143 fn eq(&self, other: &Self) -> bool {
144 self.mode == other.mode && self.total == other.total && self.poly == other.poly
145 }
146}
147
148impl<V: Variant> Eq for Sharing<V> {}
149
150impl<V: Variant> Sharing<V> {
151 pub(crate) fn new(mode: Mode, total: NonZeroU32, poly: Poly<V::Public>) -> Self {
152 Self {
153 mode,
154 total,
155 poly: Arc::new(poly),
156 #[cfg(feature = "std")]
157 evals: Arc::new(vec![OnceLock::new(); total.get() as usize]),
158 }
159 }
160
161 #[cfg(feature = "std")]
163 pub(crate) const fn mode(&self) -> Mode {
164 self.mode
165 }
166
167 pub(crate) fn scalar(&self, i: Participant) -> Option<Scalar> {
168 self.mode.scalar(self.total, i)
169 }
170
171 #[cfg(feature = "std")]
172 fn all_scalars(&self) -> impl Iterator<Item = Scalar> {
173 self.mode.all_scalars(self.total)
174 }
175
176 pub fn required<M: Faults>(&self) -> u32 {
179 M::quorum(self.total.get())
180 }
181
182 pub const fn total(&self) -> NonZeroU32 {
184 self.total
185 }
186
187 pub(crate) fn interpolator(
191 &self,
192 indices: &Set<Participant>,
193 ) -> Result<Interpolator<Participant, Scalar>, Error> {
194 self.mode
195 .interpolator(self.total, indices, |&x| Some(x))
196 .ok_or(Error::InvalidIndex)
197 }
198
199 #[cfg(feature = "std")]
207 pub fn precompute_partial_publics(&self) {
208 self.evals
210 .iter()
211 .zip(self.all_scalars())
212 .for_each(|(e, s)| {
213 e.get_or_init(|| self.poly.eval_msm(&s, &Sequential));
214 })
215 }
216
217 pub fn partial_public(&self, i: Participant) -> Result<V::Public, Error> {
221 cfg_if! {
222 if #[cfg(feature = "std")] {
223 self.evals
224 .get(usize::from(i))
225 .map(|e| {
226 *e.get_or_init(|| {
227 self.poly
228 .eval_msm(&self.scalar(i).expect("i < total"), &Sequential)
229 })
230 })
231 .ok_or(Error::InvalidIndex)
232 } else {
233 Ok(self
234 .poly
235 .eval_msm(&self.scalar(i).ok_or(Error::InvalidIndex)?, &Sequential))
236 }
237 }
238 }
239
240 pub fn public(&self) -> &V::Public {
244 self.poly.constant()
245 }
246}
247
248impl<V: Variant> EncodeSize for Sharing<V> {
249 fn encode_size(&self) -> usize {
250 self.mode.encode_size() + self.total.get().encode_size() + self.poly.encode_size()
251 }
252}
253
254impl<V: Variant> Write for Sharing<V> {
255 fn write(&self, buf: &mut impl bytes::BufMut) {
256 self.mode.write(buf);
257 self.total.get().write(buf);
258 self.poly.write(buf);
259 }
260}
261
262impl<V: Variant> Read for Sharing<V> {
263 type Cfg = NonZeroU32;
264
265 fn read_cfg(
266 buf: &mut impl bytes::Buf,
267 cfg: &Self::Cfg,
268 ) -> Result<Self, commonware_codec::Error> {
269 let mode = ReadExt::read(buf)?;
270 let total = {
273 let out: u32 = ReadExt::read(buf)?;
274 if out == 0 || out > cfg.get() {
275 return Err(commonware_codec::Error::Invalid(
276 "Sharing",
277 "total not in range",
278 ));
279 }
280 NZU32!(out)
282 };
283 let poly = Read::read_cfg(buf, &(RangeCfg::from(NZU32!(1)..=*cfg), ()))?;
284 Ok(Self::new(mode, total, poly))
285 }
286}
287
288#[cfg(feature = "arbitrary")]
289mod fuzz {
290 use super::*;
291 use arbitrary::Arbitrary;
292 use commonware_utils::{N3f1, NZU32};
293 use rand::{rngs::StdRng, SeedableRng};
294
295 impl<'a> Arbitrary<'a> for Mode {
296 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
297 match u.int_in_range(0u8..=0)? {
298 0 => Ok(Self::NonZeroCounter),
299 _ => Err(arbitrary::Error::IncorrectFormat),
300 }
301 }
302 }
303
304 impl<'a, V: Variant> Arbitrary<'a> for Sharing<V> {
305 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
306 let total: u32 = u.int_in_range(1..=100)?;
307 let mode: Mode = u.arbitrary()?;
308 let seed: u64 = u.arbitrary()?;
309 let poly = Poly::new(&mut StdRng::seed_from_u64(seed), N3f1::quorum(total) - 1);
310 Ok(Self::new(
311 mode,
312 NZU32!(total),
313 Poly::<V::Public>::commit(poly),
314 ))
315 }
316 }
317}