commonware_cryptography/bls12381/primitives/
sharing.rs1use crate::bls12381::primitives::{group::Scalar, variant::Variant, Error};
2#[cfg(not(feature = "std"))]
3use alloc::sync::Arc;
4use cfg_if::cfg_if;
5use commonware_codec::{EncodeSize, FixedSize, RangeCfg, Read, ReadExt, Write};
6use commonware_math::poly::{Interpolator, Poly};
7use commonware_utils::{ordered::Set, quorum, NZU32};
8use core::{iter, num::NonZeroU32};
9#[cfg(feature = "std")]
10use std::sync::{Arc, OnceLock};
11
12#[derive(Copy, Clone, Default, PartialEq, Eq, Debug)]
17#[repr(u8)]
18pub enum Mode {
19 #[default]
21 NonZeroCounter = 0,
22}
23
24impl Mode {
25 pub(crate) fn scalar(self, total: NonZeroU32, i: u32) -> Option<Scalar> {
29 if i >= total.get() {
30 return None;
31 }
32 match self {
33 Self::NonZeroCounter => {
34 Some(Scalar::from_u64(i as u64 + 1))
36 }
37 }
38 }
39
40 pub(crate) fn all_scalars(self, total: NonZeroU32) -> impl Iterator<Item = Scalar> {
42 (0..total.get()).map(move |i| self.scalar(total, i).expect("i < total"))
43 }
44
45 fn interpolator<I: Clone + Ord>(
56 self,
57 total: NonZeroU32,
58 indices: &Set<I>,
59 to_index: impl Fn(&I) -> Option<u32>,
60 ) -> Option<Interpolator<I, Scalar>> {
61 let mut count = 0;
62 let iter = indices
63 .iter()
64 .filter_map(|i| {
65 let scalar = self.scalar(total, to_index(i)?)?;
66 Some((i.clone(), scalar))
67 })
68 .inspect(|_| {
69 count += 1;
70 });
71 let out = Interpolator::new(iter);
72 if count != indices.len() {
74 return None;
75 }
76 Some(out)
77 }
78
79 pub(crate) fn subset_interpolator<I: Clone + Ord>(
87 self,
88 set: &Set<I>,
89 subset: &Set<I>,
90 ) -> Option<Interpolator<I, Scalar>> {
91 let Ok(total) = NonZeroU32::try_from(set.len() as u32) else {
92 return Some(Interpolator::new(iter::empty()));
93 };
94 self.interpolator(total, subset, |i| set.position(i).map(|x| x as u32))
95 }
96}
97
98impl FixedSize for Mode {
99 const SIZE: usize = 1;
100}
101
102impl Write for Mode {
103 fn write(&self, buf: &mut impl bytes::BufMut) {
104 buf.put_u8(*self as u8);
105 }
106}
107
108impl Read for Mode {
109 type Cfg = ();
110
111 fn read_cfg(
112 buf: &mut impl bytes::Buf,
113 _cfg: &Self::Cfg,
114 ) -> Result<Self, commonware_codec::Error> {
115 let tag: u8 = ReadExt::read(buf)?;
116 match tag {
117 0 => Ok(Self::NonZeroCounter),
118 o => Err(commonware_codec::Error::InvalidEnum(o)),
119 }
120 }
121}
122
123#[derive(Clone, Debug)]
127pub struct Sharing<V: Variant> {
128 mode: Mode,
129 total: NonZeroU32,
130 poly: Arc<Poly<V::Public>>,
131 #[cfg(feature = "std")]
132 evals: Arc<Vec<OnceLock<V::Public>>>,
133}
134
135impl<V: Variant> PartialEq for Sharing<V> {
136 fn eq(&self, other: &Self) -> bool {
137 self.mode == other.mode && self.total == other.total && self.poly == other.poly
138 }
139}
140
141impl<V: Variant> Eq for Sharing<V> {}
142
143impl<V: Variant> Sharing<V> {
144 pub(crate) fn new(mode: Mode, total: NonZeroU32, poly: Poly<V::Public>) -> Self {
145 Self {
146 mode,
147 total,
148 poly: Arc::new(poly),
149 #[cfg(feature = "std")]
150 evals: Arc::new(vec![OnceLock::new(); total.get() as usize]),
151 }
152 }
153
154 pub(crate) const fn mode(&self) -> Mode {
156 self.mode
157 }
158
159 pub(crate) fn scalar(&self, i: u32) -> Option<Scalar> {
160 self.mode.scalar(self.total, i)
161 }
162
163 fn all_scalars(&self) -> impl Iterator<Item = Scalar> {
164 self.mode.all_scalars(self.total)
165 }
166
167 pub fn required(&self) -> u32 {
169 quorum(self.total.get())
170 }
171
172 pub const fn total(&self) -> NonZeroU32 {
174 self.total
175 }
176
177 pub(crate) fn interpolator(
181 &self,
182 indices: &Set<u32>,
183 ) -> Result<Interpolator<u32, Scalar>, Error> {
184 self.mode
185 .interpolator(self.total, indices, |&x| Some(x))
186 .ok_or(Error::InvalidIndex)
187 }
188
189 #[cfg(feature = "std")]
197 pub fn precompute_partial_publics(&self) {
198 self.evals
200 .iter()
201 .zip(self.all_scalars())
202 .for_each(|(e, s)| {
203 e.get_or_init(|| self.poly.eval_msm(&s));
204 })
205 }
206
207 pub fn partial_public(&self, i: u32) -> Result<V::Public, Error> {
211 cfg_if! {
212 if #[cfg(feature = "std")] {
213 self.evals
214 .get(i as usize)
215 .map(|e| {
216 *e.get_or_init(|| self.poly.eval_msm(&self.scalar(i).expect("i < total")))
217 })
218 .ok_or(Error::InvalidIndex)
219 } else {
220 Ok(self.poly.eval_msm(&self.scalar(i).ok_or(Error::InvalidIndex)?))
221 }
222 }
223 }
224
225 pub fn public(&self) -> &V::Public {
229 self.poly.constant()
230 }
231}
232
233impl<V: Variant> EncodeSize for Sharing<V> {
234 fn encode_size(&self) -> usize {
235 self.mode.encode_size() + self.total.get().encode_size() + self.poly.encode_size()
236 }
237}
238
239impl<V: Variant> Write for Sharing<V> {
240 fn write(&self, buf: &mut impl bytes::BufMut) {
241 self.mode.write(buf);
242 self.total.get().write(buf);
243 self.poly.write(buf);
244 }
245}
246
247impl<V: Variant> Read for Sharing<V> {
248 type Cfg = NonZeroU32;
249
250 fn read_cfg(
251 buf: &mut impl bytes::Buf,
252 cfg: &Self::Cfg,
253 ) -> Result<Self, commonware_codec::Error> {
254 let mode = ReadExt::read(buf)?;
255 let total = {
258 let out: u32 = ReadExt::read(buf)?;
259 if out == 0 || out > cfg.get() {
260 return Err(commonware_codec::Error::Invalid(
261 "Sharing",
262 "total not in range",
263 ));
264 }
265 NZU32!(out)
267 };
268 let poly = Read::read_cfg(buf, &(RangeCfg::from(NZU32!(1)..=*cfg), ()))?;
269 Ok(Self::new(mode, total, poly))
270 }
271}
272
273#[cfg(feature = "arbitrary")]
274mod fuzz {
275 use super::*;
276 use arbitrary::Arbitrary;
277 use commonware_utils::NZU32;
278 use rand::{rngs::StdRng, SeedableRng};
279
280 impl<'a> Arbitrary<'a> for Mode {
281 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
282 match u.int_in_range(0u8..=0)? {
283 0 => Ok(Self::NonZeroCounter),
284 _ => Err(arbitrary::Error::IncorrectFormat),
285 }
286 }
287 }
288
289 impl<'a, V: Variant> Arbitrary<'a> for Sharing<V> {
290 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
291 let total: u32 = u.int_in_range(1..=100)?;
292 let mode: Mode = u.arbitrary()?;
293 let seed: u64 = u.arbitrary()?;
294 let poly = Poly::new(&mut StdRng::seed_from_u64(seed), quorum(total) - 1);
295 Ok(Self::new(
296 mode,
297 NZU32!(total),
298 Poly::<V::Public>::commit(poly),
299 ))
300 }
301 }
302}